This commit is contained in:
moji
2023-07-01 14:05:02 +08:00
parent 1f21ace507
commit 36defcbfcf

View File

@ -208,7 +208,7 @@ public class ChatController {
throw new ParamBusinessException("message");
}
return distributeAI(queryRequest.getMessage(), sseEmitter, uid);
return distributeAISql(queryRequest, sseEmitter, uid);
}
/**
@ -225,12 +225,30 @@ public class ChatController {
return chatWithOpenAi(msg, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(msg, sseEmitter);
case AZUREAI :
return chatWithAzureAi(msg, sseEmitter, uid);
}
return chatWithOpenAi(msg, sseEmitter, uid);
}
/**
* distribute with different AI
*
* @return
*/
private SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(config.getContent());
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
return chatWithOpenAiSql(queryRequest, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(queryRequest.getMessage(), sseEmitter);
case AZUREAI :
return chatWithAzureAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAiSql(queryRequest, sseEmitter, uid);
}
/**
* 使用自定义AI接口进行聊天
*
@ -309,13 +327,19 @@ public class ChatController {
/**
* chat with azure openai
*
* @param msg
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithAzureAi(String msg, SseEmitter sseEmitter, String uid) throws IOException {
private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
String messageContext = (String)LocalCache.CACHE.get(uid);
List<ChatMessage> messages = new ArrayList<>();
if (StrUtil.isNotBlank(messageContext)) {
@ -324,7 +348,7 @@ public class ChatController {
messages = messages.subList(1, contextLength);
}
}
ChatMessage currentMessage = new ChatMessage(ChatRole.USER).setContent(msg);
ChatMessage currentMessage = new ChatMessage(ChatRole.USER).setContent(prompt);
messages.add(currentMessage);
sseEmitter.send(SseEmitter.event().id(uid).name("sseEmitter connected").data(LocalDateTime.now()).reconnectTime(3000));