Fix bug with invalid schema specification

This commit is contained in:
SwallowGG
2023-08-29 23:38:24 +08:00
parent 656cd3a338
commit 9b659016f6
15 changed files with 269 additions and 241 deletions

View File

@ -5,22 +5,14 @@ import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import ai.chat2db.server.tools.base.constant.EasyToolsConstant;
import ai.chat2db.server.tools.common.util.I18nUtils;
import ai.chat2db.spi.model.ExecuteResult;
import ai.chat2db.spi.model.Header;
import ai.chat2db.spi.model.Procedure;
import ai.chat2db.spi.model.Table;
import ai.chat2db.spi.model.TableColumn;
import ai.chat2db.spi.model.TableIndex;
import ai.chat2db.spi.model.TableIndexColumn;
import ai.chat2db.spi.model.*;
import ai.chat2db.spi.util.ResultSetUtils;
import cn.hutool.core.date.TimeInterval;
import com.google.common.collect.Lists;
@ -29,18 +21,10 @@ import org.apache.commons.lang3.StringUtils;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.util.Assert;
import static ai.chat2db.spi.util.ResultSetUtils.buildColumn;
import static ai.chat2db.spi.util.ResultSetUtils.buildFunction;
import static ai.chat2db.spi.util.ResultSetUtils.buildProcedure;
import static ai.chat2db.spi.util.ResultSetUtils.buildTable;
import static ai.chat2db.spi.util.ResultSetUtils.buildTableIndexColumn;
/**
* Dbhub 统一数据库连接管理
* TODO 长时间不用连接可以关闭,待优化
*
* @author jipengfei
* @version : DbhubDataSource.java
*/
@Slf4j
public class SQLExecutor {
@ -56,9 +40,6 @@ public class SQLExecutor {
return INSTANCE;
}
//public Connection connection throws SQLException {
// return Chat2DBContext.connection;
//}
public void close() {
}
@ -72,7 +53,7 @@ public class SQLExecutor {
* @return
*/
public <R> R executeSql(Connection connection, String sql, Function<ResultSet, R> function) {
public <R> R executeSql(Connection connection, String sql, Function<ResultSet, R> function) {
if (StringUtils.isBlank(sql)) {
return null;
}
@ -111,12 +92,12 @@ public class SQLExecutor {
}
public void executeSql(Connection connection, String sql, Consumer<List<Header>> headerConsumer,
Consumer<List<String>> rowConsumer) {
Consumer<List<String>> rowConsumer) {
executeSql(connection, sql, headerConsumer, rowConsumer, true);
}
public void executeSql(Connection connection, String sql, Consumer<List<Header>> headerConsumer,
Consumer<List<String>> rowConsumer, boolean limitSize) {
Consumer<List<String>> rowConsumer, boolean limitSize) {
Assert.notNull(sql, "SQL must not be null");
log.info("execute:{}", sql);
try (Statement stmt = connection.createStatement();) {
@ -134,10 +115,10 @@ public class SQLExecutor {
List<Header> headerList = Lists.newArrayListWithExpectedSize(col);
for (int i = 1; i <= col; i++) {
headerList.add(Header.builder()
.dataType(ai.chat2db.spi.util.JdbcUtils.resolveDataType(
resultSetMetaData.getColumnTypeName(i), resultSetMetaData.getColumnType(i)).getCode())
.name(ResultSetUtils.getColumnName(resultSetMetaData, i))
.build());
.dataType(ai.chat2db.spi.util.JdbcUtils.resolveDataType(
resultSetMetaData.getColumnTypeName(i), resultSetMetaData.getColumnType(i)).getCode())
.name(ResultSetUtils.getColumnName(resultSetMetaData, i))
.build());
}
headerConsumer.accept(headerList);
@ -199,10 +180,10 @@ public class SQLExecutor {
executeResult.setHeaderList(headerList);
for (int i = 1; i <= col; i++) {
headerList.add(Header.builder()
.dataType(ai.chat2db.spi.util.JdbcUtils.resolveDataType(
resultSetMetaData.getColumnTypeName(i), resultSetMetaData.getColumnType(i)).getCode())
.name(ResultSetUtils.getColumnName(resultSetMetaData, i))
.build());
.dataType(ai.chat2db.spi.util.JdbcUtils.resolveDataType(
resultSetMetaData.getColumnTypeName(i), resultSetMetaData.getColumnType(i)).getCode())
.name(ResultSetUtils.getColumnName(resultSetMetaData, i))
.build());
}
// 获取数据信息
@ -249,58 +230,43 @@ public class SQLExecutor {
* @param connection
* @return
*/
public List<String> databases(Connection connection) {
List<String> tables = Lists.newArrayList();
public List<Database> databases(Connection connection) {
try (ResultSet resultSet = connection.getMetaData().getCatalogs();) {
if (resultSet != null) {
while (resultSet.next()) {
tables.add(resultSet.getString("TABLE_CAT"));
}
}
return ResultSetUtils.toObjectList(resultSet, Database.class);
} catch (SQLException e) {
throw new RuntimeException(e);
}
return tables;
}
/**
* 获取所有的schema
*
* @param connection
* @param databaseName
* @param schemaName
* @return
* Retrieves the schema names available in this database. The results are ordered by TABLE_CATALOG and TABLE_SCHEM.
* The schema columns are:
* TABLE_SCHEM String => schema name
* TABLE_CATALOG String => catalog name (may be null)
* Params:
* catalog a catalog name; must match the catalog name as it is stored in the database;"" retrieves those without a catalog; null means catalog name should not be used to narrow down the search. schemaPattern a schema name; must match the schema name as it is stored in the database; null means schema name should not be used to narrow down the search.
* Returns:
* a ResultSet object in which each row is a schema description
* Throws:
* SQLException if a database access error occurs
* Since:
* 1.6
* See Also:
* getSearchStringEscape
*/
public List<Map<String, String>> schemas(Connection connection, String databaseName, String schemaName) {
List<Map<String, String>> schemaList = Lists.newArrayList();
public List<Schema> schemas(Connection connection, String databaseName, String schemaName) {
if (StringUtils.isEmpty(databaseName) && StringUtils.isEmpty(schemaName)) {
try (ResultSet resultSet = connection.getMetaData().getSchemas()) {
if (resultSet != null) {
while (resultSet.next()) {
Map<String, String> map = new HashMap<>();
map.put("name", resultSet.getString("TABLE_SCHEM"));
map.put("databaseName", resultSet.getString("TABLE_CATALOG"));
schemaList.add(map);
}
}
return ResultSetUtils.toObjectList(resultSet, Schema.class);
} catch (SQLException e) {
throw new RuntimeException("Get schemas error", e);
}
return schemaList;
}
try (ResultSet resultSet = connection.getMetaData().getSchemas(databaseName, schemaName)) {
if (resultSet != null) {
while (resultSet.next()) {
Map<String, String> map = new HashMap<>();
map.put("name", resultSet.getString("TABLE_SCHEM"));
map.put("databaseName", resultSet.getString("TABLE_CATALOG"));
schemaList.add(map);
}
}
return ResultSetUtils.toObjectList(resultSet, Schema.class);
} catch (SQLException e) {
throw new RuntimeException("Get schemas error", e);
}
return schemaList;
}
/**
@ -314,24 +280,13 @@ public class SQLExecutor {
* @return
*/
public List<Table> tables(Connection connection, String databaseName, String schemaName, String tableName,
String types[]) {
List<Table> tables = Lists.newArrayList();
int n = 0;
String types[]) {
try (ResultSet resultSet = connection.getMetaData().getTables(databaseName, schemaName, tableName,
types)) {
if (resultSet != null) {
while (resultSet.next()) {
n++;
tables.add(buildTable(resultSet));
if (n >= 1000) {// 最多只取1000条
break;
}
}
}
types)) {
return ResultSetUtils.toObjectList(resultSet, Table.class);
} catch (SQLException e) {
throw new RuntimeException(e);
}
return tables;
}
/**
@ -345,99 +300,81 @@ public class SQLExecutor {
* @return
*/
public List<TableColumn> columns(Connection connection, String databaseName, String schemaName, String tableName,
String columnName) {
List<TableColumn> tableColumns = Lists.newArrayList();
String columnName) {
try (ResultSet resultSet = connection.getMetaData().getColumns(databaseName, schemaName, tableName,
columnName)) {
if (resultSet != null) {
while (resultSet.next()) {
tableColumns.add(buildColumn(resultSet));
}
}
columnName)) {
return ResultSetUtils.toObjectList(resultSet, TableColumn.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
return tableColumns;
}
/**
* 获取所有的数据库表索引
* get all table index info
*
* @param connection
* @param databaseName
* @param schemaName
* @param tableName
* @return
* @param connection connection
* @param databaseName databaseName of the index
* @param schemaName schemaName of the index
* @param tableName tableName of the index
* @return List<TableIndex> table index list
*/
public List<TableIndex> indexes(Connection connection, String databaseName, String schemaName, String tableName) {
List<TableIndex> tableIndices = Lists.newArrayList();
try (ResultSet resultSet = connection.getMetaData().getIndexInfo(databaseName, schemaName, tableName,
false,
false)) {
List<TableIndexColumn> tableIndexColumns = Lists.newArrayList();
while (resultSet != null && resultSet.next()) {
tableIndexColumns.add(buildTableIndexColumn(resultSet));
}
false,
false)) {
List<TableIndexColumn> tableIndexColumns = ResultSetUtils.toObjectList(resultSet, TableIndexColumn.class);
tableIndexColumns.stream().filter(c -> c.getIndexName() != null).collect(
Collectors.groupingBy(TableIndexColumn::getIndexName)).entrySet()
.stream().forEach(entry -> {
TableIndex tableIndex = new TableIndex();
TableIndexColumn column = entry.getValue().get(0);
tableIndex.setName(entry.getKey());
tableIndex.setTableName(column.getTableName());
tableIndex.setSchemaName(column.getSchemaName());
tableIndex.setDatabaseName(column.getDatabaseName());
tableIndex.setUnique(!column.getNonUnique());
tableIndex.setColumnList(entry.getValue());
tableIndices.add(tableIndex);
});
Collectors.groupingBy(TableIndexColumn::getIndexName)).entrySet()
.stream().forEach(entry -> {
TableIndex tableIndex = new TableIndex();
TableIndexColumn column = entry.getValue().get(0);
tableIndex.setName(entry.getKey());
tableIndex.setTableName(column.getTableName());
tableIndex.setSchemaName(column.getSchemaName());
tableIndex.setDatabaseName(column.getDatabaseName());
tableIndex.setUnique(!column.getNonUnique());
tableIndex.setColumnList(entry.getValue());
tableIndices.add(tableIndex);
});
} catch (SQLException e) {
throw new RuntimeException(e);
}
return tableIndices;
}
/**
* 获取所有的函数
* Get all functions available in a catalog.
*
* @param connection
* @param databaseName
* @param schemaName
* @return
* @param connection connection
* @param databaseName databaseName of the function
* @param schemaName schemaName of the function
* @return List<Function>
*/
public List<ai.chat2db.spi.model.Function> functions(Connection connection, String databaseName,
String schemaName) {
List<ai.chat2db.spi.model.Function> functions = Lists.newArrayList();
String schemaName) {
try (ResultSet resultSet = connection.getMetaData().getFunctions(databaseName, schemaName, null);) {
while (resultSet != null && resultSet.next()) {
functions.add(buildFunction(resultSet));
}
return ResultSetUtils.toObjectList(resultSet, ai.chat2db.spi.model.Function.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
return functions;
}
/**
* 获取所有的存储过程
* procedure list
*
* @param connection
* @param databaseName
* @param schemaName
* @return
* @param connection connection
* @param databaseName databaseName
* @param schemaName schemaName
* @return List<Procedure>
*/
public List<Procedure> procedures(Connection connection, String databaseName, String schemaName) {
List<Procedure> procedures = Lists.newArrayList();
try (ResultSet resultSet = connection.getMetaData().getProcedures(databaseName, schemaName, null)) {
while (resultSet != null && resultSet.next()) {
procedures.add(buildProcedure(resultSet));
}
return ResultSetUtils.toObjectList(resultSet, Procedure.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
return procedures;
}
}