vector update

This commit is contained in:
robin
2023-10-22 20:43:13 +08:00
parent cecf152c04
commit 07638b6211
10 changed files with 281 additions and 4 deletions

View File

@ -45,8 +45,11 @@ import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.controller.ai.request.ChatRequest;
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.model.EsTableSchema;
import ai.chat2db.server.web.api.http.model.TableSchema;
import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse;
import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
@ -498,7 +501,7 @@ public class ChatController {
TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
properties = buildTableColumn(queryParam, queryRequest.getTableNames());
} else {
properties = queryDatabaseSchema(queryRequest);
properties = querySchemaByEs(queryRequest);
}
String prompt = queryRequest.getMessage();
String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
@ -578,6 +581,41 @@ public class ChatController {
}
}
/**
* query database schema
*
* @param queryRequest
* @return
* @throws IOException
*/
public String querySchemaByEs(ChatQueryRequest queryRequest) {
// search embedding
EsTableSchemaRequest tableSchemaRequest = new EsTableSchemaRequest();
tableSchemaRequest.setSearchKey(queryRequest.getMessage());
tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId());
tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName());
tableSchemaRequest.setSchemaName(queryRequest.getSchemaName());
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
return "";
}
tableSchemaRequest.setApiKey(keyConfig.getContent());
try {
DataResult<EsTableSchemaResponse> result = gatewayClientService.schemaEsSearch(tableSchemaRequest);
List<String> schemas = Lists.newArrayList();
if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
for(EsTableSchema data: result.getData().getTableSchemas()){
schemas.add(data.getTableSchemaContent());
}
}
return JSON.toJSONString(schemas);
} catch (Exception exception) {
log.error("query es table error, do nothing");
return "";
}
}
/**
* distribute embedding with different AI
*

View File

@ -21,6 +21,7 @@ import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter;
import ai.chat2db.server.web.api.controller.rdb.request.TableBriefQueryRequest;
import ai.chat2db.server.web.api.controller.rdb.request.TableMilvusQueryRequest;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.server.web.api.http.request.WhiteListRequest;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
@ -147,6 +148,81 @@ public class EmbeddingController extends ChatController {
return ActionResult.isSuccess();
}
/**
* save datasource schema
*
* @param request
* @return
* @throws IOException
*/
@PostMapping("/datasource/es")
@CrossOrigin
public ActionResult es(@Valid EsTableSchemaRequest request)
throws Exception {
// query tables
TablePageQueryParam queryParam = rdbWebConverter.schemaReq2page(request);
TableSelector tableSelector = new TableSelector();
tableSelector.setColumnList(false);
tableSelector.setIndexList(false);
queryParam.setPageNo(1);
queryParam.setPageSize(1000);
PageResult<Table> tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector);
List<Table> tables = tableDTOPageResult.getData();
if (CollectionUtils.isEmpty(tables)) {
return ActionResult.isSuccess();
}
String tableName = tables.get(0).getName();
String tableSchema = queryTableDdlByEs(tableName, request);
request.setTableName(tableName);
request.setTableSchemaContent(tableSchema);
if (StringUtils.isBlank(tableSchema)) {
throw new ParamBusinessException("tableSchema is empty");
}
// save first table embedding
request.setDeleteBeforeInsert(true);
saveTableEs(request);
// save other table embedding
request.setDeleteBeforeInsert(false);
for (int i = 1; i < tables.size(); i++) {
tableName = tables.get(i).getName();
tableSchema = queryTableDdlByEs(tableName, request);
if (StringUtils.isBlank(tableSchema)) {
continue;
}
request.setTableName(tableName);
request.setTableSchemaContent(tableSchema);
saveTableEs(request);
}
// query all the tables
Long totalTableCount = tableDTOPageResult.getTotal();
Integer pageSize = queryParam.getPageSize();
if (pageSize < totalTableCount) {
for (int i = 2; i < totalTableCount/pageSize + 1; i++) {
queryParam.setPageNo(i);
tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector);
tables = tableDTOPageResult.getData();
for (Table table : tables) {
tableName = table.getName();
tableSchema = queryTableDdlByEs(tableName, request);
if (StringUtils.isBlank(tableSchema)) {
continue;
}
request.setTableName(tableName);
request.setTableSchemaContent(tableSchema);
saveTableEs(request);
}
}
}
return ActionResult.isSuccess();
}
/**
* sync table vector
*
@ -223,6 +299,56 @@ public class EmbeddingController extends ChatController {
gatewayClientService.schemaVectorSave(tableSchemaRequest);
}
/**
* sync table vector
*
* @param param
*/
public void syncTableEs(TableBriefQueryRequest param) throws Exception {
EsTableSchemaRequest esParam = rdbWebConverter.req2req(param);
if (Objects.isNull(esParam.getDataSourceId())) {
return;
}
if (StringUtils.isBlank(esParam.getDatabaseName()) && StringUtils.isBlank(esParam.getSchemaName())) {
return;
}
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// only sync for chat2db ai
if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) {
return;
}
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
return;
}
String apiKey = keyConfig.getContent();
TableVectorParam vectorParam = rdbWebConverter.param2param(param);
vectorParam.setApiKey(apiKey);
DataResult<Boolean> result = tableService.checkTableVector(vectorParam);
if (result.getData()) {
return;
}
esParam.setApiKey(apiKey);
es(esParam);
tableService.saveTableVector(vectorParam);
}
/**
* save table schema
*
* @param tableSchemaRequest
* @throws Exception
*/
private void saveTableEs(EsTableSchemaRequest tableSchemaRequest) throws Exception{
// save table es
gatewayClientService.schemaEsSave(tableSchemaRequest);
}
/**
* query table schema
*
@ -240,4 +366,21 @@ public class EmbeddingController extends ChatController {
return tableSchema.getData();
}
/**
* query table schema
*
* @param tableName
* @param request
* @return
*/
private String queryTableDdlByEs(String tableName, EsTableSchemaRequest request) {
ShowCreateTableParam param = new ShowCreateTableParam();
param.setTableName(tableName);
param.setDataSourceId(request.getDataSourceId());
param.setDatabaseName(request.getDatabaseName());
param.setSchemaName(request.getSchemaName());
DataResult<String> tableSchema = tableService.showCreateTable(param);
return tableSchema.getData();
}
}

View File

@ -64,7 +64,7 @@ public class TextGenerationController extends ChatController {
// query database schema info
String databaseType = queryDatabaseType(queryRequest);
String schemas = queryDatabaseSchema(queryRequest);
String schemas = querySchemaByEs(queryRequest);
if (StringUtils.isNotBlank(schemas)) {
databaseType = String.format(", given a %s database schema", databaseType);
schemas = String.format("This query will run on a database whose schema is represented in this string:\n$s", schemas);

View File

@ -77,7 +77,7 @@ public class RdbDdlController extends EmbeddingController {
singleThreadExecutor.submit(() -> {
try {
Chat2DBContext.putContext(connectInfo);
syncTableVector(request);
syncTableEs(request);
} catch (Exception e) {
log.error("sync table vector error", e);
} finally {

View File

@ -68,7 +68,7 @@ public class TableController extends EmbeddingController {
singleThreadExecutor.submit(() -> {
try {
Chat2DBContext.putContext(connectInfo);
syncTableVector(request);
syncTableEs(request);
} catch (Exception e) {
log.error("sync table vector error", e);
} finally {

View File

@ -12,6 +12,7 @@ import ai.chat2db.server.web.api.controller.rdb.vo.MetaSchemaVO;
import ai.chat2db.server.web.api.controller.rdb.vo.SchemaVO;
import ai.chat2db.server.web.api.controller.rdb.vo.SqlVO;
import ai.chat2db.server.web.api.controller.rdb.vo.TableVO;
import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest;
import ai.chat2db.spi.model.Database;
import ai.chat2db.spi.model.ExecuteResult;
import ai.chat2db.spi.model.MetaSchema;
@ -232,4 +233,8 @@ public abstract class RdbWebConverter {
@Mapping(source = "schemaName", target = "schema"),
})
public abstract TableVectorParam param2param(TableBriefQueryRequest request);
public abstract EsTableSchemaRequest req2req(TableBriefQueryRequest request);
public abstract TablePageQueryParam schemaReq2page(EsTableSchemaRequest request);
}

View File

@ -2,6 +2,7 @@ package ai.chat2db.server.web.api.http;
import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest;
import ai.chat2db.server.web.api.http.request.KnowledgeRequest;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.server.web.api.http.request.WhiteListRequest;
@ -73,6 +74,15 @@ public interface GatewayClientService {
@Post(url = "/api/client/milvus/schema/save", contentType = "application/json")
ActionResult schemaVectorSave(@Body TableSchemaRequest request);
/**
* save table schema vector
*
* @param request
* @return
*/
@Post(url = "/api/client/es/schema/save", contentType = "application/json")
ActionResult schemaEsSave(@Body EsTableSchemaRequest request);
/**
* save knowledge vector
*
@ -91,6 +101,15 @@ public interface GatewayClientService {
@Post(url = "/api/client/milvus/schema/search", contentType = "application/json")
DataResult<TableSchemaResponse> schemaVectorSearch(@Body TableSchemaRequest request);
/**
* save table schema vector
*
* @param request
* @return
*/
@Post(url = "/api/client/es/schema/search", contentType = "application/json")
DataResult<EsTableSchemaResponse> schemaEsSearch(@Body EsTableSchemaRequest request);
/**
* check in white list
*

View File

@ -0,0 +1,25 @@
package ai.chat2db.server.web.api.http.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class EsTableSchema {
private String dataSourceId;
private String databaseName;
private String apiKey;
private String schemaName;
private String tableName;
private String tableSchemaContent;
}

View File

@ -0,0 +1,29 @@
package ai.chat2db.server.web.api.http.request;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class EsTableSchemaRequest {
private Long dataSourceId;
private String databaseName;
private String apiKey;
private String schemaName;
private String tableName;
private String tableSchemaContent;
private String searchKey;
private Boolean deleteBeforeInsert;
}

View File

@ -0,0 +1,18 @@
package ai.chat2db.server.web.api.http.response;
import ai.chat2db.server.web.api.http.model.EsTableSchema;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import java.util.List;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class EsTableSchemaResponse {
private List<EsTableSchema> tableSchemas;
}