embedding query

This commit is contained in:
robin
2023-10-17 14:39:39 +08:00
parent 64c6fda770
commit 2beccd5cc4

View File

@ -13,6 +13,7 @@ import java.util.stream.Collectors;
import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.model.DataSource;
import ai.chat2db.server.domain.api.param.ShowCreateTableParam;
import ai.chat2db.server.domain.api.param.TableQueryParam;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.domain.api.service.DataSourceService;
@ -444,22 +445,39 @@ public class ChatController {
* @param tableNames
* @return
*/
private Map<String, List<TableColumn>> buildTableColumn(TableQueryParam tableQueryParam,
private String buildTableColumn(TableQueryParam tableQueryParam,
List<String> tableNames) {
if (CollectionUtils.isEmpty(tableNames)) {
return Maps.newHashMap();
return "";
}
List<TableColumn> tableColumns = Lists.newArrayList();
List<String> schemaContent = Lists.newArrayList();
try {
tableColumns = tableService.queryColumns(tableQueryParam);
schemaContent = tableNames.stream().map(tableName -> {
tableQueryParam.setTableName(tableName);
return queryTableDdl(tableName, tableQueryParam);
}).collect(Collectors.toList());
} catch (Exception exception) {
log.error("query table error, do nothing");
}
if (CollectionUtils.isEmpty(tableColumns)) {
return Maps.newHashMap();
}
return tableColumns.stream().filter(tableColumn -> tableNames.contains(tableColumn.getTableName())).collect(
Collectors.groupingBy(TableColumn::getTableName, Collectors.toList()));
return JSON.toJSONString(schemaContent);
}
/**
* query table schema
*
* @param tableName
* @param request
* @return
*/
private String queryTableDdl(String tableName, TableQueryParam request) {
ShowCreateTableParam param = new ShowCreateTableParam();
param.setTableName(tableName);
param.setDataSourceId(request.getDataSourceId());
param.setDatabaseName(request.getDatabaseName());
param.setSchemaName(request.getSchemaName());
DataResult<String> tableSchema = tableService.showCreateTable(param);
return tableSchema.getData();
}
/**
@ -478,12 +496,7 @@ public class ChatController {
String properties = "";
if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) {
TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
Map<String, List<TableColumn>> tableColumns = buildTableColumn(queryParam, queryRequest.getTableNames());
List<String> tableSchemas = tableColumns.entrySet().stream().map(
entry -> String.format("%s(%s)", entry.getKey(),
entry.getValue().stream().map(TableColumn::getName).collect(
Collectors.joining(", ")))).collect(Collectors.toList());
properties = String.join("\n#", tableSchemas);
properties = buildTableColumn(queryParam, queryRequest.getTableNames());
} else {
properties = queryDatabaseSchema(queryRequest);
}