fix exception

This commit is contained in:
jipengfei-jpf
2023-07-09 21:03:29 +08:00
parent 46e2d5279e
commit e533a31f39
5 changed files with 44 additions and 69 deletions

View File

@ -12,6 +12,7 @@ import ai.chat2db.server.domain.api.param.DataSourceSelector;
import ai.chat2db.server.domain.api.param.DataSourceUpdateParam; import ai.chat2db.server.domain.api.param.DataSourceUpdateParam;
import ai.chat2db.server.domain.api.service.ConsoleService; import ai.chat2db.server.domain.api.service.ConsoleService;
import ai.chat2db.server.domain.api.service.DataSourceService; import ai.chat2db.server.domain.api.service.DataSourceService;
import ai.chat2db.server.tools.common.exception.ConnectionException;
import ai.chat2db.spi.model.Database; import ai.chat2db.spi.model.Database;
import ai.chat2db.spi.ssh.SSHManager; import ai.chat2db.spi.ssh.SSHManager;
import ai.chat2db.server.tools.base.wrapper.result.ActionResult; import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
@ -99,7 +100,7 @@ public class DataSourceController {
session = SSHManager.getSSHSession(sshWebConverter.toInfo(request)); session = SSHManager.getSSHSession(sshWebConverter.toInfo(request));
} catch (Exception e) { } catch (Exception e) {
log.error("sshConnect error", e); log.error("sshConnect error", e);
throw new RuntimeException(e); throw new ConnectionException("connection.ssh.error",null,e);
} finally { } finally {
if (session != null) { if (session != null) {
session.disconnect(); session.disconnect();

View File

@ -32,9 +32,9 @@
<artifactId>spring-jdbc</artifactId> <artifactId>spring-jdbc</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.jcraft</groupId> <groupId>com.github.mwiede</groupId>
<artifactId>jsch</artifactId> <artifactId>jsch</artifactId>
<version>0.1.53</version> <version>0.2.9</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.oracle.ojdbc</groupId> <groupId>com.oracle.ojdbc</groupId>

View File

@ -103,8 +103,8 @@ public class Chat2DBContext {
if (session != null) { if (session != null) {
url = url.replace(host, "127.0.0.1").replace(port, ssh.getLocalPort()); url = url.replace(host, "127.0.0.1").replace(port, ssh.getLocalPort());
} }
}catch (Exception e){ } catch (Exception e) {
throw new ConnectionException("connection.ssh.error",null,e); throw new ConnectionException("connection.ssh.error", null, e);
} }
try { try {
DriverConfig config = connectInfo.getDriverConfig(); DriverConfig config = connectInfo.getDriverConfig();
@ -133,7 +133,7 @@ public class Chat2DBContext {
} catch (Exception e) { } catch (Exception e) {
} }
} }
throw new BusinessException("connection.error",null,e1); throw new BusinessException("connection.error", null, e1);
} }
connectInfo.setSession(session); connectInfo.setSession(session);
connectInfo.setConnection(connection); connectInfo.setConnection(connection);
@ -167,7 +167,17 @@ public class Chat2DBContext {
} catch (SQLException e) { } catch (SQLException e) {
log.error("close connection error", e); log.error("close connection error", e);
} }
CONNECT_INFO_THREAD_LOCAL.remove(); CONNECT_INFO_THREAD_LOCAL.remove();
Session session = connectInfo.getSession();
if (session != null && session.isConnected() && connectInfo.getSsh() != null
&& connectInfo.getSsh().isUse()) {
try {
session.delPortForwardingL(Integer.parseInt(connectInfo.getSsh().getLocalPort()));
} catch (JSchException e) {
}
}
} }
} }

View File

@ -1,11 +1,10 @@
package ai.chat2db.spi.ssh; package ai.chat2db.spi.ssh;
import java.util.concurrent.ConcurrentHashMap; import ai.chat2db.server.tools.common.exception.ConnectionException;
import ai.chat2db.spi.model.SSHInfo; import ai.chat2db.spi.model.SSHInfo;
import cn.hutool.core.net.NetUtil; import cn.hutool.core.net.NetUtil;
import com.jcraft.jsch.JSch; import cn.hutool.extra.ssh.JschUtil;
import com.jcraft.jsch.Session; import com.jcraft.jsch.Session;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@ -17,72 +16,35 @@ import org.apache.commons.lang3.StringUtils;
@Slf4j @Slf4j
public class SSHManager { public class SSHManager {
private static final ConcurrentHashMap<SSHInfo, Session> SSH_SESSION_MAP = new ConcurrentHashMap(); public static Session getSSHSession(SSHInfo ssh) {
Session session;
try {
byte[] passphrase = StringUtils.isNotBlank(ssh.getPassphrase()) ? StringUtils.getBytes(ssh.getPassphrase(),
"UTF-8") : null;
session = JschUtil.getSession(ssh.getHostName(), Integer.parseInt(ssh.getPort()), ssh.getUserName(),
ssh.getKeyFile(), passphrase);
public static Session getSSHSession(SSHInfo sshInfo) { } catch (Exception e) {
Session session = SSH_SESSION_MAP.get(sshInfo); throw new ConnectionException("connection.ssh.error", null, e);
if (session != null && session.isConnected()) {
return session;
} else {
return createSession(sshInfo);
} }
} if (StringUtils.isNotBlank(ssh.getRHost()) && StringUtils.isNotBlank(ssh.getRPort())) {
private static Session createSession(SSHInfo ssh) {
synchronized (ssh) {
Session session = SSH_SESSION_MAP.get(ssh);
if (session != null && session.isConnected()) {
return session;
}
try { try {
JSch jSch = new JSch(); int localPort = !StringUtils.isBlank(ssh.getLocalPort()) ? Integer.parseInt(ssh.getLocalPort())
if (StringUtils.isNotBlank(ssh.getKeyFile()) && StringUtils.isNotBlank(ssh.getPassphrase())) { : NetUtil.getUsableLocalPort();
jSch.addIdentity(ssh.getKeyFile(), ssh.getPassphrase()); ssh.setLocalPort(String.valueOf(localPort));
} session.setPortForwardingL(localPort, ssh.getRHost(),
session = jSch.getSession(ssh.getUserName(), ssh.getHostName(), Integer.parseInt(ssh.getPort())); Integer.parseInt(ssh.getRPort()));
if (StringUtils.isBlank(ssh.getKeyFile()) || StringUtils.isBlank(ssh.getPassphrase())) {
session.setPassword(ssh.getPassword());
}
session.setConfig("StrictHostKeyChecking", "no");
session.connect();
SSH_SESSION_MAP.put(ssh, session);
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("create ssh session error", e); if (session != null && session.isConnected()) {
} session.disconnect();
if (StringUtils.isNotBlank(ssh.getRHost()) && StringUtils.isNotBlank(ssh.getRPort())) {
try {
int localPort = !StringUtils.isBlank(ssh.getLocalPort()) ? Integer.parseInt(ssh.getLocalPort())
: NetUtil.getUsableLocalPort();
session.setPortForwardingL(localPort, ssh.getRHost(),
Integer.parseInt(ssh.getRPort()));
} catch (Exception e) {
if (session != null && session.isConnected()) {
session.disconnect();
SSH_SESSION_MAP.remove(ssh);
}
throw new RuntimeException(ssh.getLocalPort() + " port is usedplease change to another port ", e);
} }
throw new ConnectionException("connection.ssh.error", null, e);
} }
return session;
} }
return session;
} }
public static void close() { public static void close() {
SSH_SESSION_MAP.forEach((k, v) -> { JschUtil.closeAll();
if (v != null && v.isConnected()) {
try {
v.delPortForwardingL(Integer.parseInt(k.getLocalPort()));
} catch (Exception e) {
log.error("delPortForwardingL error", e);
}
try {
v.disconnect();
} catch (Exception e) {
log.error("disconnect error", e);
}
}
});
} }
} }

View File

@ -21,9 +21,9 @@ import ai.chat2db.spi.model.DataSourceConnect;
import ai.chat2db.spi.model.SSHInfo; import ai.chat2db.spi.model.SSHInfo;
import ai.chat2db.spi.sql.IDriverManager; import ai.chat2db.spi.sql.IDriverManager;
import ai.chat2db.spi.ssh.SSHManager; import ai.chat2db.spi.ssh.SSHManager;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session; import com.jcraft.jsch.Session;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/** /**
* jdbc工具类 * jdbc工具类
@ -235,10 +235,12 @@ public class JdbcUtils {
} }
if (session != null) { if (session != null) {
try { try {
session.delPortForwardingL(Integer.parseInt(ssh.getLocalPort())); if(StringUtils.isNotBlank(ssh.getLocalPort())) {
session.delPortForwardingL(Integer.parseInt(ssh.getLocalPort()));
}
session.disconnect(); session.disconnect();
} catch (Exception e) {
} catch (JSchException e) {
} }
} }
} }