mirror of
https://github.com/CodePhiliaX/Chat2DB.git
synced 2025-09-23 13:37:10 +08:00
Support for custom drivers
This commit is contained in:
@ -0,0 +1,180 @@
|
||||
/**
|
||||
* alibaba.com Inc.
|
||||
* Copyright (c) 2004-2022 All Rights Reserved.
|
||||
*/
|
||||
package ai.chat2db.spi.sql;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.SQLException;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.ServiceLoader;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import ai.chat2db.spi.DBManage;
|
||||
import ai.chat2db.spi.MetaData;
|
||||
import ai.chat2db.spi.Plugin;
|
||||
import ai.chat2db.spi.config.DBConfig;
|
||||
import ai.chat2db.spi.config.DriverConfig;
|
||||
import ai.chat2db.spi.model.SSHInfo;
|
||||
|
||||
import com.jcraft.jsch.JSchException;
|
||||
import com.jcraft.jsch.Session;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* @author jipengfei
|
||||
* @version : Chat2DBContext.java
|
||||
*/
|
||||
@Slf4j
|
||||
public class Chat2DBContext {
|
||||
|
||||
private static final ThreadLocal<ConnectInfo> CONNECT_INFO_THREAD_LOCAL = new ThreadLocal<>();
|
||||
|
||||
public static List<String> JDBC_JAR_DOWNLOAD_URL_LIST;
|
||||
|
||||
public static Map<String, Plugin> PLUGIN_MAP = new ConcurrentHashMap<>();
|
||||
|
||||
static {
|
||||
ServiceLoader<Plugin> s = ServiceLoader.load(Plugin.class);
|
||||
Iterator<Plugin> iterator = s.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
Plugin plugin = iterator.next();
|
||||
PLUGIN_MAP.put(plugin.getDBConfig().getDbType(), plugin);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前线程的ContentContext
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public static ConnectInfo getConnectInfo() {
|
||||
return CONNECT_INFO_THREAD_LOCAL.get();
|
||||
}
|
||||
|
||||
public static MetaData getMetaData() {
|
||||
return PLUGIN_MAP.get(getConnectInfo().getDbType()).getMetaData();
|
||||
}
|
||||
|
||||
public static DBConfig getDBConfig(){
|
||||
return PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBConfig();
|
||||
}
|
||||
|
||||
public static DBManage getDBManage() {
|
||||
return PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBManage();
|
||||
}
|
||||
|
||||
public static Connection getConnection() {
|
||||
return getConnectInfo().getConnection();
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置context
|
||||
*
|
||||
* @param info
|
||||
*/
|
||||
public static void putContext(ConnectInfo info) {
|
||||
ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get();
|
||||
CONNECT_INFO_THREAD_LOCAL.set(info);
|
||||
if (connectInfo == null) {
|
||||
setConnectInfoThreadLocal(info);
|
||||
if (StringUtils.isNotBlank(info.getDatabaseName())) {
|
||||
PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBManage().connectDatabase(info.getDatabaseName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void setConnectInfoThreadLocal(ConnectInfo connectInfo) {
|
||||
Session session = null;
|
||||
Connection connection = null;
|
||||
SSHInfo ssh = connectInfo.getSsh();
|
||||
String url = connectInfo.getUrl();
|
||||
String host = connectInfo.getHost();
|
||||
String port = connectInfo.getPort() + "";
|
||||
try {
|
||||
session = getSession(ssh);
|
||||
if (session != null) {
|
||||
url = url.replace(host, "127.0.0.1").replace(port, ssh.getLocalPort());
|
||||
}
|
||||
connection = getConnect(url, host, port, connectInfo.getUser(),
|
||||
connectInfo.getPassword(), connectInfo.getDbType(),
|
||||
connectInfo.getDriverConfig(), ssh, connectInfo.getExtendMap());
|
||||
} catch (Exception e1) {
|
||||
log.error("getConnect error", e1);
|
||||
if (connection != null) {
|
||||
try {
|
||||
connection.close();
|
||||
} catch (SQLException e) {
|
||||
log.error("session close error", e);
|
||||
}
|
||||
}
|
||||
if (session != null) {
|
||||
try {
|
||||
session.delPortForwardingL(Integer.parseInt(ssh.getLocalPort()));
|
||||
session.disconnect();
|
||||
} catch (JSchException e) {
|
||||
log.error("session close error", e);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException("getConnect error", e1);
|
||||
}
|
||||
connectInfo.setSession(session);
|
||||
connectInfo.setConnection(connection);
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试数据库连接
|
||||
*
|
||||
* @param url 数据库连接
|
||||
* @param userName 用户名
|
||||
* @param password 密码
|
||||
* @param dbType 数据库类型
|
||||
* @return
|
||||
*/
|
||||
private static Connection getConnect(String url, String host, String port,
|
||||
String userName, String password, String dbType,
|
||||
DriverConfig jdbc, SSHInfo ssh, Map<String, Object> properties) throws SQLException {
|
||||
// 创建连接
|
||||
return IDriverManager.getConnection(url, userName, password, jdbc, properties);
|
||||
|
||||
}
|
||||
|
||||
private static Session getSession(SSHInfo ssh) {
|
||||
Session session = null;
|
||||
if (ssh != null && ssh.isUse()) {
|
||||
session = SSHManager.getSSHSession(ssh);
|
||||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置context
|
||||
*/
|
||||
public static void removeContext() {
|
||||
ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get();
|
||||
if (connectInfo != null) {
|
||||
Connection connection = connectInfo.getConnection();
|
||||
try {
|
||||
if (connection != null && !connection.isClosed()) {
|
||||
connection.close();
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
log.error("close connection error", e);
|
||||
}
|
||||
Session session = connectInfo.getSession();
|
||||
if (session != null) {
|
||||
try {
|
||||
session.delPortForwardingL(Integer.parseInt(connectInfo.getSsh().getLocalPort()));
|
||||
session.disconnect();
|
||||
} catch (JSchException e) {
|
||||
log.error("close session error", e);
|
||||
}
|
||||
}
|
||||
CONNECT_INFO_THREAD_LOCAL.remove();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,492 @@
|
||||
/**
|
||||
* alibaba.com Inc.
|
||||
* Copyright (c) 2004-2023 All Rights Reserved.
|
||||
*/
|
||||
package ai.chat2db.spi.sql;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import ai.chat2db.spi.config.DriverConfig;
|
||||
import ai.chat2db.spi.model.KeyValue;
|
||||
import ai.chat2db.spi.model.SSHInfo;
|
||||
import ai.chat2db.spi.model.SSLInfo;
|
||||
import com.jcraft.jsch.Session;
|
||||
import org.springframework.util.ObjectUtils;
|
||||
|
||||
/**
|
||||
* @author jipengfei
|
||||
* @version : ConnectInfo.java
|
||||
*/
|
||||
public class ConnectInfo {
|
||||
/**
|
||||
* 别名
|
||||
*/
|
||||
private String alias;
|
||||
/**
|
||||
* 数据连接ID
|
||||
*/
|
||||
private Long dataSourceId;
|
||||
|
||||
|
||||
/**
|
||||
* 创建时间
|
||||
*/
|
||||
private LocalDateTime gmtCreate;
|
||||
|
||||
/**
|
||||
* 修改时间
|
||||
*/
|
||||
private LocalDateTime gmtModified;
|
||||
/**
|
||||
* database
|
||||
*/
|
||||
private String databaseName;
|
||||
|
||||
/**
|
||||
* 控制台ID
|
||||
*/
|
||||
private Long consoleId;
|
||||
|
||||
/**
|
||||
* 数据库URL
|
||||
*/
|
||||
private String url;
|
||||
|
||||
/**
|
||||
* 用户名
|
||||
*/
|
||||
private String user;
|
||||
|
||||
/**
|
||||
* 密码
|
||||
*/
|
||||
private String password;
|
||||
|
||||
/**
|
||||
* console独立占有连接
|
||||
*/
|
||||
private Boolean consoleOwn = Boolean.FALSE;
|
||||
|
||||
/**
|
||||
* 数据库类型
|
||||
*/
|
||||
private String dbType;
|
||||
|
||||
private Integer port;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String urlWithOutDatabase;
|
||||
|
||||
/**
|
||||
* host
|
||||
*/
|
||||
private String host;
|
||||
|
||||
/**
|
||||
* ssh
|
||||
*/
|
||||
private SSHInfo ssh;
|
||||
|
||||
/**
|
||||
* ssh
|
||||
*/
|
||||
private SSLInfo ssl;
|
||||
|
||||
/**
|
||||
* sid
|
||||
*/
|
||||
private String sid;
|
||||
|
||||
/**
|
||||
* driver
|
||||
*/
|
||||
private String driver;
|
||||
|
||||
/**
|
||||
* jdbc版本
|
||||
*/
|
||||
private String jdbc;
|
||||
|
||||
/**
|
||||
* 扩展信息
|
||||
*/
|
||||
private List<KeyValue> extendInfo;
|
||||
|
||||
|
||||
|
||||
public Connection connection;
|
||||
|
||||
|
||||
|
||||
|
||||
private DriverConfig driverConfig;
|
||||
|
||||
|
||||
public DriverConfig getDriverConfig() {
|
||||
return driverConfig;
|
||||
}
|
||||
|
||||
|
||||
public void setDriverConfig(DriverConfig driverConfig) {
|
||||
this.driverConfig = driverConfig;
|
||||
}
|
||||
|
||||
public Session getSession() {
|
||||
return session;
|
||||
}
|
||||
|
||||
public void setSession(Session session) {
|
||||
this.session = session;
|
||||
}
|
||||
|
||||
public Session session;
|
||||
|
||||
|
||||
public LinkedHashMap<String,Object> getExtendMap() {
|
||||
if (ObjectUtils.isEmpty(extendInfo)) {
|
||||
return new LinkedHashMap<>();
|
||||
}
|
||||
LinkedHashMap<String,Object> map = new LinkedHashMap<>();
|
||||
for (KeyValue keyValue : extendInfo) {
|
||||
map.put(keyValue.getKey(),keyValue.getValue());
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
|
||||
public void setDatabase(String database) {
|
||||
this.databaseName = database;
|
||||
}
|
||||
|
||||
public String key() {
|
||||
return this.dataSourceId + "_" + this.databaseName;
|
||||
}
|
||||
|
||||
public void setUrl(String url) {
|
||||
this.url = url;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {return true;}
|
||||
if (!(o instanceof ConnectInfo)) {return false;}
|
||||
ConnectInfo that = (ConnectInfo)o;
|
||||
return Objects.equals(dataSourceId, that.dataSourceId)
|
||||
&& Objects.equals(gmtModified, that.gmtModified)
|
||||
;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dataSourceId, consoleId, databaseName);
|
||||
}
|
||||
|
||||
public Long getDataSourceId() {
|
||||
return dataSourceId;
|
||||
}
|
||||
|
||||
public void setDataSourceId(Long dataSourceId) {
|
||||
this.dataSourceId = dataSourceId;
|
||||
}
|
||||
|
||||
public String getDatabaseName() {
|
||||
return databaseName;
|
||||
}
|
||||
|
||||
public void setDatabaseName(String databaseName) {
|
||||
this.databaseName = databaseName;
|
||||
}
|
||||
|
||||
public Long getConsoleId() {
|
||||
return consoleId;
|
||||
}
|
||||
|
||||
public void setConsoleId(Long consoleId) {
|
||||
this.consoleId = consoleId;
|
||||
}
|
||||
|
||||
public String getUrl() {
|
||||
return url;
|
||||
}
|
||||
|
||||
public String getUser() {
|
||||
return user;
|
||||
}
|
||||
|
||||
public void setUser(String user) {
|
||||
this.user = user;
|
||||
}
|
||||
|
||||
public String getPassword() {
|
||||
return password;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>password</tt>.
|
||||
*
|
||||
* @param password value to be assigned to property password
|
||||
*/
|
||||
public void setPassword(String password) {
|
||||
this.password = password;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>consoleOwn</tt>.
|
||||
*
|
||||
* @return property value of consoleOwn
|
||||
*/
|
||||
public Boolean getConsoleOwn() {
|
||||
return consoleOwn;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>consoleOwn</tt>.
|
||||
*
|
||||
* @param consoleOwn value to be assigned to property consoleOwn
|
||||
*/
|
||||
public void setConsoleOwn(Boolean consoleOwn) {
|
||||
this.consoleOwn = consoleOwn;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>dbType</tt>.
|
||||
*
|
||||
* @return property value of dbType
|
||||
*/
|
||||
public String getDbType() {
|
||||
return dbType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>dbType</tt>.
|
||||
*
|
||||
* @param dbType value to be assigned to property dbType
|
||||
*/
|
||||
public void setDbType(String dbType) {
|
||||
this.dbType = dbType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>port</tt>.
|
||||
*
|
||||
* @return property value of port
|
||||
*/
|
||||
public Integer getPort() {
|
||||
return port;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>port</tt>.
|
||||
*
|
||||
* @param port value to be assigned to property port
|
||||
*/
|
||||
public void setPort(Integer port) {
|
||||
this.port = port;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>urlWithOutDatabase</tt>.
|
||||
*
|
||||
* @return property value of urlWithOutDatabase
|
||||
*/
|
||||
public String getUrlWithOutDatabase() {
|
||||
return urlWithOutDatabase;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>urlWithOutDatabase</tt>.
|
||||
*
|
||||
* @param urlWithOutDatabase value to be assigned to property urlWithOutDatabase
|
||||
*/
|
||||
public void setUrlWithOutDatabase(String urlWithOutDatabase) {
|
||||
this.urlWithOutDatabase = urlWithOutDatabase;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>host</tt>.
|
||||
*
|
||||
* @return property value of host
|
||||
*/
|
||||
public String getHost() {
|
||||
return host;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>host</tt>.
|
||||
*
|
||||
* @param host value to be assigned to property host
|
||||
*/
|
||||
public void setHost(String host) {
|
||||
this.host = host;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>ssh</tt>.
|
||||
*
|
||||
* @return property value of ssh
|
||||
*/
|
||||
public SSHInfo getSsh() {
|
||||
return ssh;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>ssh</tt>.
|
||||
*
|
||||
* @param ssh value to be assigned to property ssh
|
||||
*/
|
||||
public void setSsh(SSHInfo ssh) {
|
||||
this.ssh = ssh;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>ssl</tt>.
|
||||
*
|
||||
* @return property value of ssl
|
||||
*/
|
||||
public SSLInfo getSsl() {
|
||||
return ssl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>ssl</tt>.
|
||||
*
|
||||
* @param ssl value to be assigned to property ssl
|
||||
*/
|
||||
public void setSsl(SSLInfo ssl) {
|
||||
this.ssl = ssl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>sid</tt>.
|
||||
*
|
||||
* @return property value of sid
|
||||
*/
|
||||
public String getSid() {
|
||||
return sid;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>sid</tt>.
|
||||
*
|
||||
* @param sid value to be assigned to property sid
|
||||
*/
|
||||
public void setSid(String sid) {
|
||||
this.sid = sid;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>driver</tt>.
|
||||
*
|
||||
* @return property value of driver
|
||||
*/
|
||||
public String getDriver() {
|
||||
return driver;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>driver</tt>.
|
||||
*
|
||||
* @param driver value to be assigned to property driver
|
||||
*/
|
||||
public void setDriver(String driver) {
|
||||
this.driver = driver;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>jdbc</tt>.
|
||||
*
|
||||
* @return property value of jdbc
|
||||
*/
|
||||
public String getJdbc() {
|
||||
return jdbc;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>jdbc</tt>.
|
||||
*
|
||||
* @param jdbc value to be assigned to property jdbc
|
||||
*/
|
||||
public void setJdbc(String jdbc) {
|
||||
this.jdbc = jdbc;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>extendInfo</tt>.
|
||||
*
|
||||
* @return property value of extendInfo
|
||||
*/
|
||||
public List<KeyValue> getExtendInfo() {
|
||||
return extendInfo;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>extendInfo</tt>.
|
||||
*
|
||||
* @param extendInfo value to be assigned to property extendInfo
|
||||
*/
|
||||
public void setExtendInfo(List<KeyValue> extendInfo) {
|
||||
this.extendInfo = extendInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>connection</tt>.
|
||||
*
|
||||
* @return property value of connection
|
||||
*/
|
||||
public Connection getConnection() {
|
||||
return connection;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>connection</tt>.
|
||||
*
|
||||
* @param connection value to be assigned to property connection
|
||||
*/
|
||||
public void setConnection(Connection connection) {
|
||||
this.connection = connection;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>alias</tt>.
|
||||
*
|
||||
* @return property value of alias
|
||||
*/
|
||||
public String getAlias() {
|
||||
return alias;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setter method for property <tt>alias</tt>.
|
||||
*
|
||||
* @param alias value to be assigned to property alias
|
||||
*/
|
||||
public void setAlias(String alias) {
|
||||
this.alias = alias;
|
||||
}
|
||||
|
||||
public LocalDateTime getGmtCreate() {
|
||||
return gmtCreate;
|
||||
}
|
||||
|
||||
public void setGmtCreate(LocalDateTime gmtCreate) {
|
||||
this.gmtCreate = gmtCreate;
|
||||
}
|
||||
|
||||
public LocalDateTime getGmtModified() {
|
||||
return gmtModified;
|
||||
}
|
||||
|
||||
public void setGmtModified(LocalDateTime gmtModified) {
|
||||
this.gmtModified = gmtModified;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,200 @@
|
||||
/**
|
||||
* alibaba.com Inc.
|
||||
* Copyright (c) 2004-2023 All Rights Reserved.
|
||||
*/
|
||||
package ai.chat2db.spi.sql;
|
||||
|
||||
import java.io.File;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.net.URLClassLoader;
|
||||
import java.sql.Connection;
|
||||
import java.sql.Driver;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
|
||||
import ai.chat2db.spi.config.DriverConfig;
|
||||
import ai.chat2db.spi.model.DriverEntry;
|
||||
import ai.chat2db.spi.util.JdbcJarUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static ai.chat2db.spi.util.JdbcJarUtils.getFullPath;
|
||||
|
||||
/**
|
||||
* @author jipengfei
|
||||
* @version : IsolationDriverManager.java
|
||||
*/
|
||||
public class IDriverManager {
|
||||
private static final Logger log = LoggerFactory.getLogger(IDriverManager.class);
|
||||
private static final Map<String, ClassLoader> CLASS_LOADER_MAP = new ConcurrentHashMap();
|
||||
private static final Map<String, DriverEntry> DRIVER_ENTRY_MAP = new ConcurrentHashMap();
|
||||
|
||||
public static Connection getConnection(String url, DriverConfig driver) throws SQLException {
|
||||
Properties info = new Properties();
|
||||
return getConnection(url, info, driver);
|
||||
}
|
||||
|
||||
public static Connection getConnection(String url, String user, String password, DriverConfig driver)
|
||||
throws SQLException {
|
||||
Properties info = new Properties();
|
||||
if (user != null) {
|
||||
info.put("user", user);
|
||||
}
|
||||
|
||||
if (password != null) {
|
||||
info.put("password", password);
|
||||
}
|
||||
|
||||
return getConnection(url, info, driver);
|
||||
}
|
||||
|
||||
public static Connection getConnection(String url, String user, String password, DriverConfig driver,
|
||||
Map<String, Object> properties)
|
||||
throws SQLException {
|
||||
Properties info = new Properties();
|
||||
if (StringUtils.isNotEmpty(user)) {
|
||||
info.put("user", user);
|
||||
}
|
||||
|
||||
if (StringUtils.isNotEmpty(password)) {
|
||||
info.put("password", password);
|
||||
}
|
||||
info.putAll(properties);
|
||||
return getConnection(url, info, driver);
|
||||
}
|
||||
|
||||
public static Connection getConnection(String url, Properties info, DriverConfig driver)
|
||||
throws SQLException {
|
||||
if (url == null) {
|
||||
throw new SQLException("The url cannot be null", "08001");
|
||||
}
|
||||
DriverManager.println("DriverManager.getConnection(\"" + url + "\")");
|
||||
SQLException reason = null;
|
||||
DriverEntry driverEntry = DRIVER_ENTRY_MAP.get(driver.getName());
|
||||
if (driverEntry == null) {
|
||||
driverEntry = getJDBCDriver(driver);
|
||||
}
|
||||
try {
|
||||
Connection con = driverEntry.getDriver().connect(url, info);
|
||||
if (con != null) {
|
||||
DriverManager.println("getConnection returning " + driverEntry.getDriver().getClass().getName());
|
||||
return con;
|
||||
}
|
||||
} catch (SQLException var7) {
|
||||
Connection con = tryConnectionAgain(driverEntry, url, info);
|
||||
if (con != null) {
|
||||
return con;
|
||||
} else {
|
||||
throw var7;
|
||||
}
|
||||
}
|
||||
|
||||
if (reason != null) {
|
||||
DriverManager.println("getConnection failed: " + reason);
|
||||
throw reason;
|
||||
} else {
|
||||
DriverManager.println("getConnection: no suitable driver found for " + url);
|
||||
throw new SQLException("No suitable driver found for " + url, "08001");
|
||||
}
|
||||
}
|
||||
|
||||
private static Connection tryConnectionAgain(DriverEntry driverEntry, String url,
|
||||
Properties info) throws SQLException {
|
||||
if (url.contains("mysql")) {
|
||||
if (!info.containsKey("useSSL")) {
|
||||
info.put("useSSL", "false");
|
||||
}
|
||||
return driverEntry.getDriver().connect(url, info);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static DriverEntry getJDBCDriver(DriverConfig driver)
|
||||
throws SQLException {
|
||||
synchronized (driver) {
|
||||
try {
|
||||
if (DRIVER_ENTRY_MAP.containsKey(driver.getName())) {
|
||||
return DRIVER_ENTRY_MAP.get(driver.getName());
|
||||
}
|
||||
ClassLoader cl = getClassLoader(driver);
|
||||
Driver d = (Driver)cl.loadClass(driver.getJdbcDriverClass()).newInstance();
|
||||
DriverEntry driverEntry = DriverEntry.builder().driverConfig(driver).driver(d).build();
|
||||
DRIVER_ENTRY_MAP.put(driver.getName(), driverEntry);
|
||||
return driverEntry;
|
||||
} catch (Exception e) {
|
||||
log.error("getJDBCDriver error", e);
|
||||
throw new SQLException("getJDBCDriver error", "08001");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static ClassLoader getClassLoader(DriverConfig driverConfig) throws MalformedURLException {
|
||||
String jarPath = driverConfig.getJdbcDriver();
|
||||
if (CLASS_LOADER_MAP.containsKey(jarPath)) {
|
||||
return CLASS_LOADER_MAP.get(jarPath);
|
||||
} else {
|
||||
synchronized (jarPath) {
|
||||
if (CLASS_LOADER_MAP.containsKey(jarPath)) {
|
||||
return CLASS_LOADER_MAP.get(jarPath);
|
||||
}
|
||||
String[] jarPaths = jarPath.split(",");
|
||||
URL[] urls = new URL[jarPaths.length];
|
||||
for (int i = 0; i < jarPaths.length; i++) {
|
||||
File driverFile = new File(getFullPath(jarPaths[i]));
|
||||
urls[i] = driverFile.toURI().toURL();
|
||||
}
|
||||
//urls[jarPaths.length] = new File(JdbcJarUtils.getFullPath("HikariCP-4.0.3.jar")).toURI().toURL();
|
||||
|
||||
URLClassLoader cl = new URLClassLoader(urls, ClassLoader.getSystemClassLoader());
|
||||
log.info("ClassLoader class:{}", cl.hashCode());
|
||||
log.info("ClassLoader URLs:{}", JSON.toJSONString(cl.getURLs()));
|
||||
|
||||
try {
|
||||
cl.loadClass(driverConfig.getJdbcDriverClass());
|
||||
} catch (Exception e) {
|
||||
//如果报错删除目录重试一次
|
||||
for (int i = 0; i < jarPaths.length; i++) {
|
||||
File driverFile = new File(JdbcJarUtils.getNewFullPath(jarPaths[i]));
|
||||
urls[i] = driverFile.toURI().toURL();
|
||||
}
|
||||
//urls[jarPaths.length] = new File(JdbcJarUtils.getFullPath("HikariCP-4.0.3.jar")).toURI().toURL();
|
||||
cl = new URLClassLoader(urls, ClassLoader.getSystemClassLoader());
|
||||
|
||||
}
|
||||
CLASS_LOADER_MAP.put(jarPath, cl);
|
||||
return cl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//private static List<Class> loadClass(String jarPath, ClassLoader classLoader) throws IOException {
|
||||
// Long s1 = System.currentTimeMillis();
|
||||
// JarFile jarFile = new JarFile(getFullPath(jarPath));
|
||||
// Enumeration<JarEntry> entries = jarFile.entries();
|
||||
// List<Class> classes = new ArrayList();
|
||||
// while (entries.hasMoreElements()) {
|
||||
// JarEntry jarEntry = entries.nextElement();
|
||||
// if (jarEntry.getName().endsWith(".class") && !jarEntry.getName().contains("$")) {
|
||||
// String className = jarEntry.getName().substring(0, jarEntry.getName().length() - 6).replaceAll("/",
|
||||
// ".");
|
||||
// try {
|
||||
// classes.add(classLoader.loadClass(className));
|
||||
// // log.info("loadClass:{}", className);
|
||||
// } catch (Throwable var7) {
|
||||
// //log.error("getClasses error "+className, var7);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// log.info("loadClass cost:{}", System.currentTimeMillis() - s1);
|
||||
// return classes;
|
||||
//}
|
||||
|
||||
}
|
@ -0,0 +1,356 @@
|
||||
/**
|
||||
* alibaba.com Inc.
|
||||
* Copyright (c) 2004-2022 All Rights Reserved.
|
||||
*/
|
||||
package ai.chat2db.spi.sql;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.ResultSetMetaData;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Statement;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import ai.chat2db.spi.model.*;
|
||||
|
||||
import cn.hutool.core.date.TimeInterval;
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.jdbc.support.JdbcUtils;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import static ai.chat2db.spi.util.ResultSetUtils.buildColumn;
|
||||
import static ai.chat2db.spi.util.ResultSetUtils.buildFunction;
|
||||
import static ai.chat2db.spi.util.ResultSetUtils.buildProcedure;
|
||||
import static ai.chat2db.spi.util.ResultSetUtils.buildTable;
|
||||
import static ai.chat2db.spi.util.ResultSetUtils.buildTableIndexColumn;
|
||||
|
||||
/**
|
||||
* Dbhub 统一数据库连接管理
|
||||
* TODO 长时间不用连接可以关闭,待优化
|
||||
*
|
||||
* @author jipengfei
|
||||
* @version : DbhubDataSource.java
|
||||
*/
|
||||
@Slf4j
|
||||
public class SQLExecutor {
|
||||
/**
|
||||
* 全局单例
|
||||
*/
|
||||
private static final SQLExecutor INSTANCE = new SQLExecutor();
|
||||
|
||||
private SQLExecutor() {
|
||||
}
|
||||
|
||||
public static SQLExecutor getInstance() {
|
||||
return INSTANCE;
|
||||
}
|
||||
|
||||
public Connection getConnection() throws SQLException {
|
||||
return Chat2DBContext.getConnection();
|
||||
}
|
||||
|
||||
public void close() {
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行sql
|
||||
*
|
||||
* @param sql
|
||||
* @param function
|
||||
* @return
|
||||
*/
|
||||
|
||||
public <R> R executeSql(String sql, Function<ResultSet, R> function) {
|
||||
if (StringUtils.isEmpty(sql)) {
|
||||
return null;
|
||||
}
|
||||
log.info("execute:{}", sql);
|
||||
Statement stmt = null;
|
||||
try {
|
||||
stmt = getConnection().createStatement();
|
||||
boolean query = stmt.execute(sql);
|
||||
// 代表是查询
|
||||
if (query) {
|
||||
ResultSet rs = null;
|
||||
try {
|
||||
rs = stmt.getResultSet();
|
||||
return function.apply(rs);
|
||||
} finally {
|
||||
if (rs != null) {
|
||||
rs.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException(e);
|
||||
} finally {
|
||||
if (stmt != null) {
|
||||
try {
|
||||
stmt.close();
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行sql
|
||||
*
|
||||
* @param sql
|
||||
* @return
|
||||
* @throws SQLException
|
||||
*/
|
||||
public ExecuteResult execute(final String sql, Connection connection) throws SQLException {
|
||||
Assert.notNull(sql, "SQL must not be null");
|
||||
log.info("execute:{}", sql);
|
||||
|
||||
ExecuteResult executeResult = ExecuteResult.builder().sql(sql).success(Boolean.TRUE).build();
|
||||
Statement stmt = null;
|
||||
try {
|
||||
TimeInterval timeInterval=new TimeInterval();
|
||||
stmt = connection.createStatement();
|
||||
boolean query = stmt.execute(sql.replaceFirst(";", ""));
|
||||
executeResult.setDescription("执行成功");
|
||||
// 代表是查询
|
||||
if (query) {
|
||||
ResultSet rs = null;
|
||||
try {
|
||||
rs = stmt.getResultSet();
|
||||
// 获取有几列
|
||||
ResultSetMetaData resultSetMetaData = rs.getMetaData();
|
||||
int col = resultSetMetaData.getColumnCount();
|
||||
|
||||
// 获取header信息
|
||||
List<Header> headerList = Lists.newArrayListWithExpectedSize(col);
|
||||
executeResult.setHeaderList(headerList);
|
||||
for (int i = 1; i <= col; i++) {
|
||||
headerList.add(Header.builder()
|
||||
.dataType(ai.chat2db.spi.util.JdbcUtils.resolveDataType(
|
||||
resultSetMetaData.getColumnTypeName(i), resultSetMetaData.getColumnType(i)).getCode())
|
||||
.name(resultSetMetaData.getColumnName(i))
|
||||
.build());
|
||||
}
|
||||
|
||||
// 获取数据信息
|
||||
List<List<String>> dataList = Lists.newArrayList();
|
||||
executeResult.setDataList(dataList);
|
||||
|
||||
while (rs.next()) {
|
||||
List<String> row = Lists.newArrayListWithExpectedSize(col);
|
||||
dataList.add(row);
|
||||
for (int i = 1; i <= col; i++) {
|
||||
row.add(ai.chat2db.spi.util.JdbcUtils.getResultSetValue(rs, i));
|
||||
}
|
||||
}
|
||||
executeResult.setDuration(timeInterval.interval());
|
||||
return executeResult;
|
||||
} finally {
|
||||
JdbcUtils.closeResultSet(rs);
|
||||
}
|
||||
} else {
|
||||
// 修改或者其他
|
||||
executeResult.setUpdateCount(stmt.getUpdateCount());
|
||||
}
|
||||
} finally {
|
||||
JdbcUtils.closeStatement(stmt);
|
||||
}
|
||||
return executeResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行sql
|
||||
*
|
||||
* @param sql
|
||||
* @return
|
||||
* @throws SQLException
|
||||
*/
|
||||
public ExecuteResult execute(String sql) throws SQLException {
|
||||
return execute(sql, getConnection());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取所有的数据库
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public List<String> databases() {
|
||||
List<String> tables = Lists.newArrayList();
|
||||
try {
|
||||
ResultSet resultSet = getConnection().getMetaData().getCatalogs();
|
||||
if (resultSet != null) {
|
||||
while (resultSet.next()) {
|
||||
tables.add(resultSet.getString("TABLE_CAT"));
|
||||
}
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
close();
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return tables;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有的schema
|
||||
*
|
||||
* @param databaseName
|
||||
* @param schemaName
|
||||
* @return
|
||||
*/
|
||||
public List<String> schemas(String databaseName, String schemaName) {
|
||||
List<String> schemaList = Lists.newArrayList();
|
||||
try {
|
||||
ResultSet resultSet = getConnection().getMetaData().getSchemas(databaseName, schemaName);
|
||||
if (resultSet != null) {
|
||||
while (resultSet.next()) {
|
||||
schemaList.add(resultSet.getString("TABLE_SCHEM"));
|
||||
}
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
close();
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return schemaList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有的数据库表
|
||||
*
|
||||
* @param databaseName
|
||||
* @param schemaName
|
||||
* @param tableName
|
||||
* @param types
|
||||
* @return
|
||||
*/
|
||||
public List<Table> tables(String databaseName, String schemaName, String tableName, String types[]) {
|
||||
List<Table> tables = Lists.newArrayList();
|
||||
try {
|
||||
ResultSet resultSet = getConnection().getMetaData().getTables(databaseName, schemaName, tableName,
|
||||
types);
|
||||
if (resultSet != null) {
|
||||
while (resultSet.next()) {
|
||||
tables.add(buildTable(resultSet));
|
||||
}
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
close();
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return tables;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有的数据库表列
|
||||
*
|
||||
* @param databaseName
|
||||
* @param schemaName
|
||||
* @param tableName
|
||||
* @param columnName
|
||||
* @return
|
||||
*/
|
||||
public List<TableColumn> columns(String databaseName, String schemaName, String tableName, String columnName) {
|
||||
List<TableColumn> tableColumns = Lists.newArrayList();
|
||||
try {
|
||||
ResultSet resultSet = getConnection().getMetaData().getColumns(databaseName, schemaName, tableName,
|
||||
columnName);
|
||||
if (resultSet != null) {
|
||||
while (resultSet.next()) {
|
||||
tableColumns.add(buildColumn(resultSet));
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
close();
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return tableColumns;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有的数据库表索引
|
||||
*
|
||||
* @param databaseName
|
||||
* @param schemaName
|
||||
* @param tableName
|
||||
* @return
|
||||
*/
|
||||
public List<TableIndex> indexes(String databaseName, String schemaName, String tableName) {
|
||||
List<TableIndex> tableIndices = Lists.newArrayList();
|
||||
try {
|
||||
List<TableIndexColumn> tableIndexColumns = Lists.newArrayList();
|
||||
ResultSet resultSet = getConnection().getMetaData().getIndexInfo(databaseName, schemaName, tableName, false,
|
||||
false);
|
||||
|
||||
while (resultSet != null && resultSet.next()) {
|
||||
tableIndexColumns.add(buildTableIndexColumn(resultSet));
|
||||
}
|
||||
|
||||
tableIndexColumns.stream().filter(c -> c.getIndexName() != null).collect(
|
||||
Collectors.groupingBy(TableIndexColumn::getIndexName)).entrySet()
|
||||
.stream().forEach(entry -> {
|
||||
TableIndex tableIndex = new TableIndex();
|
||||
TableIndexColumn column = entry.getValue().get(0);
|
||||
tableIndex.setName(entry.getKey());
|
||||
tableIndex.setTableName(column.getTableName());
|
||||
tableIndex.setSchemaName(column.getSchemaName());
|
||||
tableIndex.setDatabaseName(column.getDatabaseName());
|
||||
tableIndex.setUnique(!column.getNonUnique());
|
||||
tableIndex.setColumnList(entry.getValue());
|
||||
tableIndices.add(tableIndex);
|
||||
});
|
||||
} catch (SQLException e) {
|
||||
close();
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return tableIndices;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有的函数
|
||||
*
|
||||
* @param databaseName
|
||||
* @param schemaName
|
||||
* @return
|
||||
*/
|
||||
public List<ai.chat2db.spi.model.Function> functions(String databaseName,
|
||||
String schemaName) {
|
||||
List<ai.chat2db.spi.model.Function> functions = Lists.newArrayList();
|
||||
try {
|
||||
ResultSet resultSet = getConnection().getMetaData().getFunctions(databaseName, schemaName, null);
|
||||
while (resultSet != null && resultSet.next()) {
|
||||
functions.add(buildFunction(resultSet));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
close();
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return functions;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有的存储过程
|
||||
*
|
||||
* @param databaseName
|
||||
* @param schemaName
|
||||
* @return
|
||||
*/
|
||||
public List<Procedure> procedures(String databaseName, String schemaName) {
|
||||
List<Procedure> procedures = Lists.newArrayList();
|
||||
try {
|
||||
ResultSet resultSet = getConnection().getMetaData().getProcedures(databaseName, schemaName, null);
|
||||
while (resultSet != null && resultSet.next()) {
|
||||
procedures.add(buildProcedure(resultSet));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
close();
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return procedures;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* alibaba.com Inc.
|
||||
* Copyright (c) 2004-2023 All Rights Reserved.
|
||||
*/
|
||||
package ai.chat2db.spi.sql;
|
||||
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import ai.chat2db.spi.model.SSHInfo;
|
||||
import com.jcraft.jsch.JSch;
|
||||
import com.jcraft.jsch.Session;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* @author jipengfei
|
||||
* @version : SSHSessionManager.java
|
||||
*/
|
||||
@Slf4j
|
||||
public class SSHManager {
|
||||
|
||||
private static final ConcurrentHashMap<SSHInfo, Session> SSH_SESSION_MAP = new ConcurrentHashMap();
|
||||
|
||||
public static Session getSSHSession(SSHInfo sshInfo) {
|
||||
Session session = SSH_SESSION_MAP.get(sshInfo);
|
||||
if (session != null && session.isConnected()) {
|
||||
return session;
|
||||
} else {
|
||||
return createSession(sshInfo);
|
||||
}
|
||||
}
|
||||
|
||||
private static Session createSession(SSHInfo ssh) {
|
||||
synchronized (ssh) {
|
||||
Session session = SSH_SESSION_MAP.get(ssh);
|
||||
if (session != null && session.isConnected()) {
|
||||
return session;
|
||||
}
|
||||
try {
|
||||
JSch jSch = new JSch();
|
||||
session = jSch.getSession(ssh.getUserName(), ssh.getHostName(), Integer.parseInt(ssh.getPort()));
|
||||
session.setPassword(ssh.getPassword());
|
||||
session.setConfig("StrictHostKeyChecking", "no");
|
||||
session.connect();
|
||||
SSH_SESSION_MAP.put(ssh, session);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("create ssh session error", e);
|
||||
}
|
||||
|
||||
if (StringUtils.isNotBlank(ssh.getLocalPort()) && StringUtils.isNotBlank(ssh.getRHost())
|
||||
&& StringUtils.isNotBlank(ssh.getRPort())) {
|
||||
try {
|
||||
int port1 = session.setPortForwardingL(Integer.parseInt(ssh.getLocalPort()), ssh.getRHost(),
|
||||
Integer.parseInt(ssh.getRPort()));
|
||||
} catch (Exception e) {
|
||||
if (session != null && session.isConnected()) {
|
||||
session.disconnect();
|
||||
SSH_SESSION_MAP.remove(ssh);
|
||||
}
|
||||
throw new RuntimeException(ssh.getLocalPort() + " port is used,please change to another port ", e);
|
||||
}
|
||||
}
|
||||
return session;
|
||||
}
|
||||
}
|
||||
|
||||
public static void close() {
|
||||
SSH_SESSION_MAP.forEach((k, v) -> {
|
||||
if (v != null && v.isConnected()) {
|
||||
try {
|
||||
v.delPortForwardingL(Integer.parseInt(k.getLocalPort()));
|
||||
} catch (Exception e) {
|
||||
log.error("delPortForwardingL error", e);
|
||||
}
|
||||
try {
|
||||
v.disconnect();
|
||||
} catch (Exception e) {
|
||||
log.error("disconnect error", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user