From 21b9d45ca6be144b284305bf3b27481555cf4ac1 Mon Sep 17 00:00:00 2001 From: Sam Jewell <2903904+samjewell@users.noreply.github.com> Date: Wed, 12 Mar 2025 15:57:50 +0000 Subject: [PATCH] SQL Expressions: Add CASE/WHEN nodes and fixes (and test) for functions just added to allowlist (#102040) * SQL Expressions: Add CASE/WHEN SQL nodes to allowlist * Fixed and test for functions added in #102011 * Add remaining functions to the test-case These are mostly aliases, so the LLM chose to omit them originally. But adding now for completeness * Fix ordering of allowed nodes --- pkg/expr/sql/parser_allow.go | 15 ++++-- pkg/expr/sql/parser_allow_test.go | 85 +++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/pkg/expr/sql/parser_allow.go b/pkg/expr/sql/parser_allow.go index e7b1beed175..f05d410d9ef 100644 --- a/pkg/expr/sql/parser_allow.go +++ b/pkg/expr/sql/parser_allow.go @@ -63,6 +63,9 @@ func allowedNode(node sqlparser.SQLNode) (b bool) { case sqlparser.BoolVal: return + case *sqlparser.CaseExpr, *sqlparser.When: + return + case sqlparser.ColIdent, *sqlparser.ColName, sqlparser.Columns: return @@ -75,7 +78,7 @@ func allowedNode(node sqlparser.SQLNode) (b bool) { case *sqlparser.ComparisonExpr: return - case *sqlparser.ConvertExpr: + case *sqlparser.ConvertExpr, *sqlparser.ConvertType: return case sqlparser.GroupBy: @@ -84,6 +87,9 @@ func allowedNode(node sqlparser.SQLNode) (b bool) { case *sqlparser.IndexHints: return + case *sqlparser.IntervalExpr: + return + case *sqlparser.Into: return @@ -120,6 +126,9 @@ func allowedNode(node sqlparser.SQLNode) (b bool) { case sqlparser.TableName, sqlparser.TableExprs, sqlparser.TableIdent: return + case *sqlparser.TrimExpr: + return + case *sqlparser.With: return @@ -165,7 +174,7 @@ func allowedFunction(f *sqlparser.FuncExpr) (b bool) { return case "lower", "upper": return - case "substring", "trim": + case "substring": return // Date functions @@ -183,7 +192,7 @@ func allowedFunction(f *sqlparser.FuncExpr) (b bool) { return // Type conversion - case "cast", "convert": + case "cast": return default: diff --git a/pkg/expr/sql/parser_allow_test.go b/pkg/expr/sql/parser_allow_test.go index 98dd3277173..fc4b33369df 100644 --- a/pkg/expr/sql/parser_allow_test.go +++ b/pkg/expr/sql/parser_allow_test.go @@ -22,6 +22,16 @@ func TestAllowQuery(t *testing.T) { q: example_argo_commit_example, err: nil, }, + { + name: "case statement", + q: example_case_statement, + err: nil, + }, + { + name: "all allowed functions", + q: example_all_allowed_functions, + err: nil, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -109,3 +119,78 @@ FROM drone, argo_success, argo_failure, workflows;` + +var example_case_statement = `SELECT + value, + CASE + WHEN value > 100 THEN 'High' + WHEN value > 50 THEN 'Medium' + ELSE 'Low' + END AS category +FROM metrics` + +var example_all_allowed_functions = `SELECT + -- Conditional functions + IF(value > 100, 'High', 'Low') AS conditional_if, + COALESCE(value, 0) AS conditional_coalesce, + IFNULL(value, 0) AS conditional_ifnull, + NULLIF(value, 0) AS conditional_nullif, + + -- Aggregation functions + SUM(value) AS agg_sum, + AVG(value) AS agg_avg, + COUNT(*) AS agg_count, + MIN(value) AS agg_min, + MAX(value) AS agg_max, + STDDEV(value) AS agg_stddev, + STD(value) AS agg_std, + STDDEV_POP(value) AS agg_stddev_pop, + VARIANCE(value) AS agg_variance, + VAR_POP(value) AS agg_var_pop, + + -- Mathematical functions + ABS(value) AS math_abs, + ROUND(value, 2) AS math_round, + FLOOR(value) AS math_floor, + CEILING(value) AS math_ceiling, + CEIL(value) AS math_ceil, + SQRT(ABS(value)) AS math_sqrt, + POW(value, 2) AS math_pow, + POWER(value, 2) AS math_power, + MOD(value, 10) AS math_mod, + LOG(value) AS math_log, + LOG10(value) AS math_log10, + EXP(value) AS math_exp, + SIGN(value) AS math_sign, + + -- String functions + CONCAT('value: ', CAST(value AS CHAR)) AS str_concat, + LENGTH(name) AS str_length, + CHAR_LENGTH(name) AS str_char_length, + LOWER(name) AS str_lower, + UPPER(name) AS str_upper, + SUBSTRING(name, 1, 5) AS str_substring, + TRIM(name) AS str_trim, + + -- Date functions + STR_TO_DATE('2023-01-01', '%Y-%m-%d') AS date_str_to_date, + DATE_FORMAT(NOW(), '%Y-%m-%d') AS date_format, + NOW() AS date_now, + CURDATE() AS date_curdate, + CURTIME() AS date_curtime, + DATE_ADD(created_at, INTERVAL 1 DAY) AS date_add, + DATE_SUB(created_at, INTERVAL 1 DAY) AS date_sub, + YEAR(created_at) AS date_year, + MONTH(created_at) AS date_month, + DAY(created_at) AS date_day, + WEEKDAY(created_at) AS date_weekday, + DATEDIFF(NOW(), created_at) AS date_datediff, + UNIX_TIMESTAMP(created_at) AS date_unix_timestamp, + FROM_UNIXTIME(1634567890) AS date_from_unixtime, + + -- Type conversion + CAST(value AS CHAR) AS type_cast, + CONVERT(value, CHAR) AS type_convert +FROM metrics +GROUP BY name, value, created_at +LIMIT 10`