From fbf4014d3ba312202694f8726659d0e92eb52fb9 Mon Sep 17 00:00:00 2001 From: robinji0 <850379744@qq.com> Date: Tue, 31 Oct 2023 15:59:40 +0800 Subject: [PATCH] add baichuan support --- .../server/domain/api/model/AIConfig.java | 5 + .../web/api/controller/ai/ChatController.java | 50 +++- .../ai/baichuan/client/BaichuanAIClient.java | 90 ++++++++ .../client/BaichuanAIStreamClient.java | 214 ++++++++++++++++++ ...aichuanHeaderAuthorizationInterceptor.java | 84 +++++++ .../BaichuanChatAIEventSourceListener.java | 136 +++++++++++ .../model/BaichuanChatCompletions.java | 52 +++++ .../model/BaichuanChatCompletionsUsage.java | 53 +++++ .../ai/baichuan/model/BaichuanChatData.java | 39 ++++ .../baichuan/model/BaichuanChatMessage.java | 30 +++ .../controller/config/ConfigController.java | 36 ++- .../config/request/AIConfigCreateRequest.java | 5 + 12 files changed, 781 insertions(+), 13 deletions(-) create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIClient.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIStreamClient.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/interceptor/BaichuanHeaderAuthorizationInterceptor.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/listener/BaichuanChatAIEventSourceListener.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletions.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletionsUsage.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatData.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatMessage.java diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/model/AIConfig.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/model/AIConfig.java index 06bde9f9..48bf0d61 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/model/AIConfig.java +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/model/AIConfig.java @@ -17,6 +17,11 @@ public class AIConfig { */ private String apiKey = ""; + /** + * SECRETKEY + */ + private String secretKey = ""; + /** * APIHOST */ diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 1de6d595..59d88067 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -17,6 +17,8 @@ import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.listener.AzureOpenAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatMessage; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole; +import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient; +import ai.chat2db.server.web.api.controller.ai.baichuan.listener.BaichuanChatAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; import ai.chat2db.server.web.api.controller.ai.claude.client.ClaudeAIClient; import ai.chat2db.server.web.api.controller.ai.claude.listener.ClaudeAIEventSourceListener; @@ -361,16 +363,7 @@ public class ChatController { */ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = buildPrompt(queryRequest); - List messages = (List)LocalCache.CACHE.get(uid); - if (CollectionUtils.isNotEmpty(messages)) { - if (messages.size() >= contextLength) { - messages = messages.subList(1, contextLength); - } - } else { - messages = Lists.newArrayList(); - } - FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); - messages.add(currentMessage); + List messages = getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -381,7 +374,7 @@ public class ChatController { } /** - * chat with fast chat openai + * chat with baichuan chat openai * * @param queryRequest * @param sseEmitter @@ -389,8 +382,26 @@ public class ChatController { * @return * @throws IOException */ - private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { + private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = buildPrompt(queryRequest); + List messages = getFastChatMessage(uid, prompt); + + buildSseEmitter(sseEmitter, uid); + + BaichuanChatAIEventSourceListener sourceListener = new BaichuanChatAIEventSourceListener(sseEmitter); + BaichuanAIClient.getInstance().streamCompletions(messages, sourceListener); + LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT); + return sseEmitter; + } + + /** + * get fast chat message + * + * @param uid + * @param prompt + * @return + */ + private List getFastChatMessage(String uid, String prompt) { List messages = (List)LocalCache.CACHE.get(uid); if (CollectionUtils.isNotEmpty(messages)) { if (messages.size() >= contextLength) { @@ -401,6 +412,21 @@ public class ChatController { } FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); messages.add(currentMessage); + return messages; + } + + /** + * chat with wenxin chat openai + * + * @param queryRequest + * @param sseEmitter + * @param uid + * @return + * @throws IOException + */ + private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { + String prompt = buildPrompt(queryRequest); + List messages = getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIClient.java new file mode 100644 index 00000000..ed7a3a18 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIClient.java @@ -0,0 +1,90 @@ + +package ai.chat2db.server.web.api.controller.ai.baichuan.client; + +import ai.chat2db.server.domain.api.model.Config; +import ai.chat2db.server.domain.api.service.ConfigService; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; + +/** + * @author moji + * @date 23/09/26 + */ +@Slf4j +public class BaichuanAIClient { + + /** + * BAICHUAN OPENAI KEY + */ + public static final String BAICHUAN_API_KEY = "baichuan.chatgpt.apiKey"; + + /** + * BAICHUAN OPENAI SECRET KEY + */ + public static final String BAICHUAN_SECRET_KEY = "baichuan.chatgpt.secretKey"; + + /** + * BAICHUAN OPENAI HOST + */ + public static final String BAICHUAN_HOST = "baichuan.host"; + + /** + * BAICHUAN OPENAI model + */ + public static final String BAICHUAN_MODEL= "baichuan.model"; + + /** + * BAICHUAN OPENAI embedding model + */ + public static final String BAICHUAN_EMBEDDING_MODEL = "baichuan.embedding.model"; + + private static BaichuanAIStreamClient BAICHUAN_AI_CLIENT; + + + public static BaichuanAIStreamClient getInstance() { + if (BAICHUAN_AI_CLIENT != null) { + return BAICHUAN_AI_CLIENT; + } else { + return singleton(); + } + } + + private static BaichuanAIStreamClient singleton() { + if (BAICHUAN_AI_CLIENT == null) { + synchronized (BaichuanAIClient.class) { + if (BAICHUAN_AI_CLIENT == null) { + refresh(); + } + } + } + return BAICHUAN_AI_CLIENT; + } + + public static void refresh() { + String apiKey = ""; + String apiHost = ""; + String model = ""; + String secretKey = ""; + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config apiHostConfig = configService.find(BAICHUAN_HOST).getData(); + if (apiHostConfig != null && StringUtils.isNotBlank(apiHostConfig.getContent())) { + apiHost = apiHostConfig.getContent(); + } + Config config = configService.find(BAICHUAN_API_KEY).getData(); + if (config != null && StringUtils.isNotBlank(config.getContent())) { + apiKey = config.getContent(); + } + Config secretConfig = configService.find(BAICHUAN_SECRET_KEY).getData(); + if (secretConfig != null && StringUtils.isNotBlank(secretConfig.getContent())) { + secretKey = secretConfig.getContent(); + } + Config deployConfig = configService.find(BAICHUAN_MODEL).getData(); + if (deployConfig != null && StringUtils.isNotBlank(deployConfig.getContent())) { + model = deployConfig.getContent(); + } + BAICHUAN_AI_CLIENT = BaichuanAIStreamClient.builder().apiKey(apiKey).apiHost(apiHost).model(model) + .build(); + } + +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIStreamClient.java new file mode 100644 index 00000000..cef3b59d --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/client/BaichuanAIStreamClient.java @@ -0,0 +1,214 @@ +package ai.chat2db.server.web.api.controller.ai.baichuan.client; + +import ai.chat2db.server.tools.common.exception.ParamBusinessException; +import ai.chat2db.server.web.api.controller.ai.baichuan.interceptor.BaichuanHeaderAuthorizationInterceptor; +import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; +import cn.hutool.http.ContentType; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; +import org.apache.commons.collections4.CollectionUtils; +import org.jetbrains.annotations.NotNull; + +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +/** + * Fast Chat Aligned Client + * + * @author moji + */ +@Slf4j +public class BaichuanAIStreamClient { + + /** + * apikey + */ + @Getter + @NotNull + private String apiKey; + + @Getter + @NotNull + private String secretKey; + + /** + * apiHost + */ + @Getter + @NotNull + private String apiHost; + + /** + * model + */ + @Getter + private String model; + + /** + * embeddingModel + */ + @Getter + private String embeddingModel; + + /** + * okHttpClient + */ + @Getter + private OkHttpClient okHttpClient; + + + /** + * @param builder + */ + private BaichuanAIStreamClient(Builder builder) { + this.apiKey = builder.apiKey; + this.apiHost = builder.apiHost; + this.model = builder.model; + this.secretKey = builder.secretKey; + this.embeddingModel = builder.embeddingModel; + if (Objects.isNull(builder.okHttpClient)) { + builder.okHttpClient = this.okHttpClient(); + } + okHttpClient = builder.okHttpClient; + } + + /** + * okhttpclient + */ + private OkHttpClient okHttpClient() { + OkHttpClient okHttpClient = new OkHttpClient + .Builder() + .addInterceptor(new BaichuanHeaderAuthorizationInterceptor(this.apiKey, this.secretKey)) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(50, TimeUnit.SECONDS) + .readTimeout(50, TimeUnit.SECONDS) + .build(); + return okHttpClient; + } + + /** + * 构造 + * + * @return + */ + public static BaichuanAIStreamClient.Builder builder() { + return new BaichuanAIStreamClient.Builder(); + } + + /** + * builder + */ + public static final class Builder { + private String apiKey; + + private String secretKey; + + private String apiHost; + + private String model; + + private String embeddingModel; + + /** + * OkhttpClient + */ + private OkHttpClient okHttpClient; + + public Builder() { + } + + public BaichuanAIStreamClient.Builder apiKey(String apiKeyValue) { + this.apiKey = apiKeyValue; + return this; + } + + public BaichuanAIStreamClient.Builder secretKey(String secretKey) { + this.secretKey = secretKey; + return this; + } + + /** + * @param apiHostValue + * @return + */ + public BaichuanAIStreamClient.Builder apiHost(String apiHostValue) { + this.apiHost = apiHostValue; + return this; + } + + /** + * @param modelValue + * @return + */ + public BaichuanAIStreamClient.Builder model(String modelValue) { + this.model = modelValue; + return this; + } + + public BaichuanAIStreamClient.Builder embeddingModel(String embeddingModelValue) { + this.embeddingModel = embeddingModelValue; + return this; + } + + public BaichuanAIStreamClient.Builder okHttpClient(OkHttpClient val) { + this.okHttpClient = val; + return this; + } + + public BaichuanAIStreamClient build() { + return new BaichuanAIStreamClient(this); + } + + } + + /** + * 问答接口 stream 形式 + * + * @param chatMessages + * @param eventSourceListener + */ + public void streamCompletions(List chatMessages, EventSourceListener eventSourceListener) { + if (CollectionUtils.isEmpty(chatMessages)) { + log.error("param error:Baichuan Chat Prompt cannot be empty"); + throw new ParamBusinessException("prompt"); + } + if (Objects.isNull(eventSourceListener)) { + log.error("param error:Baichuan ChatEventSourceListener cannot be empty"); + throw new ParamBusinessException(); + } + log.info("Baichuan AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent()); + try { + + FastChatCompletionsOptions chatCompletionsOptions = new FastChatCompletionsOptions(chatMessages); + chatCompletionsOptions.setStream(true); + chatCompletionsOptions.setModel(this.model); + + EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); + ObjectMapper mapper = new ObjectMapper(); + String requestBody = mapper.writeValueAsString(chatCompletionsOptions); + Request request = new Request.Builder() + .url(apiHost) + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + //创建事件 + EventSource eventSource = factory.newEventSource(request, eventSourceListener); + log.info("finish invoking baichuan ai"); + } catch (Exception e) { + log.error("baichuan ai error", e); + eventSourceListener.onFailure(null, e, null); + throw new ParamBusinessException(); + } + } + +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/interceptor/BaichuanHeaderAuthorizationInterceptor.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/interceptor/BaichuanHeaderAuthorizationInterceptor.java new file mode 100644 index 00000000..01054a72 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/interceptor/BaichuanHeaderAuthorizationInterceptor.java @@ -0,0 +1,84 @@ +package ai.chat2db.server.web.api.controller.ai.baichuan.interceptor; + +import cn.hutool.http.ContentType; +import cn.hutool.http.Header; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; + +/** + * header apikey + * + * @author grt + * @since 2023-03-23 + */ +@Slf4j +@Getter +public class BaichuanHeaderAuthorizationInterceptor implements Interceptor { + + private String apiKey; + + private String secretKey; + + public BaichuanHeaderAuthorizationInterceptor(String apiKey, String secretKey) { + this.apiKey = apiKey; + this.secretKey = secretKey; + } + + + @Override + public Response intercept(Chain chain) throws IOException { + Request originalRequest = chain.request(); + + // 获取当前的时间戳(UTC标准时间戳) + long timestamp = System.currentTimeMillis() / 1000; + + // 构造 HTTP-Body,这里需要根据实际情况构造你的请求体 + // 这里示例构造一个空的请求体 + RequestBody requestBody = RequestBody.create("", MediaType.parse("text/plain")); + + // 计算 X-BC-Signature + String signature = calculateSignature(secretKey, requestBody, timestamp); + + // 创建新的请求,并添加自定义请求头 + Request newRequest = originalRequest.newBuilder() + .addHeader(Header.AUTHORIZATION.getValue(), "Bearer " + apiKey) + .addHeader(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) + .addHeader("X-BC-Sign-Algo", "MD5") + .addHeader("X-BC-Timestamp", String.valueOf(timestamp)) + .addHeader("X-BC-Signature", signature) + .method(originalRequest.method(), originalRequest.body()) + .build(); + + return chain.proceed(newRequest); + } + + private String calculateSignature(String secretKey, RequestBody body, long timestamp) { + try { + String requestBody = bodyToString(body); + String rawSignature = secretKey + requestBody + timestamp; + + // 使用 MD5 计算签名 + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] mdBytes = md.digest(rawSignature.getBytes(StandardCharsets.UTF_8)); + + // 将 MD5 字节数组转换为 Base64 编码的字符串 + return Base64.getEncoder().encodeToString(mdBytes); + } catch (IOException | NoSuchAlgorithmException e) { + log.error("baichuan secret key md5 error", e); + return ""; + } + } + + private String bodyToString(RequestBody body) throws IOException { + // 将 RequestBody 转换为字符串 + return body == null ? "" : body.toString(); + } +} + diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/listener/BaichuanChatAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/listener/BaichuanChatAIEventSourceListener.java new file mode 100644 index 00000000..d9e8e3fd --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/listener/BaichuanChatAIEventSourceListener.java @@ -0,0 +1,136 @@ +package ai.chat2db.server.web.api.controller.ai.baichuan.listener; + +import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatCompletions; +import ai.chat2db.server.web.api.controller.ai.baichuan.model.BaichuanChatMessage; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.unfbx.chatgpt.entity.chat.Message; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.apache.commons.lang3.StringUtils; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.Objects; + +/** + * 描述:OpenAIEventSourceListener + * + * @author https:www.unfbx.com + * @date 2023-02-22 + */ +@Slf4j +public class BaichuanChatAIEventSourceListener extends EventSourceListener { + + private SseEmitter sseEmitter; + + private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); + + public BaichuanChatAIEventSourceListener(SseEmitter sseEmitter) { + this.sseEmitter = sseEmitter; + } + + /** + * {@inheritDoc} + */ + @Override + public void onOpen(EventSource eventSource, Response response) { + log.info("Baichuan Chat Sse connecting..."); + } + + /** + * {@inheritDoc} + */ + @SneakyThrows + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + log.info("Baichuan Chat AI response data:{}", data); + if (data.equals("[DONE]")) { + log.info("Baichuan Chat AI closed"); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]") + .reconnectTime(3000)); + sseEmitter.complete(); + return; + } + + BaichuanChatCompletions chatCompletions = mapper.readValue(data, BaichuanChatCompletions.class); + String text = ""; + log.info("code={} msg={}", chatCompletions.getCode(), chatCompletions.getMsg()); + for (BaichuanChatMessage message : chatCompletions.getData().getMessages()) { + if (message != null) { + log.info("message: {}, Chat Role: {}", message.getContent(), message.getRole()); + if (message.getContent() != null) { + text = message.getContent(); + } + } + } + + Message message = new Message(); + message.setContent(text); + sseEmitter.send(SseEmitter.event() + .id(null) + .data(message) + .reconnectTime(3000)); + } + + @Override + public void onClosed(EventSource eventSource) { + try { + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]")); + } catch (IOException e) { + throw new RuntimeException(e); + } + sseEmitter.complete(); + log.info("FastChatAI close sse connection..."); + } + + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + try { + if (Objects.isNull(response)) { + String message = t.getMessage(); + Message sseMessage = new Message(); + sseMessage.setContent(message); + sseEmitter.send(SseEmitter.event() + .id("[ERROR]") + .data(sseMessage)); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]")); + sseEmitter.complete(); + return; + } + ResponseBody body = response.body(); + String bodyString = Objects.nonNull(t) ? t.getMessage() : ""; + if (Objects.nonNull(body)) { + bodyString = body.string(); + if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) { + bodyString = t.getMessage(); + } + log.error("Baichuan Chat AI sse response:{}", bodyString); + } else { + log.error("Baichuan Chat AI sse response:{},error:{}", response, t); + } + eventSource.cancel(); + Message message = new Message(); + message.setContent("Baichuan Chat AI error:" + bodyString); + sseEmitter.send(SseEmitter.event() + .id("[ERROR]") + .data(message)); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]")); + sseEmitter.complete(); + } catch (Exception exception) { + log.error("Baichuan Chat AI send data error:", exception); + } + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletions.java new file mode 100644 index 00000000..f11a1b25 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletions.java @@ -0,0 +1,52 @@ +package ai.chat2db.server.web.api.controller.ai.baichuan.model; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; + +@Data +public class BaichuanChatCompletions { + + /* + * A unique identifier associated with this chat completions response. + */ + private String msg; + + private int code; + + /* + * The collection of completions choices associated with this completions response. + * Generally, `n` choices are generated per provided prompt with a default value of 1. + * Token limits and other settings may limit the number of choices generated. + */ + @JsonProperty(value = "data") + private BaichuanChatData data; + + /* + * Usage information for tokens processed and generated as part of this completions operation. + */ + private BaichuanChatCompletionsUsage usage; + + /** + * Creates an instance of ChatCompletions class. + * + * @param msg the id value to set. + * @param code the created value to set. + * @param choices the choices value to set. + * @param usage the usage value to set. + */ + @JsonCreator + private BaichuanChatCompletions( + @JsonProperty(value = "msg") String msg, + @JsonProperty(value = "code") int code, + @JsonProperty(value = "data") BaichuanChatData choices, + @JsonProperty(value = "usage") BaichuanChatCompletionsUsage usage) { + this.msg = msg; + this.code = code; + this.data = choices; + this.usage = usage; + } + +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletionsUsage.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletionsUsage.java new file mode 100644 index 00000000..37fc1ec6 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatCompletionsUsage.java @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. +package ai.chat2db.server.web.api.controller.ai.baichuan.model; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Representation of the token counts processed for a completions request. Counts consider all tokens across prompts, + * choices, choice alternates, best_of generations, and other consumers. + */ +@Data +@NoArgsConstructor +public final class BaichuanChatCompletionsUsage { + + /* + * The number of tokens generated across all completions emissions. + */ + @JsonProperty(value = "answer_tokens") + private int answerTokens; + + /* + * The number of tokens in the provided prompts for the completions request. + */ + @JsonProperty(value = "prompt_tokens") + private int promptTokens; + + /* + * The total number of tokens processed for the completions request and response. + */ + @JsonProperty(value = "total_tokens") + private int totalTokens; + + /** + * Creates an instance of CompletionsUsage class. + * + * @param completionTokens the completionTokens value to set. + * @param promptTokens the promptTokens value to set. + * @param totalTokens the totalTokens value to set. + */ + @JsonCreator + private BaichuanChatCompletionsUsage( + @JsonProperty(value = "answer_tokens") int completionTokens, + @JsonProperty(value = "prompt_tokens") int promptTokens, + @JsonProperty(value = "total_tokens") int totalTokens) { + this.answerTokens = completionTokens; + this.promptTokens = promptTokens; + this.totalTokens = totalTokens; + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatData.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatData.java new file mode 100644 index 00000000..04b0aa5d --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatData.java @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. +package ai.chat2db.server.web.api.controller.ai.baichuan.model; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; + +/** + * The representation of a single prompt completion as part of an overall completions request. Generally, `n` choices + * are generated per provided prompt with a default value of 1. Token limits and other settings may limit the number of + * choices generated. + */ +@Data +public final class BaichuanChatData { + + /* + * The log probabilities model for tokens associated with this completions choice. + */ + @JsonProperty(value = "messages") + private List messages; + + + + /** + * Creates an instance of Choice class. + * + * @param message the message value to set + */ + @JsonCreator + private BaichuanChatData( + @JsonProperty(value = "messages") List message) { + this.messages = message; + } + +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatMessage.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatMessage.java new file mode 100644 index 00000000..6dbde185 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/baichuan/model/BaichuanChatMessage.java @@ -0,0 +1,30 @@ +package ai.chat2db.server.web.api.controller.ai.baichuan.model; + +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +@Data +public class BaichuanChatMessage { + + + /* + * The role associated with this message payload. + */ + @JsonProperty(value = "role") + private FastChatRole role; + + /* + * The text associated with this message payload. + */ + @JsonProperty(value = "content") + private String content; + + /* + * Reason for finishing + */ + @JsonProperty(value = "finish_reason") + private String finishReason; + +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/ConfigController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/ConfigController.java index ca92e964..78958d17 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/ConfigController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/ConfigController.java @@ -13,6 +13,7 @@ import ai.chat2db.server.tools.base.wrapper.result.ActionResult; import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; +import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient; import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient; import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient; @@ -91,6 +92,8 @@ public class ConfigController { case WENXINAI: saveWenxinAIConfig(request); break; + case BAICHUANAI: + saveBaichuanAIConfig(request); } return ActionResult.isSuccess(); } @@ -186,7 +189,7 @@ public class ConfigController { } /** - * save common fast chat ai config + * save common wenxin chat ai config * * @param request */ @@ -200,6 +203,27 @@ public class ConfigController { WenxinAIClient.refresh(); } + /** + * save common fast chat ai config + * + * @param request + */ + private void saveBaichuanAIConfig(AIConfigCreateRequest request) { + SystemConfigParam apikeyParam = SystemConfigParam.builder().code(BaichuanAIClient.BAICHUAN_API_KEY) + .content(request.getApiKey()).build(); + configService.createOrUpdate(apikeyParam); + SystemConfigParam secretKeyParam = SystemConfigParam.builder().code(BaichuanAIClient.BAICHUAN_SECRET_KEY) + .content(request.getSecretKey()).build(); + configService.createOrUpdate(secretKeyParam); + SystemConfigParam apiHostParam = SystemConfigParam.builder().code(BaichuanAIClient.BAICHUAN_HOST) + .content(request.getApiHost()).build(); + configService.createOrUpdate(apiHostParam); + SystemConfigParam modelParam = SystemConfigParam.builder().code(BaichuanAIClient.BAICHUAN_MODEL) + .content(request.getModel()).build(); + configService.createOrUpdate(modelParam); + FastChatAIClient.refresh(); + } + @GetMapping("/system_config/{code}") public DataResult getSystemConfig(@PathVariable("code") String code) { DataResult result = configService.find(code); @@ -276,6 +300,16 @@ public class ConfigController { config.setApiKey(Objects.nonNull(wenxinAccessToken.getData()) ? wenxinAccessToken.getData().getContent() : ""); config.setApiHost(Objects.nonNull(wenxinApiHost.getData()) ? wenxinApiHost.getData().getContent() : ""); break; + case BAICHUANAI: + DataResult baichuanApiKey = configService.find(BaichuanAIClient.BAICHUAN_API_KEY); + DataResult baichuanSecretKey = configService.find(BaichuanAIClient.BAICHUAN_SECRET_KEY); + DataResult baichuanApiHost = configService.find(BaichuanAIClient.BAICHUAN_HOST); + DataResult baichuanModel = configService.find(BaichuanAIClient.BAICHUAN_MODEL); + config.setApiKey(Objects.nonNull(baichuanApiKey.getData()) ? baichuanApiKey.getData().getContent() : ""); + config.setSecretKey(Objects.nonNull(baichuanSecretKey.getData()) ? baichuanSecretKey.getData().getContent() : ""); + config.setApiHost(Objects.nonNull(baichuanApiHost.getData()) ? baichuanApiHost.getData().getContent() : ""); + config.setModel(Objects.nonNull(baichuanModel.getData()) ? baichuanModel.getData().getContent() : ""); + break; default: break; } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/request/AIConfigCreateRequest.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/request/AIConfigCreateRequest.java index daa6b3e4..188a9b43 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/request/AIConfigCreateRequest.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/config/request/AIConfigCreateRequest.java @@ -17,6 +17,11 @@ public class AIConfigCreateRequest { */ private String apiKey; + /** + * SECRETKEY + */ + private String secretKey; + /** * APIHOST */