Skip to content

Commit 5ca790c

Browse files
committed
Optimize ast walker
1 parent 63fafe1 commit 5ca790c

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

example/walk.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"log"
55

6+
"github.com/auxten/postgresql-parser/pkg/sql/parser"
67
"github.com/auxten/postgresql-parser/pkg/walk"
78
)
89

@@ -19,6 +20,11 @@ func main() {
1920
return false
2021
},
2122
}
22-
_, _ = w.Walk(sql, nil)
23+
stmts, err := parser.Parse(sql)
24+
if err != nil {
25+
return
26+
}
27+
28+
_, _ = w.Walk(stmts, nil)
2329
return
2430
}

pkg/walk/walker.go

+24-12
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ func (rc ReferredCols) ToList() []string {
2727
return cols
2828
}
2929

30-
func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
31-
stmts, err := parser.Parse(sql)
32-
if err != nil {
33-
return false, err
34-
}
30+
func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err error) {
3531

3632
w.unknownNodes = make([]interface{}, 0)
3733
asts := make([]tree.NodeFormatter, len(stmts))
@@ -67,6 +63,8 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
6763
walk(node.Expr)
6864
case *tree.Array:
6965
walk(node.Exprs)
66+
case tree.AsOfClause:
67+
walk(node.Expr)
7068
case *tree.BinaryExpr:
7169
walk(node.Left, node.Right)
7270
case *tree.CaseExpr:
@@ -127,7 +125,6 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
127125
if node.With != nil {
128126
walk(node.With)
129127
}
130-
walk(node.Select)
131128
if node.OrderBy != nil {
132129
for _, order := range node.OrderBy {
133130
walk(order)
@@ -136,6 +133,7 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
136133
if node.Limit != nil {
137134
walk(node.Limit)
138135
}
136+
walk(node.Select)
139137
case *tree.Order:
140138
walk(node.Expr, node.Table)
141139
case *tree.Limit:
@@ -148,9 +146,6 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
148146
if node.Having != nil {
149147
walk(node.Having)
150148
}
151-
for _, table := range node.From.Tables {
152-
walk(table)
153-
}
154149
if node.DistinctOn != nil {
155150
for _, distinct := range node.DistinctOn {
156151
walk(distinct)
@@ -161,6 +156,10 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
161156
walk(group)
162157
}
163158
}
159+
walk(node.From.AsOf)
160+
for _, table := range node.From.Tables {
161+
walk(table)
162+
}
164163
case tree.SelectExpr:
165164
walk(node.Expr)
166165
case tree.SelectExprs:
@@ -192,6 +191,10 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
192191
}
193192
case *tree.Where:
194193
walk(node.Expr)
194+
case tree.Window:
195+
for _, windowDef := range node {
196+
walk(windowDef)
197+
}
195198
case *tree.WindowDef:
196199
walk(node.Partitions)
197200
if node.Frame != nil {
@@ -206,13 +209,14 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
206209
}
207210
case *tree.WindowFrameBound:
208211
walk(node.OffsetExpr)
209-
case *tree.Window:
210212
case *tree.With:
211213
for _, expr := range node.CTEList {
212214
walk(expr)
213215
}
214216
default:
215-
w.unknownNodes = append(w.unknownNodes, node)
217+
if w.unknownNodes != nil {
218+
w.unknownNodes = append(w.unknownNodes, node)
219+
}
216220
}
217221
}
218222
}
@@ -257,7 +261,15 @@ func ColNamesInSelect(sql string) (referredCols ReferredCols, err error) {
257261
return false
258262
},
259263
}
260-
_, err = w.Walk(sql, referredCols)
264+
stmts, err := parser.Parse(sql)
265+
if err != nil {
266+
return
267+
}
268+
269+
_, err = w.Walk(stmts, referredCols)
270+
if err != nil {
271+
return
272+
}
261273
for _, col := range w.unknownNodes {
262274
log.Printf("unhandled column type %T", col)
263275
}

pkg/walk/walker_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ func TestReferredVarsInSelectStatement(t *testing.T) {
195195
referredCols, err := func() (ReferredCols, error) {
196196
return ColNamesInSelect(tc.sql)
197197
}()
198-
if err.Error() != tc.err.Error() {
198+
if err != nil && err.Error() != tc.err.Error() {
199199
t.Errorf("Expect %s, got %s", tc.err, err)
200200
}
201201
cols := referredCols.ToList()

0 commit comments

Comments
 (0)