Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
SwallowGG
2023-11-28 16:01:22 +08:00
11 changed files with 155 additions and 68 deletions

View File

@ -1,15 +1,5 @@
package ai.chat2db.spi.sql;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import ai.chat2db.server.tools.base.constant.EasyToolsConstant;
import ai.chat2db.server.tools.common.util.I18nUtils;
import ai.chat2db.spi.ValueHandler;
@ -24,6 +14,12 @@ import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.Assert;
import java.sql.*;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Dbhub 统一数据库连接管理
*
@ -94,12 +90,12 @@ public class SQLExecutor {
}
public void executeSql(Connection connection, String sql, Consumer<List<Header>> headerConsumer,
Consumer<List<String>> rowConsumer,ValueHandler valueHandler) {
executeSql(connection, sql, headerConsumer, rowConsumer, true,valueHandler);
Consumer<List<String>> rowConsumer, ValueHandler valueHandler) {
executeSql(connection, sql, headerConsumer, rowConsumer, true, valueHandler);
}
public void executeSql(Connection connection, String sql, Consumer<List<Header>> headerConsumer,
Consumer<List<String>> rowConsumer, boolean limitSize,ValueHandler valueHandler) {
Consumer<List<String>> rowConsumer, boolean limitSize, ValueHandler valueHandler) {
Assert.notNull(sql, "SQL must not be null");
log.info("execute:{}", sql);
try (Statement stmt = connection.createStatement();) {
@ -147,8 +143,8 @@ public class SQLExecutor {
* @return
* @throws SQLException
*/
public ExecuteResult execute(final String sql, Connection connection,ValueHandler valueHandler) throws SQLException {
return execute(sql, connection, true, null, null,valueHandler);
public ExecuteResult execute(final String sql, Connection connection, ValueHandler valueHandler) throws SQLException {
return execute(sql, connection, true, null, null, valueHandler);
}
public ExecuteResult executeUpdate(final String sql, Connection connection, int n)
@ -162,7 +158,7 @@ public class SQLExecutor {
if (affectedRows != n) {
executeResult.setSuccess(false);
executeResult.setMessage("Update error " + sql + " update affectedRows = " + affectedRows + ", Each SQL statement should update no more than one record. Please use a unique key for updates.");
// connection.rollback();
// connection.rollback();
}
}
return executeResult;
@ -270,12 +266,12 @@ public class SQLExecutor {
* @return
* @throws SQLException
*/
public ExecuteResult execute(Connection connection, String sql,ValueHandler valueHandler) throws SQLException {
return execute(sql, connection, true, null, null,valueHandler);
public ExecuteResult execute(Connection connection, String sql, ValueHandler valueHandler) throws SQLException {
return execute(sql, connection, true, null, null, valueHandler);
}
public ExecuteResult execute(Connection connection, String sql) throws SQLException {
return execute(sql, connection, true, null, null,new DefaultValueHandler());
return execute(sql, connection, true, null, null, new DefaultValueHandler());
}
/**
@ -342,8 +338,33 @@ public class SQLExecutor {
*/
public List<Table> tables(Connection connection, String databaseName, String schemaName, String tableName,
String types[]) {
try (ResultSet resultSet = connection.getMetaData().getTables(databaseName, schemaName, tableName,
types)) {
try {
DatabaseMetaData metadata = connection.getMetaData();
ResultSet resultSet = metadata.getTables(databaseName, schemaName, tableName,
types);
// 如果connection为mysql
if ("MySQL".equalsIgnoreCase(metadata.getDatabaseProductName())) {
// 获取mysql表的comment
List<Table> tables = ResultSetUtils.toObjectList(resultSet, Table.class);
if (CollectionUtils.isNotEmpty(tables)) {
for (Table table : tables) {
String sql = "show table status where name = '" + table.getName() + "'";
try (Statement stmt = connection.createStatement()) {
boolean query = stmt.execute(sql);
if (query) {
try (ResultSet rs = stmt.getResultSet();) {
while (rs.next()) {
table.setComment(rs.getString("Comment"));
}
}
}
}
}
return tables;
}
}
return ResultSetUtils.toObjectList(resultSet, Table.class);
} catch (SQLException e) {
throw new RuntimeException(e);