From 27b81756c2bb7ef49e883480260ef948c42468c2 Mon Sep 17 00:00:00 2001 From: robin <850379744@qq.com> Date: Sun, 15 Oct 2023 19:37:21 +0800 Subject: [PATCH] embedding --- .../domain/api/param/TableVectorParam.java | 41 +++ .../domain/api/service/TableService.java | 16 + .../domain/core/converter/TableConverter.java | 17 ++ .../domain/core/impl/TableServiceImpl.java | 40 ++- .../db/migration/V2_1_3__TableVector.sql | 12 + .../tools/base/enums/WhiteListTypeEnum.java | 30 ++ .../chat2db-server-web-api/pom.xml | 4 + .../web/api/controller/ai/ChatController.java | 26 +- .../controller/ai/EmbeddingController.java | 74 ++++- .../chat2db/client/Chat2DBAIStreamClient.java | 276 ++++++++++++++++++ .../ai/chat2db/client/Chat2dbAIClient.java | 17 +- .../ai/fastchat/client/FastChatAIClient.java | 2 +- .../client/FastChatAIStreamClient.java | 9 + .../api/controller/rdb/RdbDdlController.java | 49 ++-- .../rdb/converter/RdbWebConverter.java | 8 + .../web/api/http/GatewayClientService.java | 13 +- .../api/http/request/WhiteListRequest.java | 25 ++ 17 files changed, 603 insertions(+), 56 deletions(-) create mode 100644 chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/param/TableVectorParam.java create mode 100644 chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/converter/TableConverter.java create mode 100644 chat2db-server/chat2db-server-start/src/main/resources/db/migration/V2_1_3__TableVector.sql create mode 100644 chat2db-server/chat2db-server-tools/chat2db-server-tools-base/src/main/java/ai/chat2db/server/tools/base/enums/WhiteListTypeEnum.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/request/WhiteListRequest.java diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/param/TableVectorParam.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/param/TableVectorParam.java new file mode 100644 index 00000000..1aaeae81 --- /dev/null +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/param/TableVectorParam.java @@ -0,0 +1,41 @@ +package ai.chat2db.server.domain.api.param; + + +import jakarta.validation.constraints.NotNull; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +@NoArgsConstructor +@AllArgsConstructor +public class TableVectorParam { + + /** + * api key + */ + @NotNull + private String apiKey; + + /** + * 数据源连接ID + */ + private Long dataSourceId; + + /** + * 数据库名称 + */ + private String database; + + /** + * schema名称 + */ + private String schema; + + /** + * 向量保存状态 + */ + private String status; +} diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/service/TableService.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/service/TableService.java index fb59edce..93f245af 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/service/TableService.java +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/service/TableService.java @@ -113,4 +113,20 @@ public interface TableService { * @return */ TableMeta queryTableMeta(TypeQueryParam param); + + /** + * save table vector + * + * @param param + * @return + */ + ActionResult saveTableVector(TableVectorParam param); + + /** + * check if table vector saved status + * + * @param param + * @return + */ + DataResult checkTableVector(TableVectorParam param); } diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/converter/TableConverter.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/converter/TableConverter.java new file mode 100644 index 00000000..cc2b629c --- /dev/null +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/converter/TableConverter.java @@ -0,0 +1,17 @@ +package ai.chat2db.server.domain.core.converter; + +import ai.chat2db.server.domain.api.param.TableVectorParam; +import ai.chat2db.server.domain.repository.entity.TableVectorMappingDO; +import org.mapstruct.Mapper; + +@Mapper(componentModel = "spring") +public abstract class TableConverter { + + /** + * TableVectorParam to TableVectorMappingDO + * + * @param param + * @return + */ + public abstract TableVectorMappingDO toTableVectorMappingDO(TableVectorParam param); +} diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java index 402dde19..4a0a62fc 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java @@ -6,20 +6,21 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.Function; import java.util.stream.Collectors; +import ai.chat2db.server.domain.api.enums.TableVectorEnum; import ai.chat2db.server.domain.api.param.*; import ai.chat2db.server.domain.api.service.PinService; import ai.chat2db.server.domain.api.service.TableService; import ai.chat2db.server.domain.core.cache.CacheManage; import ai.chat2db.server.domain.core.converter.PinTableConverter; -import ai.chat2db.server.domain.repository.entity.TableCacheDO; -import ai.chat2db.server.domain.repository.entity.TableCacheVersionDO; -import ai.chat2db.server.domain.repository.entity.TeamDO; -import ai.chat2db.server.domain.repository.entity.TeamUserDO; +import ai.chat2db.server.domain.core.converter.TableConverter; +import ai.chat2db.server.domain.repository.entity.*; import ai.chat2db.server.domain.repository.mapper.TableCacheMapper; import ai.chat2db.server.domain.repository.mapper.TableCacheVersionMapper; +import ai.chat2db.server.domain.repository.mapper.TableVectorMappingMapper; import ai.chat2db.server.tools.base.wrapper.result.ActionResult; import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.base.wrapper.result.ListResult; @@ -63,9 +64,15 @@ public class TableServiceImpl implements TableService { @Autowired private TableCacheMapper tableCacheMapper; + @Autowired + private TableConverter tableConverter; + @Autowired private TableCacheVersionMapper tableCacheVersionMapper; + @Autowired + private TableVectorMappingMapper mappingMapper; + @Override public DataResult showCreateTable(ShowCreateTableParam param) { MetaData metaSchema = Chat2DBContext.getMetaData(); @@ -382,4 +389,29 @@ public class TableServiceImpl implements TableService { MetaData metaSchema = Chat2DBContext.getMetaData(); return metaSchema.getTableMeta(null, null, null); } + + @Override + public ActionResult saveTableVector(TableVectorParam param) { + if (checkTableVector(param).getData()) { + return ActionResult.isSuccess(); + } + TableVectorMappingDO mappingDO = tableConverter.toTableVectorMappingDO(param); + mappingDO.setStatus(TableVectorEnum.SAVED.getCode()); + mappingMapper.insert(mappingDO); + return ActionResult.isSuccess(); + } + + @Override + public DataResult checkTableVector(TableVectorParam param) { + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper(); + queryWrapper.eq(TableVectorMappingDO::getApiKey, param.getApiKey()); + queryWrapper.eq(TableVectorMappingDO::getDataSourceId, param.getDataSourceId()); + queryWrapper.eq(TableVectorMappingDO::getDatabase, param.getDatabase()); + queryWrapper.eq(TableVectorMappingDO::getSchema, param.getSchema()); + TableVectorMappingDO mappingDO = mappingMapper.selectOne(queryWrapper); + if (Objects.nonNull(mappingDO) && TableVectorEnum.SAVED.getCode().equals(mappingDO.getStatus())) { + return DataResult.of(true); + } + return DataResult.of(false); + } } diff --git a/chat2db-server/chat2db-server-start/src/main/resources/db/migration/V2_1_3__TableVector.sql b/chat2db-server/chat2db-server-start/src/main/resources/db/migration/V2_1_3__TableVector.sql new file mode 100644 index 00000000..ec61f576 --- /dev/null +++ b/chat2db-server/chat2db-server-start/src/main/resources/db/migration/V2_1_3__TableVector.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS `table_vector_mapping` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT '主键', + `api_key` varchar(128) DEFAULT NULL COMMENT 'api key', + `data_source_id` bigint(20) unsigned DEFAULT NULL COMMENT '数据源连接ID', + `database` text DEFAULT NULL COMMENT '数据库名称', + `schema` text DEFAULT NULL COMMENT 'schema名称', + `status` varchar(4) DEFAULT NULL COMMENT '向量保存状态', + PRIMARY KEY (`id`) + ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='milvus映射表保存记录' +; + +create INDEX idx_api_key on table_vector_mapping(api_key) ; diff --git a/chat2db-server/chat2db-server-tools/chat2db-server-tools-base/src/main/java/ai/chat2db/server/tools/base/enums/WhiteListTypeEnum.java b/chat2db-server/chat2db-server-tools/chat2db-server-tools-base/src/main/java/ai/chat2db/server/tools/base/enums/WhiteListTypeEnum.java new file mode 100644 index 00000000..4290040a --- /dev/null +++ b/chat2db-server/chat2db-server-tools/chat2db-server-tools-base/src/main/java/ai/chat2db/server/tools/base/enums/WhiteListTypeEnum.java @@ -0,0 +1,30 @@ +package ai.chat2db.server.tools.base.enums; + +import lombok.Getter; + +/** + * @author moji + * @version WhiteListTypeEnum.java, v 0.1 2022年09月25日 16:57 moji Exp $ + * @date 2022/09/25 + */ +@Getter +public enum WhiteListTypeEnum implements BaseEnum { + + /** + * 向量接口 + */ + VECTOR("VECTOR"), + + ; + + final String description; + + WhiteListTypeEnum(String description) { + this.description = description; + } + + @Override + public String getCode() { + return this.name(); + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml b/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml index 0ec584a7..9d41c0e0 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml @@ -78,6 +78,10 @@ org.apache.pdfbox pdfbox + + ai.chat2db + chat2db-spi + diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 701fb29a..e2590a32 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -106,9 +106,6 @@ public class ChatController { @Resource private GatewayClientService gatewayClientService; - @Getter - private OpenAiApi openAiApi; - /** * chat的超时时间 */ @@ -316,7 +313,7 @@ public class ChatController { buildSseEmitter(sseEmitter, uid); OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter); - Chat2dbAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener); + Chat2dbAIClient.getInstance().streamCompletions(messages, openAIEventSourceListener); LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT); return sseEmitter; } @@ -564,21 +561,18 @@ public class ChatController { public FastChatEmbeddingResponse distributeAIEmbedding(String input) { ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); - String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); - if (Objects.nonNull(config)) { - aiSqlSource = config.getContent(); + String aiSqlSource = config.getContent(); + if (Objects.isNull(aiSqlSource)) { + return null; } AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource); switch (Objects.requireNonNull(aiSqlSourceEnum)) { - case AZUREAI : - case OPENAI: case CHAT2DBAI: - return embeddingWithOpenAi(input); - case RESTAI : + return embeddingWithChat2dbAi(input); case FASTCHATAI: return embeddingWithFastChatAi(input); } - return embeddingWithFastChatAi(input); + return null; } /** @@ -599,11 +593,9 @@ public class ChatController { * @param input * @return */ - private FastChatEmbeddingResponse embeddingWithOpenAi(String input) { - Embedding embedding = Embedding.builder().input(input).build(); - Single embeddings = this.openAiApi.embeddings(embedding); - EmbeddingResponse embeddingResponse = embeddings.blockingGet(); - return chatConverter.response2response(embeddingResponse); + private FastChatEmbeddingResponse embeddingWithChat2dbAi(String input) { + FastChatEmbeddingResponse embeddings = Chat2dbAIClient.getInstance().embeddings(input); + return embeddings; } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java index 83303d12..7ea9413e 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java @@ -1,20 +1,29 @@ package ai.chat2db.server.web.api.controller.ai; +import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; +import ai.chat2db.server.domain.api.model.Config; import ai.chat2db.server.domain.api.param.ShowCreateTableParam; import ai.chat2db.server.domain.api.param.TablePageQueryParam; import ai.chat2db.server.domain.api.param.TableSelector; +import ai.chat2db.server.domain.api.param.TableVectorParam; +import ai.chat2db.server.domain.api.service.ConfigService; import ai.chat2db.server.domain.api.service.TableService; +import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum; import ai.chat2db.server.tools.base.wrapper.result.ActionResult; import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.base.wrapper.result.PageResult; import ai.chat2db.server.tools.common.exception.ParamBusinessException; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; +import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; +import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient; import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter; import ai.chat2db.server.web.api.controller.rdb.request.TableBriefQueryRequest; import ai.chat2db.server.web.api.controller.rdb.request.TableMilvusQueryRequest; import ai.chat2db.server.web.api.http.GatewayClientService; import ai.chat2db.server.web.api.http.request.TableSchemaRequest; +import ai.chat2db.server.web.api.http.request.WhiteListRequest; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; import ai.chat2db.spi.model.Table; import com.google.common.collect.Lists; import jakarta.annotation.Resource; @@ -23,10 +32,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.bind.annotation.CrossOrigin; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.*; import java.io.IOException; import java.math.BigDecimal; @@ -53,6 +59,15 @@ public class EmbeddingController extends ChatController { @Autowired private TableService tableService; + /** + * check if in white list + */ + @GetMapping("/white/check") + public DataResult checkInWhite(WhiteListRequest request) { + request.setWhiteType(WhiteListTypeEnum.VECTOR.getCode()); + return gatewayClientService.checkInWhite(request); + } + /** * save datasource embeddings * @@ -66,6 +81,7 @@ public class EmbeddingController extends ChatController { throws Exception { // query tables + request.setPageNo(1); request.setPageSize(1000); TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(request); TableSelector tableSelector = new TableSelector(); @@ -128,6 +144,51 @@ public class EmbeddingController extends ChatController { return ActionResult.isSuccess(); } + /** + * sync table vector + * + * @param param + */ + public void syncTableVector(TableBriefQueryRequest param) throws Exception { + TableVectorParam vectorParam = rdbWebConverter.param2param(param); + if (Objects.isNull(vectorParam.getDataSourceId())) { + return; + } + if (StringUtils.isBlank(vectorParam.getDatabase()) && StringUtils.isBlank(vectorParam.getSchema())) { + return; + } + DataResult result = tableService.checkTableVector(vectorParam); + if (result.getData()) { + return; + } + + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); + String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); + // only sync for chat2db ai + if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) { + return; + } + Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); + if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { + return; + } + + TableMilvusQueryRequest request = rdbWebConverter.request2request(param); + String apiKey = keyConfig.getContent(); + request.setApikey(apiKey); + + // check if in white list + boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData(); + if (!res) { + return; + } + + embeddings(request); + + tableService.saveTableVector(vectorParam); + } + /** * save table embedding * @@ -144,10 +205,13 @@ public class EmbeddingController extends ChatController { // request embedding FastChatEmbeddingResponse response = distributeAIEmbedding(str); if(response == null){ - continue; + throw new ParamBusinessException(); } contentVector.add(response.getData().get(0).getEmbedding()); } + if (CollectionUtils.isEmpty(contentVector)) { + throw new ParamBusinessException(); + } tableSchemaRequest.setSchemaVector(contentVector); // save table embedding diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java new file mode 100644 index 00000000..cf6ff6f3 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java @@ -0,0 +1,276 @@ +package ai.chat2db.server.web.api.controller.ai.chat2db.client; + +import ai.chat2db.server.tools.common.exception.ParamBusinessException; +import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi; +import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding; +import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; +import cn.hutool.http.ContentType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; +import com.unfbx.chatgpt.entity.chat.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.jetbrains.annotations.NotNull; +import retrofit2.Retrofit; +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +/** + * Fast Chat Aligned Client + * + * @author moji + */ +@Slf4j +public class Chat2DBAIStreamClient { + + /** + * apikey + */ + @Getter + @NotNull + private String apiKey; + + /** + * apiHost + */ + @Getter + @NotNull + private String apiHost; + + /** + * model + */ + @Getter + private String model; + + /** + * embeddingModel + */ + @Getter + private String embeddingModel; + + /** + * okHttpClient + */ + @Getter + private OkHttpClient okHttpClient; + + @Getter + private FastChatOpenAiApi fastChatOpenAiApi; + + + /** + * @param builder + */ + private Chat2DBAIStreamClient(Builder builder) { + this.apiKey = builder.apiKey; + this.apiHost = builder.apiHost; + if (!apiHost.endsWith("/")){ + apiHost = apiHost + "/"; + } + this.model = builder.model; + this.embeddingModel = builder.embeddingModel; + if (Objects.isNull(builder.okHttpClient)) { + builder.okHttpClient = this.okHttpClient(); + } + okHttpClient = builder.okHttpClient; + this.fastChatOpenAiApi = new Retrofit.Builder() + .baseUrl(apiHost) + .client(okHttpClient) + .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) + .addConverterFactory(JacksonConverterFactory.create()) + .build().create(FastChatOpenAiApi.class); + } + + /** + * okhttpclient + */ + private OkHttpClient okHttpClient() { + OkHttpClient okHttpClient = new OkHttpClient + .Builder() + .addInterceptor(new HeaderAuthorizationInterceptor(Lists.newArrayList(this.apiKey))) + .connectTimeout(50, TimeUnit.SECONDS) + .writeTimeout(50, TimeUnit.SECONDS) + .readTimeout(50, TimeUnit.SECONDS) + .build(); + return okHttpClient; + } + + /** + * 构造 + * + * @return + */ + public static Chat2DBAIStreamClient.Builder builder() { + return new Chat2DBAIStreamClient.Builder(); + } + + /** + * builder + */ + public static final class Builder { + private String apiKey; + + private String apiHost; + + private String model; + + private String embeddingModel; + + /** + * OkhttpClient + */ + private OkHttpClient okHttpClient; + + public Builder() { + } + + public Chat2DBAIStreamClient.Builder apiKey(String apiKeyValue) { + this.apiKey = apiKeyValue; + return this; + } + + /** + * @param apiHostValue + * @return + */ + public Chat2DBAIStreamClient.Builder apiHost(String apiHostValue) { + this.apiHost = apiHostValue; + return this; + } + + /** + * @param modelValue + * @return + */ + public Chat2DBAIStreamClient.Builder model(String modelValue) { + this.model = modelValue; + return this; + } + + public Chat2DBAIStreamClient.Builder embeddingModel(String embeddingModelValue) { + this.embeddingModel = embeddingModelValue; + return this; + } + + public Chat2DBAIStreamClient.Builder okHttpClient(OkHttpClient val) { + this.okHttpClient = val; + return this; + } + + public Chat2DBAIStreamClient build() { + return new Chat2DBAIStreamClient(this); + } + + } + + /** + * 问答接口 stream 形式 + * + * @param chatMessages + * @param eventSourceListener + */ + public void streamCompletions(List chatMessages, EventSourceListener eventSourceListener) { + if (CollectionUtils.isEmpty(chatMessages)) { + log.error("param error:Chat Prompt cannot be empty"); + throw new ParamBusinessException("prompt"); + } + if (Objects.isNull(eventSourceListener)) { + log.error("param error:ChatEventSourceListener cannot be empty"); + throw new ParamBusinessException(); + } + log.info("Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent()); + try { + ChatCompletion chatCompletion = ChatCompletion.builder() + .messages(chatMessages) + .stream(true) + .build(); + + EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); + ObjectMapper mapper = new ObjectMapper(); + String requestBody = mapper.writeValueAsString(chatCompletion); + Request request = new Request.Builder() + .url(this.apiHost + "v1/chat/completions") + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + //创建事件 + EventSource eventSource = factory.newEventSource(request, eventSourceListener); + log.info("finish invoking chat ai"); + } catch (Exception e) { + log.error("chat ai error", e); + eventSourceListener.onFailure(null, e, null); + throw new ParamBusinessException(); + } + } + + /** + * Creates an embedding vector representing the input text. + * + * @param input + * @return EmbeddingResponse + */ + public FastChatEmbeddingResponse embeddings(String input) { + FastChatEmbedding embedding = FastChatEmbedding.builder().input(input).build(); + if (StringUtils.isNotBlank(this.embeddingModel)) { + embedding.setModel(this.embeddingModel); + } + return this.embeddings(embedding); + } + + /** + * Creates an embedding vector representing the input text. + * + * @param embedding + * @return EmbeddingResponse + */ + public FastChatEmbeddingResponse embeddings(FastChatEmbedding embedding) { + try { + ObjectMapper mapper = new ObjectMapper(); + String requestBody = mapper.writeValueAsString(embedding); + Request request = new Request.Builder() + .url(this.apiHost + "v1/embeddings") + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + + FastChatEmbeddingResponse chatEmbeddingResponse = null; + Response response = this.okHttpClient.newCall(request).execute(); + StringBuilder body = new StringBuilder(); + if (response.isSuccessful()) { + ResponseBody responseBody = response.body(); + if (responseBody != null) { + // 获取响应体的输入流 + java.io.InputStream inputStream = responseBody.byteStream(); + java.io.BufferedReader reader = new java.io.BufferedReader(new java.io.InputStreamReader(inputStream)); + + String line; + while ((line = reader.readLine()) != null) { + // 在这里处理每行响应内容 + body.append(line); + } + + // 关闭流 + reader.close(); + inputStream.close(); + } + chatEmbeddingResponse = mapper.readValue(body.toString(), FastChatEmbeddingResponse.class); + } + log.info("finish invoking chat embedding"); + return chatEmbeddingResponse; + } catch (Exception e) { + log.error("chat ai error", e); + throw new ParamBusinessException(); + } + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2dbAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2dbAIClient.java index 3a60b9e5..c73a4704 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2dbAIClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2dbAIClient.java @@ -24,11 +24,15 @@ public class Chat2dbAIClient { */ public static final String CHAT2DB_OPENAI_HOST = "chat2db.apiHost"; + /** + * FASTCHAT OPENAI embedding model + */ + public static final String CHAT2DB_EMBEDDING_MODEL= "fastchat.embedding.model"; - private static OpenAiStreamClient CHAT2DB_AI_STREAM_CLIENT; - private static String apiKey; - public static OpenAiStreamClient getInstance() { + private static Chat2DBAIStreamClient CHAT2DB_AI_STREAM_CLIENT; + + public static Chat2DBAIStreamClient getInstance() { if (CHAT2DB_AI_STREAM_CLIENT != null) { return CHAT2DB_AI_STREAM_CLIENT; } else { @@ -36,7 +40,7 @@ public class Chat2dbAIClient { } } - private static OpenAiStreamClient singleton() { + private static Chat2DBAIStreamClient singleton() { if (CHAT2DB_AI_STREAM_CLIENT == null) { synchronized (Chat2dbAIClient.class) { if (CHAT2DB_AI_STREAM_CLIENT == null) { @@ -65,9 +69,8 @@ public class Chat2dbAIClient { apikey = ApplicationContextUtil.getProperty(CHAT2DB_OPENAI_KEY); } log.info("refresh chat2db apikey:{}", maskApiKey(apikey)); - CHAT2DB_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey( - Lists.newArrayList(apikey)).build(); - apiKey = apikey; + CHAT2DB_AI_STREAM_CLIENT = Chat2DBAIStreamClient.builder().apiHost(apiHost).apiKey( + apikey).build(); } private static String maskApiKey(String input) { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIClient.java index f027b8ec..4a642c39 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIClient.java @@ -32,7 +32,7 @@ public class FastChatAIClient { /** * FASTCHAT OPENAI embedding model */ - public static final String FASTCHAT_embedding_MODEL= "fastchat.embedding.model"; + public static final String FASTCHAT_EMBEDDING_MODEL = "fastchat.embedding.model"; private static FastChatAIStreamClient FASTCHAT_AI_CLIENT; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIStreamClient.java index 6a08ba11..fd371d37 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/fastchat/client/FastChatAIStreamClient.java @@ -21,6 +21,9 @@ import okhttp3.sse.EventSources; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.jetbrains.annotations.NotNull; +import retrofit2.Retrofit; +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; import java.util.List; import java.util.Objects; @@ -82,6 +85,12 @@ public class FastChatAIStreamClient { builder.okHttpClient = this.okHttpClient(); } okHttpClient = builder.okHttpClient; + this.fastChatOpenAiApi = new Retrofit.Builder() + .baseUrl(apiHost) + .client(okHttpClient) + .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) + .addConverterFactory(JacksonConverterFactory.create()) + .build().create(FastChatOpenAiApi.class); } /** diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDdlController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDdlController.java index 624531f4..d88275da 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDdlController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/RdbDdlController.java @@ -1,40 +1,33 @@ package ai.chat2db.server.web.api.controller.rdb; -import java.util.List; - import ai.chat2db.server.domain.api.param.*; import ai.chat2db.server.domain.api.param.datasource.DatabaseOperationParam; import ai.chat2db.server.domain.api.service.DatabaseService; import ai.chat2db.server.domain.api.service.DlTemplateService; import ai.chat2db.server.domain.api.service.TableService; -import ai.chat2db.server.web.api.controller.rdb.vo.*; -import ai.chat2db.spi.model.*; import ai.chat2db.server.tools.base.wrapper.result.ActionResult; import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.base.wrapper.result.ListResult; import ai.chat2db.server.tools.base.wrapper.result.PageResult; import ai.chat2db.server.tools.base.wrapper.result.web.WebPageResult; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; +import ai.chat2db.server.web.api.controller.ai.EmbeddingController; import ai.chat2db.server.web.api.controller.data.source.request.DataSourceBaseRequest; import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter; -import ai.chat2db.server.web.api.controller.rdb.request.DdlExportRequest; -import ai.chat2db.server.web.api.controller.rdb.request.TableBriefQueryRequest; -import ai.chat2db.server.web.api.controller.rdb.request.TableCreateDdlQueryRequest; -import ai.chat2db.server.web.api.controller.rdb.request.TableDeleteRequest; -import ai.chat2db.server.web.api.controller.rdb.request.TableDetailQueryRequest; -import ai.chat2db.server.web.api.controller.rdb.request.TableModifySqlRequest; -import ai.chat2db.server.web.api.controller.rdb.request.TableUpdateDdlQueryRequest; -import ai.chat2db.server.web.api.controller.rdb.request.UpdateDatabaseRequest; -import ai.chat2db.server.web.api.controller.rdb.request.UpdateSchemaRequest; - +import ai.chat2db.server.web.api.controller.rdb.request.*; +import ai.chat2db.server.web.api.controller.rdb.vo.*; +import ai.chat2db.spi.model.*; +import ai.chat2db.spi.sql.Chat2DBContext; +import ai.chat2db.spi.sql.ConnectInfo; import com.google.common.collect.Lists; import jakarta.validation.Valid; +import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.*; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * mysql表运维类 @@ -46,8 +39,9 @@ import org.springframework.web.bind.annotation.RestController; @ConnectionInfoAspect @RequestMapping("/api/rdb/ddl") @RestController +@Slf4j @Deprecated -public class RdbDdlController { +public class RdbDdlController extends EmbeddingController { @Autowired private TableService tableService; @@ -61,6 +55,8 @@ public class RdbDdlController { @Autowired private DatabaseService databaseService; + public static ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor(); + /** * 查询当前DB下的表列表 * @@ -76,6 +72,19 @@ public class RdbDdlController { PageResult tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector); List tableVOS = rdbWebConverter.tableDto2vo(tableDTOPageResult.getData()); + + ConnectInfo connectInfo = Chat2DBContext.getConnectInfo(); + singleThreadExecutor.submit(() -> { + try { + Chat2DBContext.putContext(connectInfo); + syncTableVector(request); + } catch (Exception e) { + log.error("sync table vector error", e); + } finally { + Chat2DBContext.removeContext(); + } + log.info("sync table vector finish"); + }); return WebPageResult.of(tableVOS, tableDTOPageResult.getTotal(), request.getPageNo(), request.getPageSize()); } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java index 9dd847fd..078fd828 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java @@ -224,4 +224,12 @@ public abstract class RdbWebConverter { public abstract UpdateSelectResultParam request2param(SelectResultUpdateRequest request); + + public abstract TableMilvusQueryRequest request2request(TableBriefQueryRequest request); + + @Mappings({ + @Mapping(source = "databaseName", target = "database"), + @Mapping(source = "schemaName", target = "schema"), + }) + public abstract TableVectorParam param2param(TableBriefQueryRequest request); } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/GatewayClientService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/GatewayClientService.java index 70e8f6d8..e296c47c 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/GatewayClientService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/GatewayClientService.java @@ -4,11 +4,10 @@ import ai.chat2db.server.tools.base.wrapper.result.ActionResult; import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.web.api.http.request.KnowledgeRequest; import ai.chat2db.server.web.api.http.request.TableSchemaRequest; +import ai.chat2db.server.web.api.http.request.WhiteListRequest; import ai.chat2db.server.web.api.http.response.*; import com.dtflys.forest.annotation.*; -import java.math.BigDecimal; -import java.util.List; /** * Gateway 的http 服务 @@ -91,4 +90,14 @@ public interface GatewayClientService { */ @Get("/api/client/milvus/schema/search") DataResult schemaVectorSearch(TableSchemaRequest request); + + /** + * check in white list + * + * @param whiteListRequest + * @return + */ + @Get("/api/client/whitelist/check") + DataResult checkInWhite(WhiteListRequest whiteListRequest); + } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/request/WhiteListRequest.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/request/WhiteListRequest.java new file mode 100644 index 00000000..93d419b0 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/http/request/WhiteListRequest.java @@ -0,0 +1,25 @@ +package ai.chat2db.server.web.api.http.request; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + + +@Data +@SuperBuilder +@NoArgsConstructor +@AllArgsConstructor +public class WhiteListRequest { + + /** + * api key + */ + private String apiKey; + + /** + * 白名单类型,如向量 + * @see ai.chat2db.server.tools.base.enums.WhiteListTypeEnum + */ + private String whiteType; +}