diff --git a/pkg/expr/sql/db.go b/pkg/expr/sql/db.go index f1af425ad6d..a1badd44974 100644 --- a/pkg/expr/sql/db.go +++ b/pkg/expr/sql/db.go @@ -4,6 +4,7 @@ package sql import ( "context" + "fmt" sqle "github.com/dolthub/go-mysql-server" mysql "github.com/dolthub/go-mysql-server/sql" @@ -15,6 +16,42 @@ import ( // DB is a database that can execute SQL queries against a set of Frames. type DB struct{} +// GoMySQLServerError represents an error from the underlying Go MySQL Server +type GoMySQLServerError struct { + Err error +} + +// Error implements the error interface +func (e *GoMySQLServerError) Error() string { + return fmt.Sprintf("error in go-mysql-server: %v", e.Err) +} + +// Unwrap provides the original error for errors.Is/As +func (e *GoMySQLServerError) Unwrap() error { + return e.Err +} + +// WrapGoMySQLServerError wraps errors from Go MySQL Server with additional context +func WrapGoMySQLServerError(err error) error { + // Don't wrap nil errors + if err == nil { + return nil + } + + // Check if it's a function not found error or other specific GMS errors + if isFunctionNotFoundError(err) { + return &GoMySQLServerError{Err: err} + } + + // Return original error if it's not one we want to wrap + return err +} + +// isFunctionNotFoundError checks if the error is related to a function not being found +func isFunctionNotFoundError(err error) bool { + return mysql.ErrFunctionNotFound.Is(err) +} + // QueryFrames runs the sql query query against a database created from frames, and returns the frame. // The RefID of each frame becomes a table in the database. // It is expected that there is only one frame per RefID. @@ -47,7 +84,7 @@ func (db *DB) QueryFrames(ctx context.Context, name string, query string, frames schema, iter, _, err := engine.Query(mCtx, query) if err != nil { - return nil, err + return nil, WrapGoMySQLServerError(err) } f, err := convertToDataFrame(mCtx, iter, schema) diff --git a/pkg/expr/sql/db_test.go b/pkg/expr/sql/db_test.go index 57171fbca94..77263d028de 100644 --- a/pkg/expr/sql/db_test.go +++ b/pkg/expr/sql/db_test.go @@ -193,6 +193,18 @@ func TestQueryFramesDateTimeSelect(t *testing.T) { } } +func TestErrorsFromGoMySQLServerAreFlagged(t *testing.T) { + const GmsNotImplemented = "STDDEV" // not implemented in go-mysql-server as of 2025-03-18 + + db := DB{} + + query := `SELECT ` + GmsNotImplemented + `(1);` + + _, err := db.QueryFrames(context.Background(), "sqlExpressionRefId", query, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "error in go-mysql-server") +} + // p is a utility for pointers from constants func p[T any](v T) *T { return &v diff --git a/pkg/expr/sql/parser_allow_test.go b/pkg/expr/sql/parser_allow_test.go index fc4b33369df..abb997d73c1 100644 --- a/pkg/expr/sql/parser_allow_test.go +++ b/pkg/expr/sql/parser_allow_test.go @@ -129,7 +129,17 @@ var example_case_statement = `SELECT END AS category FROM metrics` -var example_all_allowed_functions = `SELECT +var example_all_allowed_functions = `WITH sample_data AS ( + SELECT + 100 AS value, + 'example' AS name, + NOW() AS created_at + UNION ALL SELECT + 50 AS value, + 'test' AS name, + DATE_SUB(NOW(), INTERVAL 1 DAY) AS created_at +) +SELECT -- Conditional functions IF(value > 100, 'High', 'Low') AS conditional_if, COALESCE(value, 0) AS conditional_coalesce, @@ -191,6 +201,6 @@ var example_all_allowed_functions = `SELECT -- Type conversion CAST(value AS CHAR) AS type_cast, CONVERT(value, CHAR) AS type_convert -FROM metrics +FROM sample_data GROUP BY name, value, created_at LIMIT 10`