This commit is contained in:
moji
2023-07-01 14:05:02 +08:00
parent 1f21ace507
commit 36defcbfcf

View File

@ -208,7 +208,7 @@ public class ChatController {
throw new ParamBusinessException("message"); throw new ParamBusinessException("message");
} }
return distributeAI(queryRequest.getMessage(), sseEmitter, uid); return distributeAISql(queryRequest, sseEmitter, uid);
} }
/** /**
@ -225,12 +225,30 @@ public class ChatController {
return chatWithOpenAi(msg, sseEmitter, uid); return chatWithOpenAi(msg, sseEmitter, uid);
case RESTAI : case RESTAI :
return chatWithRestAi(msg, sseEmitter); return chatWithRestAi(msg, sseEmitter);
case AZUREAI :
return chatWithAzureAi(msg, sseEmitter, uid);
} }
return chatWithOpenAi(msg, sseEmitter, uid); return chatWithOpenAi(msg, sseEmitter, uid);
} }
/**
* distribute with different AI
*
* @return
*/
private SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(config.getContent());
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
return chatWithOpenAiSql(queryRequest, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(queryRequest.getMessage(), sseEmitter);
case AZUREAI :
return chatWithAzureAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAiSql(queryRequest, sseEmitter, uid);
}
/** /**
* 使用自定义AI接口进行聊天 * 使用自定义AI接口进行聊天
* *
@ -309,13 +327,19 @@ public class ChatController {
/** /**
* chat with azure openai * chat with azure openai
* *
* @param msg * @param queryRequest
* @param sseEmitter * @param sseEmitter
* @param uid * @param uid
* @return * @return
* @throws IOException * @throws IOException
*/ */
private SseEmitter chatWithAzureAi(String msg, SseEmitter sseEmitter, String uid) throws IOException { private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
String messageContext = (String)LocalCache.CACHE.get(uid); String messageContext = (String)LocalCache.CACHE.get(uid);
List<ChatMessage> messages = new ArrayList<>(); List<ChatMessage> messages = new ArrayList<>();
if (StrUtil.isNotBlank(messageContext)) { if (StrUtil.isNotBlank(messageContext)) {
@ -324,7 +348,7 @@ public class ChatController {
messages = messages.subList(1, contextLength); messages = messages.subList(1, contextLength);
} }
} }
ChatMessage currentMessage = new ChatMessage(ChatRole.USER).setContent(msg); ChatMessage currentMessage = new ChatMessage(ChatRole.USER).setContent(prompt);
messages.add(currentMessage); messages.add(currentMessage);
sseEmitter.send(SseEmitter.event().id(uid).name("sseEmitter connected").data(LocalDateTime.now()).reconnectTime(3000)); sseEmitter.send(SseEmitter.event().id(uid).name("sseEmitter connected").data(LocalDateTime.now()).reconnectTime(3000));