ansj segment

This commit is contained in:
robin
2023-10-26 15:49:04 +08:00
parent 68b48717e4
commit 0c1da52ba7
4 changed files with 24 additions and 37 deletions

View File

@ -491,13 +491,7 @@ public class ChatController {
TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
properties = buildTableColumn(queryParam, queryRequest.getTableNames());
} else {
String apiKey = getApiKey();
if (StringUtils.isNotBlank(apiKey)) {
boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData();
if (res) {
properties = queryDatabaseSchema(queryRequest);
}
}
properties = mappingDatabaseSchema(queryRequest);
}
String prompt = queryRequest.getMessage();
String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
@ -525,7 +519,7 @@ public class ChatController {
*
* @return
*/
private String getApiKey() {
public String getApiKey() {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
@ -556,6 +550,19 @@ public class ChatController {
return dataSourceType;
}
public String mappingDatabaseSchema(ChatQueryRequest queryRequest) {
String properties = "";
String apiKey = getApiKey();
if (StringUtils.isNotBlank(apiKey)) {
boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData();
if (res) {
properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest);
}
} else {
properties = querySchemaByEs(queryRequest);
}
return properties;
}
/**
* query database schema
@ -629,7 +636,9 @@ public class ChatController {
schemas.add(data.getTableSchemaContent());
}
}
return JSON.toJSONString(schemas);
String res = JSON.toJSONString(schemas);
log.info("search es result:{}", res);
return res;
} catch (Exception exception) {
log.error("query es table error, do nothing");
return "";

View File

@ -242,20 +242,12 @@ public class EmbeddingController extends ChatController {
return;
}
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// only sync for chat2db ai
if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) {
return;
}
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
String apiKey = getApiKey();
if (StringUtils.isBlank(apiKey)) {
return;
}
TableMilvusQueryRequest request = rdbWebConverter.request2request(param);
String apiKey = keyConfig.getContent();
request.setApikey(apiKey);
vectorParam.setApiKey(apiKey);
@ -318,29 +310,13 @@ public class EmbeddingController extends ChatController {
return;
}
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// only sync for chat2db ai
if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) {
return;
}
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
return;
}
String apiKey = keyConfig.getContent();
TableVectorParam vectorParam = rdbWebConverter.param2param(param);
vectorParam.setApiKey(apiKey);
DataResult<Boolean> result = tableService.checkTableVector(vectorParam);
if (result.getData()) {
String apiKey = getApiKey();
if (StringUtils.isBlank(apiKey)) {
return;
}
esParam.setApiKey(apiKey);
es(esParam);
tableService.saveTableVector(vectorParam);
}
/**

View File

@ -78,6 +78,7 @@ public class RdbDdlController extends EmbeddingController {
try {
Chat2DBContext.putContext(connectInfo);
syncTableVector(request);
syncTableEs(request);
} catch (Exception e) {
log.error("sync table vector error", e);
} finally {

View File

@ -69,6 +69,7 @@ public class TableController extends EmbeddingController {
try {
Chat2DBContext.putContext(connectInfo);
syncTableVector(request);
syncTableEs(request);
} catch (Exception e) {
log.error("sync table vector error", e);
} finally {