feature: add dify-chat-ai controller type

This commit is contained in:
leven.chen
2023-12-22 17:02:42 +08:00
parent 04e100ca70
commit bea9a52f52
4 changed files with 114 additions and 66 deletions

View File

@ -27,6 +27,8 @@ import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletion
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage;
import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
import ai.chat2db.server.web.api.controller.ai.dify.client.DifyChatAIClient;
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
@ -132,7 +134,7 @@ public class ChatController {
/**
* 自定义模型流式输出接口DEMO
* <p>
* Note:使用自己本地的流式输出的自定义AI接口输入和输出需与该样例保持一致
* Note:使用自己本地的流式输出的自定义AI接口输入和输出需与该样例保持一致
* </p>
*
* @param queryRequest
@ -171,7 +173,7 @@ public class ChatController {
/**
* 自定义模型非流式输出接口DEMO
* <p>
* Note:使用自己本地的飞流式输出自定义AI接口输入和输出需与该样例保持一致
* Note:使用自己本地的飞流式输出自定义AI接口输入和输出需与该样例保持一致
* </p>
*
* @param queryRequest
@ -196,7 +198,7 @@ public class ChatController {
@GetMapping("/chat")
@CrossOrigin
public SseEmitter completions(ChatQueryRequest queryRequest, @RequestHeader Map<String, String> headers)
throws IOException {
throws IOException {
//默认30秒超时,设置为0L则永不超时
SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
String uid = headers.get("uid");
@ -230,14 +232,14 @@ public class ChatController {
}
uid = aiSqlSourceEnum.getCode() + uid;
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
case OPENAI:
return chatWithOpenAi(queryRequest, sseEmitter, uid);
case CHAT2DBAI:
return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
case RESTAI :
case RESTAI:
case FASTCHATAI:
return chatWithFastChatAi(queryRequest, sseEmitter, uid);
case AZUREAI :
case AZUREAI:
return chatWithAzureAi(queryRequest, sseEmitter, uid);
case CLAUDEAI:
return chatWithClaudeAi(queryRequest, sseEmitter, uid);
@ -249,10 +251,34 @@ public class ChatController {
return chatWithTongyiChatAi(queryRequest, sseEmitter, uid);
case ZHIPUAI:
return chatWithZhipuChatAi(queryRequest, sseEmitter, uid);
case DIFYCHAT:
return chatWithDifyChatAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAi(queryRequest, sseEmitter, uid);
}
private SseEmitter chatWithDifyChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("exceed max token length:{}input length:{}", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
prompt = prompt.replaceAll("#", "");
log.info(prompt);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
List<Message> messages = new ArrayList<>();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
DifyChatAIEventSourceListener eventSourceListener = new DifyChatAIEventSourceListener(sseEmitter);
DifyChatAIClient.getInstance().streamCompletions(messages, eventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* 使用自定义AI接口进行聊天
*
@ -276,11 +302,11 @@ public class ChatController {
* @throws IOException
*/
private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException {
throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
@ -307,11 +333,11 @@ public class ChatController {
* @throws IOException
*/
private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException {
throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("exceed max token length:{}input length:{}", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
@ -344,7 +370,7 @@ public class ChatController {
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
List<AzureChatMessage> messages = (List<AzureChatMessage>)LocalCache.CACHE.get(uid);
List<AzureChatMessage> messages = (List<AzureChatMessage>) LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
@ -363,40 +389,6 @@ public class ChatController {
return sseEmitter;
}
/**
* chat with azure openai
*
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithDifyChat(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
List<AzureChatMessage> messages = (List<AzureChatMessage>)LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
}
} else {
messages = Lists.newArrayList();
}
AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt);
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter);
AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* chat with fast chat openai
@ -490,7 +482,7 @@ public class ChatController {
* @return
*/
private List<FastChatMessage> getFastChatMessage(String uid, String prompt) {
List<FastChatMessage> messages = (List<FastChatMessage>)LocalCache.CACHE.get(uid);
List<FastChatMessage> messages = (List<FastChatMessage>) LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
@ -566,17 +558,17 @@ public class ChatController {
log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
});
sseEmitter.onTimeout(
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", on timeout#" + sseEmitter.getTimeout()));
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", on timeout#" + sseEmitter.getTimeout()));
sseEmitter.onError(
throwable -> {
try {
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", on error#" + throwable.toString());
sseEmitter.send(SseEmitter.event().id("765431").name("发生异常!").data(throwable.getMessage())
.reconnectTime(3000));
} catch (IOException e) {
e.printStackTrace();
throwable -> {
try {
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", on error#" + throwable.toString());
sseEmitter.send(SseEmitter.event().id("765431").name("发生异常!").data(throwable.getMessage())
.reconnectTime(3000));
} catch (IOException e) {
e.printStackTrace();
}
}
}
);
return sseEmitter;
}
@ -589,13 +581,13 @@ public class ChatController {
* @return
*/
private String buildTableColumn(TableQueryParam tableQueryParam,
List<String> tableNames) {
List<String> tableNames) {
if (CollectionUtils.isEmpty(tableNames)) {
return "";
}
List<String> schemaContent = Lists.newArrayList();
try {
schemaContent = tableNames.stream().map(tableName -> {
schemaContent = tableNames.stream().map(tableName -> {
tableQueryParam.setTableName(tableName);
return queryTableDdl(tableName, tableQueryParam);
}).collect(Collectors.toList());
@ -645,19 +637,19 @@ public class ChatController {
}
String prompt = queryRequest.getMessage();
String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
: queryRequest.getPromptType();
: queryRequest.getPromptType();
PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType);
String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : "";
String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format(
"### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# "
+ "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
pType.getDescription(), ext, prompt);
"### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# "
+ "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
pType.getDescription(), ext, prompt);
switch (pType) {
case SQL_2_SQL:
schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format(
"%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
"%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
"%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
"%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
default:
break;
}
@ -746,7 +738,7 @@ public class ChatController {
DataResult<TableSchemaResponse> result = gatewayClientService.schemaVectorSearch(tableSchemaRequest);
List<String> schemas = Lists.newArrayList();
if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
for(TableSchema data: result.getData().getTableSchemas()){
for (TableSchema data : result.getData().getTableSchemas()) {
schemas.add(data.getTableSchema());
}
}
@ -786,7 +778,7 @@ public class ChatController {
DataResult<EsTableSchemaResponse> result = gatewayClientService.schemaEsSearch(tableSchemaRequest);
List<String> schemas = Lists.newArrayList();
if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
for(EsTableSchema data: result.getData().getTableSchemas()){
for (EsTableSchema data : result.getData().getTableSchemas()) {
schemas.add(data.getTableSchemaContent());
}
}

View File

@ -1,7 +1,12 @@
package ai.chat2db.server.web.api.controller.ai.dify.client;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAiStreamClient;
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
public class DifyChatAIClient {
@ -19,4 +24,12 @@ public class DifyChatAIClient {
public static void refresh() {
}
public static DifyChatAIClient getInstance() {
return null;
}
public void streamCompletions(List<Message> messages, DifyChatAIEventSourceListener eventSourceListener) {
}
}

View File

@ -0,0 +1,4 @@
package ai.chat2db.server.web.api.controller.ai.dify.client;
public class DifyChatAiStreamClient {
}

View File

@ -0,0 +1,39 @@
package ai.chat2db.server.web.api.controller.ai.dify.listener;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@Slf4j
public class DifyChatAIEventSourceListener extends EventSourceListener {
private SseEmitter sseEmitter;
public DifyChatAIEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
@Override
public void onClosed(@NotNull EventSource eventSource) {
super.onClosed(eventSource);
}
@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
super.onEvent(eventSource, id, type, data);
}
@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
super.onFailure(eventSource, t, response);
}
@Override
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
super.onOpen(eventSource, response);
}
}