Merge remote-tracking branch 'origin/developing' into developing

This commit is contained in:
JiaJu Zhuang
2023-07-29 19:30:40 +08:00
17 changed files with 474 additions and 360 deletions

View File

@ -8,6 +8,7 @@ import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.common.config.Chat2dbProperties;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.response.ApiKeyResponse;
import ai.chat2db.server.web.api.http.response.InviteQrCodeResponse;
@ -64,15 +65,15 @@ public class AiConfigController {
// Representative successfully logged in
if (StringUtils.isNotBlank(qrCodeResponse.getApiKey())) {
SystemConfigParam param = SystemConfigParam.builder()
.code(OpenAIClient.OPENAI_KEY).content(qrCodeResponse.getApiKey())
.code(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).content(qrCodeResponse.getApiKey())
.build();
configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder()
.code(OpenAIClient.OPENAI_HOST)
.code(Chat2dbAIClient.CHAT2DB_OPENAI_HOST)
.content(chat2dbProperties.getGateway().getModelBaseUrl() + "/model")
.build();
configService.createOrUpdate(hostParam);
OpenAIClient.refresh();
Chat2dbAIClient.refresh();
}
return dataResult;
}
@ -107,7 +108,7 @@ public class AiConfigController {
}
private String getApiKey() {
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
DataResult<Config> apiKey = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY);
return Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null;
}
}

View File

@ -23,9 +23,9 @@ import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatMessage;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatRole;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
import ai.chat2db.server.web.api.controller.ai.enums.GptVersionType;
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.listener.AzureOpenAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.listener.OpenAIEventSourceListener;
@ -41,9 +41,6 @@ import cn.hutool.json.JSONUtil;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.entity.completions.Completion;
import com.unfbx.chatgpt.exception.BaseException;
import com.unfbx.chatgpt.exception.CommonError;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@ -55,7 +52,6 @@ import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@ -163,27 +159,6 @@ public class ChatController {
return data;
}
/**
* 问答对话模型
*
* @param msg
* @param headers
* @return
* @throws IOException
*/
@GetMapping("/chat1")
@CrossOrigin
public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<String, String> headers)
throws IOException {
//默认30秒超时,设置为0L则永不超时
SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
String uid = headers.get("uid");
if (StrUtil.isBlank(uid)) {
throw new BaseException(CommonError.SYS_ERROR);
}
return distributeAI(msg, sseEmitter, uid);
}
/**
* SQL转换模型
*
@ -211,32 +186,6 @@ public class ChatController {
return distributeAISql(queryRequest, sseEmitter, uid);
}
/**
* distribute with different AI
*
* @return
*/
private SseEmitter distributeAI(String msg, SseEmitter sseEmitter, String uid) throws IOException {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
if (Objects.nonNull(config)) {
aiSqlSource = config.getContent();
}
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
}
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
case CHAT2DBAI:
return chatWithOpenAi(msg, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(msg, sseEmitter);
}
return chatWithOpenAi(msg, sseEmitter, uid);
}
/**
* distribute with different AI
*
@ -253,16 +202,18 @@ public class ChatController {
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
}
uid = aiSqlSourceEnum.getCode() + uid;
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
return chatWithOpenAi(queryRequest, sseEmitter, uid);
case CHAT2DBAI:
return chatWithOpenAiSql(queryRequest, sseEmitter, uid);
return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(queryRequest.getMessage(), sseEmitter);
return chatWithRestAi(queryRequest, sseEmitter);
case AZUREAI :
return chatWithAzureAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAiSql(queryRequest, sseEmitter, uid);
return chatWithOpenAi(queryRequest, sseEmitter, uid);
}
/**
@ -272,9 +223,9 @@ public class ChatController {
* @param sseEmitter
* @return
*/
private SseEmitter chatWithRestAi(String prompt, SseEmitter sseEmitter) {
private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) {
RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter);
RestAIClient.getInstance().restCompletions(prompt, eventSourceListener);
RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener);
return sseEmitter;
}
@ -287,7 +238,7 @@ public class ChatController {
* @return
* @throws IOException
*/
private SseEmitter chatWithOpenAiSql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
@ -296,48 +247,48 @@ public class ChatController {
throw new ParamBusinessException();
}
GptVersionType modelType = EasyEnumUtils.getEnum(GptVersionType.class, gptVersion);
switch (modelType) {
case GPT3:
return chatGpt3(prompt, sseEmitter, uid);
case GPT35:
List<Message> messages = new ArrayList<>();
prompt = prompt.replaceAll("#", "");
log.info(prompt);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
messages.add(currentMessage);
return chatGpt35(messages, sseEmitter, uid);
default:
break;
}
return chatGpt3(prompt, sseEmitter, uid);
List<Message> messages = new ArrayList<>();
prompt = prompt.replaceAll("#", "");
log.info(prompt);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* 使用OPENAI聊天相关接口
* 使用OPENAI SQL接口
*
* @param msg
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithOpenAi(String msg, SseEmitter sseEmitter, String uid) throws IOException {
String messageContext = (String)LocalCache.CACHE.get(uid);
List<Message> messages = new ArrayList<>();
if (StrUtil.isNotBlank(messageContext)) {
messages = JSONUtil.toList(messageContext, Message.class);
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
}
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
} else {
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("exceed max token length:{}input length:{}", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
return chatGpt35(messages, sseEmitter, uid);
prompt = prompt.replaceAll("#", "");
log.info(prompt);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
List<Message> messages = new ArrayList<>();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
Chat2dbAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
/**
@ -367,24 +318,8 @@ public class ChatController {
AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt);
messages.add(currentMessage);
sseEmitter.send(SseEmitter.event().id(uid).name("sseEmitter connected").data(LocalDateTime.now()).reconnectTime(3000));
sseEmitter.onCompletion(() -> {
log.info(LocalDateTime.now() + ", uid#" + uid + ", sseEmitter on completion");
SseEmitter.event().id("[DONE]").data("[DONE]");
});
sseEmitter.onTimeout(
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", sseEmitter on timeout#" + sseEmitter.getTimeout()));
sseEmitter.onError(
throwable -> {
try {
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", sseEmitter on error#" + throwable.toString());
sseEmitter.send(SseEmitter.event().id("765431").name("exception occurs").data(throwable.getMessage())
.reconnectTime(3000));
} catch (IOException e) {
e.printStackTrace();
}
}
);
buildSseEmitter(sseEmitter, uid);
AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter);
AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
@ -392,15 +327,15 @@ public class ChatController {
}
/**
* 使用GPT3.5模型
* construct sseEmitter
*
* @param messages
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatGpt35(List<Message> messages, SseEmitter sseEmitter, String uid) throws IOException {
sseEmitter.send(SseEmitter.event().id(uid).name("连接成功").data(LocalDateTime.now()).reconnectTime(3000));
private SseEmitter buildSseEmitter(SseEmitter sseEmitter, String uid) throws IOException {
sseEmitter.send(SseEmitter.event().id(uid).name("connect successfully").data(LocalDateTime.now()).reconnectTime(3000));
sseEmitter.onCompletion(() -> {
log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
});
@ -417,46 +352,6 @@ public class ChatController {
}
}
);
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* 使用GPT3.0模型
*
* @param prompt
* @param sseEmitter
* @param uid
* @return
*/
private SseEmitter chatGpt3(String prompt, SseEmitter sseEmitter, String uid) throws IOException {
sseEmitter.send(SseEmitter.event().id(uid).name("chatGpt3连接成功").data(LocalDateTime.now())
.reconnectTime(3000));
sseEmitter.onCompletion(() -> {
log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
});
sseEmitter.onTimeout(
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", chatGpt3 on timeout#" + sseEmitter.getTimeout()));
sseEmitter.onError(
throwable -> {
try {
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", chatGpt3 on error#" + throwable.toString());
sseEmitter.send(SseEmitter.event().id("765431").name("chatGpt3 发生异常!")
.data(throwable.getMessage())
.reconnectTime(3000));
} catch (IOException e) {
e.printStackTrace();
}
}
);
// 获取返回结果
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
Completion completion = Completion.builder().maxTokens(RETURN_TOKEN_LENGTH).stream(true).stop(
Lists.newArrayList("#", ";")).user(uid).prompt(prompt).build();
OpenAIClient.getInstance().streamCompletions(completion, openAIEventSourceListener);
return sseEmitter;
}
@ -472,7 +367,12 @@ public class ChatController {
if (CollectionUtils.isEmpty(tableNames)) {
return Maps.newHashMap();
}
List<TableColumn> tableColumns = tableService.queryColumns(tableQueryParam);
List<TableColumn> tableColumns = Lists.newArrayList();
try {
tableColumns = tableService.queryColumns(tableQueryParam);
} catch (Exception exception) {
log.error("query table error, do nothing");
}
if (CollectionUtils.isEmpty(tableColumns)) {
return Maps.newHashMap();
}
@ -520,4 +420,116 @@ public class ChatController {
}
return schemaProperty;
}
///**
// * 问答对话模型
// *
// * @param msg
// * @param headers
// * @return
// * @throws IOException
// */
//@GetMapping("/chat1")
//@CrossOrigin
//public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<String, String> headers)
// throws IOException {
// //默认30秒超时,设置为0L则永不超时
// SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
// String uid = headers.get("uid");
// if (StrUtil.isBlank(uid)) {
// throw new BaseException(CommonError.SYS_ERROR);
// }
// return distributeAI(msg, sseEmitter, uid);
//}
///**
// * distribute with different AI
// *
// * @return
// */
//private SseEmitter distributeAI(String msg, SseEmitter sseEmitter, String uid) throws IOException {
// ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
// Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
// String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// if (Objects.nonNull(config)) {
// aiSqlSource = config.getContent();
// }
// AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
// if (Objects.isNull(aiSqlSourceEnum)) {
// aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
// }
// switch (Objects.requireNonNull(aiSqlSourceEnum)) {
// case OPENAI :
// return chatWithOpenAi(msg, sseEmitter, uid);
// case CHAT2DBAI:
// return chatWithOpenAi(msg, sseEmitter, uid);
// case RESTAI :
// return chatWithRestAi(msg, sseEmitter);
// }
// return chatWithOpenAi(msg, sseEmitter, uid);
//}
///**
// * 使用OPENAI聊天相关接口
// *
// * @param msg
// * @param sseEmitter
// * @param uid
// * @return
// * @throws IOException
// */
//private SseEmitter chatWithOpenAi(String msg, SseEmitter sseEmitter, String uid) throws IOException {
// String messageContext = (String)LocalCache.CACHE.get(uid);
// List<Message> messages = new ArrayList<>();
// if (StrUtil.isNotBlank(messageContext)) {
// messages = JSONUtil.toList(messageContext, Message.class);
// if (messages.size() >= contextLength) {
// messages = messages.subList(1, contextLength);
// }
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
// messages.add(currentMessage);
// } else {
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
// messages.add(currentMessage);
// }
//
// return chatGpt35(messages, sseEmitter, uid);
//}
///**
// * 使用GPT3.0模型
// *
// * @param prompt
// * @param sseEmitter
// * @param uid
// * @return
// */
//private SseEmitter chatGpt3(String prompt, SseEmitter sseEmitter, String uid) throws IOException {
// sseEmitter.send(SseEmitter.event().id(uid).name("chatGpt3连接成功").data(LocalDateTime.now())
// .reconnectTime(3000));
// sseEmitter.onCompletion(() -> {
// log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
// });
// sseEmitter.onTimeout(
// () -> log.info(LocalDateTime.now() + ", uid#" + uid + ", chatGpt3 on timeout#" + sseEmitter.getTimeout()));
// sseEmitter.onError(
// throwable -> {
// try {
// log.info(LocalDateTime.now() + ", uid#" + "765431" + ", chatGpt3 on error#" + throwable.toString());
// sseEmitter.send(SseEmitter.event().id("765431").name("chatGpt3 发生异常!")
// .data(throwable.getMessage())
// .reconnectTime(3000));
// } catch (IOException e) {
// e.printStackTrace();
// }
// }
// );
//
// // 获取返回结果
// OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
// Completion completion = Completion.builder().maxTokens(RETURN_TOKEN_LENGTH).stream(true).stop(
// Lists.newArrayList("#", ";")).user(uid).prompt(prompt).build();
// OpenAIClient.getInstance().streamCompletions(completion, openAIEventSourceListener);
// return sseEmitter;
//}
}

View File

@ -0,0 +1,84 @@
package ai.chat2db.server.web.api.controller.ai.chat2db.client;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import com.google.common.collect.Lists;
import com.unfbx.chatgpt.OpenAiStreamClient;
import com.unfbx.chatgpt.constant.OpenAIConst;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* @author jipengfei
* @version : OpenAIClient.java
*/
@Slf4j
public class Chat2dbAIClient {
public static final String CHAT2DB_OPENAI_KEY = "chat2db.apiKey";
/**
* OPENAI接口域名
*/
public static final String CHAT2DB_OPENAI_HOST = "chat2db.apiHost";
private static OpenAiStreamClient CHAT2DB_AI_STREAM_CLIENT;
private static String apiKey;
public static OpenAiStreamClient getInstance() {
if (CHAT2DB_AI_STREAM_CLIENT != null) {
return CHAT2DB_AI_STREAM_CLIENT;
} else {
return singleton();
}
}
private static OpenAiStreamClient singleton() {
if (CHAT2DB_AI_STREAM_CLIENT == null) {
synchronized (Chat2dbAIClient.class) {
if (CHAT2DB_AI_STREAM_CLIENT == null) {
refresh();
}
}
}
return CHAT2DB_AI_STREAM_CLIENT;
}
public static void refresh() {
String apikey;
String apiHost = ApplicationContextUtil.getProperty(CHAT2DB_OPENAI_HOST);
if (StringUtils.isBlank(apiHost)) {
apiHost = OpenAIConst.OPENAI_HOST;
}
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(CHAT2DB_OPENAI_HOST).getData();
if (apiHostConfig != null) {
apiHost = apiHostConfig.getContent();
}
Config config = configService.find(CHAT2DB_OPENAI_KEY).getData();
if (config != null) {
apikey = config.getContent();
} else {
apikey = ApplicationContextUtil.getProperty(CHAT2DB_OPENAI_KEY);
}
log.info("refresh chat2db apikey:{}", maskApiKey(apikey));
CHAT2DB_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey(
Lists.newArrayList(apikey)).build();
apiKey = apikey;
}
private static String maskApiKey(String input) {
if (input == null) {
return input;
}
StringBuilder maskedString = new StringBuilder(input);
for (int i = input.length() / 4; i < input.length() / 2; i++) {
maskedString.setCharAt(i, '*');
}
return maskedString.toString();
}
}

View File

@ -1,5 +1,6 @@
package ai.chat2db.server.web.api.controller.ai.listener;
import java.io.IOException;
import java.util.Objects;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatChoice;
@ -93,8 +94,15 @@ public class AzureOpenAIEventSourceListener extends EventSourceListener {
@Override
public void onClosed(EventSource eventSource) {
try {
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
} catch (IOException e) {
throw new RuntimeException(e);
}
sseEmitter.complete();
log.info("AzureOpenAI关闭sse连接...");
log.info("AzureOpenAI close sse connection...");
}
@Override
@ -102,11 +110,6 @@ public class AzureOpenAIEventSourceListener extends EventSourceListener {
try {
if (Objects.isNull(response)) {
String message = t.getMessage();
if ("No route to host".equals(message)) {
message = "网络连接超时,请检查网络连通性,参考文章<https://github.com/chat2db/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>";
} else {
message = "Azure AI无法正常访问请参考文章<https://github.com/chat2db/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>进行配置";
}
Message sseMessage = new Message();
sseMessage.setContent(message);
sseEmitter.send(SseEmitter.event()

View File

@ -13,6 +13,7 @@ import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.config.request.AIConfigCreateRequest;
import ai.chat2db.server.web.api.controller.config.request.AISystemConfigRequest;
import ai.chat2db.server.web.api.controller.config.request.SystemConfigRequest;
@ -52,41 +53,6 @@ public class ConfigController {
return ActionResult.isSuccess();
}
/**
* save ai config
*
* @param request
* @return
*/
@PostMapping("/system_config/chatgpt")
public ActionResult addAiSystemConfig(@RequestBody AISystemConfigRequest request) {
String sqlSource = StringUtils.isNotBlank(request.getAiSqlSource()) ? request.getAiSqlSource()
: AiSqlSourceEnum.CHAT2DBAI.getCode();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource);
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI;
sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
}
SystemConfigParam param = SystemConfigParam.builder().code(RestAIClient.AI_SQL_SOURCE).content(sqlSource)
.build();
configService.createOrUpdate(param);
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
saveOpenAIConfig(request);
break;
case CHAT2DBAI:
saveChat2dbAIConfig(request);
break;
case RESTAI :
saveRestAIConfig(request);
break;
case AZUREAI :
saveAzureAIConfig(request);
break;
}
return ActionResult.isSuccess();
}
/**
* 保存ChatGPT相关配置
@ -129,13 +95,13 @@ public class ConfigController {
* @param request
*/
private void saveChat2dbAIConfig(AIConfigCreateRequest request) {
SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content(
SystemConfigParam param = SystemConfigParam.builder().code(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).content(
request.getApiKey()).build();
configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content(
SystemConfigParam hostParam = SystemConfigParam.builder().code(Chat2dbAIClient.CHAT2DB_OPENAI_HOST).content(
request.getApiHost()).build();
configService.createOrUpdate(hostParam);
OpenAIClient.refresh();
Chat2dbAIClient.refresh();
}
/**
@ -192,76 +158,6 @@ public class ConfigController {
AzureOpenAIClient.refresh();
}
/**
* 保存OPENAI相关配置
*
* @param request
*/
private void saveChat2dbAIConfig(AISystemConfigRequest request) {
SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content(
request.getChat2dbApiKey()).build();
configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content(
request.getChat2dbApiHost()).build();
configService.createOrUpdate(hostParam);
OpenAIClient.refresh();
}
/**
* 保存OPENAI相关配置
*
* @param request
*/
private void saveOpenAIConfig(AISystemConfigRequest request) {
SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content(
request.getApiKey()).build();
configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content(
request.getApiHost()).build();
configService.createOrUpdate(hostParam);
SystemConfigParam httpProxyHostParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_HOST).content(
request.getHttpProxyHost()).build();
configService.createOrUpdate(httpProxyHostParam);
SystemConfigParam httpProxyPortParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_PORT).content(
request.getHttpProxyPort()).build();
configService.createOrUpdate(httpProxyPortParam);
OpenAIClient.refresh();
}
/**
* 保存RESTAI接口相关配置
*
* @param request
*/
private void saveRestAIConfig(AISystemConfigRequest request) {
SystemConfigParam restParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_URL).content(
request.getRestAiUrl())
.build();
configService.createOrUpdate(restParam);
SystemConfigParam methodParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_STREAM_OUT).content(
request.getRestAiStream().toString()).build();
configService.createOrUpdate(methodParam);
RestAIClient.refresh();
}
/**
* 保存azure配置
*
* @param request
*/
private void saveAzureAIConfig(AISystemConfigRequest request) {
SystemConfigParam apikeyParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_API_KEY).content(
request.getAzureApiKey()).build();
configService.createOrUpdate(apikeyParam);
SystemConfigParam endpointParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT).content(
request.getAzureEndpoint()).build();
configService.createOrUpdate(endpointParam);
SystemConfigParam modelParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID).content(
request.getAzureDeploymentId()).build();
configService.createOrUpdate(modelParam);
AzureOpenAIClient.refresh();
}
@GetMapping("/system_config/{code}")
public DataResult<Config> getSystemConfig(@PathVariable("code") String code) {
DataResult<Config> result = configService.find(code);
@ -269,7 +165,7 @@ public class ConfigController {
}
/**
* 查询ChatGPT相关配置
* ai config info
*
* @return
*/
@ -291,24 +187,20 @@ public class ConfigController {
config.setAiSqlSource(aiSqlSource);
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
if (!StringUtils.equals(dbSqlSource.getData().getContent(), AiSqlSourceEnum.CHAT2DBAI.getCode())) {
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST);
DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST);
DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT);
config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : "");
config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : "");
config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : "");
config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : "");
}
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST);
DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST);
DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT);
config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : "");
config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : "");
config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : "");
config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : "");
break;
case CHAT2DBAI:
if (!StringUtils.equals(dbSqlSource.getData().getContent(), AiSqlSourceEnum.OPENAI.getCode())) {
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST);
config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : "");
config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : "");
}
DataResult<Config> chat2dbApiKey = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY);
DataResult<Config> chat2dbApiHost = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_HOST);
config.setApiKey(Objects.nonNull(chat2dbApiKey.getData()) ? chat2dbApiKey.getData().getContent() : "");
config.setApiHost(Objects.nonNull(chat2dbApiHost.getData()) ? chat2dbApiHost.getData().getContent() : "");
break;
case AZUREAI:
DataResult<Config> azureApiKey = configService.find(AzureOpenAIClient.AZURE_CHATGPT_API_KEY);
@ -332,54 +224,145 @@ public class ConfigController {
return DataResult.of(config);
}
/**
* 查询ChatGPT相关配置
*
* @return
*/
@GetMapping("/system_config/chatgpt")
public DataResult<ChatGptConfig> getChatGptSystemConfig() {
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST);
DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST);
DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT);
DataResult<Config> aiSqlSource = configService.find(RestAIClient.AI_SQL_SOURCE);
DataResult<Config> restAiUrl = configService.find(RestAIClient.REST_AI_URL);
DataResult<Config> restAiHttpMethod = configService.find(RestAIClient.REST_AI_STREAM_OUT);
DataResult<Config> azureApiKey = configService.find(AzureOpenAIClient.AZURE_CHATGPT_API_KEY);
DataResult<Config> azureEndpoint = configService.find(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT);
DataResult<Config> azureDeployId = configService.find(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID);
ChatGptConfig config = new ChatGptConfig();
///**
// * save ai config
// *
// * @param request
// * @return
// */
//@PostMapping("/system_config/chatgpt")
//public ActionResult addAiSystemConfig(@RequestBody AISystemConfigRequest request) {
// String sqlSource = StringUtils.isNotBlank(request.getAiSqlSource()) ? request.getAiSqlSource()
// : AiSqlSourceEnum.CHAT2DBAI.getCode();
// AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource);
// if (Objects.isNull(aiSqlSourceEnum)) {
// aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI;
// sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// }
// SystemConfigParam param = SystemConfigParam.builder().code(RestAIClient.AI_SQL_SOURCE).content(sqlSource)
// .build();
// configService.createOrUpdate(param);
//
// switch (Objects.requireNonNull(aiSqlSourceEnum)) {
// case OPENAI :
// saveOpenAIConfig(request);
// break;
// case CHAT2DBAI:
// saveChat2dbAIConfig(request);
// break;
// case RESTAI :
// saveRestAIConfig(request);
// break;
// case AZUREAI :
// saveAzureAIConfig(request);
// break;
// }
// return ActionResult.isSuccess();
//}
//
///**
// * 保存OPENAI相关配置
// *
// * @param request
// */
//private void saveOpenAIConfig(AISystemConfigRequest request) {
// SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content(
// request.getApiKey()).build();
// configService.createOrUpdate(param);
// SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content(
// request.getApiHost()).build();
// configService.createOrUpdate(hostParam);
// SystemConfigParam httpProxyHostParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_HOST).content(
// request.getHttpProxyHost()).build();
// configService.createOrUpdate(httpProxyHostParam);
// SystemConfigParam httpProxyPortParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_PORT).content(
// request.getHttpProxyPort()).build();
// configService.createOrUpdate(httpProxyPortParam);
// OpenAIClient.refresh();
//}
//
///**
// * 保存RESTAI接口相关配置
// *
// * @param request
// */
//private void saveRestAIConfig(AISystemConfigRequest request) {
// SystemConfigParam restParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_URL).content(
// request.getRestAiUrl())
// .build();
// configService.createOrUpdate(restParam);
// SystemConfigParam methodParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_STREAM_OUT).content(
// request.getRestAiStream().toString()).build();
// configService.createOrUpdate(methodParam);
// RestAIClient.refresh();
//}
//
///**
// * 保存azure配置
// *
// * @param request
// */
//private void saveAzureAIConfig(AISystemConfigRequest request) {
// SystemConfigParam apikeyParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_API_KEY).content(
// request.getAzureApiKey()).build();
// configService.createOrUpdate(apikeyParam);
// SystemConfigParam endpointParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT).content(
// request.getAzureEndpoint()).build();
// configService.createOrUpdate(endpointParam);
// SystemConfigParam modelParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID).content(
// request.getAzureDeploymentId()).build();
// configService.createOrUpdate(modelParam);
// AzureOpenAIClient.refresh();
//}
String sqlSource = Objects.nonNull(aiSqlSource.getData()) ? aiSqlSource.getData().getContent() : AiSqlSourceEnum.CHAT2DBAI.getCode();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource);
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI;
sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
}
config.setAiSqlSource(sqlSource);
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null);
config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null);
config.setChat2dbApiKey("");
config.setChat2dbApiHost("");
break;
case CHAT2DBAI:
config.setApiKey("");
config.setApiHost("");
config.setChat2dbApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null);
config.setChat2dbApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null);
break;
}
config.setRestAiUrl(Objects.nonNull(restAiUrl.getData()) ? restAiUrl.getData().getContent() : null);
config.setRestAiStream(Objects.nonNull(restAiHttpMethod.getData()) ? Boolean.valueOf(
restAiHttpMethod.getData().getContent()) : Boolean.TRUE);
config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : null);
config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : null);
config.setAzureApiKey(Objects.nonNull(azureApiKey.getData()) ? azureApiKey.getData().getContent() : null);
config.setAzureEndpoint(Objects.nonNull(azureEndpoint.getData()) ? azureEndpoint.getData().getContent() : null);
config.setAzureDeploymentId(Objects.nonNull(azureDeployId.getData()) ? azureDeployId.getData().getContent() : null);
return DataResult.of(config);
}
///**
// * 查询ChatGPT相关配置
// *
// * @return
// */
//@GetMapping("/system_config/chatgpt")
//public DataResult<ChatGptConfig> getChatGptSystemConfig() {
// DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
// DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST);
// DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST);
// DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT);
// DataResult<Config> aiSqlSource = configService.find(RestAIClient.AI_SQL_SOURCE);
// DataResult<Config> restAiUrl = configService.find(RestAIClient.REST_AI_URL);
// DataResult<Config> restAiHttpMethod = configService.find(RestAIClient.REST_AI_STREAM_OUT);
// DataResult<Config> azureApiKey = configService.find(AzureOpenAIClient.AZURE_CHATGPT_API_KEY);
// DataResult<Config> azureEndpoint = configService.find(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT);
// DataResult<Config> azureDeployId = configService.find(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID);
// ChatGptConfig config = new ChatGptConfig();
//
// String sqlSource = Objects.nonNull(aiSqlSource.getData()) ? aiSqlSource.getData().getContent() : AiSqlSourceEnum.CHAT2DBAI.getCode();
// AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource);
// if (Objects.isNull(aiSqlSourceEnum)) {
// aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI;
// sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// }
// config.setAiSqlSource(sqlSource);
// switch (Objects.requireNonNull(aiSqlSourceEnum)) {
// case OPENAI :
// config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null);
// config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null);
// config.setChat2dbApiKey("");
// config.setChat2dbApiHost("");
// break;
// case CHAT2DBAI:
// config.setApiKey("");
// config.setApiHost("");
// config.setChat2dbApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null);
// config.setChat2dbApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null);
// break;
// }
// config.setRestAiUrl(Objects.nonNull(restAiUrl.getData()) ? restAiUrl.getData().getContent() : null);
// config.setRestAiStream(Objects.nonNull(restAiHttpMethod.getData()) ? Boolean.valueOf(
// restAiHttpMethod.getData().getContent()) : Boolean.TRUE);
// config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : null);
// config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : null);
// config.setAzureApiKey(Objects.nonNull(azureApiKey.getData()) ? azureApiKey.getData().getContent() : null);
// config.setAzureEndpoint(Objects.nonNull(azureEndpoint.getData()) ? azureEndpoint.getData().getContent() : null);
// config.setAzureDeploymentId(Objects.nonNull(azureDeployId.getData()) ? azureDeployId.getData().getContent() : null);
// return DataResult.of(config);
//}
}