Merge remote-tracking branch 'origin/developing' into developing

This commit is contained in:
JiaJu Zhuang
2023-07-29 19:30:40 +08:00
17 changed files with 474 additions and 360 deletions

View File

@ -6,6 +6,7 @@ import i18n from '@/i18n';
import classnames from 'classnames'; import classnames from 'classnames';
import { IAiConfig } from '@/typings/setting'; import { IAiConfig } from '@/typings/setting';
import styles from './index.less'; import styles from './index.less';
import Popularize from '@/components/Popularize';
interface IProps { interface IProps {
handleApplyAiConfig: (aiConfig: IAiConfig) => void; handleApplyAiConfig: (aiConfig: IAiConfig) => void;
aiConfig: IAiConfig; aiConfig: IAiConfig;
@ -201,10 +202,8 @@ export default function SettingAI(props: IProps) {
{i18n('setting.button.apply')} {i18n('setting.button.apply')}
</Button> </Button>
</div> </div>
{/* {aiConfig?.aiSqlSource === AiSqlSourceType.CHAT2DBAI && (
<Popularize source='setting'></Popularize> {aiConfig?.aiSqlSource === AiSqlSourceType.CHAT2DBAI && !aiConfig.apiKey && <Popularize source="setting" />}
)
} */}
</> </>
); );
} }

View File

@ -238,8 +238,8 @@ function Console(props: IProps) {
}); });
const handleMessage = (message: string) => { const handleMessage = (message: string) => {
// console.log('message', message);
setIsLoading(false); setIsLoading(false);
console.log('message', message);
try { try {
const isEOF = message === '[DONE]'; const isEOF = message === '[DONE]';
if (isEOF) { if (isEOF) {

View File

@ -9,7 +9,8 @@ interface IProps {
imageUrl?: string; imageUrl?: string;
tip?: string; tip?: string;
} }
const url = 'https://oss-chat2db.alibaba.com/static/wechat.webp'; const url =
'http://oss.sqlgpt.cn/static/chat2db-wechat.jpg?x-oss-process=image/auto-orient,1/resize,m_lfit,w_256/quality,Q_80/format,webp';
export default memo<IProps>(function Popularize(props) { export default memo<IProps>(function Popularize(props) {
const { className } = props; const { className } = props;
@ -19,7 +20,7 @@ export default memo<IProps>(function Popularize(props) {
} }
let dom; let dom;
if (props.source === 'setting') { if (props.source === 'setting') {
dom = <p>{i18n('common.text.wechatPopularizeAi2')}</p>; dom = <p>{'关注公众号获取AI Key'}</p>;
} else { } else {
dom = <p>{i18n('common.text.wechatPopularizeAi')}</p>; dom = <p>{i18n('common.text.wechatPopularizeAi')}</p>;
} }

View File

@ -5,7 +5,7 @@ import { Button, InputNumber, Popover, Select } from 'antd';
import { IResultConfig } from '@/typings'; import { IResultConfig } from '@/typings';
import i18n from '@/i18n'; import i18n from '@/i18n';
import _ from 'lodash'; import _ from 'lodash';
import styles from './Pagination.less'; import styles from './index.less';
interface IProps { interface IProps {
onPageSizeChange?: (pageSize: number) => void; onPageSizeChange?: (pageSize: number) => void;

View File

@ -1,4 +1,4 @@
@import '../../styles/var.less'; @import '../../../styles/var.less';
.tableBox { .tableBox {
position: absolute; position: absolute;

View File

@ -1,21 +1,20 @@
import React, { useEffect, useMemo, useState } from 'react'; import React, { useMemo, useState } from 'react';
import { TableDataType } from '@/constants/table'; import { TableDataType } from '@/constants/table';
import { IManageResultData, IResultConfig, ITableHeaderItem } from '@/typings/database'; import { IManageResultData, IResultConfig } from '@/typings/database';
import { formatDate } from '@/utils/date'; import { formatDate } from '@/utils/date';
import { Button, message, Modal, Pagination, Select, Table } from 'antd'; import { Button, message, Modal } from 'antd';
import antd from 'antd'; import { BaseTable, ArtColumn, useTablePipeline, features, SortItem } from 'ali-react-table';
import { BaseTable, ArtColumn, useTablePipeline, features, SortItem, BaseTableProps } from 'ali-react-table'; import Iconfont from '../../Iconfont';
import Iconfont from '../Iconfont';
import classnames from 'classnames'; import classnames from 'classnames';
import StateIndicator from '../StateIndicator'; import StateIndicator from '../../StateIndicator';
import MonacoEditor from '../Console/MonacoEditor'; import MonacoEditor from '../../Console/MonacoEditor';
import { useTheme } from '@/hooks/useTheme'; import { useTheme } from '@/hooks/useTheme';
import styled from 'styled-components'; import styled from 'styled-components';
import styles from './TableBox.less';
import { ThemeType } from '@/constants'; import { ThemeType } from '@/constants';
import i18n from '@/i18n'; import i18n from '@/i18n';
import { compareStrings } from '@/utils/sort'; import { compareStrings } from '@/utils/sort';
import MyPagination from './Pagination'; import MyPagination from '../Pagination';
import styles from './index.less';
interface ITableProps { interface ITableProps {
className?: string; className?: string;
@ -189,6 +188,16 @@ export default function TableBox(props: ITableProps) {
onClickTotalBtn={onClickTotalBtn} onClickTotalBtn={onClickTotalBtn}
/> />
</div> </div>
{/* <div className={styles.toolBarItem}>
<Button
type='text'
onClick={() => {
console.log('config', config);
}}
>
Excel
</Button>
</div> */}
</div> </div>
<DarkSupportBaseTable <DarkSupportBaseTable
className={classnames({ dark: isDarkTheme }, props.className, styles.table)} className={classnames({ dark: isDarkTheme }, props.className, styles.table)}

View File

@ -1,10 +1,9 @@
import React, { memo, useEffect, useState, useRef, useMemo, Fragment } from 'react'; import React, { memo, useEffect, useState, useMemo, Fragment } from 'react';
import classnames from 'classnames'; import classnames from 'classnames';
import Tabs, { IOption } from '@/components/Tabs'; import Tabs, { IOption } from '@/components/Tabs';
import Iconfont from '@/components/Iconfont'; import Iconfont from '@/components/Iconfont';
import StateIndicator from '@/components/StateIndicator'; import StateIndicator from '@/components/StateIndicator';
import { Spin, Popover } from 'antd'; import { Spin, Popover } from 'antd';
import { StatusType } from '@/constants';
import { IManageResultData, IResultConfig } from '@/typings'; import { IManageResultData, IResultConfig } from '@/typings';
import i18n from '@/i18n'; import i18n from '@/i18n';
import TableBox from './TableBox'; import TableBox from './TableBox';

View File

@ -79,10 +79,11 @@ const AIModel: IAIModelType = {
}); });
} catch (error) {} } catch (error) {}
}, },
*fetchRemainingUse({ payload }: { type: any; payload?: { apiKey?: string } }, { put }) { *fetchRemainingUse({ payload }: { type: any; payload?: { apiKey?: string } }, { put, select }) {
const currentState = (yield select((state: any) => state.ai)) as IAIState;
const { apiKey } = payload || {}; const { apiKey } = payload || {};
try { try {
if (!apiKey) { if (!apiKey || currentState.aiConfig.aiSqlSource !== AiSqlSourceType.CHAT2DBAI) {
yield put({ yield put({
type: 'setRemainUse', type: 'setRemainUse',
payload: undefined, payload: undefined,

View File

@ -1,6 +1,7 @@
import createRequest from './base'; import createRequest from './base';
import { IPageResponse, IPageParams, IUniversalTableParams, IManageResultData } from '@/typings'; import { IPageResponse, IPageParams, IUniversalTableParams, IManageResultData } from '@/typings';
import { DatabaseTypeCode } from '@/constants'; import { DatabaseTypeCode } from '@/constants';
import { ExportSizeEnum, ExportTypeEnum } from '@/typings/resultTable';
export interface IGetListParams extends IPageParams { export interface IGetListParams extends IPageParams {
dataSourceId: number; dataSourceId: number;
@ -114,6 +115,16 @@ const deleteTablePin = createRequest<IUniversalTableParams, void>('/api/pin/tabl
/** 获取当前执行SQL 所有行 */ /** 获取当前执行SQL 所有行 */
const getDMLCount = createRequest<IExecuteSqlParams, number>('/api/rdb/dml/count', { method: 'post' }); const getDMLCount = createRequest<IExecuteSqlParams, number>('/api/rdb/dml/count', { method: 'post' });
export interface IExportParams extends IExecuteSqlParams {
originalSql: string;
exportType: ExportTypeEnum;
exportSize: ExportSizeEnum;
}
/**
* 导出-表格
*/
const exportResultTable = createRequest<IExportParams, any>('/api/rdb/dml/export', { method: 'post' });
export default { export default {
getList, getList,
executeSql, executeSql,
@ -131,4 +142,5 @@ export default {
addTablePin, addTablePin,
deleteTablePin, deleteTablePin,
getDMLCount, getDMLCount,
exportResultTable
}; };

View File

@ -0,0 +1,10 @@
import { IExecuteSqlParams } from '@/service/sql';
export enum ExportTypeEnum {
CSV = 'CSV',
INSERT = 'INSERT',
}
export enum ExportSizeEnum {
CURRENT_PAGE = 'CURRENT_PAGE',
ALL = 'ALL',
}

View File

@ -17,7 +17,7 @@ const connectToEventSource = (params: { url: string; uid: string; onMessage: Fun
const eventSource = new EventSourcePolyfill(`${window._BaseURL}${url}`, p); const eventSource = new EventSourcePolyfill(`${window._BaseURL}${url}`, p);
eventSource.onmessage = (event) => { eventSource.onmessage = (event) => {
console.log('onmessage', event); // console.log('onmessage', event);
onMessage(event.data); onMessage(event.data);
}; };

View File

@ -8,6 +8,7 @@ import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.common.config.Chat2dbProperties; import ai.chat2db.server.tools.common.config.Chat2dbProperties;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.http.GatewayClientService; import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.response.ApiKeyResponse; import ai.chat2db.server.web.api.http.response.ApiKeyResponse;
import ai.chat2db.server.web.api.http.response.InviteQrCodeResponse; import ai.chat2db.server.web.api.http.response.InviteQrCodeResponse;
@ -64,15 +65,15 @@ public class AiConfigController {
// Representative successfully logged in // Representative successfully logged in
if (StringUtils.isNotBlank(qrCodeResponse.getApiKey())) { if (StringUtils.isNotBlank(qrCodeResponse.getApiKey())) {
SystemConfigParam param = SystemConfigParam.builder() SystemConfigParam param = SystemConfigParam.builder()
.code(OpenAIClient.OPENAI_KEY).content(qrCodeResponse.getApiKey()) .code(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).content(qrCodeResponse.getApiKey())
.build(); .build();
configService.createOrUpdate(param); configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder() SystemConfigParam hostParam = SystemConfigParam.builder()
.code(OpenAIClient.OPENAI_HOST) .code(Chat2dbAIClient.CHAT2DB_OPENAI_HOST)
.content(chat2dbProperties.getGateway().getModelBaseUrl() + "/model") .content(chat2dbProperties.getGateway().getModelBaseUrl() + "/model")
.build(); .build();
configService.createOrUpdate(hostParam); configService.createOrUpdate(hostParam);
OpenAIClient.refresh(); Chat2dbAIClient.refresh();
} }
return dataResult; return dataResult;
} }
@ -107,7 +108,7 @@ public class AiConfigController {
} }
private String getApiKey() { private String getApiKey() {
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY); DataResult<Config> apiKey = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY);
return Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null; return Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null;
} }
} }

View File

@ -23,9 +23,9 @@ import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatMessage; import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatMessage;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatRole; import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatRole;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.ai.config.LocalCache; import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter; import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
import ai.chat2db.server.web.api.controller.ai.enums.GptVersionType;
import ai.chat2db.server.web.api.controller.ai.enums.PromptType; import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.listener.AzureOpenAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.listener.AzureOpenAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.listener.OpenAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.listener.OpenAIEventSourceListener;
@ -41,9 +41,6 @@ import cn.hutool.json.JSONUtil;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import com.unfbx.chatgpt.entity.chat.Message; import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.entity.completions.Completion;
import com.unfbx.chatgpt.exception.BaseException;
import com.unfbx.chatgpt.exception.CommonError;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -55,7 +52,6 @@ import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@ -163,27 +159,6 @@ public class ChatController {
return data; return data;
} }
/**
* 问答对话模型
*
* @param msg
* @param headers
* @return
* @throws IOException
*/
@GetMapping("/chat1")
@CrossOrigin
public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<String, String> headers)
throws IOException {
//默认30秒超时,设置为0L则永不超时
SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
String uid = headers.get("uid");
if (StrUtil.isBlank(uid)) {
throw new BaseException(CommonError.SYS_ERROR);
}
return distributeAI(msg, sseEmitter, uid);
}
/** /**
* SQL转换模型 * SQL转换模型
* *
@ -211,32 +186,6 @@ public class ChatController {
return distributeAISql(queryRequest, sseEmitter, uid); return distributeAISql(queryRequest, sseEmitter, uid);
} }
/**
* distribute with different AI
*
* @return
*/
private SseEmitter distributeAI(String msg, SseEmitter sseEmitter, String uid) throws IOException {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
if (Objects.nonNull(config)) {
aiSqlSource = config.getContent();
}
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
}
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
case CHAT2DBAI:
return chatWithOpenAi(msg, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(msg, sseEmitter);
}
return chatWithOpenAi(msg, sseEmitter, uid);
}
/** /**
* distribute with different AI * distribute with different AI
* *
@ -253,16 +202,18 @@ public class ChatController {
if (Objects.isNull(aiSqlSourceEnum)) { if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.OPENAI; aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
} }
uid = aiSqlSourceEnum.getCode() + uid;
switch (Objects.requireNonNull(aiSqlSourceEnum)) { switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI : case OPENAI :
return chatWithOpenAi(queryRequest, sseEmitter, uid);
case CHAT2DBAI: case CHAT2DBAI:
return chatWithOpenAiSql(queryRequest, sseEmitter, uid); return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
case RESTAI : case RESTAI :
return chatWithRestAi(queryRequest.getMessage(), sseEmitter); return chatWithRestAi(queryRequest, sseEmitter);
case AZUREAI : case AZUREAI :
return chatWithAzureAi(queryRequest, sseEmitter, uid); return chatWithAzureAi(queryRequest, sseEmitter, uid);
} }
return chatWithOpenAiSql(queryRequest, sseEmitter, uid); return chatWithOpenAi(queryRequest, sseEmitter, uid);
} }
/** /**
@ -272,9 +223,9 @@ public class ChatController {
* @param sseEmitter * @param sseEmitter
* @return * @return
*/ */
private SseEmitter chatWithRestAi(String prompt, SseEmitter sseEmitter) { private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) {
RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter); RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter);
RestAIClient.getInstance().restCompletions(prompt, eventSourceListener); RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener);
return sseEmitter; return sseEmitter;
} }
@ -287,7 +238,7 @@ public class ChatController {
* @return * @return
* @throws IOException * @throws IOException
*/ */
private SseEmitter chatWithOpenAiSql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException { throws IOException {
String prompt = buildPrompt(queryRequest); String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
@ -296,48 +247,48 @@ public class ChatController {
throw new ParamBusinessException(); throw new ParamBusinessException();
} }
GptVersionType modelType = EasyEnumUtils.getEnum(GptVersionType.class, gptVersion); List<Message> messages = new ArrayList<>();
switch (modelType) { prompt = prompt.replaceAll("#", "");
case GPT3: log.info(prompt);
return chatGpt3(prompt, sseEmitter, uid); Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
case GPT35: messages.add(currentMessage);
List<Message> messages = new ArrayList<>(); buildSseEmitter(sseEmitter, uid);
prompt = prompt.replaceAll("#", "");
log.info(prompt); OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
messages.add(currentMessage); LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return chatGpt35(messages, sseEmitter, uid); return sseEmitter;
default:
break;
}
return chatGpt3(prompt, sseEmitter, uid);
} }
/** /**
* 使用OPENAI聊天相关接口 * 使用OPENAI SQL接口
* *
* @param msg * @param queryRequest
* @param sseEmitter * @param sseEmitter
* @param uid * @param uid
* @return * @return
* @throws IOException * @throws IOException
*/ */
private SseEmitter chatWithOpenAi(String msg, SseEmitter sseEmitter, String uid) throws IOException { private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
String messageContext = (String)LocalCache.CACHE.get(uid); throws IOException {
List<Message> messages = new ArrayList<>(); String prompt = buildPrompt(queryRequest);
if (StrUtil.isNotBlank(messageContext)) { if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
messages = JSONUtil.toList(messageContext, Message.class); log.error("exceed max token length:{}input length:{}", MAX_PROMPT_LENGTH,
if (messages.size() >= contextLength) { prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
messages = messages.subList(1, contextLength); throw new ParamBusinessException();
}
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
} else {
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
} }
return chatGpt35(messages, sseEmitter, uid); prompt = prompt.replaceAll("#", "");
log.info(prompt);
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
List<Message> messages = new ArrayList<>();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
Chat2dbAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
} }
/** /**
@ -367,24 +318,8 @@ public class ChatController {
AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt); AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt);
messages.add(currentMessage); messages.add(currentMessage);
sseEmitter.send(SseEmitter.event().id(uid).name("sseEmitter connected").data(LocalDateTime.now()).reconnectTime(3000)); buildSseEmitter(sseEmitter, uid);
sseEmitter.onCompletion(() -> {
log.info(LocalDateTime.now() + ", uid#" + uid + ", sseEmitter on completion");
SseEmitter.event().id("[DONE]").data("[DONE]");
});
sseEmitter.onTimeout(
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", sseEmitter on timeout#" + sseEmitter.getTimeout()));
sseEmitter.onError(
throwable -> {
try {
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", sseEmitter on error#" + throwable.toString());
sseEmitter.send(SseEmitter.event().id("765431").name("exception occurs").data(throwable.getMessage())
.reconnectTime(3000));
} catch (IOException e) {
e.printStackTrace();
}
}
);
AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter); AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter);
AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener); AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT); LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
@ -392,15 +327,15 @@ public class ChatController {
} }
/** /**
* 使用GPT3.5模型 * construct sseEmitter
* *
* @param messages
* @param sseEmitter * @param sseEmitter
* @param uid * @param uid
* @return * @return
* @throws IOException
*/ */
private SseEmitter chatGpt35(List<Message> messages, SseEmitter sseEmitter, String uid) throws IOException { private SseEmitter buildSseEmitter(SseEmitter sseEmitter, String uid) throws IOException {
sseEmitter.send(SseEmitter.event().id(uid).name("连接成功").data(LocalDateTime.now()).reconnectTime(3000)); sseEmitter.send(SseEmitter.event().id(uid).name("connect successfully").data(LocalDateTime.now()).reconnectTime(3000));
sseEmitter.onCompletion(() -> { sseEmitter.onCompletion(() -> {
log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion"); log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
}); });
@ -417,46 +352,6 @@ public class ChatController {
} }
} }
); );
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* 使用GPT3.0模型
*
* @param prompt
* @param sseEmitter
* @param uid
* @return
*/
private SseEmitter chatGpt3(String prompt, SseEmitter sseEmitter, String uid) throws IOException {
sseEmitter.send(SseEmitter.event().id(uid).name("chatGpt3连接成功").data(LocalDateTime.now())
.reconnectTime(3000));
sseEmitter.onCompletion(() -> {
log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
});
sseEmitter.onTimeout(
() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", chatGpt3 on timeout#" + sseEmitter.getTimeout()));
sseEmitter.onError(
throwable -> {
try {
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", chatGpt3 on error#" + throwable.toString());
sseEmitter.send(SseEmitter.event().id("765431").name("chatGpt3 发生异常!")
.data(throwable.getMessage())
.reconnectTime(3000));
} catch (IOException e) {
e.printStackTrace();
}
}
);
// 获取返回结果
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
Completion completion = Completion.builder().maxTokens(RETURN_TOKEN_LENGTH).stream(true).stop(
Lists.newArrayList("#", ";")).user(uid).prompt(prompt).build();
OpenAIClient.getInstance().streamCompletions(completion, openAIEventSourceListener);
return sseEmitter; return sseEmitter;
} }
@ -472,7 +367,12 @@ public class ChatController {
if (CollectionUtils.isEmpty(tableNames)) { if (CollectionUtils.isEmpty(tableNames)) {
return Maps.newHashMap(); return Maps.newHashMap();
} }
List<TableColumn> tableColumns = tableService.queryColumns(tableQueryParam); List<TableColumn> tableColumns = Lists.newArrayList();
try {
tableColumns = tableService.queryColumns(tableQueryParam);
} catch (Exception exception) {
log.error("query table error, do nothing");
}
if (CollectionUtils.isEmpty(tableColumns)) { if (CollectionUtils.isEmpty(tableColumns)) {
return Maps.newHashMap(); return Maps.newHashMap();
} }
@ -520,4 +420,116 @@ public class ChatController {
} }
return schemaProperty; return schemaProperty;
} }
///**
// * 问答对话模型
// *
// * @param msg
// * @param headers
// * @return
// * @throws IOException
// */
//@GetMapping("/chat1")
//@CrossOrigin
//public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<String, String> headers)
// throws IOException {
// //默认30秒超时,设置为0L则永不超时
// SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
// String uid = headers.get("uid");
// if (StrUtil.isBlank(uid)) {
// throw new BaseException(CommonError.SYS_ERROR);
// }
// return distributeAI(msg, sseEmitter, uid);
//}
///**
// * distribute with different AI
// *
// * @return
// */
//private SseEmitter distributeAI(String msg, SseEmitter sseEmitter, String uid) throws IOException {
// ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
// Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
// String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// if (Objects.nonNull(config)) {
// aiSqlSource = config.getContent();
// }
// AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
// if (Objects.isNull(aiSqlSourceEnum)) {
// aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
// }
// switch (Objects.requireNonNull(aiSqlSourceEnum)) {
// case OPENAI :
// return chatWithOpenAi(msg, sseEmitter, uid);
// case CHAT2DBAI:
// return chatWithOpenAi(msg, sseEmitter, uid);
// case RESTAI :
// return chatWithRestAi(msg, sseEmitter);
// }
// return chatWithOpenAi(msg, sseEmitter, uid);
//}
///**
// * 使用OPENAI聊天相关接口
// *
// * @param msg
// * @param sseEmitter
// * @param uid
// * @return
// * @throws IOException
// */
//private SseEmitter chatWithOpenAi(String msg, SseEmitter sseEmitter, String uid) throws IOException {
// String messageContext = (String)LocalCache.CACHE.get(uid);
// List<Message> messages = new ArrayList<>();
// if (StrUtil.isNotBlank(messageContext)) {
// messages = JSONUtil.toList(messageContext, Message.class);
// if (messages.size() >= contextLength) {
// messages = messages.subList(1, contextLength);
// }
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
// messages.add(currentMessage);
// } else {
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
// messages.add(currentMessage);
// }
//
// return chatGpt35(messages, sseEmitter, uid);
//}
///**
// * 使用GPT3.0模型
// *
// * @param prompt
// * @param sseEmitter
// * @param uid
// * @return
// */
//private SseEmitter chatGpt3(String prompt, SseEmitter sseEmitter, String uid) throws IOException {
// sseEmitter.send(SseEmitter.event().id(uid).name("chatGpt3连接成功").data(LocalDateTime.now())
// .reconnectTime(3000));
// sseEmitter.onCompletion(() -> {
// log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
// });
// sseEmitter.onTimeout(
// () -> log.info(LocalDateTime.now() + ", uid#" + uid + ", chatGpt3 on timeout#" + sseEmitter.getTimeout()));
// sseEmitter.onError(
// throwable -> {
// try {
// log.info(LocalDateTime.now() + ", uid#" + "765431" + ", chatGpt3 on error#" + throwable.toString());
// sseEmitter.send(SseEmitter.event().id("765431").name("chatGpt3 发生异常!")
// .data(throwable.getMessage())
// .reconnectTime(3000));
// } catch (IOException e) {
// e.printStackTrace();
// }
// }
// );
//
// // 获取返回结果
// OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
// Completion completion = Completion.builder().maxTokens(RETURN_TOKEN_LENGTH).stream(true).stop(
// Lists.newArrayList("#", ";")).user(uid).prompt(prompt).build();
// OpenAIClient.getInstance().streamCompletions(completion, openAIEventSourceListener);
// return sseEmitter;
//}
} }

View File

@ -0,0 +1,84 @@
package ai.chat2db.server.web.api.controller.ai.chat2db.client;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import com.google.common.collect.Lists;
import com.unfbx.chatgpt.OpenAiStreamClient;
import com.unfbx.chatgpt.constant.OpenAIConst;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* @author jipengfei
* @version : OpenAIClient.java
*/
@Slf4j
public class Chat2dbAIClient {
public static final String CHAT2DB_OPENAI_KEY = "chat2db.apiKey";
/**
* OPENAI接口域名
*/
public static final String CHAT2DB_OPENAI_HOST = "chat2db.apiHost";
private static OpenAiStreamClient CHAT2DB_AI_STREAM_CLIENT;
private static String apiKey;
public static OpenAiStreamClient getInstance() {
if (CHAT2DB_AI_STREAM_CLIENT != null) {
return CHAT2DB_AI_STREAM_CLIENT;
} else {
return singleton();
}
}
private static OpenAiStreamClient singleton() {
if (CHAT2DB_AI_STREAM_CLIENT == null) {
synchronized (Chat2dbAIClient.class) {
if (CHAT2DB_AI_STREAM_CLIENT == null) {
refresh();
}
}
}
return CHAT2DB_AI_STREAM_CLIENT;
}
public static void refresh() {
String apikey;
String apiHost = ApplicationContextUtil.getProperty(CHAT2DB_OPENAI_HOST);
if (StringUtils.isBlank(apiHost)) {
apiHost = OpenAIConst.OPENAI_HOST;
}
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(CHAT2DB_OPENAI_HOST).getData();
if (apiHostConfig != null) {
apiHost = apiHostConfig.getContent();
}
Config config = configService.find(CHAT2DB_OPENAI_KEY).getData();
if (config != null) {
apikey = config.getContent();
} else {
apikey = ApplicationContextUtil.getProperty(CHAT2DB_OPENAI_KEY);
}
log.info("refresh chat2db apikey:{}", maskApiKey(apikey));
CHAT2DB_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey(
Lists.newArrayList(apikey)).build();
apiKey = apikey;
}
private static String maskApiKey(String input) {
if (input == null) {
return input;
}
StringBuilder maskedString = new StringBuilder(input);
for (int i = input.length() / 4; i < input.length() / 2; i++) {
maskedString.setCharAt(i, '*');
}
return maskedString.toString();
}
}

View File

@ -1,5 +1,6 @@
package ai.chat2db.server.web.api.controller.ai.listener; package ai.chat2db.server.web.api.controller.ai.listener;
import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatChoice; import ai.chat2db.server.web.api.controller.ai.azure.models.AzureChatChoice;
@ -93,8 +94,15 @@ public class AzureOpenAIEventSourceListener extends EventSourceListener {
@Override @Override
public void onClosed(EventSource eventSource) { public void onClosed(EventSource eventSource) {
try {
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
} catch (IOException e) {
throw new RuntimeException(e);
}
sseEmitter.complete(); sseEmitter.complete();
log.info("AzureOpenAI关闭sse连接..."); log.info("AzureOpenAI close sse connection...");
} }
@Override @Override
@ -102,11 +110,6 @@ public class AzureOpenAIEventSourceListener extends EventSourceListener {
try { try {
if (Objects.isNull(response)) { if (Objects.isNull(response)) {
String message = t.getMessage(); String message = t.getMessage();
if ("No route to host".equals(message)) {
message = "网络连接超时,请检查网络连通性,参考文章<https://github.com/chat2db/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>";
} else {
message = "Azure AI无法正常访问请参考文章<https://github.com/chat2db/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>进行配置";
}
Message sseMessage = new Message(); Message sseMessage = new Message();
sseMessage.setContent(message); sseMessage.setContent(message);
sseEmitter.send(SseEmitter.event() sseEmitter.send(SseEmitter.event()

View File

@ -13,6 +13,7 @@ import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.config.request.AIConfigCreateRequest; import ai.chat2db.server.web.api.controller.config.request.AIConfigCreateRequest;
import ai.chat2db.server.web.api.controller.config.request.AISystemConfigRequest; import ai.chat2db.server.web.api.controller.config.request.AISystemConfigRequest;
import ai.chat2db.server.web.api.controller.config.request.SystemConfigRequest; import ai.chat2db.server.web.api.controller.config.request.SystemConfigRequest;
@ -52,41 +53,6 @@ public class ConfigController {
return ActionResult.isSuccess(); return ActionResult.isSuccess();
} }
/**
* save ai config
*
* @param request
* @return
*/
@PostMapping("/system_config/chatgpt")
public ActionResult addAiSystemConfig(@RequestBody AISystemConfigRequest request) {
String sqlSource = StringUtils.isNotBlank(request.getAiSqlSource()) ? request.getAiSqlSource()
: AiSqlSourceEnum.CHAT2DBAI.getCode();
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource);
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI;
sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
}
SystemConfigParam param = SystemConfigParam.builder().code(RestAIClient.AI_SQL_SOURCE).content(sqlSource)
.build();
configService.createOrUpdate(param);
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
saveOpenAIConfig(request);
break;
case CHAT2DBAI:
saveChat2dbAIConfig(request);
break;
case RESTAI :
saveRestAIConfig(request);
break;
case AZUREAI :
saveAzureAIConfig(request);
break;
}
return ActionResult.isSuccess();
}
/** /**
* 保存ChatGPT相关配置 * 保存ChatGPT相关配置
@ -129,13 +95,13 @@ public class ConfigController {
* @param request * @param request
*/ */
private void saveChat2dbAIConfig(AIConfigCreateRequest request) { private void saveChat2dbAIConfig(AIConfigCreateRequest request) {
SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content( SystemConfigParam param = SystemConfigParam.builder().code(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).content(
request.getApiKey()).build(); request.getApiKey()).build();
configService.createOrUpdate(param); configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content( SystemConfigParam hostParam = SystemConfigParam.builder().code(Chat2dbAIClient.CHAT2DB_OPENAI_HOST).content(
request.getApiHost()).build(); request.getApiHost()).build();
configService.createOrUpdate(hostParam); configService.createOrUpdate(hostParam);
OpenAIClient.refresh(); Chat2dbAIClient.refresh();
} }
/** /**
@ -192,76 +158,6 @@ public class ConfigController {
AzureOpenAIClient.refresh(); AzureOpenAIClient.refresh();
} }
/**
* 保存OPENAI相关配置
*
* @param request
*/
private void saveChat2dbAIConfig(AISystemConfigRequest request) {
SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content(
request.getChat2dbApiKey()).build();
configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content(
request.getChat2dbApiHost()).build();
configService.createOrUpdate(hostParam);
OpenAIClient.refresh();
}
/**
* 保存OPENAI相关配置
*
* @param request
*/
private void saveOpenAIConfig(AISystemConfigRequest request) {
SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content(
request.getApiKey()).build();
configService.createOrUpdate(param);
SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content(
request.getApiHost()).build();
configService.createOrUpdate(hostParam);
SystemConfigParam httpProxyHostParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_HOST).content(
request.getHttpProxyHost()).build();
configService.createOrUpdate(httpProxyHostParam);
SystemConfigParam httpProxyPortParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_PORT).content(
request.getHttpProxyPort()).build();
configService.createOrUpdate(httpProxyPortParam);
OpenAIClient.refresh();
}
/**
* 保存RESTAI接口相关配置
*
* @param request
*/
private void saveRestAIConfig(AISystemConfigRequest request) {
SystemConfigParam restParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_URL).content(
request.getRestAiUrl())
.build();
configService.createOrUpdate(restParam);
SystemConfigParam methodParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_STREAM_OUT).content(
request.getRestAiStream().toString()).build();
configService.createOrUpdate(methodParam);
RestAIClient.refresh();
}
/**
* 保存azure配置
*
* @param request
*/
private void saveAzureAIConfig(AISystemConfigRequest request) {
SystemConfigParam apikeyParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_API_KEY).content(
request.getAzureApiKey()).build();
configService.createOrUpdate(apikeyParam);
SystemConfigParam endpointParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT).content(
request.getAzureEndpoint()).build();
configService.createOrUpdate(endpointParam);
SystemConfigParam modelParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID).content(
request.getAzureDeploymentId()).build();
configService.createOrUpdate(modelParam);
AzureOpenAIClient.refresh();
}
@GetMapping("/system_config/{code}") @GetMapping("/system_config/{code}")
public DataResult<Config> getSystemConfig(@PathVariable("code") String code) { public DataResult<Config> getSystemConfig(@PathVariable("code") String code) {
DataResult<Config> result = configService.find(code); DataResult<Config> result = configService.find(code);
@ -269,7 +165,7 @@ public class ConfigController {
} }
/** /**
* 查询ChatGPT相关配置 * ai config info
* *
* @return * @return
*/ */
@ -291,24 +187,20 @@ public class ConfigController {
config.setAiSqlSource(aiSqlSource); config.setAiSqlSource(aiSqlSource);
switch (Objects.requireNonNull(aiSqlSourceEnum)) { switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI : case OPENAI :
if (!StringUtils.equals(dbSqlSource.getData().getContent(), AiSqlSourceEnum.CHAT2DBAI.getCode())) { DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY); DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST);
DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST); DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST);
DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST); DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT);
DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT); config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : "");
config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : ""); config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : "");
config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : ""); config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : "");
config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : ""); config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : "");
config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : "");
}
break; break;
case CHAT2DBAI: case CHAT2DBAI:
if (!StringUtils.equals(dbSqlSource.getData().getContent(), AiSqlSourceEnum.OPENAI.getCode())) { DataResult<Config> chat2dbApiKey = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY);
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY); DataResult<Config> chat2dbApiHost = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_HOST);
DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST); config.setApiKey(Objects.nonNull(chat2dbApiKey.getData()) ? chat2dbApiKey.getData().getContent() : "");
config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : ""); config.setApiHost(Objects.nonNull(chat2dbApiHost.getData()) ? chat2dbApiHost.getData().getContent() : "");
config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : "");
}
break; break;
case AZUREAI: case AZUREAI:
DataResult<Config> azureApiKey = configService.find(AzureOpenAIClient.AZURE_CHATGPT_API_KEY); DataResult<Config> azureApiKey = configService.find(AzureOpenAIClient.AZURE_CHATGPT_API_KEY);
@ -332,54 +224,145 @@ public class ConfigController {
return DataResult.of(config); return DataResult.of(config);
} }
/** ///**
* 查询ChatGPT相关配置 // * save ai config
* // *
* @return // * @param request
*/ // * @return
@GetMapping("/system_config/chatgpt") // */
public DataResult<ChatGptConfig> getChatGptSystemConfig() { //@PostMapping("/system_config/chatgpt")
DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY); //public ActionResult addAiSystemConfig(@RequestBody AISystemConfigRequest request) {
DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST); // String sqlSource = StringUtils.isNotBlank(request.getAiSqlSource()) ? request.getAiSqlSource()
DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST); // : AiSqlSourceEnum.CHAT2DBAI.getCode();
DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT); // AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource);
DataResult<Config> aiSqlSource = configService.find(RestAIClient.AI_SQL_SOURCE); // if (Objects.isNull(aiSqlSourceEnum)) {
DataResult<Config> restAiUrl = configService.find(RestAIClient.REST_AI_URL); // aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI;
DataResult<Config> restAiHttpMethod = configService.find(RestAIClient.REST_AI_STREAM_OUT); // sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
DataResult<Config> azureApiKey = configService.find(AzureOpenAIClient.AZURE_CHATGPT_API_KEY); // }
DataResult<Config> azureEndpoint = configService.find(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT); // SystemConfigParam param = SystemConfigParam.builder().code(RestAIClient.AI_SQL_SOURCE).content(sqlSource)
DataResult<Config> azureDeployId = configService.find(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID); // .build();
ChatGptConfig config = new ChatGptConfig(); // configService.createOrUpdate(param);
//
// switch (Objects.requireNonNull(aiSqlSourceEnum)) {
// case OPENAI :
// saveOpenAIConfig(request);
// break;
// case CHAT2DBAI:
// saveChat2dbAIConfig(request);
// break;
// case RESTAI :
// saveRestAIConfig(request);
// break;
// case AZUREAI :
// saveAzureAIConfig(request);
// break;
// }
// return ActionResult.isSuccess();
//}
//
///**
// * 保存OPENAI相关配置
// *
// * @param request
// */
//private void saveOpenAIConfig(AISystemConfigRequest request) {
// SystemConfigParam param = SystemConfigParam.builder().code(OpenAIClient.OPENAI_KEY).content(
// request.getApiKey()).build();
// configService.createOrUpdate(param);
// SystemConfigParam hostParam = SystemConfigParam.builder().code(OpenAIClient.OPENAI_HOST).content(
// request.getApiHost()).build();
// configService.createOrUpdate(hostParam);
// SystemConfigParam httpProxyHostParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_HOST).content(
// request.getHttpProxyHost()).build();
// configService.createOrUpdate(httpProxyHostParam);
// SystemConfigParam httpProxyPortParam = SystemConfigParam.builder().code(OpenAIClient.PROXY_PORT).content(
// request.getHttpProxyPort()).build();
// configService.createOrUpdate(httpProxyPortParam);
// OpenAIClient.refresh();
//}
//
///**
// * 保存RESTAI接口相关配置
// *
// * @param request
// */
//private void saveRestAIConfig(AISystemConfigRequest request) {
// SystemConfigParam restParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_URL).content(
// request.getRestAiUrl())
// .build();
// configService.createOrUpdate(restParam);
// SystemConfigParam methodParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_STREAM_OUT).content(
// request.getRestAiStream().toString()).build();
// configService.createOrUpdate(methodParam);
// RestAIClient.refresh();
//}
//
///**
// * 保存azure配置
// *
// * @param request
// */
//private void saveAzureAIConfig(AISystemConfigRequest request) {
// SystemConfigParam apikeyParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_API_KEY).content(
// request.getAzureApiKey()).build();
// configService.createOrUpdate(apikeyParam);
// SystemConfigParam endpointParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT).content(
// request.getAzureEndpoint()).build();
// configService.createOrUpdate(endpointParam);
// SystemConfigParam modelParam = SystemConfigParam.builder().code(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID).content(
// request.getAzureDeploymentId()).build();
// configService.createOrUpdate(modelParam);
// AzureOpenAIClient.refresh();
//}
String sqlSource = Objects.nonNull(aiSqlSource.getData()) ? aiSqlSource.getData().getContent() : AiSqlSourceEnum.CHAT2DBAI.getCode(); ///**
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource); // * 查询ChatGPT相关配置
if (Objects.isNull(aiSqlSourceEnum)) { // *
aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI; // * @return
sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); // */
} //@GetMapping("/system_config/chatgpt")
config.setAiSqlSource(sqlSource); //public DataResult<ChatGptConfig> getChatGptSystemConfig() {
switch (Objects.requireNonNull(aiSqlSourceEnum)) { // DataResult<Config> apiKey = configService.find(OpenAIClient.OPENAI_KEY);
case OPENAI : // DataResult<Config> apiHost = configService.find(OpenAIClient.OPENAI_HOST);
config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null); // DataResult<Config> httpProxyHost = configService.find(OpenAIClient.PROXY_HOST);
config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null); // DataResult<Config> httpProxyPort = configService.find(OpenAIClient.PROXY_PORT);
config.setChat2dbApiKey(""); // DataResult<Config> aiSqlSource = configService.find(RestAIClient.AI_SQL_SOURCE);
config.setChat2dbApiHost(""); // DataResult<Config> restAiUrl = configService.find(RestAIClient.REST_AI_URL);
break; // DataResult<Config> restAiHttpMethod = configService.find(RestAIClient.REST_AI_STREAM_OUT);
case CHAT2DBAI: // DataResult<Config> azureApiKey = configService.find(AzureOpenAIClient.AZURE_CHATGPT_API_KEY);
config.setApiKey(""); // DataResult<Config> azureEndpoint = configService.find(AzureOpenAIClient.AZURE_CHATGPT_ENDPOINT);
config.setApiHost(""); // DataResult<Config> azureDeployId = configService.find(AzureOpenAIClient.AZURE_CHATGPT_DEPLOYMENT_ID);
config.setChat2dbApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null); // ChatGptConfig config = new ChatGptConfig();
config.setChat2dbApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null); //
break; // String sqlSource = Objects.nonNull(aiSqlSource.getData()) ? aiSqlSource.getData().getContent() : AiSqlSourceEnum.CHAT2DBAI.getCode();
} // AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(sqlSource);
config.setRestAiUrl(Objects.nonNull(restAiUrl.getData()) ? restAiUrl.getData().getContent() : null); // if (Objects.isNull(aiSqlSourceEnum)) {
config.setRestAiStream(Objects.nonNull(restAiHttpMethod.getData()) ? Boolean.valueOf( // aiSqlSourceEnum = AiSqlSourceEnum.CHAT2DBAI;
restAiHttpMethod.getData().getContent()) : Boolean.TRUE); // sqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : null); // }
config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : null); // config.setAiSqlSource(sqlSource);
config.setAzureApiKey(Objects.nonNull(azureApiKey.getData()) ? azureApiKey.getData().getContent() : null); // switch (Objects.requireNonNull(aiSqlSourceEnum)) {
config.setAzureEndpoint(Objects.nonNull(azureEndpoint.getData()) ? azureEndpoint.getData().getContent() : null); // case OPENAI :
config.setAzureDeploymentId(Objects.nonNull(azureDeployId.getData()) ? azureDeployId.getData().getContent() : null); // config.setApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null);
return DataResult.of(config); // config.setApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null);
} // config.setChat2dbApiKey("");
// config.setChat2dbApiHost("");
// break;
// case CHAT2DBAI:
// config.setApiKey("");
// config.setApiHost("");
// config.setChat2dbApiKey(Objects.nonNull(apiKey.getData()) ? apiKey.getData().getContent() : null);
// config.setChat2dbApiHost(Objects.nonNull(apiHost.getData()) ? apiHost.getData().getContent() : null);
// break;
// }
// config.setRestAiUrl(Objects.nonNull(restAiUrl.getData()) ? restAiUrl.getData().getContent() : null);
// config.setRestAiStream(Objects.nonNull(restAiHttpMethod.getData()) ? Boolean.valueOf(
// restAiHttpMethod.getData().getContent()) : Boolean.TRUE);
// config.setHttpProxyHost(Objects.nonNull(httpProxyHost.getData()) ? httpProxyHost.getData().getContent() : null);
// config.setHttpProxyPort(Objects.nonNull(httpProxyPort.getData()) ? httpProxyPort.getData().getContent() : null);
// config.setAzureApiKey(Objects.nonNull(azureApiKey.getData()) ? azureApiKey.getData().getContent() : null);
// config.setAzureEndpoint(Objects.nonNull(azureEndpoint.getData()) ? azureEndpoint.getData().getContent() : null);
// config.setAzureDeploymentId(Objects.nonNull(azureDeployId.getData()) ? azureDeployId.getData().getContent() : null);
// return DataResult.of(config);
//}
} }