feature: add dify-chat-ai controller type

This commit is contained in:
leven.chen
2023-12-22 17:02:42 +08:00
parent 04e100ca70
commit bea9a52f52
4 changed files with 114 additions and 66 deletions

View File

@ -27,6 +27,8 @@ import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletion
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage;
import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
import ai.chat2db.server.web.api.controller.ai.dify.client.DifyChatAIClient;
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
@ -249,10 +251,34 @@ public class ChatController {
return chatWithTongyiChatAi(queryRequest, sseEmitter, uid);
case ZHIPUAI:
return chatWithZhipuChatAi(queryRequest, sseEmitter, uid);
case DIFYCHAT:
return chatWithDifyChatAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAi(queryRequest, sseEmitter, uid);
}
private SseEmitter chatWithDifyChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("exceed max token length:{}input length:{}", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
prompt = prompt.replaceAll("#", "");
log.info(prompt);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
List<Message> messages = new ArrayList<>();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
DifyChatAIEventSourceListener eventSourceListener = new DifyChatAIEventSourceListener(sseEmitter);
DifyChatAIClient.getInstance().streamCompletions(messages, eventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* 使用自定义AI接口进行聊天
*
@ -363,40 +389,6 @@ public class ChatController {
return sseEmitter;
}
/**
* chat with azure openai
*
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithDifyChat(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
List<AzureChatMessage> messages = (List<AzureChatMessage>)LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
}
} else {
messages = Lists.newArrayList();
}
AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt);
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter);
AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* chat with fast chat openai

View File

@ -1,7 +1,12 @@
package ai.chat2db.server.web.api.controller.ai.dify.client;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAiStreamClient;
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
public class DifyChatAIClient {
@ -19,4 +24,12 @@ public class DifyChatAIClient {
public static void refresh() {
}
public static DifyChatAIClient getInstance() {
return null;
}
public void streamCompletions(List<Message> messages, DifyChatAIEventSourceListener eventSourceListener) {
}
}

View File

@ -0,0 +1,4 @@
package ai.chat2db.server.web.api.controller.ai.dify.client;
public class DifyChatAiStreamClient {
}

View File

@ -0,0 +1,39 @@
package ai.chat2db.server.web.api.controller.ai.dify.listener;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Slf4j
public class DifyChatAIEventSourceListener extends EventSourceListener {
private SseEmitter sseEmitter;
public DifyChatAIEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
@Override
public void onClosed(@NotNull EventSource eventSource) {
super.onClosed(eventSource);
}
@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
super.onEvent(eventSource, id, type, data);
}
@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
super.onFailure(eventSource, t, response);
}
@Override
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
super.onOpen(eventSource, response);
}
}