Sql Expressions: State when error is from GMS (#102112)

This commit is contained in:
Sam Jewell
2025-03-18 22:41:42 +00:00
committed by GitHub
parent 671ba2ab02
commit d7aeebe5e3
3 changed files with 62 additions and 3 deletions

View File

@ -4,6 +4,7 @@ package sql
import (
"context"
"fmt"
sqle "github.com/dolthub/go-mysql-server"
mysql "github.com/dolthub/go-mysql-server/sql"
@ -15,6 +16,42 @@ import (
// DB is a database that can execute SQL queries against a set of Frames.
type DB struct{}
// GoMySQLServerError represents an error from the underlying Go MySQL Server
type GoMySQLServerError struct {
Err error
}
// Error implements the error interface
func (e *GoMySQLServerError) Error() string {
return fmt.Sprintf("error in go-mysql-server: %v", e.Err)
}
// Unwrap provides the original error for errors.Is/As
func (e *GoMySQLServerError) Unwrap() error {
return e.Err
}
// WrapGoMySQLServerError wraps errors from Go MySQL Server with additional context
func WrapGoMySQLServerError(err error) error {
// Don't wrap nil errors
if err == nil {
return nil
}
// Check if it's a function not found error or other specific GMS errors
if isFunctionNotFoundError(err) {
return &GoMySQLServerError{Err: err}
}
// Return original error if it's not one we want to wrap
return err
}
// isFunctionNotFoundError checks if the error is related to a function not being found
func isFunctionNotFoundError(err error) bool {
return mysql.ErrFunctionNotFound.Is(err)
}
// QueryFrames runs the sql query query against a database created from frames, and returns the frame.
// The RefID of each frame becomes a table in the database.
// It is expected that there is only one frame per RefID.
@ -47,7 +84,7 @@ func (db *DB) QueryFrames(ctx context.Context, name string, query string, frames
schema, iter, _, err := engine.Query(mCtx, query)
if err != nil {
return nil, err
return nil, WrapGoMySQLServerError(err)
}
f, err := convertToDataFrame(mCtx, iter, schema)

View File

@ -193,6 +193,18 @@ func TestQueryFramesDateTimeSelect(t *testing.T) {
}
}
func TestErrorsFromGoMySQLServerAreFlagged(t *testing.T) {
const GmsNotImplemented = "STDDEV" // not implemented in go-mysql-server as of 2025-03-18
db := DB{}
query := `SELECT ` + GmsNotImplemented + `(1);`
_, err := db.QueryFrames(context.Background(), "sqlExpressionRefId", query, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "error in go-mysql-server")
}
// p is a utility for pointers from constants
func p[T any](v T) *T {
return &v

View File

@ -129,7 +129,17 @@ var example_case_statement = `SELECT
END AS category
FROM metrics`
var example_all_allowed_functions = `SELECT
var example_all_allowed_functions = `WITH sample_data AS (
SELECT
100 AS value,
'example' AS name,
NOW() AS created_at
UNION ALL SELECT
50 AS value,
'test' AS name,
DATE_SUB(NOW(), INTERVAL 1 DAY) AS created_at
)
SELECT
-- Conditional functions
IF(value > 100, 'High', 'Low') AS conditional_if,
COALESCE(value, 0) AS conditional_coalesce,
@ -191,6 +201,6 @@ var example_all_allowed_functions = `SELECT
-- Type conversion
CAST(value AS CHAR) AS type_cast,
CONVERT(value, CHAR) AS type_convert
FROM metrics
FROM sample_data
GROUP BY name, value, created_at
LIMIT 10`