mirror of
https://github.com/CodePhiliaX/Chat2DB.git
synced 2025-07-31 11:42:41 +08:00
Support websocket long connections
This commit is contained in:
@ -0,0 +1,14 @@
|
||||
package ai.chat2db.server.web.api.ws;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
|
||||
|
||||
@Configuration
|
||||
public class WsConfig {
|
||||
|
||||
@Bean
|
||||
public ServerEndpointExporter serverEndpointExporter() {
|
||||
return new ServerEndpointExporter();
|
||||
}
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package ai.chat2db.server.web.api.ws;
|
||||
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class WsMessage {
|
||||
|
||||
/**
|
||||
* message id
|
||||
*/
|
||||
private String uuid;
|
||||
|
||||
/**
|
||||
* message content
|
||||
*/
|
||||
private JSONObject message;
|
||||
|
||||
/**
|
||||
* message type
|
||||
*/
|
||||
private String actionType;
|
||||
|
||||
|
||||
public static class ActionType {
|
||||
public static final String EXECUTE = "execute";
|
||||
public static final String LOGIN = "login";
|
||||
public static final String PING = "ping";
|
||||
public static final String OPEN_SESSION = "open_session";
|
||||
public static final String ERROR = "error";
|
||||
public static final String MESSAGE = "message";
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package ai.chat2db.server.web.api.ws;
|
||||
|
||||
|
||||
import ai.chat2db.server.tools.base.wrapper.Result;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class WsResult {
|
||||
/**
|
||||
* message id
|
||||
*/
|
||||
private String uuid;
|
||||
|
||||
/**
|
||||
* message content
|
||||
*/
|
||||
private Result message;
|
||||
|
||||
/**
|
||||
* message type
|
||||
*/
|
||||
private String actionType;
|
||||
}
|
@ -0,0 +1,250 @@
|
||||
package ai.chat2db.server.web.api.ws;
|
||||
|
||||
import ai.chat2db.server.domain.repository.Dbutils;
|
||||
import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
|
||||
import ai.chat2db.server.tools.base.wrapper.result.ListResult;
|
||||
import ai.chat2db.server.tools.common.model.Context;
|
||||
import ai.chat2db.server.tools.common.model.LoginUser;
|
||||
import ai.chat2db.server.tools.common.util.ContextUtils;
|
||||
import ai.chat2db.server.web.api.controller.rdb.request.DmlRequest;
|
||||
import ai.chat2db.server.web.api.controller.rdb.vo.ExecuteResultVO;
|
||||
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
|
||||
import ai.chat2db.spi.sql.Chat2DBContext;
|
||||
import ai.chat2db.spi.sql.ConnectInfo;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.jcraft.jsch.JSchException;
|
||||
import jakarta.websocket.*;
|
||||
import jakarta.websocket.server.PathParam;
|
||||
import jakarta.websocket.server.ServerEndpoint;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.context.support.SpringBeanAutowiringSupport;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.sql.Connection;
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
import java.util.Timer;
|
||||
import java.util.TimerTask;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArraySet;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@ServerEndpoint("/api/ws/{token}")
|
||||
public class WsServer {
|
||||
private Session session;
|
||||
|
||||
private static final AtomicInteger OnlineCount = new AtomicInteger(0);
|
||||
|
||||
// concurrent包的线程安全Set,用来存放每个客户端对应的Session对象。
|
||||
private static CopyOnWriteArraySet<Session> SessionSet = new CopyOnWriteArraySet<Session>();
|
||||
|
||||
private static int num = 0;
|
||||
|
||||
private Timer timer = new Timer();
|
||||
|
||||
|
||||
private Map<String, ConnectInfo> connectInfoMap = new ConcurrentHashMap<>();
|
||||
|
||||
|
||||
private LoginUser loginUser;
|
||||
|
||||
private WsService wsService;
|
||||
|
||||
/**
|
||||
* 连接建立成功调用的方法
|
||||
*/
|
||||
@OnOpen
|
||||
public void onOpen(Session session, @PathParam("token") String token) throws IOException {
|
||||
SessionSet.add(session);
|
||||
this.session = session;
|
||||
int cnt = OnlineCount.incrementAndGet(); // 在线数加1
|
||||
log.info("有连接加入,当前连接数为:{}", cnt);
|
||||
|
||||
heartBeat(session);
|
||||
this.wsService = ApplicationContextUtil.getBean(WsService.class);
|
||||
Dbutils.setSession();
|
||||
this.loginUser = wsService.doLogin(token);
|
||||
if (this.loginUser == null) {
|
||||
ActionResult actionResult = new ActionResult();
|
||||
actionResult.setSuccess(false);
|
||||
actionResult.setErrorCode("LOGIN_FAIL");
|
||||
WsResult wsMessage = new WsResult();
|
||||
wsMessage.setActionType(WsMessage.ActionType.OPEN_SESSION);
|
||||
wsMessage.setUuid(token);
|
||||
wsMessage.setMessage(actionResult);
|
||||
SendMessage(this.session, wsMessage);
|
||||
onClose();
|
||||
}else {
|
||||
ActionResult actionResult = new ActionResult();
|
||||
actionResult.setSuccess(true);
|
||||
WsResult wsMessage = new WsResult();
|
||||
wsMessage.setActionType(WsMessage.ActionType.OPEN_SESSION);
|
||||
wsMessage.setUuid(token);
|
||||
wsMessage.setMessage(actionResult);
|
||||
SendMessage(this.session, wsMessage);
|
||||
}
|
||||
Dbutils.removeSession();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 连接关闭调用的方法
|
||||
*/
|
||||
@OnClose
|
||||
public void onClose() throws IOException {
|
||||
if (SessionSet.contains(session)) {
|
||||
SessionSet.remove(this.session);
|
||||
session.close();
|
||||
for (Map.Entry<String, ConnectInfo> entry : connectInfoMap.entrySet()) {
|
||||
ConnectInfo connectInfo = entry.getValue();
|
||||
if (connectInfo != null) {
|
||||
Connection connection = connectInfo.getConnection();
|
||||
try {
|
||||
if (connection != null && !connection.isClosed()) {
|
||||
connection.close();
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
log.error("close connection error", e);
|
||||
}
|
||||
|
||||
com.jcraft.jsch.Session session = connectInfo.getSession();
|
||||
if (session != null && session.isConnected() && connectInfo.getSsh() != null
|
||||
&& connectInfo.getSsh().isUse()) {
|
||||
try {
|
||||
session.delPortForwardingL(Integer.parseInt(connectInfo.getSsh().getLocalPort()));
|
||||
} catch (JSchException e) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int cnt = OnlineCount.decrementAndGet();
|
||||
log.info("有连接关闭,session:{},{}", session, this);
|
||||
log.info("有连接关闭,当前连接数为:{}", cnt);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 收到客户端消息后调用的方法
|
||||
*
|
||||
* @param message 客户端发送过来的消息
|
||||
*/
|
||||
@OnMessage(maxMessageSize = 1024000)
|
||||
public void onMessage(String message, Session session) {
|
||||
CompletableFuture.runAsync(() -> {
|
||||
WsMessage wsMessage = JSONObject.parseObject(message, WsMessage.class);
|
||||
// 在这里处理你的消息
|
||||
try {
|
||||
String actionType = wsMessage.getActionType();
|
||||
if (WsMessage.ActionType.PING.equalsIgnoreCase(actionType)) {
|
||||
WsResult wsResult = new WsResult();
|
||||
ActionResult actionResult = new ActionResult();
|
||||
actionResult.setSuccess(true);
|
||||
wsResult.setActionType(WsMessage.ActionType.PING);
|
||||
wsResult.setUuid(wsMessage.getUuid());
|
||||
wsResult.setMessage(actionResult);
|
||||
SendMessage(session, wsResult);
|
||||
timer.cancel();
|
||||
heartBeat(session);
|
||||
} else {
|
||||
ContextUtils.setContext(Context.builder()
|
||||
.loginUser(loginUser)
|
||||
.build());
|
||||
Dbutils.setSession();
|
||||
JSONObject jsonObject = wsMessage.getMessage();
|
||||
Long dataSourceId = jsonObject.getLong("dataSourceId");
|
||||
String databaseName = jsonObject.getString("databaseName");
|
||||
String schemaName = jsonObject.getString("schemaName");
|
||||
Long consoleId = jsonObject.getLong("consoleId");
|
||||
String key = connectInfoKey(dataSourceId, databaseName, schemaName, consoleId);
|
||||
ConnectInfo connectInfo = connectInfoMap.get(key);
|
||||
if (connectInfo == null) {
|
||||
connectInfo = wsService.toInfo(dataSourceId, databaseName, consoleId, schemaName);
|
||||
connectInfoMap.put(key, connectInfo);
|
||||
}
|
||||
Chat2DBContext.putContext(connectInfo);
|
||||
if (WsMessage.ActionType.EXECUTE.equalsIgnoreCase(actionType)) {
|
||||
DmlRequest request = jsonObject.toJavaObject(DmlRequest.class);
|
||||
ListResult<ExecuteResultVO> result = wsService.execute(request);
|
||||
WsResult resultMessage = new WsResult();
|
||||
resultMessage.setUuid(wsMessage.getUuid());
|
||||
resultMessage.setActionType(wsMessage.getActionType());
|
||||
resultMessage.setMessage(result);
|
||||
SendMessage(session, resultMessage);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
WsResult wsResult = new WsResult();
|
||||
ActionResult actionResult = new ActionResult();
|
||||
actionResult.setSuccess(false);
|
||||
actionResult.setErrorCode(e.getMessage());
|
||||
wsResult.setActionType(WsMessage.ActionType.ERROR);
|
||||
wsResult.setUuid(wsMessage.getUuid());
|
||||
wsResult.setMessage(actionResult);
|
||||
SendMessage(session, wsResult);
|
||||
} finally {
|
||||
Chat2DBContext.remove();
|
||||
ContextUtils.removeContext();
|
||||
Dbutils.removeSession();
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
|
||||
private String connectInfoKey(Long dataSourceId, String databaseName, String schemaName, Long consoleId) {
|
||||
return dataSourceId + "_" + databaseName + "_" + schemaName + "_" + consoleId;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 出现错误
|
||||
*
|
||||
* @param session
|
||||
* @param error
|
||||
*/
|
||||
@OnError
|
||||
public void onError(Session session, Throwable error) {
|
||||
log.error("发生错误:{},Session ID: {}", error.getMessage(), session.getId(), error);
|
||||
error.printStackTrace();
|
||||
}
|
||||
|
||||
/**
|
||||
* 心跳
|
||||
*
|
||||
* @param session
|
||||
*/
|
||||
private void heartBeat(Session session) {
|
||||
timer = new Timer();
|
||||
timer.schedule(new TimerTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
onClose();
|
||||
} catch (IOException e) {
|
||||
log.error("发送消息出错:{}", e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}, 600000);
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送消息,实践表明,每次浏览器刷新,session会发生变化。
|
||||
*
|
||||
* @param session
|
||||
* @param wsResult
|
||||
*/
|
||||
public static void SendMessage(Session session, WsResult wsResult) {
|
||||
try {
|
||||
if (session.isOpen()) {
|
||||
session.getBasicRemote().sendText(JSONObject.toJSONString(wsResult));
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("发送消息出错:{}", e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,162 @@
|
||||
package ai.chat2db.server.web.api.ws;
|
||||
|
||||
import ai.chat2db.server.domain.api.enums.RoleCodeEnum;
|
||||
import ai.chat2db.server.domain.api.model.Config;
|
||||
import ai.chat2db.server.domain.api.model.DataSource;
|
||||
import ai.chat2db.server.domain.api.model.User;
|
||||
import ai.chat2db.server.domain.api.param.DlExecuteParam;
|
||||
import ai.chat2db.server.domain.api.service.*;
|
||||
import ai.chat2db.server.domain.core.cache.CacheKey;
|
||||
import ai.chat2db.server.domain.core.cache.MemoryCacheManage;
|
||||
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
|
||||
import ai.chat2db.server.tools.base.wrapper.result.ListResult;
|
||||
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
|
||||
import ai.chat2db.server.tools.common.model.LoginUser;
|
||||
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
|
||||
import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter;
|
||||
import ai.chat2db.server.web.api.controller.rdb.request.DmlRequest;
|
||||
import ai.chat2db.server.web.api.controller.rdb.vo.ExecuteResultVO;
|
||||
import ai.chat2db.server.web.api.http.GatewayClientService;
|
||||
import ai.chat2db.server.web.api.http.request.SqlExecuteHistoryCreateRequest;
|
||||
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
|
||||
import ai.chat2db.spi.config.DriverConfig;
|
||||
import ai.chat2db.spi.model.ExecuteResult;
|
||||
import ai.chat2db.spi.sql.Chat2DBContext;
|
||||
import ai.chat2db.spi.sql.ConnectInfo;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
@Component
|
||||
public class WsService {
|
||||
|
||||
@Autowired
|
||||
private UserService userService;
|
||||
|
||||
|
||||
@Autowired
|
||||
private DataSourceService dataSourceService;
|
||||
|
||||
@Autowired
|
||||
private DataSourceAccessBusinessService dataSourceAccessBusinessService;
|
||||
|
||||
|
||||
@Autowired
|
||||
private RdbWebConverter rdbWebConverter;
|
||||
|
||||
@Autowired
|
||||
private DlTemplateService dlTemplateService;
|
||||
|
||||
@Autowired
|
||||
private GatewayClientService gatewayClientService;
|
||||
|
||||
|
||||
public static ExecutorService executorService = Executors.newFixedThreadPool(10);
|
||||
|
||||
public ListResult<ExecuteResultVO> execute(DmlRequest request) {
|
||||
DlExecuteParam param = rdbWebConverter.request2param(request);
|
||||
ListResult<ExecuteResult> resultDTOListResult = dlTemplateService.execute(param);
|
||||
List<ExecuteResultVO> resultVOS = rdbWebConverter.dto2vo(resultDTOListResult.getData());
|
||||
String type = Chat2DBContext.getConnectInfo().getDbType();
|
||||
String clientId = getApiKey();
|
||||
String sqlContent = request.getSql();
|
||||
executorService.submit(() -> {
|
||||
try {
|
||||
addOperationLog(clientId, type, sqlContent, resultDTOListResult.getErrorMessage(), resultDTOListResult.getSuccess(), resultVOS);
|
||||
} catch (Exception e) {
|
||||
// do nothing
|
||||
}
|
||||
});
|
||||
return ListResult.of(resultVOS);
|
||||
}
|
||||
|
||||
|
||||
public LoginUser doLogin(String token) {
|
||||
Long userId = RoleCodeEnum.DESKTOP.getDefaultUserId();
|
||||
LoginUser loginUser = MemoryCacheManage.computeIfAbsent(CacheKey.getLoginUserKey(userId), () -> {
|
||||
User user = userService.query(userId).getData();
|
||||
if (user == null) {
|
||||
return null;
|
||||
}
|
||||
boolean admin = RoleCodeEnum.ADMIN.getCode().equals(user.getRoleCode());
|
||||
|
||||
return LoginUser.builder()
|
||||
.id(user.getId())
|
||||
.nickName(user.getNickName())
|
||||
.admin(admin)
|
||||
.roleCode(user.getRoleCode())
|
||||
.build();
|
||||
});
|
||||
|
||||
|
||||
loginUser.setToken(userId.toString());
|
||||
return loginUser;
|
||||
}
|
||||
|
||||
public ConnectInfo toInfo(Long dataSourceId, String database, Long consoleId, String schemaName) {
|
||||
DataResult<DataSource> result = dataSourceService.queryById(dataSourceId);
|
||||
DataSource dataSource = result.getData();
|
||||
if (!result.success() || dataSource == null) {
|
||||
throw new ParamBusinessException("dataSourceId");
|
||||
}
|
||||
// Verify permissions
|
||||
dataSourceAccessBusinessService.checkPermission(dataSource);
|
||||
ConnectInfo connectInfo = new ConnectInfo();
|
||||
connectInfo.setAlias(dataSource.getAlias());
|
||||
connectInfo.setUser(dataSource.getUserName());
|
||||
connectInfo.setConsoleId(consoleId);
|
||||
connectInfo.setDataSourceId(dataSourceId);
|
||||
connectInfo.setPassword(dataSource.getPassword());
|
||||
connectInfo.setDbType(dataSource.getType());
|
||||
connectInfo.setUrl(dataSource.getUrl());
|
||||
connectInfo.setDatabase(database);
|
||||
connectInfo.setSchemaName(schemaName);
|
||||
connectInfo.setConsoleOwn(false);
|
||||
connectInfo.setDriver(dataSource.getDriver());
|
||||
connectInfo.setSsh(dataSource.getSsh());
|
||||
connectInfo.setSsl(dataSource.getSsl());
|
||||
connectInfo.setJdbc(dataSource.getJdbc());
|
||||
connectInfo.setExtendInfo(dataSource.getExtendInfo());
|
||||
connectInfo.setUrl(dataSource.getUrl());
|
||||
connectInfo.setPort(StringUtils.isNotBlank(dataSource.getPort()) ? Integer.parseInt(dataSource.getPort()) : null);
|
||||
connectInfo.setHost(dataSource.getHost());
|
||||
DriverConfig driverConfig = dataSource.getDriverConfig();
|
||||
if (driverConfig != null && driverConfig.notEmpty()) {
|
||||
connectInfo.setDriverConfig(driverConfig);
|
||||
}
|
||||
return connectInfo;
|
||||
}
|
||||
|
||||
|
||||
private String getApiKey() {
|
||||
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
|
||||
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
|
||||
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
|
||||
return null;
|
||||
}
|
||||
return keyConfig.getContent();
|
||||
}
|
||||
|
||||
private void addOperationLog(String clientId, String sqlType, String sqlContent, String errorMessage, Boolean isSuccess, List<ExecuteResultVO> executeResultVOS) {
|
||||
SqlExecuteHistoryCreateRequest createRequest = new SqlExecuteHistoryCreateRequest();
|
||||
createRequest.setClientId(clientId);
|
||||
createRequest.setErrorMessage(errorMessage);
|
||||
createRequest.setDatabaseType(sqlType);
|
||||
createRequest.setSqlContent(sqlContent);
|
||||
createRequest.setExecuteStatus(isSuccess ? "success" : "fail");
|
||||
executeResultVOS.forEach(executeResultVO -> {
|
||||
createRequest.setSqlType(executeResultVO.getSqlType());
|
||||
createRequest.setDuration(executeResultVO.getDuration());
|
||||
createRequest.setTableName(executeResultVO.getTableName());
|
||||
gatewayClientService.addOperationLog(createRequest);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
}
|
Reference in New Issue
Block a user