diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index b10c1a83..b31ad8ec 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -208,7 +208,7 @@ public class ChatController { throw new ParamBusinessException("message"); } - return distributeAI(queryRequest.getMessage(), sseEmitter, uid); + return distributeAISql(queryRequest, sseEmitter, uid); } /** @@ -225,12 +225,30 @@ public class ChatController { return chatWithOpenAi(msg, sseEmitter, uid); case RESTAI : return chatWithRestAi(msg, sseEmitter); - case AZUREAI : - return chatWithAzureAi(msg, sseEmitter, uid); } return chatWithOpenAi(msg, sseEmitter, uid); } + /** + * distribute with different AI + * + * @return + */ + private SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); + AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(config.getContent()); + switch (Objects.requireNonNull(aiSqlSourceEnum)) { + case OPENAI : + return chatWithOpenAiSql(queryRequest, sseEmitter, uid); + case RESTAI : + return chatWithRestAi(queryRequest.getMessage(), sseEmitter); + case AZUREAI : + return chatWithAzureAi(queryRequest, sseEmitter, uid); + } + return chatWithOpenAiSql(queryRequest, sseEmitter, uid); + } + /** * 使用自定义AI接口进行聊天 * @@ -309,13 +327,19 @@ public class ChatController { /** * chat with azure openai * - * @param msg + * @param queryRequest * @param sseEmitter * @param uid * @return * @throws IOException */ - private SseEmitter chatWithAzureAi(String msg, SseEmitter sseEmitter, String uid) throws IOException { + private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { + String prompt = buildPrompt(queryRequest); + if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { + log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, + prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); + throw new ParamBusinessException(); + } String messageContext = (String)LocalCache.CACHE.get(uid); List messages = new ArrayList<>(); if (StrUtil.isNotBlank(messageContext)) { @@ -324,7 +348,7 @@ public class ChatController { messages = messages.subList(1, contextLength); } } - ChatMessage currentMessage = new ChatMessage(ChatRole.USER).setContent(msg); + ChatMessage currentMessage = new ChatMessage(ChatRole.USER).setContent(prompt); messages.add(currentMessage); sseEmitter.send(SseEmitter.event().id(uid).name("sseEmitter connected!!!!").data(LocalDateTime.now()).reconnectTime(3000));