sql expressions: improve parser (#87277)

sql expressions: improve parser
This commit is contained in:
Scott Lepper
2024-05-03 13:08:07 +01:00
committed by GitHub
parent 48f77cdebe
commit 1a2bbd61fd
2 changed files with 100 additions and 16 deletions

View File

@ -24,21 +24,7 @@ func TablesList(rawSQL string) ([]string, error) {
switch kind := stmt.(type) {
case *sqlparser.Select:
for _, from := range kind.From {
buf := sqlparser.NewTrackedBuffer(nil)
from.Format(buf)
fromClause := buf.String()
upperFromClause := strings.ToUpper(fromClause)
if strings.Contains(upperFromClause, "JOIN") {
return extractTablesFrom(fromClause), nil
}
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
if strings.Contains(upperFromClause, " AS") {
name := stripAlias(fromClause)
tables = append(tables, name)
continue
}
tables = append(tables, fromClause)
}
tables = append(tables, getTables(from)...)
}
default:
return parseTables(rawSQL)
@ -46,7 +32,66 @@ func TablesList(rawSQL string) ([]string, error) {
if len(tables) == 0 {
return parseTables(rawSQL)
}
return tables, nil
return validateTables(tables), nil
}
func validateTables(tables []string) []string {
validTables := []string{}
for _, table := range tables {
if strings.ToUpper(table) != "DUAL" {
validTables = append(validTables, table)
}
}
return validTables
}
func joinTables(join *sqlparser.JoinTableExpr) []string {
t := getTables(join.LeftExpr)
t = append(t, getTables(join.RightExpr)...)
return t
}
func getTables(te sqlparser.TableExpr) []string {
tables := []string{}
switch v := te.(type) {
case *sqlparser.AliasedTableExpr:
tables = append(tables, nodeValue(v.Expr))
return tables
case *sqlparser.JoinTableExpr:
tables = append(tables, joinTables(v)...)
return tables
case *sqlparser.ParenTableExpr:
for _, e := range v.Exprs {
tables = getTables(e)
}
default:
tables = append(tables, unknownExpr(te)...)
}
return tables
}
func unknownExpr(te sqlparser.TableExpr) []string {
tables := []string{}
fromClause := nodeValue(te)
upperFromClause := strings.ToUpper(fromClause)
if strings.Contains(upperFromClause, "JOIN") {
return extractTablesFrom(fromClause)
}
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
if strings.Contains(upperFromClause, " AS") {
name := stripAlias(fromClause)
tables = append(tables, name)
return tables
}
tables = append(tables, fromClause)
}
return tables
}
func nodeValue(node sqlparser.SQLNode) string {
buf := sqlparser.NewTrackedBuffer(nil)
node.Format(buf)
return buf.String()
}
func extractTablesFrom(stmt string) []string {