Support for custom drivers

This commit is contained in:
jipengfei-jpf
2023-06-22 20:30:26 +08:00
parent e67d3dae54
commit a25fae6cfc
112 changed files with 4399 additions and 279 deletions

View File

@ -0,0 +1,180 @@
/**
* alibaba.com Inc.
* Copyright (c) 2004-2022 All Rights Reserved.
*/
package ai.chat2db.spi.sql;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.ConcurrentHashMap;
import ai.chat2db.spi.DBManage;
import ai.chat2db.spi.MetaData;
import ai.chat2db.spi.Plugin;
import ai.chat2db.spi.config.DBConfig;
import ai.chat2db.spi.config.DriverConfig;
import ai.chat2db.spi.model.SSHInfo;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* @author jipengfei
* @version : Chat2DBContext.java
*/
@Slf4j
public class Chat2DBContext {
private static final ThreadLocal<ConnectInfo> CONNECT_INFO_THREAD_LOCAL = new ThreadLocal<>();
public static List<String> JDBC_JAR_DOWNLOAD_URL_LIST;
public static Map<String, Plugin> PLUGIN_MAP = new ConcurrentHashMap<>();
static {
ServiceLoader<Plugin> s = ServiceLoader.load(Plugin.class);
Iterator<Plugin> iterator = s.iterator();
while (iterator.hasNext()) {
Plugin plugin = iterator.next();
PLUGIN_MAP.put(plugin.getDBConfig().getDbType(), plugin);
}
}
/**
* 获取当前线程的ContentContext
*
* @return
*/
public static ConnectInfo getConnectInfo() {
return CONNECT_INFO_THREAD_LOCAL.get();
}
public static MetaData getMetaData() {
return PLUGIN_MAP.get(getConnectInfo().getDbType()).getMetaData();
}
public static DBConfig getDBConfig(){
return PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBConfig();
}
public static DBManage getDBManage() {
return PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBManage();
}
public static Connection getConnection() {
return getConnectInfo().getConnection();
}
/**
* 设置context
*
* @param info
*/
public static void putContext(ConnectInfo info) {
ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get();
CONNECT_INFO_THREAD_LOCAL.set(info);
if (connectInfo == null) {
setConnectInfoThreadLocal(info);
if (StringUtils.isNotBlank(info.getDatabaseName())) {
PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBManage().connectDatabase(info.getDatabaseName());
}
}
}
private static void setConnectInfoThreadLocal(ConnectInfo connectInfo) {
Session session = null;
Connection connection = null;
SSHInfo ssh = connectInfo.getSsh();
String url = connectInfo.getUrl();
String host = connectInfo.getHost();
String port = connectInfo.getPort() + "";
try {
session = getSession(ssh);
if (session != null) {
url = url.replace(host, "127.0.0.1").replace(port, ssh.getLocalPort());
}
connection = getConnect(url, host, port, connectInfo.getUser(),
connectInfo.getPassword(), connectInfo.getDbType(),
connectInfo.getDriverConfig(), ssh, connectInfo.getExtendMap());
} catch (Exception e1) {
log.error("getConnect error", e1);
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
log.error("session close error", e);
}
}
if (session != null) {
try {
session.delPortForwardingL(Integer.parseInt(ssh.getLocalPort()));
session.disconnect();
} catch (JSchException e) {
log.error("session close error", e);
}
}
throw new RuntimeException("getConnect error", e1);
}
connectInfo.setSession(session);
connectInfo.setConnection(connection);
}
/**
* 测试数据库连接
*
* @param url 数据库连接
* @param userName 用户名
* @param password 密码
* @param dbType 数据库类型
* @return
*/
private static Connection getConnect(String url, String host, String port,
String userName, String password, String dbType,
DriverConfig jdbc, SSHInfo ssh, Map<String, Object> properties) throws SQLException {
// 创建连接
return IDriverManager.getConnection(url, userName, password, jdbc, properties);
}
private static Session getSession(SSHInfo ssh) {
Session session = null;
if (ssh != null && ssh.isUse()) {
session = SSHManager.getSSHSession(ssh);
}
return session;
}
/**
* 设置context
*/
public static void removeContext() {
ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get();
if (connectInfo != null) {
Connection connection = connectInfo.getConnection();
try {
if (connection != null && !connection.isClosed()) {
connection.close();
}
} catch (SQLException e) {
log.error("close connection error", e);
}
Session session = connectInfo.getSession();
if (session != null) {
try {
session.delPortForwardingL(Integer.parseInt(connectInfo.getSsh().getLocalPort()));
session.disconnect();
} catch (JSchException e) {
log.error("close session error", e);
}
}
CONNECT_INFO_THREAD_LOCAL.remove();
}
}
}

View File

@ -0,0 +1,492 @@
/**
* alibaba.com Inc.
* Copyright (c) 2004-2023 All Rights Reserved.
*/
package ai.chat2db.spi.sql;
import java.sql.Connection;
import java.time.LocalDateTime;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import ai.chat2db.spi.config.DriverConfig;
import ai.chat2db.spi.model.KeyValue;
import ai.chat2db.spi.model.SSHInfo;
import ai.chat2db.spi.model.SSLInfo;
import com.jcraft.jsch.Session;
import org.springframework.util.ObjectUtils;
/**
* @author jipengfei
* @version : ConnectInfo.java
*/
public class ConnectInfo {
/**
* 别名
*/
private String alias;
/**
* 数据连接ID
*/
private Long dataSourceId;
/**
* 创建时间
*/
private LocalDateTime gmtCreate;
/**
* 修改时间
*/
private LocalDateTime gmtModified;
/**
* database
*/
private String databaseName;
/**
* 控制台ID
*/
private Long consoleId;
/**
* 数据库URL
*/
private String url;
/**
* 用户名
*/
private String user;
/**
* 密码
*/
private String password;
/**
* console独立占有连接
*/
private Boolean consoleOwn = Boolean.FALSE;
/**
* 数据库类型
*/
private String dbType;
private Integer port;
/**
*
*/
private String urlWithOutDatabase;
/**
* host
*/
private String host;
/**
* ssh
*/
private SSHInfo ssh;
/**
* ssh
*/
private SSLInfo ssl;
/**
* sid
*/
private String sid;
/**
* driver
*/
private String driver;
/**
* jdbc版本
*/
private String jdbc;
/**
* 扩展信息
*/
private List<KeyValue> extendInfo;
public Connection connection;
private DriverConfig driverConfig;
public DriverConfig getDriverConfig() {
return driverConfig;
}
public void setDriverConfig(DriverConfig driverConfig) {
this.driverConfig = driverConfig;
}
public Session getSession() {
return session;
}
public void setSession(Session session) {
this.session = session;
}
public Session session;
public LinkedHashMap<String,Object> getExtendMap() {
if (ObjectUtils.isEmpty(extendInfo)) {
return new LinkedHashMap<>();
}
LinkedHashMap<String,Object> map = new LinkedHashMap<>();
for (KeyValue keyValue : extendInfo) {
map.put(keyValue.getKey(),keyValue.getValue());
}
return map;
}
public void setDatabase(String database) {
this.databaseName = database;
}
public String key() {
return this.dataSourceId + "_" + this.databaseName;
}
public void setUrl(String url) {
this.url = url;
}
@Override
public boolean equals(Object o) {
if (this == o) {return true;}
if (!(o instanceof ConnectInfo)) {return false;}
ConnectInfo that = (ConnectInfo)o;
return Objects.equals(dataSourceId, that.dataSourceId)
&& Objects.equals(gmtModified, that.gmtModified)
;
}
@Override
public int hashCode() {
return Objects.hash(dataSourceId, consoleId, databaseName);
}
public Long getDataSourceId() {
return dataSourceId;
}
public void setDataSourceId(Long dataSourceId) {
this.dataSourceId = dataSourceId;
}
public String getDatabaseName() {
return databaseName;
}
public void setDatabaseName(String databaseName) {
this.databaseName = databaseName;
}
public Long getConsoleId() {
return consoleId;
}
public void setConsoleId(Long consoleId) {
this.consoleId = consoleId;
}
public String getUrl() {
return url;
}
public String getUser() {
return user;
}
public void setUser(String user) {
this.user = user;
}
public String getPassword() {
return password;
}
/**
* Setter method for property <tt>password</tt>.
*
* @param password value to be assigned to property password
*/
public void setPassword(String password) {
this.password = password;
}
/**
* Getter method for property <tt>consoleOwn</tt>.
*
* @return property value of consoleOwn
*/
public Boolean getConsoleOwn() {
return consoleOwn;
}
/**
* Setter method for property <tt>consoleOwn</tt>.
*
* @param consoleOwn value to be assigned to property consoleOwn
*/
public void setConsoleOwn(Boolean consoleOwn) {
this.consoleOwn = consoleOwn;
}
/**
* Getter method for property <tt>dbType</tt>.
*
* @return property value of dbType
*/
public String getDbType() {
return dbType;
}
/**
* Setter method for property <tt>dbType</tt>.
*
* @param dbType value to be assigned to property dbType
*/
public void setDbType(String dbType) {
this.dbType = dbType;
}
/**
* Getter method for property <tt>port</tt>.
*
* @return property value of port
*/
public Integer getPort() {
return port;
}
/**
* Setter method for property <tt>port</tt>.
*
* @param port value to be assigned to property port
*/
public void setPort(Integer port) {
this.port = port;
}
/**
* Getter method for property <tt>urlWithOutDatabase</tt>.
*
* @return property value of urlWithOutDatabase
*/
public String getUrlWithOutDatabase() {
return urlWithOutDatabase;
}
/**
* Setter method for property <tt>urlWithOutDatabase</tt>.
*
* @param urlWithOutDatabase value to be assigned to property urlWithOutDatabase
*/
public void setUrlWithOutDatabase(String urlWithOutDatabase) {
this.urlWithOutDatabase = urlWithOutDatabase;
}
/**
* Getter method for property <tt>host</tt>.
*
* @return property value of host
*/
public String getHost() {
return host;
}
/**
* Setter method for property <tt>host</tt>.
*
* @param host value to be assigned to property host
*/
public void setHost(String host) {
this.host = host;
}
/**
* Getter method for property <tt>ssh</tt>.
*
* @return property value of ssh
*/
public SSHInfo getSsh() {
return ssh;
}
/**
* Setter method for property <tt>ssh</tt>.
*
* @param ssh value to be assigned to property ssh
*/
public void setSsh(SSHInfo ssh) {
this.ssh = ssh;
}
/**
* Getter method for property <tt>ssl</tt>.
*
* @return property value of ssl
*/
public SSLInfo getSsl() {
return ssl;
}
/**
* Setter method for property <tt>ssl</tt>.
*
* @param ssl value to be assigned to property ssl
*/
public void setSsl(SSLInfo ssl) {
this.ssl = ssl;
}
/**
* Getter method for property <tt>sid</tt>.
*
* @return property value of sid
*/
public String getSid() {
return sid;
}
/**
* Setter method for property <tt>sid</tt>.
*
* @param sid value to be assigned to property sid
*/
public void setSid(String sid) {
this.sid = sid;
}
/**
* Getter method for property <tt>driver</tt>.
*
* @return property value of driver
*/
public String getDriver() {
return driver;
}
/**
* Setter method for property <tt>driver</tt>.
*
* @param driver value to be assigned to property driver
*/
public void setDriver(String driver) {
this.driver = driver;
}
/**
* Getter method for property <tt>jdbc</tt>.
*
* @return property value of jdbc
*/
public String getJdbc() {
return jdbc;
}
/**
* Setter method for property <tt>jdbc</tt>.
*
* @param jdbc value to be assigned to property jdbc
*/
public void setJdbc(String jdbc) {
this.jdbc = jdbc;
}
/**
* Getter method for property <tt>extendInfo</tt>.
*
* @return property value of extendInfo
*/
public List<KeyValue> getExtendInfo() {
return extendInfo;
}
/**
* Setter method for property <tt>extendInfo</tt>.
*
* @param extendInfo value to be assigned to property extendInfo
*/
public void setExtendInfo(List<KeyValue> extendInfo) {
this.extendInfo = extendInfo;
}
/**
* Getter method for property <tt>connection</tt>.
*
* @return property value of connection
*/
public Connection getConnection() {
return connection;
}
/**
* Setter method for property <tt>connection</tt>.
*
* @param connection value to be assigned to property connection
*/
public void setConnection(Connection connection) {
this.connection = connection;
}
/**
* Getter method for property <tt>alias</tt>.
*
* @return property value of alias
*/
public String getAlias() {
return alias;
}
/**
* Setter method for property <tt>alias</tt>.
*
* @param alias value to be assigned to property alias
*/
public void setAlias(String alias) {
this.alias = alias;
}
public LocalDateTime getGmtCreate() {
return gmtCreate;
}
public void setGmtCreate(LocalDateTime gmtCreate) {
this.gmtCreate = gmtCreate;
}
public LocalDateTime getGmtModified() {
return gmtModified;
}
public void setGmtModified(LocalDateTime gmtModified) {
this.gmtModified = gmtModified;
}
}

View File

@ -0,0 +1,200 @@
/**
* alibaba.com Inc.
* Copyright (c) 2004-2023 All Rights Reserved.
*/
package ai.chat2db.spi.sql;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.sql.Connection;
import java.sql.Driver;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import com.alibaba.fastjson2.JSON;
import ai.chat2db.spi.config.DriverConfig;
import ai.chat2db.spi.model.DriverEntry;
import ai.chat2db.spi.util.JdbcJarUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static ai.chat2db.spi.util.JdbcJarUtils.getFullPath;
/**
* @author jipengfei
* @version : IsolationDriverManager.java
*/
public class IDriverManager {
private static final Logger log = LoggerFactory.getLogger(IDriverManager.class);
private static final Map<String, ClassLoader> CLASS_LOADER_MAP = new ConcurrentHashMap();
private static final Map<String, DriverEntry> DRIVER_ENTRY_MAP = new ConcurrentHashMap();
public static Connection getConnection(String url, DriverConfig driver) throws SQLException {
Properties info = new Properties();
return getConnection(url, info, driver);
}
public static Connection getConnection(String url, String user, String password, DriverConfig driver)
throws SQLException {
Properties info = new Properties();
if (user != null) {
info.put("user", user);
}
if (password != null) {
info.put("password", password);
}
return getConnection(url, info, driver);
}
public static Connection getConnection(String url, String user, String password, DriverConfig driver,
Map<String, Object> properties)
throws SQLException {
Properties info = new Properties();
if (StringUtils.isNotEmpty(user)) {
info.put("user", user);
}
if (StringUtils.isNotEmpty(password)) {
info.put("password", password);
}
info.putAll(properties);
return getConnection(url, info, driver);
}
public static Connection getConnection(String url, Properties info, DriverConfig driver)
throws SQLException {
if (url == null) {
throw new SQLException("The url cannot be null", "08001");
}
DriverManager.println("DriverManager.getConnection(\"" + url + "\")");
SQLException reason = null;
DriverEntry driverEntry = DRIVER_ENTRY_MAP.get(driver.getName());
if (driverEntry == null) {
driverEntry = getJDBCDriver(driver);
}
try {
Connection con = driverEntry.getDriver().connect(url, info);
if (con != null) {
DriverManager.println("getConnection returning " + driverEntry.getDriver().getClass().getName());
return con;
}
} catch (SQLException var7) {
Connection con = tryConnectionAgain(driverEntry, url, info);
if (con != null) {
return con;
} else {
throw var7;
}
}
if (reason != null) {
DriverManager.println("getConnection failed: " + reason);
throw reason;
} else {
DriverManager.println("getConnection: no suitable driver found for " + url);
throw new SQLException("No suitable driver found for " + url, "08001");
}
}
private static Connection tryConnectionAgain(DriverEntry driverEntry, String url,
Properties info) throws SQLException {
if (url.contains("mysql")) {
if (!info.containsKey("useSSL")) {
info.put("useSSL", "false");
}
return driverEntry.getDriver().connect(url, info);
}
return null;
}
private static DriverEntry getJDBCDriver(DriverConfig driver)
throws SQLException {
synchronized (driver) {
try {
if (DRIVER_ENTRY_MAP.containsKey(driver.getName())) {
return DRIVER_ENTRY_MAP.get(driver.getName());
}
ClassLoader cl = getClassLoader(driver);
Driver d = (Driver)cl.loadClass(driver.getJdbcDriverClass()).newInstance();
DriverEntry driverEntry = DriverEntry.builder().driverConfig(driver).driver(d).build();
DRIVER_ENTRY_MAP.put(driver.getName(), driverEntry);
return driverEntry;
} catch (Exception e) {
log.error("getJDBCDriver error", e);
throw new SQLException("getJDBCDriver error", "08001");
}
}
}
public static ClassLoader getClassLoader(DriverConfig driverConfig) throws MalformedURLException {
String jarPath = driverConfig.getJdbcDriver();
if (CLASS_LOADER_MAP.containsKey(jarPath)) {
return CLASS_LOADER_MAP.get(jarPath);
} else {
synchronized (jarPath) {
if (CLASS_LOADER_MAP.containsKey(jarPath)) {
return CLASS_LOADER_MAP.get(jarPath);
}
String[] jarPaths = jarPath.split(",");
URL[] urls = new URL[jarPaths.length];
for (int i = 0; i < jarPaths.length; i++) {
File driverFile = new File(getFullPath(jarPaths[i]));
urls[i] = driverFile.toURI().toURL();
}
//urls[jarPaths.length] = new File(JdbcJarUtils.getFullPath("HikariCP-4.0.3.jar")).toURI().toURL();
URLClassLoader cl = new URLClassLoader(urls, ClassLoader.getSystemClassLoader());
log.info("ClassLoader class:{}", cl.hashCode());
log.info("ClassLoader URLs:{}", JSON.toJSONString(cl.getURLs()));
try {
cl.loadClass(driverConfig.getJdbcDriverClass());
} catch (Exception e) {
//如果报错删除目录重试一次
for (int i = 0; i < jarPaths.length; i++) {
File driverFile = new File(JdbcJarUtils.getNewFullPath(jarPaths[i]));
urls[i] = driverFile.toURI().toURL();
}
//urls[jarPaths.length] = new File(JdbcJarUtils.getFullPath("HikariCP-4.0.3.jar")).toURI().toURL();
cl = new URLClassLoader(urls, ClassLoader.getSystemClassLoader());
}
CLASS_LOADER_MAP.put(jarPath, cl);
return cl;
}
}
}
//private static List<Class> loadClass(String jarPath, ClassLoader classLoader) throws IOException {
// Long s1 = System.currentTimeMillis();
// JarFile jarFile = new JarFile(getFullPath(jarPath));
// Enumeration<JarEntry> entries = jarFile.entries();
// List<Class> classes = new ArrayList();
// while (entries.hasMoreElements()) {
// JarEntry jarEntry = entries.nextElement();
// if (jarEntry.getName().endsWith(".class") && !jarEntry.getName().contains("$")) {
// String className = jarEntry.getName().substring(0, jarEntry.getName().length() - 6).replaceAll("/",
// ".");
// try {
// classes.add(classLoader.loadClass(className));
// // log.info("loadClass:{}", className);
// } catch (Throwable var7) {
// //log.error("getClasses error "+className, var7);
// }
// }
// }
// log.info("loadClass cost:{}", System.currentTimeMillis() - s1);
// return classes;
//}
}

View File

@ -0,0 +1,356 @@
/**
* alibaba.com Inc.
* Copyright (c) 2004-2022 All Rights Reserved.
*/
package ai.chat2db.spi.sql;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import ai.chat2db.spi.model.*;
import cn.hutool.core.date.TimeInterval;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import static ai.chat2db.spi.util.ResultSetUtils.buildColumn;
import static ai.chat2db.spi.util.ResultSetUtils.buildFunction;
import static ai.chat2db.spi.util.ResultSetUtils.buildProcedure;
import static ai.chat2db.spi.util.ResultSetUtils.buildTable;
import static ai.chat2db.spi.util.ResultSetUtils.buildTableIndexColumn;
/**
* Dbhub 统一数据库连接管理
* TODO 长时间不用连接可以关闭,待优化
*
* @author jipengfei
* @version : DbhubDataSource.java
*/
@Slf4j
public class SQLExecutor {
/**
* 全局单例
*/
private static final SQLExecutor INSTANCE = new SQLExecutor();
private SQLExecutor() {
}
public static SQLExecutor getInstance() {
return INSTANCE;
}
public Connection getConnection() throws SQLException {
return Chat2DBContext.getConnection();
}
public void close() {
}
/**
* 执行sql
*
* @param sql
* @param function
* @return
*/
public <R> R executeSql(String sql, Function<ResultSet, R> function) {
if (StringUtils.isEmpty(sql)) {
return null;
}
log.info("execute:{}", sql);
Statement stmt = null;
try {
stmt = getConnection().createStatement();
boolean query = stmt.execute(sql);
// 代表是查询
if (query) {
ResultSet rs = null;
try {
rs = stmt.getResultSet();
return function.apply(rs);
} finally {
if (rs != null) {
rs.close();
}
}
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
if (stmt != null) {
try {
stmt.close();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}
return null;
}
/**
* 执行sql
*
* @param sql
* @return
* @throws SQLException
*/
public ExecuteResult execute(final String sql, Connection connection) throws SQLException {
Assert.notNull(sql, "SQL must not be null");
log.info("execute:{}", sql);
ExecuteResult executeResult = ExecuteResult.builder().sql(sql).success(Boolean.TRUE).build();
Statement stmt = null;
try {
TimeInterval timeInterval=new TimeInterval();
stmt = connection.createStatement();
boolean query = stmt.execute(sql.replaceFirst(";", ""));
executeResult.setDescription("执行成功");
// 代表是查询
if (query) {
ResultSet rs = null;
try {
rs = stmt.getResultSet();
// 获取有几列
ResultSetMetaData resultSetMetaData = rs.getMetaData();
int col = resultSetMetaData.getColumnCount();
// 获取header信息
List<Header> headerList = Lists.newArrayListWithExpectedSize(col);
executeResult.setHeaderList(headerList);
for (int i = 1; i <= col; i++) {
headerList.add(Header.builder()
.dataType(ai.chat2db.spi.util.JdbcUtils.resolveDataType(
resultSetMetaData.getColumnTypeName(i), resultSetMetaData.getColumnType(i)).getCode())
.name(resultSetMetaData.getColumnName(i))
.build());
}
// 获取数据信息
List<List<String>> dataList = Lists.newArrayList();
executeResult.setDataList(dataList);
while (rs.next()) {
List<String> row = Lists.newArrayListWithExpectedSize(col);
dataList.add(row);
for (int i = 1; i <= col; i++) {
row.add(ai.chat2db.spi.util.JdbcUtils.getResultSetValue(rs, i));
}
}
executeResult.setDuration(timeInterval.interval());
return executeResult;
} finally {
JdbcUtils.closeResultSet(rs);
}
} else {
// 修改或者其他
executeResult.setUpdateCount(stmt.getUpdateCount());
}
} finally {
JdbcUtils.closeStatement(stmt);
}
return executeResult;
}
/**
* 执行sql
*
* @param sql
* @return
* @throws SQLException
*/
public ExecuteResult execute(String sql) throws SQLException {
return execute(sql, getConnection());
}
/**
* 获取所有的数据库
*
* @return
*/
public List<String> databases() {
List<String> tables = Lists.newArrayList();
try {
ResultSet resultSet = getConnection().getMetaData().getCatalogs();
if (resultSet != null) {
while (resultSet.next()) {
tables.add(resultSet.getString("TABLE_CAT"));
}
}
} catch (SQLException e) {
close();
throw new RuntimeException(e);
}
return tables;
}
/**
* 获取所有的schema
*
* @param databaseName
* @param schemaName
* @return
*/
public List<String> schemas(String databaseName, String schemaName) {
List<String> schemaList = Lists.newArrayList();
try {
ResultSet resultSet = getConnection().getMetaData().getSchemas(databaseName, schemaName);
if (resultSet != null) {
while (resultSet.next()) {
schemaList.add(resultSet.getString("TABLE_SCHEM"));
}
}
} catch (SQLException e) {
close();
throw new RuntimeException(e);
}
return schemaList;
}
/**
* 获取所有的数据库表
*
* @param databaseName
* @param schemaName
* @param tableName
* @param types
* @return
*/
public List<Table> tables(String databaseName, String schemaName, String tableName, String types[]) {
List<Table> tables = Lists.newArrayList();
try {
ResultSet resultSet = getConnection().getMetaData().getTables(databaseName, schemaName, tableName,
types);
if (resultSet != null) {
while (resultSet.next()) {
tables.add(buildTable(resultSet));
}
}
} catch (SQLException e) {
close();
throw new RuntimeException(e);
}
return tables;
}
/**
* 获取所有的数据库表列
*
* @param databaseName
* @param schemaName
* @param tableName
* @param columnName
* @return
*/
public List<TableColumn> columns(String databaseName, String schemaName, String tableName, String columnName) {
List<TableColumn> tableColumns = Lists.newArrayList();
try {
ResultSet resultSet = getConnection().getMetaData().getColumns(databaseName, schemaName, tableName,
columnName);
if (resultSet != null) {
while (resultSet.next()) {
tableColumns.add(buildColumn(resultSet));
}
}
} catch (Exception e) {
close();
throw new RuntimeException(e);
}
return tableColumns;
}
/**
* 获取所有的数据库表索引
*
* @param databaseName
* @param schemaName
* @param tableName
* @return
*/
public List<TableIndex> indexes(String databaseName, String schemaName, String tableName) {
List<TableIndex> tableIndices = Lists.newArrayList();
try {
List<TableIndexColumn> tableIndexColumns = Lists.newArrayList();
ResultSet resultSet = getConnection().getMetaData().getIndexInfo(databaseName, schemaName, tableName, false,
false);
while (resultSet != null && resultSet.next()) {
tableIndexColumns.add(buildTableIndexColumn(resultSet));
}
tableIndexColumns.stream().filter(c -> c.getIndexName() != null).collect(
Collectors.groupingBy(TableIndexColumn::getIndexName)).entrySet()
.stream().forEach(entry -> {
TableIndex tableIndex = new TableIndex();
TableIndexColumn column = entry.getValue().get(0);
tableIndex.setName(entry.getKey());
tableIndex.setTableName(column.getTableName());
tableIndex.setSchemaName(column.getSchemaName());
tableIndex.setDatabaseName(column.getDatabaseName());
tableIndex.setUnique(!column.getNonUnique());
tableIndex.setColumnList(entry.getValue());
tableIndices.add(tableIndex);
});
} catch (SQLException e) {
close();
throw new RuntimeException(e);
}
return tableIndices;
}
/**
* 获取所有的函数
*
* @param databaseName
* @param schemaName
* @return
*/
public List<ai.chat2db.spi.model.Function> functions(String databaseName,
String schemaName) {
List<ai.chat2db.spi.model.Function> functions = Lists.newArrayList();
try {
ResultSet resultSet = getConnection().getMetaData().getFunctions(databaseName, schemaName, null);
while (resultSet != null && resultSet.next()) {
functions.add(buildFunction(resultSet));
}
} catch (Exception e) {
close();
throw new RuntimeException(e);
}
return functions;
}
/**
* 获取所有的存储过程
*
* @param databaseName
* @param schemaName
* @return
*/
public List<Procedure> procedures(String databaseName, String schemaName) {
List<Procedure> procedures = Lists.newArrayList();
try {
ResultSet resultSet = getConnection().getMetaData().getProcedures(databaseName, schemaName, null);
while (resultSet != null && resultSet.next()) {
procedures.add(buildProcedure(resultSet));
}
} catch (Exception e) {
close();
throw new RuntimeException(e);
}
return procedures;
}
}

View File

@ -0,0 +1,83 @@
/**
* alibaba.com Inc.
* Copyright (c) 2004-2023 All Rights Reserved.
*/
package ai.chat2db.spi.sql;
import java.util.concurrent.ConcurrentHashMap;
import ai.chat2db.spi.model.SSHInfo;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* @author jipengfei
* @version : SSHSessionManager.java
*/
@Slf4j
public class SSHManager {
private static final ConcurrentHashMap<SSHInfo, Session> SSH_SESSION_MAP = new ConcurrentHashMap();
public static Session getSSHSession(SSHInfo sshInfo) {
Session session = SSH_SESSION_MAP.get(sshInfo);
if (session != null && session.isConnected()) {
return session;
} else {
return createSession(sshInfo);
}
}
private static Session createSession(SSHInfo ssh) {
synchronized (ssh) {
Session session = SSH_SESSION_MAP.get(ssh);
if (session != null && session.isConnected()) {
return session;
}
try {
JSch jSch = new JSch();
session = jSch.getSession(ssh.getUserName(), ssh.getHostName(), Integer.parseInt(ssh.getPort()));
session.setPassword(ssh.getPassword());
session.setConfig("StrictHostKeyChecking", "no");
session.connect();
SSH_SESSION_MAP.put(ssh, session);
} catch (Exception e) {
throw new RuntimeException("create ssh session error", e);
}
if (StringUtils.isNotBlank(ssh.getLocalPort()) && StringUtils.isNotBlank(ssh.getRHost())
&& StringUtils.isNotBlank(ssh.getRPort())) {
try {
int port1 = session.setPortForwardingL(Integer.parseInt(ssh.getLocalPort()), ssh.getRHost(),
Integer.parseInt(ssh.getRPort()));
} catch (Exception e) {
if (session != null && session.isConnected()) {
session.disconnect();
SSH_SESSION_MAP.remove(ssh);
}
throw new RuntimeException(ssh.getLocalPort() + " port is usedplease change to another port ", e);
}
}
return session;
}
}
public static void close() {
SSH_SESSION_MAP.forEach((k, v) -> {
if (v != null && v.isConnected()) {
try {
v.delPortForwardingL(Integer.parseInt(k.getLocalPort()));
} catch (Exception e) {
log.error("delPortForwardingL error", e);
}
try {
v.disconnect();
} catch (Exception e) {
log.error("disconnect error", e);
}
}
});
}
}