diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java index 1f63d4ab..34436ca3 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java @@ -196,6 +196,7 @@ public class Chat2DBAIStreamClient { try { ChatCompletion chatCompletion = ChatCompletion.builder() .messages(chatMessages) + .model(this.model) .stream(true) .build(); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index 2f0bcef0..2ae383eb 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -2,8 +2,19 @@ package ai.chat2db.server.web.api.controller.ai.openai.listener; import java.util.Objects; +import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; +import ai.chat2db.server.domain.api.model.Config; +import ai.chat2db.server.domain.api.service.ConfigService; +import ai.chat2db.server.tools.base.wrapper.result.DataResult; +import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatCompletions; +import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatMessage; +import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletions; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.unfbx.chatgpt.entity.chat.Message; import lombok.SneakyThrows; @@ -54,16 +65,50 @@ public class OpenAIEventSourceListener extends EventSourceListener { return; } ObjectMapper mapper = new ObjectMapper(); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + DataResult chat2dbModel = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_MODEL); + String model = Objects.nonNull(chat2dbModel.getData()) ? chat2dbModel.getData().getContent() : AiSqlSourceEnum.OPENAI.getCode(); + AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(model); + String text = ""; + String completionId = null; // 读取Json - ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class); - String text = completionResponse.getChoices().get(0).getDelta() == null - ? completionResponse.getChoices().get(0).getText() - : completionResponse.getChoices().get(0).getDelta().getContent(); + switch (aiSqlSourceEnum) { + case BAICHUANAI: + BaichuanChatCompletions chatCompletions = mapper.readValue(data, BaichuanChatCompletions.class); + for (BaichuanChatMessage message : chatCompletions.getData().getMessages()) { + if (message != null) { + if (message.getContent() != null) { + text = message.getContent(); + } + } + } + break; + case ZHIPUAI: + ZhipuChatCompletions zhipuChatCompletions = mapper.readValue(data, ZhipuChatCompletions.class); + text = zhipuChatCompletions.getData(); + if (Objects.isNull(text)) { + for (FastChatMessage message : zhipuChatCompletions.getBody().getChoices()) { + if (message != null && message.getContent() != null) { + text = message.getContent(); + } + } + } + break; + default: + ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class); + text = completionResponse.getChoices().get(0).getDelta() == null + ? completionResponse.getChoices().get(0).getText() + : completionResponse.getChoices().get(0).getDelta().getContent(); + completionId = completionResponse.getId(); + break; + } + Message message = new Message(); if (text != null) { message.setContent(text); sseEmitter.send(SseEmitter.event() - .id(completionResponse.getId()) + .id(completionId) .data(message) .reconnectTime(3000)); }