ai result update

This commit is contained in:
robin
2023-11-03 18:51:11 +08:00
parent dd2d198d3f
commit d1337f7671
2 changed files with 6 additions and 52 deletions

View File

@ -196,26 +196,11 @@ public class Chat2DBAIStreamClient {
log.error("param errorChatEventSourceListener cannot be empty"); log.error("param errorChatEventSourceListener cannot be empty");
throw new ParamBusinessException(); throw new ParamBusinessException();
} }
log.info("Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
try { try {
ChatCompletion chatCompletion = ChatCompletion.builder() ChatCompletion chatCompletion = ChatCompletion.builder()
.messages(chatMessages) .messages(chatMessages)
.stream(true) .stream(true)
.build(); .build();
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
DataResult<Config> chat2dbModel = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_MODEL);
String model = Objects.nonNull(chat2dbModel.getData()) && StringUtils.isNotBlank(chat2dbModel.getData().getContent()) ? chat2dbModel.getData().getContent() : AiSqlSourceEnum.OPENAI.getCode();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(model);
switch (aiSqlSourceEnum) {
case BAICHUANAI:
chatCompletion = ChatCompletion.builder().messages(chatMessages).model("Baichuan2-53B").build();
break;
case ZHIPUAI:
chatCompletion = ChatCompletion.builder().messages(chatMessages).model("chatglm_turbo").build();
break;
default:
break;
}
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper(); ObjectMapper mapper = new ObjectMapper();

View File

@ -66,43 +66,12 @@ public class Chat2dbAIEventSourceListener extends EventSourceListener {
} }
ObjectMapper mapper = new ObjectMapper(); ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
DataResult<Config> chat2dbModel = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_MODEL); String text = completionResponse.getChoices().get(0).getDelta() == null
String model = Objects.nonNull(chat2dbModel.getData()) && StringUtils.isNotBlank(chat2dbModel.getData().getContent()) ? chat2dbModel.getData().getContent() : AiSqlSourceEnum.OPENAI.getCode(); ? completionResponse.getChoices().get(0).getText()
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(model); : completionResponse.getChoices().get(0).getDelta().getContent();
String text = ""; String completionId = completionResponse.getId();
String completionId = null;
// 读取Json
switch (aiSqlSourceEnum) {
case BAICHUANAI:
BaichuanChatCompletions chatCompletions = mapper.readValue(data, BaichuanChatCompletions.class);
for (BaichuanChatMessage message : chatCompletions.getData().getMessages()) {
if (message != null) {
if (message.getContent() != null) {
text = message.getContent();
}
}
}
break;
case ZHIPUAI:
ZhipuChatCompletions zhipuChatCompletions = mapper.readValue(data, ZhipuChatCompletions.class);
text = zhipuChatCompletions.getData();
if (Objects.isNull(text)) {
for (FastChatMessage message : zhipuChatCompletions.getBody().getChoices()) {
if (message != null && message.getContent() != null) {
text = message.getContent();
}
}
}
break;
default:
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
text = completionResponse.getChoices().get(0).getDelta() == null
? completionResponse.getChoices().get(0).getText()
: completionResponse.getChoices().get(0).getDelta().getContent();
completionId = completionResponse.getId();
break;
}
Message message = new Message(); Message message = new Message();
if (text != null) { if (text != null) {
message.setContent(text); message.setContent(text);