ansj segment

This commit is contained in:
robin
2023-10-26 15:49:04 +08:00
parent 68b48717e4
commit 0c1da52ba7
4 changed files with 24 additions and 37 deletions

View File

@ -491,13 +491,7 @@ public class ChatController {
TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
properties = buildTableColumn(queryParam, queryRequest.getTableNames()); properties = buildTableColumn(queryParam, queryRequest.getTableNames());
} else { } else {
String apiKey = getApiKey(); properties = mappingDatabaseSchema(queryRequest);
if (StringUtils.isNotBlank(apiKey)) {
boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData();
if (res) {
properties = queryDatabaseSchema(queryRequest);
}
}
} }
String prompt = queryRequest.getMessage(); String prompt = queryRequest.getMessage();
String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
@ -525,7 +519,7 @@ public class ChatController {
* *
* @return * @return
*/ */
private String getApiKey() { public String getApiKey() {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
@ -556,6 +550,19 @@ public class ChatController {
return dataSourceType; return dataSourceType;
} }
public String mappingDatabaseSchema(ChatQueryRequest queryRequest) {
String properties = "";
String apiKey = getApiKey();
if (StringUtils.isNotBlank(apiKey)) {
boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData();
if (res) {
properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest);
}
} else {
properties = querySchemaByEs(queryRequest);
}
return properties;
}
/** /**
* query database schema * query database schema
@ -629,7 +636,9 @@ public class ChatController {
schemas.add(data.getTableSchemaContent()); schemas.add(data.getTableSchemaContent());
} }
} }
return JSON.toJSONString(schemas); String res = JSON.toJSONString(schemas);
log.info("search es result:{}", res);
return res;
} catch (Exception exception) { } catch (Exception exception) {
log.error("query es table error, do nothing"); log.error("query es table error, do nothing");
return ""; return "";

View File

@ -242,20 +242,12 @@ public class EmbeddingController extends ChatController {
return; return;
} }
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); String apiKey = getApiKey();
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); if (StringUtils.isBlank(apiKey)) {
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// only sync for chat2db ai
if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) {
return;
}
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
return; return;
} }
TableMilvusQueryRequest request = rdbWebConverter.request2request(param); TableMilvusQueryRequest request = rdbWebConverter.request2request(param);
String apiKey = keyConfig.getContent();
request.setApikey(apiKey); request.setApikey(apiKey);
vectorParam.setApiKey(apiKey); vectorParam.setApiKey(apiKey);
@ -318,29 +310,13 @@ public class EmbeddingController extends ChatController {
return; return;
} }
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); String apiKey = getApiKey();
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); if (StringUtils.isBlank(apiKey)) {
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// only sync for chat2db ai
if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) {
return;
}
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
return;
}
String apiKey = keyConfig.getContent();
TableVectorParam vectorParam = rdbWebConverter.param2param(param);
vectorParam.setApiKey(apiKey);
DataResult<Boolean> result = tableService.checkTableVector(vectorParam);
if (result.getData()) {
return; return;
} }
esParam.setApiKey(apiKey); esParam.setApiKey(apiKey);
es(esParam); es(esParam);
tableService.saveTableVector(vectorParam);
} }
/** /**

View File

@ -78,6 +78,7 @@ public class RdbDdlController extends EmbeddingController {
try { try {
Chat2DBContext.putContext(connectInfo); Chat2DBContext.putContext(connectInfo);
syncTableVector(request); syncTableVector(request);
syncTableEs(request);
} catch (Exception e) { } catch (Exception e) {
log.error("sync table vector error", e); log.error("sync table vector error", e);
} finally { } finally {

View File

@ -69,6 +69,7 @@ public class TableController extends EmbeddingController {
try { try {
Chat2DBContext.putContext(connectInfo); Chat2DBContext.putContext(connectInfo);
syncTableVector(request); syncTableVector(request);
syncTableEs(request);
} catch (Exception e) { } catch (Exception e) {
log.error("sync table vector error", e); log.error("sync table vector error", e);
} finally { } finally {