mirror of
https://github.com/CodePhiliaX/Chat2DB.git
synced 2025-07-29 10:43:06 +08:00
upgrade chat2db
This commit is contained in:
@ -20,6 +20,7 @@ import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole;
|
||||
import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient;
|
||||
import ai.chat2db.server.web.api.controller.ai.baichuan.listener.BaichuanChatAIEventSourceListener;
|
||||
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
|
||||
import ai.chat2db.server.web.api.controller.ai.chat2db.listener.Chat2dbAIEventSourceListener;
|
||||
import ai.chat2db.server.web.api.controller.ai.claude.client.ClaudeAIClient;
|
||||
import ai.chat2db.server.web.api.controller.ai.claude.listener.ClaudeAIEventSourceListener;
|
||||
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletionsOptions;
|
||||
@ -321,7 +322,7 @@ public class ChatController {
|
||||
messages.add(currentMessage);
|
||||
buildSseEmitter(sseEmitter, uid);
|
||||
|
||||
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
|
||||
Chat2dbAIEventSourceListener openAIEventSourceListener = new Chat2dbAIEventSourceListener(sseEmitter);
|
||||
Chat2dbAIClient.getInstance().streamCompletions(messages, openAIEventSourceListener);
|
||||
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
|
||||
return sseEmitter;
|
||||
|
@ -3,17 +3,13 @@ package ai.chat2db.server.web.api.controller.ai.baichuan.client;
|
||||
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
|
||||
import ai.chat2db.server.web.api.controller.ai.baichuan.interceptor.BaichuanHeaderAuthorizationInterceptor;
|
||||
import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatCompletionsOptions;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
|
||||
import cn.hutool.http.ContentType;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import okhttp3.sse.EventSources;
|
||||
import okio.BufferedSource;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
@ -1,13 +1,17 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.chat2db.client;
|
||||
|
||||
import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
|
||||
import ai.chat2db.server.domain.api.model.Config;
|
||||
import ai.chat2db.server.domain.api.service.ConfigService;
|
||||
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
|
||||
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
|
||||
import ai.chat2db.server.web.api.controller.ai.chat2db.interceptor.Chat2dbHeaderAuthorizationInterceptor;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
|
||||
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
|
||||
import cn.hutool.http.ContentType;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor;
|
||||
@ -198,6 +202,20 @@ public class Chat2DBAIStreamClient {
|
||||
.messages(chatMessages)
|
||||
.stream(true)
|
||||
.build();
|
||||
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
|
||||
DataResult<Config> chat2dbModel = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_MODEL);
|
||||
String model = Objects.nonNull(chat2dbModel.getData()) && StringUtils.isNotBlank(chat2dbModel.getData().getContent()) ? chat2dbModel.getData().getContent() : AiSqlSourceEnum.OPENAI.getCode();
|
||||
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(model);
|
||||
switch (aiSqlSourceEnum) {
|
||||
case BAICHUANAI:
|
||||
chatCompletion = ChatCompletion.builder().messages(chatMessages).model("Baichuan2-53B").build();
|
||||
break;
|
||||
case ZHIPUAI:
|
||||
chatCompletion = ChatCompletion.builder().messages(chatMessages).model("chatglm_turbo").build();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
|
@ -0,0 +1,160 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.chat2db.listener;
|
||||
|
||||
import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
|
||||
import ai.chat2db.server.domain.api.model.Config;
|
||||
import ai.chat2db.server.domain.api.service.ConfigService;
|
||||
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
|
||||
import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatCompletions;
|
||||
import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatMessage;
|
||||
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
|
||||
import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse;
|
||||
import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletions;
|
||||
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
|
||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.ResponseBody;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 描述:Chat2dbAIEventSourceListener
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* @date 2023-02-22
|
||||
*/
|
||||
@Slf4j
|
||||
public class Chat2dbAIEventSourceListener extends EventSourceListener {
|
||||
|
||||
private SseEmitter sseEmitter;
|
||||
|
||||
public Chat2dbAIEventSourceListener(SseEmitter sseEmitter) {
|
||||
this.sseEmitter = sseEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void onOpen(EventSource eventSource, Response response) {
|
||||
log.info("Chat2db AI 建立sse连接...");
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void onEvent(EventSource eventSource, String id, String type, String data) {
|
||||
log.info("Chat2db AI 返回数据:{}", data);
|
||||
if (data.equals("[DONE]")) {
|
||||
log.info("Chat2db AI 返回数据结束了");
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[DONE]")
|
||||
.data("[DONE]")
|
||||
.reconnectTime(3000));
|
||||
sseEmitter.complete();
|
||||
return;
|
||||
}
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
|
||||
DataResult<Config> chat2dbModel = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_MODEL);
|
||||
String model = Objects.nonNull(chat2dbModel.getData()) && StringUtils.isNotBlank(chat2dbModel.getData().getContent()) ? chat2dbModel.getData().getContent() : AiSqlSourceEnum.OPENAI.getCode();
|
||||
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(model);
|
||||
String text = "";
|
||||
String completionId = null;
|
||||
// 读取Json
|
||||
switch (aiSqlSourceEnum) {
|
||||
case BAICHUANAI:
|
||||
BaichuanChatCompletions chatCompletions = mapper.readValue(data, BaichuanChatCompletions.class);
|
||||
for (BaichuanChatMessage message : chatCompletions.getData().getMessages()) {
|
||||
if (message != null) {
|
||||
if (message.getContent() != null) {
|
||||
text = message.getContent();
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case ZHIPUAI:
|
||||
ZhipuChatCompletions zhipuChatCompletions = mapper.readValue(data, ZhipuChatCompletions.class);
|
||||
text = zhipuChatCompletions.getData();
|
||||
if (Objects.isNull(text)) {
|
||||
for (FastChatMessage message : zhipuChatCompletions.getBody().getChoices()) {
|
||||
if (message != null && message.getContent() != null) {
|
||||
text = message.getContent();
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
|
||||
text = completionResponse.getChoices().get(0).getDelta() == null
|
||||
? completionResponse.getChoices().get(0).getText()
|
||||
: completionResponse.getChoices().get(0).getDelta().getContent();
|
||||
completionId = completionResponse.getId();
|
||||
break;
|
||||
}
|
||||
Message message = new Message();
|
||||
if (text != null) {
|
||||
message.setContent(text);
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id(completionId)
|
||||
.data(message)
|
||||
.reconnectTime(3000));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(EventSource eventSource) {
|
||||
sseEmitter.complete();
|
||||
log.info("Chat2db AI 关闭sse连接...");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||
try {
|
||||
if (Objects.isNull(response)) {
|
||||
String message = t.getMessage();
|
||||
Message sseMessage = new Message();
|
||||
sseMessage.setContent(message);
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[ERROR]")
|
||||
.data(sseMessage));
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[DONE]")
|
||||
.data("[DONE]"));
|
||||
sseEmitter.complete();
|
||||
return;
|
||||
}
|
||||
ResponseBody body = response.body();
|
||||
String bodyString = null;
|
||||
if (Objects.nonNull(body)) {
|
||||
bodyString = body.string();
|
||||
log.error("Chat2db AI sse连接异常data:{}", bodyString, t);
|
||||
} else {
|
||||
log.error("Chat2db AI sse连接异常data:{}", response, t);
|
||||
}
|
||||
eventSource.cancel();
|
||||
Message message = new Message();
|
||||
message.setContent("Chat2db AI Error:" + bodyString);
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[ERROR]")
|
||||
.data(message));
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[DONE]")
|
||||
.data("[DONE]"));
|
||||
sseEmitter.complete();
|
||||
} catch (Exception exception) {
|
||||
log.error("发送数据异常:", exception);
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user