fix exception

This commit is contained in:
jipengfei-jpf
2023-07-09 21:03:29 +08:00
parent 46e2d5279e
commit e533a31f39
5 changed files with 44 additions and 69 deletions

View File

@ -12,6 +12,7 @@ import ai.chat2db.server.domain.api.param.DataSourceSelector;
import ai.chat2db.server.domain.api.param.DataSourceUpdateParam;
import ai.chat2db.server.domain.api.service.ConsoleService;
import ai.chat2db.server.domain.api.service.DataSourceService;
import ai.chat2db.server.tools.common.exception.ConnectionException;
import ai.chat2db.spi.model.Database;
import ai.chat2db.spi.ssh.SSHManager;
import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
@ -99,7 +100,7 @@ public class DataSourceController {
session = SSHManager.getSSHSession(sshWebConverter.toInfo(request));
} catch (Exception e) {
log.error("sshConnect error", e);
throw new RuntimeException(e);
throw new ConnectionException("connection.ssh.error",null,e);
} finally {
if (session != null) {
session.disconnect();

View File

@ -32,9 +32,9 @@
<artifactId>spring-jdbc</artifactId>
</dependency>
<dependency>
<groupId>com.jcraft</groupId>
<groupId>com.github.mwiede</groupId>
<artifactId>jsch</artifactId>
<version>0.1.53</version>
<version>0.2.9</version>
</dependency>
<dependency>
<groupId>com.oracle.ojdbc</groupId>

View File

@ -103,8 +103,8 @@ public class Chat2DBContext {
if (session != null) {
url = url.replace(host, "127.0.0.1").replace(port, ssh.getLocalPort());
}
}catch (Exception e){
throw new ConnectionException("connection.ssh.error",null,e);
} catch (Exception e) {
throw new ConnectionException("connection.ssh.error", null, e);
}
try {
DriverConfig config = connectInfo.getDriverConfig();
@ -133,7 +133,7 @@ public class Chat2DBContext {
} catch (Exception e) {
}
}
throw new BusinessException("connection.error",null,e1);
throw new BusinessException("connection.error", null, e1);
}
connectInfo.setSession(session);
connectInfo.setConnection(connection);
@ -167,7 +167,17 @@ public class Chat2DBContext {
} catch (SQLException e) {
log.error("close connection error", e);
}
CONNECT_INFO_THREAD_LOCAL.remove();
Session session = connectInfo.getSession();
if (session != null && session.isConnected() && connectInfo.getSsh() != null
&& connectInfo.getSsh().isUse()) {
try {
session.delPortForwardingL(Integer.parseInt(connectInfo.getSsh().getLocalPort()));
} catch (JSchException e) {
}
}
}
}

View File

@ -1,11 +1,10 @@
package ai.chat2db.spi.ssh;
import java.util.concurrent.ConcurrentHashMap;
import ai.chat2db.server.tools.common.exception.ConnectionException;
import ai.chat2db.spi.model.SSHInfo;
import cn.hutool.core.net.NetUtil;
import com.jcraft.jsch.JSch;
import cn.hutool.extra.ssh.JschUtil;
import com.jcraft.jsch.Session;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@ -17,72 +16,35 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j
public class SSHManager {
private static final ConcurrentHashMap<SSHInfo, Session> SSH_SESSION_MAP = new ConcurrentHashMap();
public static Session getSSHSession(SSHInfo ssh) {
Session session;
try {
byte[] passphrase = StringUtils.isNotBlank(ssh.getPassphrase()) ? StringUtils.getBytes(ssh.getPassphrase(),
"UTF-8") : null;
session = JschUtil.getSession(ssh.getHostName(), Integer.parseInt(ssh.getPort()), ssh.getUserName(),
ssh.getKeyFile(), passphrase);
public static Session getSSHSession(SSHInfo sshInfo) {
Session session = SSH_SESSION_MAP.get(sshInfo);
if (session != null && session.isConnected()) {
return session;
} else {
return createSession(sshInfo);
} catch (Exception e) {
throw new ConnectionException("connection.ssh.error", null, e);
}
}
private static Session createSession(SSHInfo ssh) {
synchronized (ssh) {
Session session = SSH_SESSION_MAP.get(ssh);
if (session != null && session.isConnected()) {
return session;
}
if (StringUtils.isNotBlank(ssh.getRHost()) && StringUtils.isNotBlank(ssh.getRPort())) {
try {
JSch jSch = new JSch();
if (StringUtils.isNotBlank(ssh.getKeyFile()) && StringUtils.isNotBlank(ssh.getPassphrase())) {
jSch.addIdentity(ssh.getKeyFile(), ssh.getPassphrase());
}
session = jSch.getSession(ssh.getUserName(), ssh.getHostName(), Integer.parseInt(ssh.getPort()));
if (StringUtils.isBlank(ssh.getKeyFile()) || StringUtils.isBlank(ssh.getPassphrase())) {
session.setPassword(ssh.getPassword());
}
session.setConfig("StrictHostKeyChecking", "no");
session.connect();
SSH_SESSION_MAP.put(ssh, session);
int localPort = !StringUtils.isBlank(ssh.getLocalPort()) ? Integer.parseInt(ssh.getLocalPort())
: NetUtil.getUsableLocalPort();
ssh.setLocalPort(String.valueOf(localPort));
session.setPortForwardingL(localPort, ssh.getRHost(),
Integer.parseInt(ssh.getRPort()));
} catch (Exception e) {
throw new RuntimeException("create ssh session error", e);
}
if (StringUtils.isNotBlank(ssh.getRHost()) && StringUtils.isNotBlank(ssh.getRPort())) {
try {
int localPort = !StringUtils.isBlank(ssh.getLocalPort()) ? Integer.parseInt(ssh.getLocalPort())
: NetUtil.getUsableLocalPort();
session.setPortForwardingL(localPort, ssh.getRHost(),
Integer.parseInt(ssh.getRPort()));
} catch (Exception e) {
if (session != null && session.isConnected()) {
session.disconnect();
SSH_SESSION_MAP.remove(ssh);
}
throw new RuntimeException(ssh.getLocalPort() + " port is usedplease change to another port ", e);
if (session != null && session.isConnected()) {
session.disconnect();
}
throw new ConnectionException("connection.ssh.error", null, e);
}
return session;
}
return session;
}
public static void close() {
SSH_SESSION_MAP.forEach((k, v) -> {
if (v != null && v.isConnected()) {
try {
v.delPortForwardingL(Integer.parseInt(k.getLocalPort()));
} catch (Exception e) {
log.error("delPortForwardingL error", e);
}
try {
v.disconnect();
} catch (Exception e) {
log.error("disconnect error", e);
}
}
});
JschUtil.closeAll();
}
}

View File

@ -21,9 +21,9 @@ import ai.chat2db.spi.model.DataSourceConnect;
import ai.chat2db.spi.model.SSHInfo;
import ai.chat2db.spi.sql.IDriverManager;
import ai.chat2db.spi.ssh.SSHManager;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* jdbc工具类
@ -235,10 +235,12 @@ public class JdbcUtils {
}
if (session != null) {
try {
session.delPortForwardingL(Integer.parseInt(ssh.getLocalPort()));
if(StringUtils.isNotBlank(ssh.getLocalPort())) {
session.delPortForwardingL(Integer.parseInt(ssh.getLocalPort()));
}
session.disconnect();
} catch (Exception e) {
} catch (JSchException e) {
}
}
}