diff --git a/chat2db-client/src/components/SearchResult/TableBox/index.tsx b/chat2db-client/src/components/SearchResult/TableBox/index.tsx
index f893eec3..7109f55e 100644
--- a/chat2db-client/src/components/SearchResult/TableBox/index.tsx
+++ b/chat2db-client/src/components/SearchResult/TableBox/index.tsx
@@ -234,7 +234,7 @@ export default function TableBox(props: ITableProps) {
return (
<>
-
{bottomStatus}
+ {bottomStatus}
>
);
} else {
diff --git a/chat2db-client/src/hooks/useTheme.ts b/chat2db-client/src/hooks/useTheme.ts
index 50e441a1..9993a757 100644
--- a/chat2db-client/src/hooks/useTheme.ts
+++ b/chat2db-client/src/hooks/useTheme.ts
@@ -6,13 +6,25 @@ import { ThemeType, PrimaryColorType } from '@/constants';
import { getPrimaryColor, getTheme, setPrimaryColor, setTheme } from '@/utils/localStorage';
const initialTheme = () => {
- let backgroundColor = getTheme() || ThemeType.Dark;
+ const localStorageTheme = getTheme();
+ const localStoragePrimaryColor = getPrimaryColor();
- let primaryColor = getPrimaryColor() || PrimaryColorType.Golden_Purple;
+ // 判断localStorage的theme在不在ThemeType中, 如果存在就用localStorageTheme
+ let backgroundColor = ThemeType.Light
+ if (Object.values(ThemeType).includes(localStorageTheme)) {
+ backgroundColor = localStorageTheme;
+ }
+
+ let primaryColor = PrimaryColorType.Golden_Purple
+ if (Object.values(PrimaryColorType).includes(localStoragePrimaryColor)) {
+ primaryColor = localStoragePrimaryColor;
+ }
if (backgroundColor === ThemeType.FollowOs) {
backgroundColor = getOsTheme();
}
+ document.documentElement.setAttribute('theme', backgroundColor);
+ document.documentElement.setAttribute('primary-color', primaryColor);
return {
backgroundColor,
primaryColor,
diff --git a/chat2db-client/src/layouts/index.tsx b/chat2db-client/src/layouts/index.tsx
index a22f0097..ea8ea78a 100644
--- a/chat2db-client/src/layouts/index.tsx
+++ b/chat2db-client/src/layouts/index.tsx
@@ -98,7 +98,6 @@ function AppContainer() {
// 初始化app
function collectInitApp() {
monitorOsTheme();
- initTheme();
initLang();
setInitEnd(true);
}
@@ -119,19 +118,6 @@ function AppContainer() {
};
}
- // 初始化主题
- function initTheme() {
- let theme = getTheme();
- if (theme === ThemeType.FollowOs) {
- theme =
- (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches
- ? ThemeType.Dark
- : ThemeType.Light) || ThemeType.Dark;
- }
- document.documentElement.setAttribute('theme', theme);
- document.documentElement.setAttribute('primary-color', getPrimaryColor());
- }
-
// 初始化语言
function initLang() {
if (!getLang()) {
diff --git a/chat2db-client/src/utils/localStorage.ts b/chat2db-client/src/utils/localStorage.ts
index 27979e75..3f449970 100644
--- a/chat2db-client/src/utils/localStorage.ts
+++ b/chat2db-client/src/utils/localStorage.ts
@@ -10,8 +10,8 @@ export function setLang(lang: LangType) {
}
export function getTheme(): ThemeType {
- const themeColor:any = localStorage.getItem('theme') as ThemeType
- if(themeColor){
+ const themeColor: any = localStorage.getItem('theme') as ThemeType
+ if (themeColor) {
return themeColor
}
localStorage.setItem('theme', ThemeType.Light)
@@ -25,7 +25,7 @@ export function setTheme(theme: ThemeType) {
export function getPrimaryColor(): PrimaryColorType {
const primaryColor = localStorage.getItem('primary-color') as PrimaryColorType
- if(primaryColor){
+ if (primaryColor) {
return primaryColor
}
localStorage.setItem('primary-color', PrimaryColorType.Golden_Purple)
diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java
index d475f83e..522d5764 100644
--- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java
+++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java
@@ -32,6 +32,11 @@ public enum AiSqlSourceEnum implements BaseEnum {
*/
CHAT2DBAI("CHAT2DB OPENAI"),
+ /**
+ * CLAUDE AI
+ */
+ CLAUDEAI("CLAUDE AI"),
+
;
final String description;
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java
index 012f561f..634ac703 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java
@@ -24,10 +24,14 @@ import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatMessage;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatRole;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
+import ai.chat2db.server.web.api.controller.ai.claude.client.ClaudeAIClient;
+import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletionsOptions;
+import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage;
import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.listener.AzureOpenAIEventSourceListener;
+import ai.chat2db.server.web.api.controller.ai.listener.ClaudeAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.listener.OpenAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.listener.RestAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
@@ -212,6 +216,8 @@ public class ChatController {
return chatWithRestAi(queryRequest, sseEmitter);
case AZUREAI :
return chatWithAzureAi(queryRequest, sseEmitter, uid);
+ case CLAUDEAI:
+ return chatWithClaudeAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAi(queryRequest, sseEmitter, uid);
}
@@ -326,6 +332,31 @@ public class ChatController {
return sseEmitter;
}
+
+ /**
+ * chat with claude ai
+ *
+ * @param queryRequest
+ * @param sseEmitter
+ * @param uid
+ * @return
+ * @throws IOException
+ */
+ private SseEmitter chatWithClaudeAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
+ String prompt = buildPrompt(queryRequest);
+ ClaudeChatMessage claudeChatMessage = new ClaudeChatMessage();
+ claudeChatMessage.setText(prompt);
+ ClaudeChatCompletionsOptions chatCompletionsOptions = new ClaudeChatCompletionsOptions();
+ chatCompletionsOptions.setPrompt(prompt);
+ claudeChatMessage.setCompletion(chatCompletionsOptions);
+
+ buildSseEmitter(sseEmitter, uid);
+
+ ClaudeAIEventSourceListener sourceListener = new ClaudeAIEventSourceListener(sseEmitter);
+ ClaudeAIClient.getInstance().streamCompletions(claudeChatMessage, sourceListener);
+ return sseEmitter;
+ }
+
/**
* construct sseEmitter
*
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAIClient.java
new file mode 100644
index 00000000..77dd97f2
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAIClient.java
@@ -0,0 +1,87 @@
+
+package ai.chat2db.server.web.api.controller.ai.claude.client;
+
+import ai.chat2db.server.domain.api.model.Config;
+import ai.chat2db.server.domain.api.service.ConfigService;
+import ai.chat2db.server.web.api.util.ApplicationContextUtil;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+
+/**
+ * @author jipengfei
+ * @version : OpenAIClient.java
+ */
+@Slf4j
+public class ClaudeAIClient {
+
+ public static final String CLAUDE_SESSION_KEY = "claude.sessionKey";
+
+ public static final String CLAUDE_API_HOST = "claude.apiHost";
+
+ public static final String CLAUDE_ORG_ID = "claude.orgId";
+
+ public static final String CLAUDE_USER_ID = "claude.userId";
+
+
+ private static ClaudeAiStreamClient CLAUDE_AI_STREAM_CLIENT;
+ private static String apiKey;
+
+ public static ClaudeAiStreamClient getInstance() {
+ if (CLAUDE_AI_STREAM_CLIENT != null) {
+ return CLAUDE_AI_STREAM_CLIENT;
+ } else {
+ return singleton();
+ }
+ }
+
+ private static ClaudeAiStreamClient singleton() {
+ if (CLAUDE_AI_STREAM_CLIENT == null) {
+ synchronized (ClaudeAIClient.class) {
+ if (CLAUDE_AI_STREAM_CLIENT == null) {
+ refresh();
+ }
+ }
+ }
+ return CLAUDE_AI_STREAM_CLIENT;
+ }
+
+ public static void refresh() {
+ String apikey = "";
+ String orgId = "";
+ String userId = "";
+ String apiHost = "https://claude.ai/api/append_message";
+ ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
+ Config apiHostConfig = configService.find(CLAUDE_API_HOST).getData();
+ if (apiHostConfig != null) {
+ apiHost = apiHostConfig.getContent();
+ }
+ Config config = configService.find(CLAUDE_SESSION_KEY).getData();
+ if (config != null) {
+ apikey = config.getContent();
+ }
+ Config orgConfig = configService.find(CLAUDE_ORG_ID).getData();
+ if (orgConfig != null) {
+ orgId = orgConfig.getContent();
+ }
+ Config userConfig = configService.find(CLAUDE_USER_ID).getData();
+ if (userConfig != null) {
+ userId = userConfig.getContent();
+ }
+ log.info("refresh claude sessionKey:{}", maskApiKey(apikey));
+ CLAUDE_AI_STREAM_CLIENT = ClaudeAiStreamClient.builder().apiHost(apiHost)
+ .sessionKey(apikey).orgId(orgId).userId(userId).build();
+ apiKey = apikey;
+ }
+
+ private static String maskApiKey(String input) {
+ if (StringUtils.isBlank(input)) {
+ return input;
+ }
+
+ StringBuilder maskedString = new StringBuilder(input);
+ for (int i = input.length() / 4; i < input.length() / 2; i++) {
+ maskedString.setCharAt(i, '*');
+ }
+ return maskedString.toString();
+ }
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAiStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAiStreamClient.java
new file mode 100644
index 00000000..d3607902
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAiStreamClient.java
@@ -0,0 +1,186 @@
+package ai.chat2db.server.web.api.controller.ai.claude.client;
+
+import ai.chat2db.server.tools.common.exception.ParamBusinessException;
+import ai.chat2db.server.web.api.controller.ai.claude.interceptor.ClaudeHeaderAuthorizationInterceptor;
+import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage;
+import cn.hutool.http.ContentType;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.Getter;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.MediaType;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import okhttp3.sse.EventSources;
+import org.jetbrains.annotations.NotNull;
+
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * 自定义AI接口client
+ *
+ * @author moji
+ */
+@Slf4j
+public class ClaudeAiStreamClient {
+
+ /**
+ * apikey
+ */
+ @Getter
+ @NotNull
+ private String sessionKey;
+
+ /**
+ * endpoint
+ */
+ @Getter
+ @NotNull
+ private String orgId;
+
+ /**
+ * deployId
+ */
+ @Getter
+ private String apiHost;
+
+ @Getter
+ private String userId;
+
+ /**
+ * okHttpClient
+ */
+ @Getter
+ private OkHttpClient okHttpClient;
+
+
+ /**
+ * @param builder
+ */
+ private ClaudeAiStreamClient(Builder builder) {
+ this.sessionKey = builder.sessionKey;
+ this.orgId = builder.orgId;
+ this.apiHost = builder.apiHost;
+ this.userId = builder.userId;
+ if (Objects.isNull(builder.okHttpClient)) {
+ builder.okHttpClient = this.okHttpClient();
+ }
+ okHttpClient = builder.okHttpClient;
+ }
+
+ /**
+ * okhttpclient
+ */
+ private OkHttpClient okHttpClient() {
+ OkHttpClient okHttpClient = new OkHttpClient
+ .Builder()
+ .addInterceptor(new ClaudeHeaderAuthorizationInterceptor(this.sessionKey, this.orgId))
+ .connectTimeout(10, TimeUnit.SECONDS)
+ .writeTimeout(50, TimeUnit.SECONDS)
+ .readTimeout(50, TimeUnit.SECONDS)
+ .build();
+ return okHttpClient;
+ }
+
+ /**
+ * 构造
+ *
+ * @return
+ */
+ public static ClaudeAiStreamClient.Builder builder() {
+ return new ClaudeAiStreamClient.Builder();
+ }
+
+ public static final class Builder {
+ private String sessionKey;
+
+ private String orgId;
+
+ private String apiHost;
+
+ private String userId;
+
+ /**
+ * 自定义OkhttpClient
+ */
+ private OkHttpClient okHttpClient;
+
+ public Builder() {
+ }
+
+ public ClaudeAiStreamClient.Builder sessionKey(String sessionKey) {
+ this.sessionKey = sessionKey;
+ return this;
+ }
+
+ /**
+ * @param apiHost
+ * @return
+ */
+ public ClaudeAiStreamClient.Builder apiHost(String apiHost) {
+ this.apiHost = apiHost;
+ return this;
+ }
+
+ /**
+ * @param orgId
+ * @return
+ */
+ public ClaudeAiStreamClient.Builder orgId(String orgId) {
+ this.orgId = orgId;
+ return this;
+ }
+
+ public ClaudeAiStreamClient.Builder userId(String userId) {
+ this.userId = userId;
+ return this;
+ }
+
+ public ClaudeAiStreamClient.Builder okHttpClient(OkHttpClient val) {
+ this.okHttpClient = val;
+ return this;
+ }
+
+ public ClaudeAiStreamClient build() {
+ return new ClaudeAiStreamClient(this);
+ }
+
+ }
+
+ /**
+ * chat
+ *
+ * @param claudeChatMessage
+ * @param eventSourceListener
+ */
+ public void streamCompletions(ClaudeChatMessage claudeChatMessage, EventSourceListener eventSourceListener) {
+ if (Objects.isNull(eventSourceListener)) {
+ log.error("param error:AzureEventSourceListener cannot be empty");
+ throw new ParamBusinessException();
+ }
+ log.info("Claude AI, prompt:{}", claudeChatMessage.getText());
+ try {
+ claudeChatMessage.setOrganization_uuid(this.orgId);
+ claudeChatMessage.setConversation_uuid(this.userId);
+ EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
+ ObjectMapper mapper = new ObjectMapper();
+ String requestBody = mapper.writeValueAsString(claudeChatMessage);
+
+ Request request = new Request.Builder()
+ .url(this.apiHost)
+ .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
+ .build();
+ //创建事件
+ EventSource eventSource = factory.newEventSource(request, eventSourceListener);
+ log.info("finish invoking claude ai");
+ } catch (Exception e) {
+ log.error("claude ai error", e);
+ eventSourceListener.onFailure(null, e, null);
+ throw new ParamBusinessException();
+ }
+ }
+
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/interceptor/ClaudeHeaderAuthorizationInterceptor.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/interceptor/ClaudeHeaderAuthorizationInterceptor.java
new file mode 100644
index 00000000..351bb6c5
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/interceptor/ClaudeHeaderAuthorizationInterceptor.java
@@ -0,0 +1,40 @@
+package ai.chat2db.server.web.api.controller.ai.claude.interceptor;
+
+import cn.hutool.http.ContentType;
+import cn.hutool.http.Header;
+import lombok.Getter;
+import okhttp3.Interceptor;
+import okhttp3.Request;
+import okhttp3.Response;
+
+import java.io.IOException;
+
+/**
+ * 描述:请求增加header apikey
+ *
+ * @author grt
+ * @since 2023-03-23
+ */
+@Getter
+public class ClaudeHeaderAuthorizationInterceptor implements Interceptor {
+
+ private String sessionKey;
+
+ private String orgId;
+
+ public ClaudeHeaderAuthorizationInterceptor(String sessionKey, String orgId) {
+ this.orgId = orgId;
+ this.sessionKey = sessionKey;
+ }
+
+ @Override
+ public Response intercept(Chain chain) throws IOException {
+ Request original = chain.request();
+ Request request = original.newBuilder()
+ .header("Cookie", "sessionKey=" + sessionKey)
+ .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue())
+ .method(original.method(), original.body())
+ .build();
+ return chain.proceed(request);
+ }
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatCompletionsOptions.java
new file mode 100644
index 00000000..cc796124
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatCompletionsOptions.java
@@ -0,0 +1,38 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+// Code generated by Microsoft (R) AutoRest Code Generator.
+package ai.chat2db.server.web.api.controller.ai.claude.model;
+
+import lombok.Data;
+
+@Data
+public final class ClaudeChatCompletionsOptions {
+
+ private Boolean incremental = true;
+
+ private String model = "claude-2";
+
+ private String prompt;
+
+ private String timezone = "Asia/Shanghai";
+
+ private Boolean stream = true;
+
+ public Boolean isStream() {
+ return this.stream;
+ }
+
+ public ClaudeChatCompletionsOptions setStream(Boolean stream) {
+ this.stream = stream;
+ return this;
+ }
+
+ public String getModel() {
+ return this.model;
+ }
+
+ public ClaudeChatCompletionsOptions setModel(String model) {
+ this.model = model;
+ return this;
+ }
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatMessage.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatMessage.java
new file mode 100644
index 00000000..5b9fe07b
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatMessage.java
@@ -0,0 +1,15 @@
+package ai.chat2db.server.web.api.controller.ai.claude.model;
+
+import lombok.Data;
+
+@Data
+public class ClaudeChatMessage {
+
+ private String conversation_uuid;
+
+ private String organization_uuid;
+
+ private String text;
+
+ private ClaudeChatCompletionsOptions completion;
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeCompletionResponse.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeCompletionResponse.java
new file mode 100644
index 00000000..627078c9
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeCompletionResponse.java
@@ -0,0 +1,25 @@
+
+package ai.chat2db.server.web.api.controller.ai.claude.model;
+
+import com.unfbx.chatgpt.entity.common.Usage;
+import lombok.Data;
+
+import java.io.Serial;
+import java.io.Serializable;
+
+/**
+ * @author moji
+ * @version : ClaudeCompletionResponse.java
+ */
+@Data
+public class ClaudeCompletionResponse implements Serializable {
+ @Serial
+ private static final long serialVersionUID = 4968922211204353592L;
+ private String log_id;
+ private String stop_reason;
+ private String stop;
+ private String model;
+ private String completion;
+ private Usage usage;
+ private ClaudeMessageLimit messageLimit;
+}
\ No newline at end of file
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeMessageLimit.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeMessageLimit.java
new file mode 100644
index 00000000..f6b0c4e8
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeMessageLimit.java
@@ -0,0 +1,9 @@
+package ai.chat2db.server.web.api.controller.ai.claude.model;
+
+import lombok.Data;
+
+@Data
+public class ClaudeMessageLimit {
+
+ private String type;
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/listener/ClaudeAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/listener/ClaudeAIEventSourceListener.java
new file mode 100644
index 00000000..c6401867
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/listener/ClaudeAIEventSourceListener.java
@@ -0,0 +1,112 @@
+package ai.chat2db.server.web.api.controller.ai.listener;
+
+import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeCompletionResponse;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.unfbx.chatgpt.entity.chat.Message;
+import lombok.SneakyThrows;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.Response;
+import okhttp3.ResponseBody;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.util.Objects;
+
+/**
+ * ClaudeAIEventSourceListener
+ */
+@Slf4j
+public class ClaudeAIEventSourceListener extends EventSourceListener {
+
+ private SseEmitter sseEmitter;
+
+ private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
+
+ public ClaudeAIEventSourceListener(SseEmitter sseEmitter) {
+ this.sseEmitter = sseEmitter;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void onOpen(EventSource eventSource, Response response) {
+ log.info("ClaudeAIEventSourceListener...");
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @SneakyThrows
+ @Override
+ public void onEvent(EventSource eventSource, String id, String type, String data) {
+ log.info("Claude AI data:{}", data);
+ if (data.equals("[DONE]")) {
+ log.info("Claude AI end");
+ sseEmitter.send(SseEmitter.event()
+ .id("[DONE]")
+ .data("[DONE]")
+ .reconnectTime(3000));
+ sseEmitter.complete();
+ return;
+ }
+ // 读取Json
+ ClaudeCompletionResponse completionResponse = mapper.readValue(data, ClaudeCompletionResponse.class);
+ String text = completionResponse.getCompletion();
+ Message message = new Message();
+ if (text != null) {
+ message.setContent(text);
+ sseEmitter.send(SseEmitter.event()
+ .id(null)
+ .data(message)
+ .reconnectTime(3000));
+ }
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ sseEmitter.complete();
+ log.info("Claude AI closed...");
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, Throwable t, Response response) {
+ try {
+ if (Objects.isNull(response)) {
+ String message = t.getMessage();
+ Message sseMessage = new Message();
+ sseMessage.setContent(message);
+ sseEmitter.send(SseEmitter.event()
+ .id("[ERROR]")
+ .data(sseMessage));
+ sseEmitter.send(SseEmitter.event()
+ .id("[DONE]")
+ .data("[DONE]"));
+ sseEmitter.complete();
+ return;
+ }
+ ResponseBody body = response.body();
+ String bodyString = null;
+ if (Objects.nonNull(body)) {
+ bodyString = body.string();
+ log.error("Claude sse error:{}", bodyString, t);
+ } else {
+ log.error("Claude sse body error:{}", response, t);
+ }
+ eventSource.cancel();
+ Message message = new Message();
+ message.setContent("Claude sse error:" + bodyString);
+ sseEmitter.send(SseEmitter.event()
+ .id("[ERROR]")
+ .data(message));
+ sseEmitter.send(SseEmitter.event()
+ .id("[DONE]")
+ .data("[DONE]"));
+ sseEmitter.complete();
+ } catch (Exception exception) {
+ log.error("发送数据异常:", exception);
+ }
+ }
+}