feature: add dify-chat-ai controller type

This commit is contained in:
leven.chen
2023-12-22 17:02:42 +08:00
parent 04e100ca70
commit bea9a52f52
4 changed files with 114 additions and 66 deletions

View File

@ -27,6 +27,8 @@ import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletion
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage;
import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
import ai.chat2db.server.web.api.controller.ai.dify.client.DifyChatAIClient;
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
@ -230,14 +232,14 @@ public class ChatController {
}
uid = aiSqlSourceEnum.getCode() + uid;
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
case OPENAI:
return chatWithOpenAi(queryRequest, sseEmitter, uid);
case CHAT2DBAI:
return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
case RESTAI :
case RESTAI:
case FASTCHATAI:
return chatWithFastChatAi(queryRequest, sseEmitter, uid);
case AZUREAI :
case AZUREAI:
return chatWithAzureAi(queryRequest, sseEmitter, uid);
case CLAUDEAI:
return chatWithClaudeAi(queryRequest, sseEmitter, uid);
@ -249,10 +251,34 @@ public class ChatController {
return chatWithTongyiChatAi(queryRequest, sseEmitter, uid);
case ZHIPUAI:
return chatWithZhipuChatAi(queryRequest, sseEmitter, uid);
case DIFYCHAT:
return chatWithDifyChatAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAi(queryRequest, sseEmitter, uid);
}
private SseEmitter chatWithDifyChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("exceed max token length:{}input length:{}", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
prompt = prompt.replaceAll("#", "");
log.info(prompt);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
List<Message> messages = new ArrayList<>();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
DifyChatAIEventSourceListener eventSourceListener = new DifyChatAIEventSourceListener(sseEmitter);
DifyChatAIClient.getInstance().streamCompletions(messages, eventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* 使用自定义AI接口进行聊天
*
@ -344,7 +370,7 @@ public class ChatController {
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
List<AzureChatMessage> messages = (List<AzureChatMessage>)LocalCache.CACHE.get(uid);
List<AzureChatMessage> messages = (List<AzureChatMessage>) LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
@ -363,40 +389,6 @@ public class ChatController {
return sseEmitter;
}
/**
* chat with azure openai
*
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithDifyChat(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
List<AzureChatMessage> messages = (List<AzureChatMessage>)LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
}
} else {
messages = Lists.newArrayList();
}
AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt);
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter);
AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* chat with fast chat openai
@ -490,7 +482,7 @@ public class ChatController {
* @return
*/
private List<FastChatMessage> getFastChatMessage(String uid, String prompt) {
List<FastChatMessage> messages = (List<FastChatMessage>)LocalCache.CACHE.get(uid);
List<FastChatMessage> messages = (List<FastChatMessage>) LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
@ -746,7 +738,7 @@ public class ChatController {
DataResult<TableSchemaResponse> result = gatewayClientService.schemaVectorSearch(tableSchemaRequest);
List<String> schemas = Lists.newArrayList();
if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
for(TableSchema data: result.getData().getTableSchemas()){
for (TableSchema data : result.getData().getTableSchemas()) {
schemas.add(data.getTableSchema());
}
}
@ -786,7 +778,7 @@ public class ChatController {
DataResult<EsTableSchemaResponse> result = gatewayClientService.schemaEsSearch(tableSchemaRequest);
List<String> schemas = Lists.newArrayList();
if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
for(EsTableSchema data: result.getData().getTableSchemas()){
for (EsTableSchema data : result.getData().getTableSchemas()) {
schemas.add(data.getTableSchemaContent());
}
}

View File

@ -1,7 +1,12 @@
package ai.chat2db.server.web.api.controller.ai.dify.client;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAiStreamClient;
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
public class DifyChatAIClient {
@ -19,4 +24,12 @@ public class DifyChatAIClient {
public static void refresh() {
}
public static DifyChatAIClient getInstance() {
return null;
}
public void streamCompletions(List<Message> messages, DifyChatAIEventSourceListener eventSourceListener) {
}
}

View File

@ -0,0 +1,4 @@
package ai.chat2db.server.web.api.controller.ai.dify.client;
public class DifyChatAiStreamClient {
}

View File

@ -0,0 +1,39 @@
package ai.chat2db.server.web.api.controller.ai.dify.listener;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Slf4j
public class DifyChatAIEventSourceListener extends EventSourceListener {
private SseEmitter sseEmitter;
public DifyChatAIEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
@Override
public void onClosed(@NotNull EventSource eventSource) {
super.onClosed(eventSource);
}
@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
super.onEvent(eventSource, id, type, data);
}
@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
super.onFailure(eventSource, t, response);
}
@Override
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
super.onOpen(eventSource, response);
}
}