fix(chat2db): ensure proper escaping of string values in SQL queries

String values in SQL queries are now properly escaped to prevent potential security issues
and incorrect query syntax. This update affects the JDBC value processing logic and the
way columns are built into SQL queries, streamlining the escaping mechanism for various
data types.

The changes include:
- Removal of unnecessary null checks that were redundant with Objects.isNull().
- Streamlined string escaping logic using EasyStringUtils.escapeAndQuoteString().- Utilization of the stream API for more concise and readable code.

BREAKING CHANGE: If any external code relies on the previous behavior of not escaping
string values, it must now handle the escaped values appropriately to avoid syntax
errors or potential SQL injection vulnerabilities.
This commit is contained in:
zgq
2024-07-08 17:01:42 +08:00
parent 6e3b58a8f1
commit ec9121bf35
12 changed files with 251 additions and 38 deletions

View File

@ -8,10 +8,13 @@ import ai.chat2db.spi.model.Database;
import ai.chat2db.spi.model.Table;
import ai.chat2db.spi.model.TableColumn;
import ai.chat2db.spi.model.TableIndex;
import ai.chat2db.spi.util.SqlUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
public class MysqlSqlBuilder extends DefaultSqlBuilder {
@ -103,7 +106,7 @@ public class MysqlSqlBuilder extends DefaultSqlBuilder {
// 判断新增字段
List<TableColumn> addColumnList = new ArrayList<>();
for (TableColumn tableColumn : newTable.getColumnList()) {
if (tableColumn.getEditStatus() != null ? tableColumn.getEditStatus().equals("ADD") : false) {
if (tableColumn.getEditStatus() != null ? tableColumn.getEditStatus().equals("ADD") : false) {
addColumnList.add(tableColumn);
}
}
@ -116,7 +119,7 @@ public class MysqlSqlBuilder extends DefaultSqlBuilder {
if ((StringUtils.isNotBlank(tableColumn.getEditStatus()) && StringUtils.isNotBlank(tableColumn.getColumnType())
&& StringUtils.isNotBlank(tableColumn.getName())) || moveColumnList.contains(tableColumn) || addColumnList.contains(tableColumn)) {
MysqlColumnTypeEnum typeEnum = MysqlColumnTypeEnum.getByType(tableColumn.getColumnType());
if(typeEnum == null){
if (typeEnum == null) {
continue;
}
if (moveColumnList.contains(tableColumn) || addColumnList.contains(tableColumn)) {
@ -131,7 +134,7 @@ public class MysqlSqlBuilder extends DefaultSqlBuilder {
for (TableIndex tableIndex : newTable.getIndexList()) {
if (StringUtils.isNotBlank(tableIndex.getEditStatus()) && StringUtils.isNotBlank(tableIndex.getType())) {
MysqlIndexTypeEnum mysqlIndexTypeEnum = MysqlIndexTypeEnum.getByType(tableIndex.getType());
if(mysqlIndexTypeEnum == null){
if (mysqlIndexTypeEnum == null) {
continue;
}
script.append("\t").append(mysqlIndexTypeEnum.buildModifyIndex(tableIndex)).append(",\n");
@ -139,13 +142,13 @@ public class MysqlSqlBuilder extends DefaultSqlBuilder {
}
// append reorder column
// script.append(buildGenerateReorderColumnSql(oldTable, newTable));
// script.append(buildGenerateReorderColumnSql(oldTable, newTable));
if (script.length() > 2) {
script = new StringBuilder(script.substring(0, script.length() - 2));
script.append(";");
return tableBuilder.append(script).toString();
}else {
} else {
return StringUtils.EMPTY;
}
@ -400,8 +403,22 @@ public class MysqlSqlBuilder extends DefaultSqlBuilder {
@Override
protected void buildTableName(String databaseName, String schemaName, String tableName, StringBuilder script) {
if (StringUtils.isNotBlank(databaseName)) {
script.append("`").append(databaseName).append("`").append('.');
script.append(SqlUtils.quoteObjectName(databaseName, "`")).append('.');
}
script.append("`").append(tableName).append("`");
script.append(SqlUtils.quoteObjectName(tableName, "`"));
}
/**
* @param columnList
* @param script
*/
@Override
protected void buildColumns(List<String> columnList, StringBuilder script) {
if (CollectionUtils.isNotEmpty(columnList)) {
script.append(" (")
.append(columnList.stream().map(s -> SqlUtils.quoteObjectName(s, "`")).collect(Collectors.joining(",")))
.append(") ");
}
}
}

View File

@ -1,10 +1,13 @@
package ai.chat2db.plugin.mysql.value;
import ai.chat2db.plugin.mysql.value.factory.MysqlValueProcessorFactory;
import ai.chat2db.server.tools.common.util.EasyStringUtils;
import ai.chat2db.spi.jdbc.DefaultValueProcessor;
import ai.chat2db.spi.model.JDBCDataValue;
import ai.chat2db.spi.model.SQLDataValue;
import org.apache.commons.lang3.StringUtils;
import java.util.Objects;
import java.util.Set;
/**
@ -17,6 +20,46 @@ import java.util.Set;
public class MysqlValueProcessor extends DefaultValueProcessor {
public static final Set<String> FUNCTION_SET = Set.of("now()", "default");
@Override
public String getJdbcValue(JDBCDataValue dataValue) {
Object value = dataValue.getObject();
if (Objects.isNull(value)) {
// mysql -> example: [date]->0000-00-00
String stringValue = dataValue.getStringValue();
if (Objects.nonNull(stringValue)) {
return stringValue;
}
return null;
}
if (value instanceof String emptyStr) {
if (StringUtils.isBlank(emptyStr)) {
return emptyStr;
}
}
return convertJDBCValueByType(dataValue);
}
@Override
public String getJdbcValueString(JDBCDataValue dataValue) {
Object value = dataValue.getObject();
if (Objects.isNull(value)) {
// mysql -> example: [date]->0000-00-00
String stringValue = dataValue.getStringValue();
if (Objects.nonNull(stringValue)) {
return EasyStringUtils.escapeAndQuoteString(stringValue);
}
return "NULL";
}
if (value instanceof String stringValue) {
if (StringUtils.isBlank(stringValue)) {
return EasyStringUtils.quoteString(stringValue);
}
}
return convertJDBCValueStrByType(dataValue);
}
@Override
public String convertSQLValueByType(SQLDataValue dataValue) {
if (FUNCTION_SET.contains(dataValue.getValue().toLowerCase())) {

View File

@ -13,7 +13,7 @@ public class MysqlTimestampProcessor extends DefaultValueProcessor {
@Override
public String convertSQLValueByType(SQLDataValue dataValue) {
return dataValue.getValue();
return EasyStringUtils.quoteString(dataValue.getValue());
}