This commit is contained in:
jipengfei-jpf
2023-07-17 20:51:06 +08:00
parent de6fa3cb58
commit c81985e916
41 changed files with 294 additions and 795 deletions

View File

@ -138,7 +138,7 @@ public class Chat2DBContext {
connectInfo.setSession(session);
connectInfo.setConnection(connection);
if (StringUtils.isNotBlank(connectInfo.getDatabaseName())) {
PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBManage().connectDatabase(
PLUGIN_MAP.get(getConnectInfo().getDbType()).getDBManage().connectDatabase(connection,
connectInfo.getDatabaseName());
}
return connection;
@ -181,4 +181,4 @@ public class Chat2DBContext {
}
}
}
}

View File

@ -48,27 +48,27 @@ public class SQLExecutor {
return INSTANCE;
}
public Connection getConnection() throws SQLException {
return Chat2DBContext.getConnection();
}
//public Connection connection throws SQLException {
// return Chat2DBContext.connection;
//}
public void close() {
}
/**
* 执行sql
*
* @param connection
* @param sql
* @param function
* @return
*/
public <R> R executeSql(String sql, Function<ResultSet, R> function) {
public <R> R executeSql(Connection connection,String sql, Function<ResultSet, R> function) {
if (StringUtils.isEmpty(sql)) {
return null;
}
log.info("execute:{}", sql);
try (Statement stmt = getConnection().createStatement();) {
try (Statement stmt = connection.createStatement();) {
boolean query = stmt.execute(sql);
// 代表是查询
if (query) {
@ -148,22 +148,23 @@ public class SQLExecutor {
/**
* 执行sql
*
* @param connection
* @param sql
* @return
* @throws SQLException
*/
public ExecuteResult execute(String sql) throws SQLException {
return execute(sql, getConnection());
public ExecuteResult execute(Connection connection,String sql) throws SQLException {
return execute(sql, connection);
}
/**
* 获取所有的数据库
*
* @param connection
* @return
*/
public List<String> databases() {
public List<String> databases(Connection connection) {
List<String> tables = Lists.newArrayList();
try (ResultSet resultSet = getConnection().getMetaData().getCatalogs();) {
try (ResultSet resultSet = connection.getMetaData().getCatalogs();) {
if (resultSet != null) {
while (resultSet.next()) {
tables.add(resultSet.getString("TABLE_CAT"));
@ -177,15 +178,15 @@ public class SQLExecutor {
/**
* 获取所有的schema
*
* @param connection
* @param databaseName
* @param schemaName
* @return
*/
public List<Map<String, String>> schemas(String databaseName, String schemaName) {
public List<Map<String, String>> schemas(Connection connection,String databaseName, String schemaName) {
List<Map<String, String>> schemaList = Lists.newArrayList();
if (StringUtils.isEmpty(databaseName) && StringUtils.isEmpty(schemaName)) {
try (ResultSet resultSet = getConnection().getMetaData().getSchemas()) {
try (ResultSet resultSet = connection.getMetaData().getSchemas()) {
if (resultSet != null) {
while (resultSet.next()) {
Map<String, String> map = new HashMap<>();
@ -199,7 +200,7 @@ public class SQLExecutor {
}
return schemaList;
}
try (ResultSet resultSet = getConnection().getMetaData().getSchemas(databaseName, schemaName)) {
try (ResultSet resultSet = connection.getMetaData().getSchemas(databaseName, schemaName)) {
if (resultSet != null) {
while (resultSet.next()) {
Map<String, String> map = new HashMap<>();
@ -216,17 +217,17 @@ public class SQLExecutor {
/**
* 获取所有的数据库表
*
* @param connection
* @param databaseName
* @param schemaName
* @param tableName
* @param types
* @return
*/
public List<Table> tables(String databaseName, String schemaName, String tableName, String types[]) {
public List<Table> tables(Connection connection,String databaseName, String schemaName, String tableName, String types[]) {
List<Table> tables = Lists.newArrayList();
int n = 0;
try (ResultSet resultSet = getConnection().getMetaData().getTables(databaseName, schemaName, tableName,
try (ResultSet resultSet = connection.getMetaData().getTables(databaseName, schemaName, tableName,
types)) {
if (resultSet != null) {
while (resultSet.next()) {
@ -245,16 +246,16 @@ public class SQLExecutor {
/**
* 获取所有的数据库表列
*
* @param connection
* @param databaseName
* @param schemaName
* @param tableName
* @param columnName
* @return
*/
public List<TableColumn> columns(String databaseName, String schemaName, String tableName, String columnName) {
public List<TableColumn> columns(Connection connection,String databaseName, String schemaName, String tableName, String columnName) {
List<TableColumn> tableColumns = Lists.newArrayList();
try (ResultSet resultSet = getConnection().getMetaData().getColumns(databaseName, schemaName, tableName,
try (ResultSet resultSet = connection.getMetaData().getColumns(databaseName, schemaName, tableName,
columnName)) {
if (resultSet != null) {
while (resultSet.next()) {
@ -269,15 +270,15 @@ public class SQLExecutor {
/**
* 获取所有的数据库表索引
*
* @param connection
* @param databaseName
* @param schemaName
* @param tableName
* @return
*/
public List<TableIndex> indexes(String databaseName, String schemaName, String tableName) {
public List<TableIndex> indexes(Connection connection,String databaseName, String schemaName, String tableName) {
List<TableIndex> tableIndices = Lists.newArrayList();
try (ResultSet resultSet = getConnection().getMetaData().getIndexInfo(databaseName, schemaName, tableName,
try (ResultSet resultSet = connection.getMetaData().getIndexInfo(databaseName, schemaName, tableName,
false,
false)) {
List<TableIndexColumn> tableIndexColumns = Lists.newArrayList();
@ -307,15 +308,15 @@ public class SQLExecutor {
/**
* 获取所有的函数
*
* @param connection
* @param databaseName
* @param schemaName
* @return
*/
public List<ai.chat2db.spi.model.Function> functions(String databaseName,
public List<ai.chat2db.spi.model.Function> functions(Connection connection,String databaseName,
String schemaName) {
List<ai.chat2db.spi.model.Function> functions = Lists.newArrayList();
try (ResultSet resultSet = getConnection().getMetaData().getFunctions(databaseName, schemaName, null);) {
try (ResultSet resultSet = connection.getMetaData().getFunctions(databaseName, schemaName, null);) {
while (resultSet != null && resultSet.next()) {
functions.add(buildFunction(resultSet));
}
@ -327,14 +328,14 @@ public class SQLExecutor {
/**
* 获取所有的存储过程
*
* @param connection
* @param databaseName
* @param schemaName
* @return
*/
public List<Procedure> procedures(String databaseName, String schemaName) {
public List<Procedure> procedures(Connection connection,String databaseName, String schemaName) {
List<Procedure> procedures = Lists.newArrayList();
try (ResultSet resultSet = getConnection().getMetaData().getProcedures(databaseName, schemaName, null)) {
try (ResultSet resultSet = connection.getMetaData().getProcedures(databaseName, schemaName, null)) {
while (resultSet != null && resultSet.next()) {
procedures.add(buildProcedure(resultSet));
}