upgrade chat2db

This commit is contained in:
robin
2023-11-03 15:28:11 +08:00
parent bfd53615e8
commit 4c976ba732
4 changed files with 181 additions and 6 deletions

View File

@ -20,6 +20,7 @@ import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole;
import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient;
import ai.chat2db.server.web.api.controller.ai.baichuan.listener.BaichuanChatAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.ai.chat2db.listener.Chat2dbAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.claude.client.ClaudeAIClient;
import ai.chat2db.server.web.api.controller.ai.claude.listener.ClaudeAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletionsOptions;
@ -321,7 +322,7 @@ public class ChatController {
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
Chat2dbAIEventSourceListener openAIEventSourceListener = new Chat2dbAIEventSourceListener(sseEmitter);
Chat2dbAIClient.getInstance().streamCompletions(messages, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;

View File

@ -3,17 +3,13 @@ package ai.chat2db.server.web.api.controller.ai.baichuan.client;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.baichuan.interceptor.BaichuanHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import okio.BufferedSource;
import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;

View File

@ -1,13 +1,17 @@
package ai.chat2db.server.web.api.controller.ai.chat2db.client;
import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.chat2db.interceptor.Chat2dbHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor;
@ -198,6 +202,20 @@ public class Chat2DBAIStreamClient {
.messages(chatMessages)
.stream(true)
.build();
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
DataResult<Config> chat2dbModel = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_MODEL);
String model = Objects.nonNull(chat2dbModel.getData()) && StringUtils.isNotBlank(chat2dbModel.getData().getContent()) ? chat2dbModel.getData().getContent() : AiSqlSourceEnum.OPENAI.getCode();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(model);
switch (aiSqlSourceEnum) {
case BAICHUANAI:
chatCompletion = ChatCompletion.builder().messages(chatMessages).model("Baichuan2-53B").build();
break;
case ZHIPUAI:
chatCompletion = ChatCompletion.builder().messages(chatMessages).model("chatglm_turbo").build();
break;
default:
break;
}
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();

View File

@ -0,0 +1,160 @@
package ai.chat2db.server.web.api.controller.ai.chat2db.listener;
import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatCompletions;
import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatMessage;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse;
import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletions;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Objects;
/**
* 描述Chat2dbAIEventSourceListener
*
* @author https:www.unfbx.com
* @date 2023-02-22
*/
@Slf4j
public class Chat2dbAIEventSourceListener extends EventSourceListener {
private SseEmitter sseEmitter;
public Chat2dbAIEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
/**
* {@inheritDoc}
*/
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("Chat2db AI 建立sse连接...");
}
/**
* {@inheritDoc}
*/
@SneakyThrows
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("Chat2db AI 返回数据:{}", data);
if (data.equals("[DONE]")) {
log.info("Chat2db AI 返回数据结束了");
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
sseEmitter.complete();
return;
}
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
DataResult<Config> chat2dbModel = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_MODEL);
String model = Objects.nonNull(chat2dbModel.getData()) && StringUtils.isNotBlank(chat2dbModel.getData().getContent()) ? chat2dbModel.getData().getContent() : AiSqlSourceEnum.OPENAI.getCode();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(model);
String text = "";
String completionId = null;
// 读取Json
switch (aiSqlSourceEnum) {
case BAICHUANAI:
BaichuanChatCompletions chatCompletions = mapper.readValue(data, BaichuanChatCompletions.class);
for (BaichuanChatMessage message : chatCompletions.getData().getMessages()) {
if (message != null) {
if (message.getContent() != null) {
text = message.getContent();
}
}
}
break;
case ZHIPUAI:
ZhipuChatCompletions zhipuChatCompletions = mapper.readValue(data, ZhipuChatCompletions.class);
text = zhipuChatCompletions.getData();
if (Objects.isNull(text)) {
for (FastChatMessage message : zhipuChatCompletions.getBody().getChoices()) {
if (message != null && message.getContent() != null) {
text = message.getContent();
}
}
}
break;
default:
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
text = completionResponse.getChoices().get(0).getDelta() == null
? completionResponse.getChoices().get(0).getText()
: completionResponse.getChoices().get(0).getDelta().getContent();
completionId = completionResponse.getId();
break;
}
Message message = new Message();
if (text != null) {
message.setContent(text);
sseEmitter.send(SseEmitter.event()
.id(completionId)
.data(message)
.reconnectTime(3000));
}
}
@Override
public void onClosed(EventSource eventSource) {
sseEmitter.complete();
log.info("Chat2db AI 关闭sse连接...");
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
try {
if (Objects.isNull(response)) {
String message = t.getMessage();
Message sseMessage = new Message();
sseMessage.setContent(message);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(sseMessage));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
return;
}
ResponseBody body = response.body();
String bodyString = null;
if (Objects.nonNull(body)) {
bodyString = body.string();
log.error("Chat2db AI sse连接异常data{}", bodyString, t);
} else {
log.error("Chat2db AI sse连接异常data{}", response, t);
}
eventSource.cancel();
Message message = new Message();
message.setContent("Chat2db AI Error" + bodyString);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(message));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
} catch (Exception exception) {
log.error("发送数据异常:", exception);
}
}
}