From 2beccd5cc4aa1599396c32848a1e1140baf15def Mon Sep 17 00:00:00 2001 From: robin <850379744@qq.com> Date: Tue, 17 Oct 2023 14:39:39 +0800 Subject: [PATCH] embedding query --- .../web/api/controller/ai/ChatController.java | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 32fc93b9..ae3fecc3 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -13,6 +13,7 @@ import java.util.stream.Collectors; import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; import ai.chat2db.server.domain.api.model.Config; import ai.chat2db.server.domain.api.model.DataSource; +import ai.chat2db.server.domain.api.param.ShowCreateTableParam; import ai.chat2db.server.domain.api.param.TableQueryParam; import ai.chat2db.server.domain.api.service.ConfigService; import ai.chat2db.server.domain.api.service.DataSourceService; @@ -444,22 +445,39 @@ public class ChatController { * @param tableNames * @return */ - private Map> buildTableColumn(TableQueryParam tableQueryParam, + private String buildTableColumn(TableQueryParam tableQueryParam, List tableNames) { if (CollectionUtils.isEmpty(tableNames)) { - return Maps.newHashMap(); + return ""; } - List tableColumns = Lists.newArrayList(); + List schemaContent = Lists.newArrayList(); try { - tableColumns = tableService.queryColumns(tableQueryParam); + schemaContent = tableNames.stream().map(tableName -> { + tableQueryParam.setTableName(tableName); + return queryTableDdl(tableName, tableQueryParam); + }).collect(Collectors.toList()); } catch (Exception exception) { log.error("query table error, do nothing"); } - if (CollectionUtils.isEmpty(tableColumns)) { - return Maps.newHashMap(); - } - return tableColumns.stream().filter(tableColumn -> tableNames.contains(tableColumn.getTableName())).collect( - Collectors.groupingBy(TableColumn::getTableName, Collectors.toList())); + + return JSON.toJSONString(schemaContent); + } + + /** + * query table schema + * + * @param tableName + * @param request + * @return + */ + private String queryTableDdl(String tableName, TableQueryParam request) { + ShowCreateTableParam param = new ShowCreateTableParam(); + param.setTableName(tableName); + param.setDataSourceId(request.getDataSourceId()); + param.setDatabaseName(request.getDatabaseName()); + param.setSchemaName(request.getSchemaName()); + DataResult tableSchema = tableService.showCreateTable(param); + return tableSchema.getData(); } /** @@ -478,12 +496,7 @@ public class ChatController { String properties = ""; if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); - Map> tableColumns = buildTableColumn(queryParam, queryRequest.getTableNames()); - List tableSchemas = tableColumns.entrySet().stream().map( - entry -> String.format("%s(%s)", entry.getKey(), - entry.getValue().stream().map(TableColumn::getName).collect( - Collectors.joining(", ")))).collect(Collectors.toList()); - properties = String.join("\n#", tableSchemas); + properties = buildTableColumn(queryParam, queryRequest.getTableNames()); } else { properties = queryDatabaseSchema(queryRequest); }