Support websocket long connections

This commit is contained in:
SwallowGG
2024-01-02 15:19:04 +08:00
parent 11ed8a4203
commit 9156ac3d5f
11 changed files with 507 additions and 8 deletions

View File

@ -0,0 +1,14 @@
package ai.chat2db.server.web.api.ws;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
@Configuration
public class WsConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}

View File

@ -0,0 +1,34 @@
package ai.chat2db.server.web.api.ws;
import com.alibaba.fastjson2.JSONObject;
import lombok.Data;
@Data
public class WsMessage {
/**
* message id
*/
private String uuid;
/**
* message content
*/
private JSONObject message;
/**
* message type
*/
private String actionType;
public static class ActionType {
public static final String EXECUTE = "execute";
public static final String LOGIN = "login";
public static final String PING = "ping";
public static final String OPEN_SESSION = "open_session";
public static final String ERROR = "error";
public static final String MESSAGE = "message";
}
}

View File

@ -0,0 +1,23 @@
package ai.chat2db.server.web.api.ws;
import ai.chat2db.server.tools.base.wrapper.Result;
import lombok.Data;
@Data
public class WsResult {
/**
* message id
*/
private String uuid;
/**
* message content
*/
private Result message;
/**
* message type
*/
private String actionType;
}

View File

@ -0,0 +1,250 @@
package ai.chat2db.server.web.api.ws;
import ai.chat2db.server.domain.repository.Dbutils;
import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.ListResult;
import ai.chat2db.server.tools.common.model.Context;
import ai.chat2db.server.tools.common.model.LoginUser;
import ai.chat2db.server.tools.common.util.ContextUtils;
import ai.chat2db.server.web.api.controller.rdb.request.DmlRequest;
import ai.chat2db.server.web.api.controller.rdb.vo.ExecuteResultVO;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import ai.chat2db.spi.sql.Chat2DBContext;
import ai.chat2db.spi.sql.ConnectInfo;
import com.alibaba.fastjson2.JSONObject;
import com.jcraft.jsch.JSchException;
import jakarta.websocket.*;
import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.context.support.SpringBeanAutowiringSupport;
import java.io.IOException;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@Component
@ServerEndpoint("/api/ws/{token}")
public class WsServer {
private Session session;
private static final AtomicInteger OnlineCount = new AtomicInteger(0);
// concurrent包的线程安全Set用来存放每个客户端对应的Session对象。
private static CopyOnWriteArraySet<Session> SessionSet = new CopyOnWriteArraySet<Session>();
private static int num = 0;
private Timer timer = new Timer();
private Map<String, ConnectInfo> connectInfoMap = new ConcurrentHashMap<>();
private LoginUser loginUser;
private WsService wsService;
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam("token") String token) throws IOException {
SessionSet.add(session);
this.session = session;
int cnt = OnlineCount.incrementAndGet(); // 在线数加1
log.info("有连接加入,当前连接数为:{}", cnt);
heartBeat(session);
this.wsService = ApplicationContextUtil.getBean(WsService.class);
Dbutils.setSession();
this.loginUser = wsService.doLogin(token);
if (this.loginUser == null) {
ActionResult actionResult = new ActionResult();
actionResult.setSuccess(false);
actionResult.setErrorCode("LOGIN_FAIL");
WsResult wsMessage = new WsResult();
wsMessage.setActionType(WsMessage.ActionType.OPEN_SESSION);
wsMessage.setUuid(token);
wsMessage.setMessage(actionResult);
SendMessage(this.session, wsMessage);
onClose();
}else {
ActionResult actionResult = new ActionResult();
actionResult.setSuccess(true);
WsResult wsMessage = new WsResult();
wsMessage.setActionType(WsMessage.ActionType.OPEN_SESSION);
wsMessage.setUuid(token);
wsMessage.setMessage(actionResult);
SendMessage(this.session, wsMessage);
}
Dbutils.removeSession();
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() throws IOException {
if (SessionSet.contains(session)) {
SessionSet.remove(this.session);
session.close();
for (Map.Entry<String, ConnectInfo> entry : connectInfoMap.entrySet()) {
ConnectInfo connectInfo = entry.getValue();
if (connectInfo != null) {
Connection connection = connectInfo.getConnection();
try {
if (connection != null && !connection.isClosed()) {
connection.close();
}
} catch (SQLException e) {
log.error("close connection error", e);
}
com.jcraft.jsch.Session session = connectInfo.getSession();
if (session != null && session.isConnected() && connectInfo.getSsh() != null
&& connectInfo.getSsh().isUse()) {
try {
session.delPortForwardingL(Integer.parseInt(connectInfo.getSsh().getLocalPort()));
} catch (JSchException e) {
}
}
}
}
int cnt = OnlineCount.decrementAndGet();
log.info("有连接关闭session:{},{}", session, this);
log.info("有连接关闭,当前连接数为:{}", cnt);
}
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage(maxMessageSize = 1024000)
public void onMessage(String message, Session session) {
CompletableFuture.runAsync(() -> {
WsMessage wsMessage = JSONObject.parseObject(message, WsMessage.class);
// 在这里处理你的消息
try {
String actionType = wsMessage.getActionType();
if (WsMessage.ActionType.PING.equalsIgnoreCase(actionType)) {
WsResult wsResult = new WsResult();
ActionResult actionResult = new ActionResult();
actionResult.setSuccess(true);
wsResult.setActionType(WsMessage.ActionType.PING);
wsResult.setUuid(wsMessage.getUuid());
wsResult.setMessage(actionResult);
SendMessage(session, wsResult);
timer.cancel();
heartBeat(session);
} else {
ContextUtils.setContext(Context.builder()
.loginUser(loginUser)
.build());
Dbutils.setSession();
JSONObject jsonObject = wsMessage.getMessage();
Long dataSourceId = jsonObject.getLong("dataSourceId");
String databaseName = jsonObject.getString("databaseName");
String schemaName = jsonObject.getString("schemaName");
Long consoleId = jsonObject.getLong("consoleId");
String key = connectInfoKey(dataSourceId, databaseName, schemaName, consoleId);
ConnectInfo connectInfo = connectInfoMap.get(key);
if (connectInfo == null) {
connectInfo = wsService.toInfo(dataSourceId, databaseName, consoleId, schemaName);
connectInfoMap.put(key, connectInfo);
}
Chat2DBContext.putContext(connectInfo);
if (WsMessage.ActionType.EXECUTE.equalsIgnoreCase(actionType)) {
DmlRequest request = jsonObject.toJavaObject(DmlRequest.class);
ListResult<ExecuteResultVO> result = wsService.execute(request);
WsResult resultMessage = new WsResult();
resultMessage.setUuid(wsMessage.getUuid());
resultMessage.setActionType(wsMessage.getActionType());
resultMessage.setMessage(result);
SendMessage(session, resultMessage);
}
}
} catch (Exception e) {
WsResult wsResult = new WsResult();
ActionResult actionResult = new ActionResult();
actionResult.setSuccess(false);
actionResult.setErrorCode(e.getMessage());
wsResult.setActionType(WsMessage.ActionType.ERROR);
wsResult.setUuid(wsMessage.getUuid());
wsResult.setMessage(actionResult);
SendMessage(session, wsResult);
} finally {
Chat2DBContext.remove();
ContextUtils.removeContext();
Dbutils.removeSession();
}
});
}
private String connectInfoKey(Long dataSourceId, String databaseName, String schemaName, Long consoleId) {
return dataSourceId + "_" + databaseName + "_" + schemaName + "_" + consoleId;
}
/**
* 出现错误
*
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
log.error("发生错误:{}Session ID {}", error.getMessage(), session.getId(), error);
error.printStackTrace();
}
/**
* 心跳
*
* @param session
*/
private void heartBeat(Session session) {
timer = new Timer();
timer.schedule(new TimerTask() {
@Override
public void run() {
try {
onClose();
} catch (IOException e) {
log.error("发送消息出错:{}", e.getMessage(), e);
}
}
}, 600000);
}
/**
* 发送消息实践表明每次浏览器刷新session会发生变化。
*
* @param session
* @param wsResult
*/
public static void SendMessage(Session session, WsResult wsResult) {
try {
if (session.isOpen()) {
session.getBasicRemote().sendText(JSONObject.toJSONString(wsResult));
}
} catch (IOException e) {
log.error("发送消息出错:{}", e.getMessage());
e.printStackTrace();
}
}
}

View File

@ -0,0 +1,162 @@
package ai.chat2db.server.web.api.ws;
import ai.chat2db.server.domain.api.enums.RoleCodeEnum;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.model.DataSource;
import ai.chat2db.server.domain.api.model.User;
import ai.chat2db.server.domain.api.param.DlExecuteParam;
import ai.chat2db.server.domain.api.service.*;
import ai.chat2db.server.domain.core.cache.CacheKey;
import ai.chat2db.server.domain.core.cache.MemoryCacheManage;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.base.wrapper.result.ListResult;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.tools.common.model.LoginUser;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter;
import ai.chat2db.server.web.api.controller.rdb.request.DmlRequest;
import ai.chat2db.server.web.api.controller.rdb.vo.ExecuteResultVO;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.request.SqlExecuteHistoryCreateRequest;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import ai.chat2db.spi.config.DriverConfig;
import ai.chat2db.spi.model.ExecuteResult;
import ai.chat2db.spi.sql.Chat2DBContext;
import ai.chat2db.spi.sql.ConnectInfo;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.RequestBody;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@Component
public class WsService {
@Autowired
private UserService userService;
@Autowired
private DataSourceService dataSourceService;
@Autowired
private DataSourceAccessBusinessService dataSourceAccessBusinessService;
@Autowired
private RdbWebConverter rdbWebConverter;
@Autowired
private DlTemplateService dlTemplateService;
@Autowired
private GatewayClientService gatewayClientService;
public static ExecutorService executorService = Executors.newFixedThreadPool(10);
public ListResult<ExecuteResultVO> execute(DmlRequest request) {
DlExecuteParam param = rdbWebConverter.request2param(request);
ListResult<ExecuteResult> resultDTOListResult = dlTemplateService.execute(param);
List<ExecuteResultVO> resultVOS = rdbWebConverter.dto2vo(resultDTOListResult.getData());
String type = Chat2DBContext.getConnectInfo().getDbType();
String clientId = getApiKey();
String sqlContent = request.getSql();
executorService.submit(() -> {
try {
addOperationLog(clientId, type, sqlContent, resultDTOListResult.getErrorMessage(), resultDTOListResult.getSuccess(), resultVOS);
} catch (Exception e) {
// do nothing
}
});
return ListResult.of(resultVOS);
}
public LoginUser doLogin(String token) {
Long userId = RoleCodeEnum.DESKTOP.getDefaultUserId();
LoginUser loginUser = MemoryCacheManage.computeIfAbsent(CacheKey.getLoginUserKey(userId), () -> {
User user = userService.query(userId).getData();
if (user == null) {
return null;
}
boolean admin = RoleCodeEnum.ADMIN.getCode().equals(user.getRoleCode());
return LoginUser.builder()
.id(user.getId())
.nickName(user.getNickName())
.admin(admin)
.roleCode(user.getRoleCode())
.build();
});
loginUser.setToken(userId.toString());
return loginUser;
}
public ConnectInfo toInfo(Long dataSourceId, String database, Long consoleId, String schemaName) {
DataResult<DataSource> result = dataSourceService.queryById(dataSourceId);
DataSource dataSource = result.getData();
if (!result.success() || dataSource == null) {
throw new ParamBusinessException("dataSourceId");
}
// Verify permissions
dataSourceAccessBusinessService.checkPermission(dataSource);
ConnectInfo connectInfo = new ConnectInfo();
connectInfo.setAlias(dataSource.getAlias());
connectInfo.setUser(dataSource.getUserName());
connectInfo.setConsoleId(consoleId);
connectInfo.setDataSourceId(dataSourceId);
connectInfo.setPassword(dataSource.getPassword());
connectInfo.setDbType(dataSource.getType());
connectInfo.setUrl(dataSource.getUrl());
connectInfo.setDatabase(database);
connectInfo.setSchemaName(schemaName);
connectInfo.setConsoleOwn(false);
connectInfo.setDriver(dataSource.getDriver());
connectInfo.setSsh(dataSource.getSsh());
connectInfo.setSsl(dataSource.getSsl());
connectInfo.setJdbc(dataSource.getJdbc());
connectInfo.setExtendInfo(dataSource.getExtendInfo());
connectInfo.setUrl(dataSource.getUrl());
connectInfo.setPort(StringUtils.isNotBlank(dataSource.getPort()) ? Integer.parseInt(dataSource.getPort()) : null);
connectInfo.setHost(dataSource.getHost());
DriverConfig driverConfig = dataSource.getDriverConfig();
if (driverConfig != null && driverConfig.notEmpty()) {
connectInfo.setDriverConfig(driverConfig);
}
return connectInfo;
}
private String getApiKey() {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
return null;
}
return keyConfig.getContent();
}
private void addOperationLog(String clientId, String sqlType, String sqlContent, String errorMessage, Boolean isSuccess, List<ExecuteResultVO> executeResultVOS) {
SqlExecuteHistoryCreateRequest createRequest = new SqlExecuteHistoryCreateRequest();
createRequest.setClientId(clientId);
createRequest.setErrorMessage(errorMessage);
createRequest.setDatabaseType(sqlType);
createRequest.setSqlContent(sqlContent);
createRequest.setExecuteStatus(isSuccess ? "success" : "fail");
executeResultVOS.forEach(executeResultVO -> {
createRequest.setSqlType(executeResultVO.getSqlType());
createRequest.setDuration(executeResultVO.getDuration());
createRequest.setTableName(executeResultVO.getTableName());
gatewayClientService.addOperationLog(createRequest);
});
}
}