mirror of
https://github.com/CodePhiliaX/Chat2DB.git
synced 2025-07-30 11:12:55 +08:00
feature: add dify-chat-ai controller type
This commit is contained in:
@ -27,6 +27,8 @@ import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletion
|
||||
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage;
|
||||
import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
|
||||
import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
|
||||
import ai.chat2db.server.web.api.controller.ai.dify.client.DifyChatAIClient;
|
||||
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
|
||||
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
|
||||
@ -132,7 +134,7 @@ public class ChatController {
|
||||
/**
|
||||
* 自定义模型流式输出接口DEMO
|
||||
* <p>
|
||||
* Note:使用自己本地的流式输出的自定义AI,接口输入和输出需与该样例保持一致
|
||||
* Note:使用自己本地的流式输出的自定义AI,接口输入和输出需与该样例保持一致
|
||||
* </p>
|
||||
*
|
||||
* @param queryRequest
|
||||
@ -171,7 +173,7 @@ public class ChatController {
|
||||
/**
|
||||
* 自定义模型非流式输出接口DEMO
|
||||
* <p>
|
||||
* Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
|
||||
* Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
|
||||
* </p>
|
||||
*
|
||||
* @param queryRequest
|
||||
@ -196,7 +198,7 @@ public class ChatController {
|
||||
@GetMapping("/chat")
|
||||
@CrossOrigin
|
||||
public SseEmitter completions(ChatQueryRequest queryRequest, @RequestHeader Map<String, String> headers)
|
||||
throws IOException {
|
||||
throws IOException {
|
||||
//默认30秒超时,设置为0L则永不超时
|
||||
SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
|
||||
String uid = headers.get("uid");
|
||||
@ -230,14 +232,14 @@ public class ChatController {
|
||||
}
|
||||
uid = aiSqlSourceEnum.getCode() + uid;
|
||||
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
|
||||
case OPENAI :
|
||||
case OPENAI:
|
||||
return chatWithOpenAi(queryRequest, sseEmitter, uid);
|
||||
case CHAT2DBAI:
|
||||
return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
|
||||
case RESTAI :
|
||||
case RESTAI:
|
||||
case FASTCHATAI:
|
||||
return chatWithFastChatAi(queryRequest, sseEmitter, uid);
|
||||
case AZUREAI :
|
||||
case AZUREAI:
|
||||
return chatWithAzureAi(queryRequest, sseEmitter, uid);
|
||||
case CLAUDEAI:
|
||||
return chatWithClaudeAi(queryRequest, sseEmitter, uid);
|
||||
@ -249,10 +251,34 @@ public class ChatController {
|
||||
return chatWithTongyiChatAi(queryRequest, sseEmitter, uid);
|
||||
case ZHIPUAI:
|
||||
return chatWithZhipuChatAi(queryRequest, sseEmitter, uid);
|
||||
case DIFYCHAT:
|
||||
return chatWithDifyChatAi(queryRequest, sseEmitter, uid);
|
||||
}
|
||||
return chatWithOpenAi(queryRequest, sseEmitter, uid);
|
||||
}
|
||||
|
||||
private SseEmitter chatWithDifyChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
|
||||
String prompt = buildPrompt(queryRequest);
|
||||
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
|
||||
log.error("exceed max token length:{},input length:{}", MAX_PROMPT_LENGTH,
|
||||
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
|
||||
throw new ParamBusinessException();
|
||||
}
|
||||
|
||||
prompt = prompt.replaceAll("#", "");
|
||||
log.info(prompt);
|
||||
|
||||
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(currentMessage);
|
||||
buildSseEmitter(sseEmitter, uid);
|
||||
|
||||
DifyChatAIEventSourceListener eventSourceListener = new DifyChatAIEventSourceListener(sseEmitter);
|
||||
DifyChatAIClient.getInstance().streamCompletions(messages, eventSourceListener);
|
||||
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
|
||||
return sseEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用自定义AI接口进行聊天
|
||||
*
|
||||
@ -276,11 +302,11 @@ public class ChatController {
|
||||
* @throws IOException
|
||||
*/
|
||||
private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
|
||||
throws IOException {
|
||||
throws IOException {
|
||||
String prompt = buildPrompt(queryRequest);
|
||||
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
|
||||
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
|
||||
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
|
||||
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
|
||||
throw new ParamBusinessException();
|
||||
}
|
||||
|
||||
@ -307,11 +333,11 @@ public class ChatController {
|
||||
* @throws IOException
|
||||
*/
|
||||
private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
|
||||
throws IOException {
|
||||
throws IOException {
|
||||
String prompt = buildPrompt(queryRequest);
|
||||
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
|
||||
log.error("exceed max token length:{},input length:{}", MAX_PROMPT_LENGTH,
|
||||
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
|
||||
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
|
||||
throw new ParamBusinessException();
|
||||
}
|
||||
|
||||
@ -344,7 +370,7 @@ public class ChatController {
|
||||
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
|
||||
throw new ParamBusinessException();
|
||||
}
|
||||
List<AzureChatMessage> messages = (List<AzureChatMessage>)LocalCache.CACHE.get(uid);
|
||||
List<AzureChatMessage> messages = (List<AzureChatMessage>) LocalCache.CACHE.get(uid);
|
||||
if (CollectionUtils.isNotEmpty(messages)) {
|
||||
if (messages.size() >= contextLength) {
|
||||
messages = messages.subList(1, contextLength);
|
||||
@ -363,40 +389,6 @@ public class ChatController {
|
||||
return sseEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* chat with azure openai
|
||||
*
|
||||
* @param queryRequest
|
||||
* @param sseEmitter
|
||||
* @param uid
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
private SseEmitter chatWithDifyChat(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
|
||||
String prompt = buildPrompt(queryRequest);
|
||||
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
|
||||
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
|
||||
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
|
||||
throw new ParamBusinessException();
|
||||
}
|
||||
List<AzureChatMessage> messages = (List<AzureChatMessage>)LocalCache.CACHE.get(uid);
|
||||
if (CollectionUtils.isNotEmpty(messages)) {
|
||||
if (messages.size() >= contextLength) {
|
||||
messages = messages.subList(1, contextLength);
|
||||
}
|
||||
} else {
|
||||
messages = Lists.newArrayList();
|
||||
}
|
||||
AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt);
|
||||
messages.add(currentMessage);
|
||||
|
||||
buildSseEmitter(sseEmitter, uid);
|
||||
|
||||
AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter);
|
||||
AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener);
|
||||
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
|
||||
return sseEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* chat with fast chat openai
|
||||
@ -490,7 +482,7 @@ public class ChatController {
|
||||
* @return
|
||||
*/
|
||||
private List<FastChatMessage> getFastChatMessage(String uid, String prompt) {
|
||||
List<FastChatMessage> messages = (List<FastChatMessage>)LocalCache.CACHE.get(uid);
|
||||
List<FastChatMessage> messages = (List<FastChatMessage>) LocalCache.CACHE.get(uid);
|
||||
if (CollectionUtils.isNotEmpty(messages)) {
|
||||
if (messages.size() >= contextLength) {
|
||||
messages = messages.subList(1, contextLength);
|
||||
@ -566,17 +558,17 @@ public class ChatController {
|
||||
log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
|
||||
});
|
||||
sseEmitter.onTimeout(
|
||||
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", on timeout#" + sseEmitter.getTimeout()));
|
||||
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", on timeout#" + sseEmitter.getTimeout()));
|
||||
sseEmitter.onError(
|
||||
throwable -> {
|
||||
try {
|
||||
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", on error#" + throwable.toString());
|
||||
sseEmitter.send(SseEmitter.event().id("765431").name("发生异常!").data(throwable.getMessage())
|
||||
.reconnectTime(3000));
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
throwable -> {
|
||||
try {
|
||||
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", on error#" + throwable.toString());
|
||||
sseEmitter.send(SseEmitter.event().id("765431").name("发生异常!").data(throwable.getMessage())
|
||||
.reconnectTime(3000));
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
return sseEmitter;
|
||||
}
|
||||
@ -589,13 +581,13 @@ public class ChatController {
|
||||
* @return
|
||||
*/
|
||||
private String buildTableColumn(TableQueryParam tableQueryParam,
|
||||
List<String> tableNames) {
|
||||
List<String> tableNames) {
|
||||
if (CollectionUtils.isEmpty(tableNames)) {
|
||||
return "";
|
||||
}
|
||||
List<String> schemaContent = Lists.newArrayList();
|
||||
try {
|
||||
schemaContent = tableNames.stream().map(tableName -> {
|
||||
schemaContent = tableNames.stream().map(tableName -> {
|
||||
tableQueryParam.setTableName(tableName);
|
||||
return queryTableDdl(tableName, tableQueryParam);
|
||||
}).collect(Collectors.toList());
|
||||
@ -645,19 +637,19 @@ public class ChatController {
|
||||
}
|
||||
String prompt = queryRequest.getMessage();
|
||||
String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
|
||||
: queryRequest.getPromptType();
|
||||
: queryRequest.getPromptType();
|
||||
PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType);
|
||||
String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : "";
|
||||
String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format(
|
||||
"### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# "
|
||||
+ "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
|
||||
properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
|
||||
pType.getDescription(), ext, prompt);
|
||||
"### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# "
|
||||
+ "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
|
||||
properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
|
||||
pType.getDescription(), ext, prompt);
|
||||
switch (pType) {
|
||||
case SQL_2_SQL:
|
||||
schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format(
|
||||
"%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
|
||||
"%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
|
||||
"%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
|
||||
"%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -746,7 +738,7 @@ public class ChatController {
|
||||
DataResult<TableSchemaResponse> result = gatewayClientService.schemaVectorSearch(tableSchemaRequest);
|
||||
List<String> schemas = Lists.newArrayList();
|
||||
if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
|
||||
for(TableSchema data: result.getData().getTableSchemas()){
|
||||
for (TableSchema data : result.getData().getTableSchemas()) {
|
||||
schemas.add(data.getTableSchema());
|
||||
}
|
||||
}
|
||||
@ -786,7 +778,7 @@ public class ChatController {
|
||||
DataResult<EsTableSchemaResponse> result = gatewayClientService.schemaEsSearch(tableSchemaRequest);
|
||||
List<String> schemas = Lists.newArrayList();
|
||||
if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
|
||||
for(EsTableSchema data: result.getData().getTableSchemas()){
|
||||
for (EsTableSchema data : result.getData().getTableSchemas()) {
|
||||
schemas.add(data.getTableSchemaContent());
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,12 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.dify.client;
|
||||
|
||||
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAiStreamClient;
|
||||
import ai.chat2db.server.web.api.controller.ai.dify.listener.DifyChatAIEventSourceListener;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class DifyChatAIClient {
|
||||
|
||||
@ -19,4 +24,12 @@ public class DifyChatAIClient {
|
||||
public static void refresh() {
|
||||
|
||||
}
|
||||
|
||||
public static DifyChatAIClient getInstance() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public void streamCompletions(List<Message> messages, DifyChatAIEventSourceListener eventSourceListener) {
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,4 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.dify.client;
|
||||
|
||||
public class DifyChatAiStreamClient {
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.dify.listener;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
@Slf4j
|
||||
public class DifyChatAIEventSourceListener extends EventSourceListener {
|
||||
|
||||
private SseEmitter sseEmitter;
|
||||
|
||||
public DifyChatAIEventSourceListener(SseEmitter sseEmitter) {
|
||||
this.sseEmitter = sseEmitter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(@NotNull EventSource eventSource) {
|
||||
super.onClosed(eventSource);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
|
||||
super.onEvent(eventSource, id, type, data);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
|
||||
super.onFailure(eventSource, t, response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
|
||||
super.onOpen(eventSource, response);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user