diff --git a/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/Application.java b/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/Application.java index 118baaa2..9a9a9aa0 100644 --- a/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/Application.java +++ b/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/Application.java @@ -1,14 +1,7 @@ package ai.chat2db.server.start; -import ai.chat2db.server.tools.common.enums.ModeEnum; -import ai.chat2db.server.tools.common.model.ConfigJson; -import ai.chat2db.server.tools.common.util.ConfigUtils; -import ai.chat2db.server.tools.common.util.EasyEnumUtils; -import cn.hutool.core.lang.UUID; +import ai.chat2db.server.domain.repository.Dbutils; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.ArrayUtils; -import org.apache.commons.lang3.StringUtils; -import org.mybatis.spring.annotation.MapperScan; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCaching; @@ -17,6 +10,8 @@ import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.stereotype.Indexed; +import java.util.concurrent.CompletableFuture; + /** * 启动类 * @@ -33,6 +28,9 @@ public class Application { public static void main(String[] args) { //ConfigUtils.pid(); + CompletableFuture.runAsync(() -> { + Dbutils.init(); + }); SpringApplication.run(Application.class, args); } } diff --git a/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/config/config/Chat2dbWebMvcConfigurer.java b/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/config/config/Chat2dbWebMvcConfigurer.java index 2d486f91..2c178c95 100644 --- a/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/config/config/Chat2dbWebMvcConfigurer.java +++ b/chat2db-server/chat2db-server-start/src/main/java/ai/chat2db/server/start/config/config/Chat2dbWebMvcConfigurer.java @@ -89,6 +89,7 @@ public class Chat2dbWebMvcConfigurer implements WebMvcConfigurer { // 代表用户可能被删除了 return true; } + loginUser.setToken(userId.toString()); ContextUtils.setContext(Context.builder() .loginUser(loginUser) diff --git a/chat2db-server/chat2db-server-tools/chat2db-server-tools-common/src/main/java/ai/chat2db/server/tools/common/model/LoginUser.java b/chat2db-server/chat2db-server-tools/chat2db-server-tools-common/src/main/java/ai/chat2db/server/tools/common/model/LoginUser.java index a7c71826..be064187 100644 --- a/chat2db-server/chat2db-server-tools/chat2db-server-tools-common/src/main/java/ai/chat2db/server/tools/common/model/LoginUser.java +++ b/chat2db-server/chat2db-server-tools/chat2db-server-tools-common/src/main/java/ai/chat2db/server/tools/common/model/LoginUser.java @@ -44,4 +44,7 @@ public class LoginUser implements Serializable { * @see RoleCodeEnum */ private String roleCode; + + + private String token; } diff --git a/chat2db-server/chat2db-server-web-start/src/main/java/ai/chat2db/server/web/start/config/config/Chat2dbWebMvcConfigurer.java b/chat2db-server/chat2db-server-web-start/src/main/java/ai/chat2db/server/web/start/config/config/Chat2dbWebMvcConfigurer.java index fef6f9b7..519fa42f 100644 --- a/chat2db-server/chat2db-server-web-start/src/main/java/ai/chat2db/server/web/start/config/config/Chat2dbWebMvcConfigurer.java +++ b/chat2db-server/chat2db-server-web-start/src/main/java/ai/chat2db/server/web/start/config/config/Chat2dbWebMvcConfigurer.java @@ -111,6 +111,7 @@ public class Chat2dbWebMvcConfigurer implements WebMvcConfigurer { return true; } + loginUser.setToken(StpUtil.getTokenValue()); ContextUtils.setContext(Context.builder() .loginUser(loginUser) .build()); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml b/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml index 02e3e7a2..cf32aae0 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/pom.xml @@ -100,6 +100,10 @@ com.github.vertical-blank sql-formatter + + org.springframework.boot + spring-boot-starter-websocket + diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsConfig.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsConfig.java new file mode 100644 index 00000000..1545dc43 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsConfig.java @@ -0,0 +1,14 @@ +package ai.chat2db.server.web.api.ws; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.server.standard.ServerEndpointExporter; + +@Configuration +public class WsConfig { + + @Bean + public ServerEndpointExporter serverEndpointExporter() { + return new ServerEndpointExporter(); + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsMessage.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsMessage.java new file mode 100644 index 00000000..19564a5d --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsMessage.java @@ -0,0 +1,34 @@ +package ai.chat2db.server.web.api.ws; + +import com.alibaba.fastjson2.JSONObject; +import lombok.Data; + +@Data +public class WsMessage { + + /** + * message id + */ + private String uuid; + + /** + * message content + */ + private JSONObject message; + + /** + * message type + */ + private String actionType; + + + public static class ActionType { + public static final String EXECUTE = "execute"; + public static final String LOGIN = "login"; + public static final String PING = "ping"; + public static final String OPEN_SESSION = "open_session"; + public static final String ERROR = "error"; + public static final String MESSAGE = "message"; + } + +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsResult.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsResult.java new file mode 100644 index 00000000..9bd6e6aa --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsResult.java @@ -0,0 +1,23 @@ +package ai.chat2db.server.web.api.ws; + + +import ai.chat2db.server.tools.base.wrapper.Result; +import lombok.Data; + +@Data +public class WsResult { + /** + * message id + */ + private String uuid; + + /** + * message content + */ + private Result message; + + /** + * message type + */ + private String actionType; +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsServer.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsServer.java new file mode 100644 index 00000000..1d192526 --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsServer.java @@ -0,0 +1,250 @@ +package ai.chat2db.server.web.api.ws; + +import ai.chat2db.server.domain.repository.Dbutils; +import ai.chat2db.server.tools.base.wrapper.result.ActionResult; +import ai.chat2db.server.tools.base.wrapper.result.ListResult; +import ai.chat2db.server.tools.common.model.Context; +import ai.chat2db.server.tools.common.model.LoginUser; +import ai.chat2db.server.tools.common.util.ContextUtils; +import ai.chat2db.server.web.api.controller.rdb.request.DmlRequest; +import ai.chat2db.server.web.api.controller.rdb.vo.ExecuteResultVO; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import ai.chat2db.spi.sql.Chat2DBContext; +import ai.chat2db.spi.sql.ConnectInfo; +import com.alibaba.fastjson2.JSONObject; +import com.jcraft.jsch.JSchException; +import jakarta.websocket.*; +import jakarta.websocket.server.PathParam; +import jakarta.websocket.server.ServerEndpoint; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import org.springframework.web.context.support.SpringBeanAutowiringSupport; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Map; +import java.util.Timer; +import java.util.TimerTask; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArraySet; +import java.util.concurrent.atomic.AtomicInteger; + +@Slf4j +@Component +@ServerEndpoint("/api/ws/{token}") +public class WsServer { + private Session session; + + private static final AtomicInteger OnlineCount = new AtomicInteger(0); + + // concurrent包的线程安全Set,用来存放每个客户端对应的Session对象。 + private static CopyOnWriteArraySet SessionSet = new CopyOnWriteArraySet(); + + private static int num = 0; + + private Timer timer = new Timer(); + + + private Map connectInfoMap = new ConcurrentHashMap<>(); + + + private LoginUser loginUser; + + private WsService wsService; + + /** + * 连接建立成功调用的方法 + */ + @OnOpen + public void onOpen(Session session, @PathParam("token") String token) throws IOException { + SessionSet.add(session); + this.session = session; + int cnt = OnlineCount.incrementAndGet(); // 在线数加1 + log.info("有连接加入,当前连接数为:{}", cnt); + + heartBeat(session); + this.wsService = ApplicationContextUtil.getBean(WsService.class); + Dbutils.setSession(); + this.loginUser = wsService.doLogin(token); + if (this.loginUser == null) { + ActionResult actionResult = new ActionResult(); + actionResult.setSuccess(false); + actionResult.setErrorCode("LOGIN_FAIL"); + WsResult wsMessage = new WsResult(); + wsMessage.setActionType(WsMessage.ActionType.OPEN_SESSION); + wsMessage.setUuid(token); + wsMessage.setMessage(actionResult); + SendMessage(this.session, wsMessage); + onClose(); + }else { + ActionResult actionResult = new ActionResult(); + actionResult.setSuccess(true); + WsResult wsMessage = new WsResult(); + wsMessage.setActionType(WsMessage.ActionType.OPEN_SESSION); + wsMessage.setUuid(token); + wsMessage.setMessage(actionResult); + SendMessage(this.session, wsMessage); + } + Dbutils.removeSession(); + } + + + /** + * 连接关闭调用的方法 + */ + @OnClose + public void onClose() throws IOException { + if (SessionSet.contains(session)) { + SessionSet.remove(this.session); + session.close(); + for (Map.Entry entry : connectInfoMap.entrySet()) { + ConnectInfo connectInfo = entry.getValue(); + if (connectInfo != null) { + Connection connection = connectInfo.getConnection(); + try { + if (connection != null && !connection.isClosed()) { + connection.close(); + } + } catch (SQLException e) { + log.error("close connection error", e); + } + + com.jcraft.jsch.Session session = connectInfo.getSession(); + if (session != null && session.isConnected() && connectInfo.getSsh() != null + && connectInfo.getSsh().isUse()) { + try { + session.delPortForwardingL(Integer.parseInt(connectInfo.getSsh().getLocalPort())); + } catch (JSchException e) { + } + } + } + } + int cnt = OnlineCount.decrementAndGet(); + log.info("有连接关闭,session:{},{}", session, this); + log.info("有连接关闭,当前连接数为:{}", cnt); + } + } + + /** + * 收到客户端消息后调用的方法 + * + * @param message 客户端发送过来的消息 + */ + @OnMessage(maxMessageSize = 1024000) + public void onMessage(String message, Session session) { + CompletableFuture.runAsync(() -> { + WsMessage wsMessage = JSONObject.parseObject(message, WsMessage.class); + // 在这里处理你的消息 + try { + String actionType = wsMessage.getActionType(); + if (WsMessage.ActionType.PING.equalsIgnoreCase(actionType)) { + WsResult wsResult = new WsResult(); + ActionResult actionResult = new ActionResult(); + actionResult.setSuccess(true); + wsResult.setActionType(WsMessage.ActionType.PING); + wsResult.setUuid(wsMessage.getUuid()); + wsResult.setMessage(actionResult); + SendMessage(session, wsResult); + timer.cancel(); + heartBeat(session); + } else { + ContextUtils.setContext(Context.builder() + .loginUser(loginUser) + .build()); + Dbutils.setSession(); + JSONObject jsonObject = wsMessage.getMessage(); + Long dataSourceId = jsonObject.getLong("dataSourceId"); + String databaseName = jsonObject.getString("databaseName"); + String schemaName = jsonObject.getString("schemaName"); + Long consoleId = jsonObject.getLong("consoleId"); + String key = connectInfoKey(dataSourceId, databaseName, schemaName, consoleId); + ConnectInfo connectInfo = connectInfoMap.get(key); + if (connectInfo == null) { + connectInfo = wsService.toInfo(dataSourceId, databaseName, consoleId, schemaName); + connectInfoMap.put(key, connectInfo); + } + Chat2DBContext.putContext(connectInfo); + if (WsMessage.ActionType.EXECUTE.equalsIgnoreCase(actionType)) { + DmlRequest request = jsonObject.toJavaObject(DmlRequest.class); + ListResult result = wsService.execute(request); + WsResult resultMessage = new WsResult(); + resultMessage.setUuid(wsMessage.getUuid()); + resultMessage.setActionType(wsMessage.getActionType()); + resultMessage.setMessage(result); + SendMessage(session, resultMessage); + } + } + } catch (Exception e) { + WsResult wsResult = new WsResult(); + ActionResult actionResult = new ActionResult(); + actionResult.setSuccess(false); + actionResult.setErrorCode(e.getMessage()); + wsResult.setActionType(WsMessage.ActionType.ERROR); + wsResult.setUuid(wsMessage.getUuid()); + wsResult.setMessage(actionResult); + SendMessage(session, wsResult); + } finally { + Chat2DBContext.remove(); + ContextUtils.removeContext(); + Dbutils.removeSession(); + } + }); + + } + + + private String connectInfoKey(Long dataSourceId, String databaseName, String schemaName, Long consoleId) { + return dataSourceId + "_" + databaseName + "_" + schemaName + "_" + consoleId; + } + + + /** + * 出现错误 + * + * @param session + * @param error + */ + @OnError + public void onError(Session session, Throwable error) { + log.error("发生错误:{},Session ID: {}", error.getMessage(), session.getId(), error); + error.printStackTrace(); + } + + /** + * 心跳 + * + * @param session + */ + private void heartBeat(Session session) { + timer = new Timer(); + timer.schedule(new TimerTask() { + @Override + public void run() { + try { + onClose(); + } catch (IOException e) { + log.error("发送消息出错:{}", e.getMessage(), e); + } + } + }, 600000); + } + + /** + * 发送消息,实践表明,每次浏览器刷新,session会发生变化。 + * + * @param session + * @param wsResult + */ + public static void SendMessage(Session session, WsResult wsResult) { + try { + if (session.isOpen()) { + session.getBasicRemote().sendText(JSONObject.toJSONString(wsResult)); + } + } catch (IOException e) { + log.error("发送消息出错:{}", e.getMessage()); + e.printStackTrace(); + } + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsService.java new file mode 100644 index 00000000..e2f575ab --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/ws/WsService.java @@ -0,0 +1,162 @@ +package ai.chat2db.server.web.api.ws; + +import ai.chat2db.server.domain.api.enums.RoleCodeEnum; +import ai.chat2db.server.domain.api.model.Config; +import ai.chat2db.server.domain.api.model.DataSource; +import ai.chat2db.server.domain.api.model.User; +import ai.chat2db.server.domain.api.param.DlExecuteParam; +import ai.chat2db.server.domain.api.service.*; +import ai.chat2db.server.domain.core.cache.CacheKey; +import ai.chat2db.server.domain.core.cache.MemoryCacheManage; +import ai.chat2db.server.tools.base.wrapper.result.DataResult; +import ai.chat2db.server.tools.base.wrapper.result.ListResult; +import ai.chat2db.server.tools.common.exception.ParamBusinessException; +import ai.chat2db.server.tools.common.model.LoginUser; +import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; +import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter; +import ai.chat2db.server.web.api.controller.rdb.request.DmlRequest; +import ai.chat2db.server.web.api.controller.rdb.vo.ExecuteResultVO; +import ai.chat2db.server.web.api.http.GatewayClientService; +import ai.chat2db.server.web.api.http.request.SqlExecuteHistoryCreateRequest; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import ai.chat2db.spi.config.DriverConfig; +import ai.chat2db.spi.model.ExecuteResult; +import ai.chat2db.spi.sql.Chat2DBContext; +import ai.chat2db.spi.sql.ConnectInfo; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +import org.springframework.web.bind.annotation.RequestBody; + +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +@Component +public class WsService { + + @Autowired + private UserService userService; + + + @Autowired + private DataSourceService dataSourceService; + + @Autowired + private DataSourceAccessBusinessService dataSourceAccessBusinessService; + + + @Autowired + private RdbWebConverter rdbWebConverter; + + @Autowired + private DlTemplateService dlTemplateService; + + @Autowired + private GatewayClientService gatewayClientService; + + + public static ExecutorService executorService = Executors.newFixedThreadPool(10); + + public ListResult execute(DmlRequest request) { + DlExecuteParam param = rdbWebConverter.request2param(request); + ListResult resultDTOListResult = dlTemplateService.execute(param); + List resultVOS = rdbWebConverter.dto2vo(resultDTOListResult.getData()); + String type = Chat2DBContext.getConnectInfo().getDbType(); + String clientId = getApiKey(); + String sqlContent = request.getSql(); + executorService.submit(() -> { + try { + addOperationLog(clientId, type, sqlContent, resultDTOListResult.getErrorMessage(), resultDTOListResult.getSuccess(), resultVOS); + } catch (Exception e) { + // do nothing + } + }); + return ListResult.of(resultVOS); + } + + + public LoginUser doLogin(String token) { + Long userId = RoleCodeEnum.DESKTOP.getDefaultUserId(); + LoginUser loginUser = MemoryCacheManage.computeIfAbsent(CacheKey.getLoginUserKey(userId), () -> { + User user = userService.query(userId).getData(); + if (user == null) { + return null; + } + boolean admin = RoleCodeEnum.ADMIN.getCode().equals(user.getRoleCode()); + + return LoginUser.builder() + .id(user.getId()) + .nickName(user.getNickName()) + .admin(admin) + .roleCode(user.getRoleCode()) + .build(); + }); + + + loginUser.setToken(userId.toString()); + return loginUser; + } + + public ConnectInfo toInfo(Long dataSourceId, String database, Long consoleId, String schemaName) { + DataResult result = dataSourceService.queryById(dataSourceId); + DataSource dataSource = result.getData(); + if (!result.success() || dataSource == null) { + throw new ParamBusinessException("dataSourceId"); + } + // Verify permissions + dataSourceAccessBusinessService.checkPermission(dataSource); + ConnectInfo connectInfo = new ConnectInfo(); + connectInfo.setAlias(dataSource.getAlias()); + connectInfo.setUser(dataSource.getUserName()); + connectInfo.setConsoleId(consoleId); + connectInfo.setDataSourceId(dataSourceId); + connectInfo.setPassword(dataSource.getPassword()); + connectInfo.setDbType(dataSource.getType()); + connectInfo.setUrl(dataSource.getUrl()); + connectInfo.setDatabase(database); + connectInfo.setSchemaName(schemaName); + connectInfo.setConsoleOwn(false); + connectInfo.setDriver(dataSource.getDriver()); + connectInfo.setSsh(dataSource.getSsh()); + connectInfo.setSsl(dataSource.getSsl()); + connectInfo.setJdbc(dataSource.getJdbc()); + connectInfo.setExtendInfo(dataSource.getExtendInfo()); + connectInfo.setUrl(dataSource.getUrl()); + connectInfo.setPort(StringUtils.isNotBlank(dataSource.getPort()) ? Integer.parseInt(dataSource.getPort()) : null); + connectInfo.setHost(dataSource.getHost()); + DriverConfig driverConfig = dataSource.getDriverConfig(); + if (driverConfig != null && driverConfig.notEmpty()) { + connectInfo.setDriverConfig(driverConfig); + } + return connectInfo; + } + + + private String getApiKey() { + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); + if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { + return null; + } + return keyConfig.getContent(); + } + + private void addOperationLog(String clientId, String sqlType, String sqlContent, String errorMessage, Boolean isSuccess, List executeResultVOS) { + SqlExecuteHistoryCreateRequest createRequest = new SqlExecuteHistoryCreateRequest(); + createRequest.setClientId(clientId); + createRequest.setErrorMessage(errorMessage); + createRequest.setDatabaseType(sqlType); + createRequest.setSqlContent(sqlContent); + createRequest.setExecuteStatus(isSuccess ? "success" : "fail"); + executeResultVOS.forEach(executeResultVO -> { + createRequest.setSqlType(executeResultVO.getSqlType()); + createRequest.setDuration(executeResultVO.getDuration()); + createRequest.setTableName(executeResultVO.getTableName()); + gatewayClientService.addOperationLog(createRequest); + }); + } + + +} diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java index 56a28024..9e6fce81 100644 --- a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java +++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java @@ -159,4 +159,13 @@ public class Chat2DBContext { } } + /** + * 设置context + */ + public static void remove() { + ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get(); + if (connectInfo != null) { + CONNECT_INFO_THREAD_LOCAL.remove(); + } + } }