mirror of
https://github.com/grafana/grafana.git
synced 2025-08-01 01:51:49 +08:00
[sql expressions] fix: use ast to read tables (#87867)
* [sql expressions] fix: use ast to read tables * can't run tests during ci yet. need to install duckdb * skip for now. need duckdb cli
This commit is contained in:
@ -1,199 +1,72 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
parser "github.com/krasun/gosqlparser"
|
||||
"github.com/xwb1989/sqlparser"
|
||||
"github.com/jeremywohl/flatten"
|
||||
"github.com/scottlepp/go-duck/duck"
|
||||
)
|
||||
|
||||
const (
|
||||
TABLE_NAME = "table_name"
|
||||
ERROR = ".error"
|
||||
ERROR_MESSAGE = ".error_message"
|
||||
)
|
||||
|
||||
// TablesList returns a list of tables for the sql statement
|
||||
// TODO: should we just return all query refs instead of trying to parse them from the sql?
|
||||
func TablesList(rawSQL string) ([]string, error) {
|
||||
stmt, err := sqlparser.Parse(rawSQL)
|
||||
duckDB := duck.NewInMemoryDB()
|
||||
cmd := fmt.Sprintf("SELECT json_serialize_sql('%s')", rawSQL)
|
||||
ret, err := duckDB.RunCommands([]string{cmd})
|
||||
if err != nil {
|
||||
tables, err := parse(rawSQL)
|
||||
if err != nil {
|
||||
return parseTables(rawSQL)
|
||||
}
|
||||
return tables, nil
|
||||
return nil, fmt.Errorf("error serializing sql: %s", err.Error())
|
||||
}
|
||||
|
||||
tables := []string{}
|
||||
switch kind := stmt.(type) {
|
||||
case *sqlparser.Select:
|
||||
for _, from := range kind.From {
|
||||
tables = append(tables, getTables(from)...)
|
||||
}
|
||||
default:
|
||||
return parseTables(rawSQL)
|
||||
}
|
||||
if len(tables) == 0 {
|
||||
return parseTables(rawSQL)
|
||||
}
|
||||
return validateTables(tables), nil
|
||||
}
|
||||
|
||||
func validateTables(tables []string) []string {
|
||||
validTables := []string{}
|
||||
for _, table := range tables {
|
||||
if strings.ToUpper(table) != "DUAL" {
|
||||
validTables = append(validTables, table)
|
||||
}
|
||||
}
|
||||
return validTables
|
||||
}
|
||||
|
||||
func joinTables(join *sqlparser.JoinTableExpr) []string {
|
||||
t := getTables(join.LeftExpr)
|
||||
t = append(t, getTables(join.RightExpr)...)
|
||||
return t
|
||||
}
|
||||
|
||||
func getTables(te sqlparser.TableExpr) []string {
|
||||
tables := []string{}
|
||||
switch v := te.(type) {
|
||||
case *sqlparser.AliasedTableExpr:
|
||||
tables = append(tables, nodeValue(v.Expr))
|
||||
return tables
|
||||
case *sqlparser.JoinTableExpr:
|
||||
tables = append(tables, joinTables(v)...)
|
||||
return tables
|
||||
case *sqlparser.ParenTableExpr:
|
||||
for _, e := range v.Exprs {
|
||||
tables = getTables(e)
|
||||
}
|
||||
default:
|
||||
tables = append(tables, unknownExpr(te)...)
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func unknownExpr(te sqlparser.TableExpr) []string {
|
||||
tables := []string{}
|
||||
fromClause := nodeValue(te)
|
||||
upperFromClause := strings.ToUpper(fromClause)
|
||||
if strings.Contains(upperFromClause, "JOIN") {
|
||||
return extractTablesFrom(fromClause)
|
||||
}
|
||||
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
|
||||
if strings.Contains(upperFromClause, " AS") {
|
||||
name := stripAlias(fromClause)
|
||||
tables = append(tables, name)
|
||||
return tables
|
||||
}
|
||||
tables = append(tables, fromClause)
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func nodeValue(node sqlparser.SQLNode) string {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
node.Format(buf)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func extractTablesFrom(stmt string) []string {
|
||||
// example: A join B on A.name = B.name
|
||||
tables := []string{}
|
||||
parts := strings.Split(stmt, " ")
|
||||
for _, part := range parts {
|
||||
part = strings.ToUpper(part)
|
||||
if isJoin(part) {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(part, "ON") {
|
||||
break
|
||||
}
|
||||
if part != "" {
|
||||
if !existsInList(part, tables) {
|
||||
tables = append(tables, part)
|
||||
}
|
||||
}
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func stripAlias(table string) string {
|
||||
tableParts := []string{}
|
||||
for _, part := range strings.Split(table, " ") {
|
||||
if strings.ToUpper(part) == "AS" {
|
||||
break
|
||||
}
|
||||
tableParts = append(tableParts, part)
|
||||
}
|
||||
return strings.Join(tableParts, " ")
|
||||
}
|
||||
|
||||
// uses a simple tokenizer
|
||||
func parse(rawSQL string) ([]string, error) {
|
||||
query, err := parser.Parse(rawSQL)
|
||||
ast := []map[string]any{}
|
||||
err = json.Unmarshal([]byte(ret), &ast)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("error converting json to ast: %s", err.Error())
|
||||
}
|
||||
if query.GetType() == parser.StatementSelect {
|
||||
sel, ok := query.(*parser.Select)
|
||||
if ok {
|
||||
return []string{sel.Table}, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
|
||||
return tablesFromAST(ast)
|
||||
}
|
||||
|
||||
func parseTables(rawSQL string) ([]string, error) {
|
||||
checkSql := strings.ToUpper(rawSQL)
|
||||
rawSQL = strings.ReplaceAll(rawSQL, "\n", " ")
|
||||
rawSQL = strings.ReplaceAll(rawSQL, "\r", " ")
|
||||
if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") {
|
||||
tables := []string{}
|
||||
tokens := strings.Split(rawSQL, " ")
|
||||
checkNext := false
|
||||
takeNext := false
|
||||
for _, token := range tokens {
|
||||
t := strings.ToUpper(token)
|
||||
t = strings.TrimSpace(t)
|
||||
func tablesFromAST(ast []map[string]any) ([]string, error) {
|
||||
flat, err := flatten.Flatten(ast[0], "", flatten.DotStyle)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error flattening ast: %s", err.Error())
|
||||
}
|
||||
|
||||
if takeNext {
|
||||
if !existsInList(token, tables) {
|
||||
tables = append(tables, token)
|
||||
}
|
||||
checkNext = false
|
||||
takeNext = false
|
||||
continue
|
||||
}
|
||||
if checkNext {
|
||||
if strings.Contains(t, "(") {
|
||||
checkNext = false
|
||||
continue
|
||||
}
|
||||
if strings.Contains(t, ",") {
|
||||
values := strings.Split(token, ",")
|
||||
for _, v := range values {
|
||||
v := strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
if !existsInList(token, tables) {
|
||||
tables = append(tables, v)
|
||||
}
|
||||
} else {
|
||||
takeNext = true
|
||||
break
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !existsInList(token, tables) {
|
||||
tables = append(tables, token)
|
||||
}
|
||||
checkNext = false
|
||||
}
|
||||
if t == "FROM" {
|
||||
checkNext = true
|
||||
tables := []string{}
|
||||
for k, v := range flat {
|
||||
if strings.HasSuffix(k, ERROR) {
|
||||
v, ok := v.(bool)
|
||||
if ok && v {
|
||||
return nil, astError(k, flat)
|
||||
}
|
||||
}
|
||||
if strings.Contains(k, TABLE_NAME) {
|
||||
table, ok := v.(string)
|
||||
if ok && !existsInList(table, tables) {
|
||||
tables = append(tables, v.(string))
|
||||
}
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
return nil, errors.New("not a select statement")
|
||||
sort.Strings(tables)
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func astError(k string, flat map[string]any) error {
|
||||
key := strings.Replace(k, ERROR, "", 1)
|
||||
message, ok := flat[key+ERROR_MESSAGE]
|
||||
if !ok {
|
||||
message = "unknown error in sql"
|
||||
}
|
||||
return fmt.Errorf("error in sql: %s", message)
|
||||
}
|
||||
|
||||
func existsInList(table string, list []string) bool {
|
||||
@ -204,15 +77,3 @@ func existsInList(table string, list []string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var joins = []string{"JOIN", "INNER", "LEFT", "RIGHT", "FULL", "OUTER"}
|
||||
|
||||
func isJoin(token string) bool {
|
||||
token = strings.ToUpper(token)
|
||||
for _, join := range joins {
|
||||
if token == join {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Reference in New Issue
Block a user