save schema embedding update

This commit is contained in:
robin
2023-10-04 15:46:58 +08:00
parent f9d52b08a2
commit 11f9a1894c
5 changed files with 195 additions and 8 deletions

View File

@ -45,6 +45,7 @@ import ai.chat2db.server.web.api.controller.ai.request.ChatRequest;
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.model.TableSchema;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
@ -528,7 +529,15 @@ public class ChatController {
contentVector.add(response.getData().get(0).getEmbedding());
// search embedding
DataResult<TableSchemaResponse> result = gatewayClientService.schemaVectorSearch(contentVector, queryRequest.getDataSourceId());
TableSchemaRequest tableSchemaRequest = new TableSchemaRequest();
tableSchemaRequest.setSchemaVector(contentVector);
tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId());
String databaseName = StringUtils.isNotBlank(queryRequest.getDatabaseName()) ? queryRequest.getDatabaseName() : queryRequest.getSchemaName();
if (Objects.isNull(databaseName)) {
databaseName = "";
}
tableSchemaRequest.setDatabaseName(databaseName);
DataResult<TableSchemaResponse> result = gatewayClientService.schemaVectorSearch(tableSchemaRequest);
List<String> schemas = Lists.newArrayList();
if (CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {

View File

@ -0,0 +1,175 @@
package ai.chat2db.server.web.api.controller.ai;
import ai.chat2db.server.domain.api.param.ShowCreateTableParam;
import ai.chat2db.server.domain.api.param.TablePageQueryParam;
import ai.chat2db.server.domain.api.param.TableSelector;
import ai.chat2db.server.domain.api.service.TableService;
import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.base.wrapper.result.PageResult;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter;
import ai.chat2db.server.web.api.controller.rdb.request.TableBriefQueryRequest;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.spi.model.Table;
import com.google.common.collect.Lists;
import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
* @author moji
*/
@RestController
@ConnectionInfoAspect
@RequestMapping("/api/ai/embedding")
@Slf4j
public class EmbeddingController extends ChatController {
@Resource
private GatewayClientService gatewayClientService;
@Autowired
private RdbWebConverter rdbWebConverter;
@Autowired
private TableService tableService;
/**
* save knowledge embeddings from pdf file
*
* @param request
* @return
* @throws IOException
*/
@PostMapping("/datasource")
@CrossOrigin
public ActionResult embeddings(@Valid TableBriefQueryRequest request)
throws Exception {
// query tables
request.setPageSize(1000);
TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(request);
TableSelector tableSelector = new TableSelector();
tableSelector.setColumnList(false);
tableSelector.setIndexList(false);
PageResult<Table> tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector);
List<Table> tables = tableDTOPageResult.getData();
if (CollectionUtils.isEmpty(tables)) {
return ActionResult.isSuccess();
}
String tableName = tables.get(0).getName();
String tableSchema = queryTableDdl(tableName, request);
if (StringUtils.isBlank(tableSchema)) {
throw new ParamBusinessException("tableSchema is empty");
}
// save first table embedding
TableSchemaRequest tableSchemaRequest = new TableSchemaRequest();
tableSchemaRequest.setDataSourceId(request.getDataSourceId());
tableSchemaRequest.setDeleteBeforeInsert(true);
String databaseName = StringUtils.isNotBlank(request.getDatabaseName()) ? request.getDatabaseName() : request.getSchemaName();
if (Objects.isNull(databaseName)) {
databaseName = "";
}
tableSchemaRequest.setDatabaseName(databaseName);
saveTableEmbedding(tableSchema, tableSchemaRequest);
// save other table embedding
tableSchemaRequest.setDeleteBeforeInsert(false);
for (int i = 1; i < tables.size(); i++) {
tableName = tables.get(i).getName();
tableSchema = queryTableDdl(tableName, request);
if (StringUtils.isBlank(tableSchema)) {
continue;
}
saveTableEmbedding(tableSchema, tableSchemaRequest);
}
// query all the tables
Long totalTableCount = tableDTOPageResult.getTotal();
Integer pageSize = queryParam.getPageSize();
if (pageSize < totalTableCount) {
for (int i = 2; i < totalTableCount/pageSize + 1; i++) {
queryParam.setPageNo(i);
tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector);
tables = tableDTOPageResult.getData();
for (Table table : tables) {
tableName = table.getName();
tableSchema = queryTableDdl(tableName, request);
if (StringUtils.isBlank(tableSchema)) {
continue;
}
saveTableEmbedding(tableSchema, tableSchemaRequest);
}
}
}
return ActionResult.isSuccess();
}
/**
* save table embedding
*
* @param tableSchema
* @param tableSchemaRequest
* @throws Exception
*/
private void saveTableEmbedding(String tableSchema, TableSchemaRequest tableSchemaRequest) throws Exception{
List<String> schemaList = Lists.newArrayList(tableSchema);
tableSchemaRequest.setSchemaList(schemaList);
List<List<BigDecimal>> contentVector = new ArrayList<>();
for(String str : schemaList){
// request embedding
FastChatEmbeddingResponse response = distributeAIEmbedding(str);
if(response == null){
continue;
}
contentVector.add(response.getData().get(0).getEmbedding());
}
tableSchemaRequest.setSchemaVector(contentVector);
// save table embedding
gatewayClientService.schemaVectorSave(tableSchemaRequest);
}
/**
* query table schema
*
* @param tableName
* @param request
* @return
*/
private String queryTableDdl(String tableName, TableBriefQueryRequest request) {
ShowCreateTableParam param = new ShowCreateTableParam();
param.setTableName(tableName);
param.setDataSourceId(request.getDataSourceId());
param.setDatabaseName(request.getDatabaseName());
param.setSchemaName(request.getSchemaName());
DataResult<String> tableSchema = tableService.showCreateTable(param);
return tableSchema.getData();
}
}

View File

@ -102,7 +102,9 @@ public class KnowledgeController extends ChatController {
contentVector.add(response.getData().get(0).getEmbedding());
// search embedding
DataResult<KnowledgeResponse> result = gatewayClientService.knowledgeVectorSearch(contentVector);
KnowledgeRequest knowledgeRequest = new KnowledgeRequest();
knowledgeRequest.setContentVector(contentVector);
DataResult<KnowledgeResponse> result = gatewayClientService.knowledgeVectorSearch(knowledgeRequest);
queryRequest.setPromptType(PromptType.TEXT_GENERATION.getCode());
String prompt = queryRequest.getMessage();
if (CollectionUtils.isNotEmpty(result.getData().getKnowledgeList())) {

View File

@ -81,15 +81,14 @@ public interface GatewayClientService {
* @return
*/
@Get("/api/milvus/knowledge/search")
DataResult<KnowledgeResponse> knowledgeVectorSearch(List<List<BigDecimal>> searchVectors);
DataResult<KnowledgeResponse> knowledgeVectorSearch(KnowledgeRequest searchVectors);
/**
* save table schema vector
*
* @param searchVectors
* @param datasourceId
* @param request
* @return
*/
@Get("/api/milvus/schema/search")
DataResult<TableSchemaResponse> schemaVectorSearch(List<List<BigDecimal>> searchVectors, Long datasourceId);
DataResult<TableSchemaResponse> schemaVectorSearch(TableSchemaRequest request);
}

View File

@ -16,9 +16,11 @@ public class TableSchemaRequest {
private Long dataSourceId;
private List<List<BigDecimal>> contentVector;
private String databaseName;
private List<String> sentenceList;
private List<java.util.List<BigDecimal>> schemaVector;
private List<String> schemaList;
private Boolean deleteBeforeInsert;
}