diff --git a/chat2db-client/src/components/SearchResult/TableBox/index.tsx b/chat2db-client/src/components/SearchResult/TableBox/index.tsx index f893eec3..7109f55e 100644 --- a/chat2db-client/src/components/SearchResult/TableBox/index.tsx +++ b/chat2db-client/src/components/SearchResult/TableBox/index.tsx @@ -234,7 +234,7 @@ export default function TableBox(props: ITableProps) { return ( <> -
{bottomStatus}
+
{bottomStatus}
); } else { diff --git a/chat2db-client/src/hooks/useTheme.ts b/chat2db-client/src/hooks/useTheme.ts index 50e441a1..9993a757 100644 --- a/chat2db-client/src/hooks/useTheme.ts +++ b/chat2db-client/src/hooks/useTheme.ts @@ -6,13 +6,25 @@ import { ThemeType, PrimaryColorType } from '@/constants'; import { getPrimaryColor, getTheme, setPrimaryColor, setTheme } from '@/utils/localStorage'; const initialTheme = () => { - let backgroundColor = getTheme() || ThemeType.Dark; + const localStorageTheme = getTheme(); + const localStoragePrimaryColor = getPrimaryColor(); - let primaryColor = getPrimaryColor() || PrimaryColorType.Golden_Purple; + // 判断localStorage的theme在不在ThemeType中, 如果存在就用localStorageTheme + let backgroundColor = ThemeType.Light + if (Object.values(ThemeType).includes(localStorageTheme)) { + backgroundColor = localStorageTheme; + } + + let primaryColor = PrimaryColorType.Golden_Purple + if (Object.values(PrimaryColorType).includes(localStoragePrimaryColor)) { + primaryColor = localStoragePrimaryColor; + } if (backgroundColor === ThemeType.FollowOs) { backgroundColor = getOsTheme(); } + document.documentElement.setAttribute('theme', backgroundColor); + document.documentElement.setAttribute('primary-color', primaryColor); return { backgroundColor, primaryColor, diff --git a/chat2db-client/src/layouts/index.tsx b/chat2db-client/src/layouts/index.tsx index a22f0097..ea8ea78a 100644 --- a/chat2db-client/src/layouts/index.tsx +++ b/chat2db-client/src/layouts/index.tsx @@ -98,7 +98,6 @@ function AppContainer() { // 初始化app function collectInitApp() { monitorOsTheme(); - initTheme(); initLang(); setInitEnd(true); } @@ -119,19 +118,6 @@ function AppContainer() { }; } - // 初始化主题 - function initTheme() { - let theme = getTheme(); - if (theme === ThemeType.FollowOs) { - theme = - (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches - ? ThemeType.Dark - : ThemeType.Light) || ThemeType.Dark; - } - document.documentElement.setAttribute('theme', theme); - document.documentElement.setAttribute('primary-color', getPrimaryColor()); - } - // 初始化语言 function initLang() { if (!getLang()) { diff --git a/chat2db-client/src/utils/localStorage.ts b/chat2db-client/src/utils/localStorage.ts index 27979e75..3f449970 100644 --- a/chat2db-client/src/utils/localStorage.ts +++ b/chat2db-client/src/utils/localStorage.ts @@ -10,8 +10,8 @@ export function setLang(lang: LangType) { } export function getTheme(): ThemeType { - const themeColor:any = localStorage.getItem('theme') as ThemeType - if(themeColor){ + const themeColor: any = localStorage.getItem('theme') as ThemeType + if (themeColor) { return themeColor } localStorage.setItem('theme', ThemeType.Light) @@ -25,7 +25,7 @@ export function setTheme(theme: ThemeType) { export function getPrimaryColor(): PrimaryColorType { const primaryColor = localStorage.getItem('primary-color') as PrimaryColorType - if(primaryColor){ + if (primaryColor) { return primaryColor } localStorage.setItem('primary-color', PrimaryColorType.Golden_Purple) diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java index d475f83e..522d5764 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-api/src/main/java/ai/chat2db/server/domain/api/enums/AiSqlSourceEnum.java @@ -32,6 +32,11 @@ public enum AiSqlSourceEnum implements BaseEnum { */ CHAT2DBAI("CHAT2DB OPENAI"), + /** + * CLAUDE AI + */ + CLAUDEAI("CLAUDE AI"), + ; final String description; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 012f561f..634ac703 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -24,10 +24,14 @@ import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatMessage; import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatRole; import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; +import ai.chat2db.server.web.api.controller.ai.claude.client.ClaudeAIClient; +import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletionsOptions; +import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage; import ai.chat2db.server.web.api.controller.ai.config.LocalCache; import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter; import ai.chat2db.server.web.api.controller.ai.enums.PromptType; import ai.chat2db.server.web.api.controller.ai.listener.AzureOpenAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.listener.ClaudeAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.listener.OpenAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.listener.RestAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; @@ -212,6 +216,8 @@ public class ChatController { return chatWithRestAi(queryRequest, sseEmitter); case AZUREAI : return chatWithAzureAi(queryRequest, sseEmitter, uid); + case CLAUDEAI: + return chatWithClaudeAi(queryRequest, sseEmitter, uid); } return chatWithOpenAi(queryRequest, sseEmitter, uid); } @@ -326,6 +332,31 @@ public class ChatController { return sseEmitter; } + + /** + * chat with claude ai + * + * @param queryRequest + * @param sseEmitter + * @param uid + * @return + * @throws IOException + */ + private SseEmitter chatWithClaudeAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { + String prompt = buildPrompt(queryRequest); + ClaudeChatMessage claudeChatMessage = new ClaudeChatMessage(); + claudeChatMessage.setText(prompt); + ClaudeChatCompletionsOptions chatCompletionsOptions = new ClaudeChatCompletionsOptions(); + chatCompletionsOptions.setPrompt(prompt); + claudeChatMessage.setCompletion(chatCompletionsOptions); + + buildSseEmitter(sseEmitter, uid); + + ClaudeAIEventSourceListener sourceListener = new ClaudeAIEventSourceListener(sseEmitter); + ClaudeAIClient.getInstance().streamCompletions(claudeChatMessage, sourceListener); + return sseEmitter; + } + /** * construct sseEmitter * diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAIClient.java new file mode 100644 index 00000000..77dd97f2 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAIClient.java @@ -0,0 +1,87 @@ + +package ai.chat2db.server.web.api.controller.ai.claude.client; + +import ai.chat2db.server.domain.api.model.Config; +import ai.chat2db.server.domain.api.service.ConfigService; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; + +/** + * @author jipengfei + * @version : OpenAIClient.java + */ +@Slf4j +public class ClaudeAIClient { + + public static final String CLAUDE_SESSION_KEY = "claude.sessionKey"; + + public static final String CLAUDE_API_HOST = "claude.apiHost"; + + public static final String CLAUDE_ORG_ID = "claude.orgId"; + + public static final String CLAUDE_USER_ID = "claude.userId"; + + + private static ClaudeAiStreamClient CLAUDE_AI_STREAM_CLIENT; + private static String apiKey; + + public static ClaudeAiStreamClient getInstance() { + if (CLAUDE_AI_STREAM_CLIENT != null) { + return CLAUDE_AI_STREAM_CLIENT; + } else { + return singleton(); + } + } + + private static ClaudeAiStreamClient singleton() { + if (CLAUDE_AI_STREAM_CLIENT == null) { + synchronized (ClaudeAIClient.class) { + if (CLAUDE_AI_STREAM_CLIENT == null) { + refresh(); + } + } + } + return CLAUDE_AI_STREAM_CLIENT; + } + + public static void refresh() { + String apikey = ""; + String orgId = ""; + String userId = ""; + String apiHost = "https://claude.ai/api/append_message"; + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config apiHostConfig = configService.find(CLAUDE_API_HOST).getData(); + if (apiHostConfig != null) { + apiHost = apiHostConfig.getContent(); + } + Config config = configService.find(CLAUDE_SESSION_KEY).getData(); + if (config != null) { + apikey = config.getContent(); + } + Config orgConfig = configService.find(CLAUDE_ORG_ID).getData(); + if (orgConfig != null) { + orgId = orgConfig.getContent(); + } + Config userConfig = configService.find(CLAUDE_USER_ID).getData(); + if (userConfig != null) { + userId = userConfig.getContent(); + } + log.info("refresh claude sessionKey:{}", maskApiKey(apikey)); + CLAUDE_AI_STREAM_CLIENT = ClaudeAiStreamClient.builder().apiHost(apiHost) + .sessionKey(apikey).orgId(orgId).userId(userId).build(); + apiKey = apikey; + } + + private static String maskApiKey(String input) { + if (StringUtils.isBlank(input)) { + return input; + } + + StringBuilder maskedString = new StringBuilder(input); + for (int i = input.length() / 4; i < input.length() / 2; i++) { + maskedString.setCharAt(i, '*'); + } + return maskedString.toString(); + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAiStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAiStreamClient.java new file mode 100644 index 00000000..d3607902 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/client/ClaudeAiStreamClient.java @@ -0,0 +1,186 @@ +package ai.chat2db.server.web.api.controller.ai.claude.client; + +import ai.chat2db.server.tools.common.exception.ParamBusinessException; +import ai.chat2db.server.web.api.controller.ai.claude.interceptor.ClaudeHeaderAuthorizationInterceptor; +import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage; +import cn.hutool.http.ContentType; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; +import org.jetbrains.annotations.NotNull; + +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +/** + * 自定义AI接口client + * + * @author moji + */ +@Slf4j +public class ClaudeAiStreamClient { + + /** + * apikey + */ + @Getter + @NotNull + private String sessionKey; + + /** + * endpoint + */ + @Getter + @NotNull + private String orgId; + + /** + * deployId + */ + @Getter + private String apiHost; + + @Getter + private String userId; + + /** + * okHttpClient + */ + @Getter + private OkHttpClient okHttpClient; + + + /** + * @param builder + */ + private ClaudeAiStreamClient(Builder builder) { + this.sessionKey = builder.sessionKey; + this.orgId = builder.orgId; + this.apiHost = builder.apiHost; + this.userId = builder.userId; + if (Objects.isNull(builder.okHttpClient)) { + builder.okHttpClient = this.okHttpClient(); + } + okHttpClient = builder.okHttpClient; + } + + /** + * okhttpclient + */ + private OkHttpClient okHttpClient() { + OkHttpClient okHttpClient = new OkHttpClient + .Builder() + .addInterceptor(new ClaudeHeaderAuthorizationInterceptor(this.sessionKey, this.orgId)) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(50, TimeUnit.SECONDS) + .readTimeout(50, TimeUnit.SECONDS) + .build(); + return okHttpClient; + } + + /** + * 构造 + * + * @return + */ + public static ClaudeAiStreamClient.Builder builder() { + return new ClaudeAiStreamClient.Builder(); + } + + public static final class Builder { + private String sessionKey; + + private String orgId; + + private String apiHost; + + private String userId; + + /** + * 自定义OkhttpClient + */ + private OkHttpClient okHttpClient; + + public Builder() { + } + + public ClaudeAiStreamClient.Builder sessionKey(String sessionKey) { + this.sessionKey = sessionKey; + return this; + } + + /** + * @param apiHost + * @return + */ + public ClaudeAiStreamClient.Builder apiHost(String apiHost) { + this.apiHost = apiHost; + return this; + } + + /** + * @param orgId + * @return + */ + public ClaudeAiStreamClient.Builder orgId(String orgId) { + this.orgId = orgId; + return this; + } + + public ClaudeAiStreamClient.Builder userId(String userId) { + this.userId = userId; + return this; + } + + public ClaudeAiStreamClient.Builder okHttpClient(OkHttpClient val) { + this.okHttpClient = val; + return this; + } + + public ClaudeAiStreamClient build() { + return new ClaudeAiStreamClient(this); + } + + } + + /** + * chat + * + * @param claudeChatMessage + * @param eventSourceListener + */ + public void streamCompletions(ClaudeChatMessage claudeChatMessage, EventSourceListener eventSourceListener) { + if (Objects.isNull(eventSourceListener)) { + log.error("param error:AzureEventSourceListener cannot be empty"); + throw new ParamBusinessException(); + } + log.info("Claude AI, prompt:{}", claudeChatMessage.getText()); + try { + claudeChatMessage.setOrganization_uuid(this.orgId); + claudeChatMessage.setConversation_uuid(this.userId); + EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); + ObjectMapper mapper = new ObjectMapper(); + String requestBody = mapper.writeValueAsString(claudeChatMessage); + + Request request = new Request.Builder() + .url(this.apiHost) + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + //创建事件 + EventSource eventSource = factory.newEventSource(request, eventSourceListener); + log.info("finish invoking claude ai"); + } catch (Exception e) { + log.error("claude ai error", e); + eventSourceListener.onFailure(null, e, null); + throw new ParamBusinessException(); + } + } + +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/interceptor/ClaudeHeaderAuthorizationInterceptor.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/interceptor/ClaudeHeaderAuthorizationInterceptor.java new file mode 100644 index 00000000..351bb6c5 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/interceptor/ClaudeHeaderAuthorizationInterceptor.java @@ -0,0 +1,40 @@ +package ai.chat2db.server.web.api.controller.ai.claude.interceptor; + +import cn.hutool.http.ContentType; +import cn.hutool.http.Header; +import lombok.Getter; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; + +/** + * 描述:请求增加header apikey + * + * @author grt + * @since 2023-03-23 + */ +@Getter +public class ClaudeHeaderAuthorizationInterceptor implements Interceptor { + + private String sessionKey; + + private String orgId; + + public ClaudeHeaderAuthorizationInterceptor(String sessionKey, String orgId) { + this.orgId = orgId; + this.sessionKey = sessionKey; + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request original = chain.request(); + Request request = original.newBuilder() + .header("Cookie", "sessionKey=" + sessionKey) + .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) + .method(original.method(), original.body()) + .build(); + return chain.proceed(request); + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatCompletionsOptions.java new file mode 100644 index 00000000..cc796124 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatCompletionsOptions.java @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. +package ai.chat2db.server.web.api.controller.ai.claude.model; + +import lombok.Data; + +@Data +public final class ClaudeChatCompletionsOptions { + + private Boolean incremental = true; + + private String model = "claude-2"; + + private String prompt; + + private String timezone = "Asia/Shanghai"; + + private Boolean stream = true; + + public Boolean isStream() { + return this.stream; + } + + public ClaudeChatCompletionsOptions setStream(Boolean stream) { + this.stream = stream; + return this; + } + + public String getModel() { + return this.model; + } + + public ClaudeChatCompletionsOptions setModel(String model) { + this.model = model; + return this; + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatMessage.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatMessage.java new file mode 100644 index 00000000..5b9fe07b --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeChatMessage.java @@ -0,0 +1,15 @@ +package ai.chat2db.server.web.api.controller.ai.claude.model; + +import lombok.Data; + +@Data +public class ClaudeChatMessage { + + private String conversation_uuid; + + private String organization_uuid; + + private String text; + + private ClaudeChatCompletionsOptions completion; +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeCompletionResponse.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeCompletionResponse.java new file mode 100644 index 00000000..627078c9 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeCompletionResponse.java @@ -0,0 +1,25 @@ + +package ai.chat2db.server.web.api.controller.ai.claude.model; + +import com.unfbx.chatgpt.entity.common.Usage; +import lombok.Data; + +import java.io.Serial; +import java.io.Serializable; + +/** + * @author moji + * @version : ClaudeCompletionResponse.java + */ +@Data +public class ClaudeCompletionResponse implements Serializable { + @Serial + private static final long serialVersionUID = 4968922211204353592L; + private String log_id; + private String stop_reason; + private String stop; + private String model; + private String completion; + private Usage usage; + private ClaudeMessageLimit messageLimit; +} \ No newline at end of file diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeMessageLimit.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeMessageLimit.java new file mode 100644 index 00000000..f6b0c4e8 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/claude/model/ClaudeMessageLimit.java @@ -0,0 +1,9 @@ +package ai.chat2db.server.web.api.controller.ai.claude.model; + +import lombok.Data; + +@Data +public class ClaudeMessageLimit { + + private String type; +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/listener/ClaudeAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/listener/ClaudeAIEventSourceListener.java new file mode 100644 index 00000000..c6401867 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/listener/ClaudeAIEventSourceListener.java @@ -0,0 +1,112 @@ +package ai.chat2db.server.web.api.controller.ai.listener; + +import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeCompletionResponse; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.unfbx.chatgpt.entity.chat.Message; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.Objects; + +/** + * ClaudeAIEventSourceListener + */ +@Slf4j +public class ClaudeAIEventSourceListener extends EventSourceListener { + + private SseEmitter sseEmitter; + + private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); + + public ClaudeAIEventSourceListener(SseEmitter sseEmitter) { + this.sseEmitter = sseEmitter; + } + + /** + * {@inheritDoc} + */ + @Override + public void onOpen(EventSource eventSource, Response response) { + log.info("ClaudeAIEventSourceListener..."); + } + + /** + * {@inheritDoc} + */ + @SneakyThrows + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + log.info("Claude AI data:{}", data); + if (data.equals("[DONE]")) { + log.info("Claude AI end"); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]") + .reconnectTime(3000)); + sseEmitter.complete(); + return; + } + // 读取Json + ClaudeCompletionResponse completionResponse = mapper.readValue(data, ClaudeCompletionResponse.class); + String text = completionResponse.getCompletion(); + Message message = new Message(); + if (text != null) { + message.setContent(text); + sseEmitter.send(SseEmitter.event() + .id(null) + .data(message) + .reconnectTime(3000)); + } + } + + @Override + public void onClosed(EventSource eventSource) { + sseEmitter.complete(); + log.info("Claude AI closed..."); + } + + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + try { + if (Objects.isNull(response)) { + String message = t.getMessage(); + Message sseMessage = new Message(); + sseMessage.setContent(message); + sseEmitter.send(SseEmitter.event() + .id("[ERROR]") + .data(sseMessage)); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]")); + sseEmitter.complete(); + return; + } + ResponseBody body = response.body(); + String bodyString = null; + if (Objects.nonNull(body)) { + bodyString = body.string(); + log.error("Claude sse error:{}", bodyString, t); + } else { + log.error("Claude sse body error:{}", response, t); + } + eventSource.cancel(); + Message message = new Message(); + message.setContent("Claude sse error:" + bodyString); + sseEmitter.send(SseEmitter.event() + .id("[ERROR]") + .data(message)); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]")); + sseEmitter.complete(); + } catch (Exception exception) { + log.error("发送数据异常:", exception); + } + } +}