15
15
package main
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"fmt"
21
+ "math/big"
20
22
"os"
23
+ "regexp"
21
24
"strings"
22
25
"time"
23
26
27
+ "github.com/dolthub/vitess/go/vt/sqlparser"
28
+ "github.com/jackc/pgx/v5"
24
29
"github.com/sergi/go-diff/diffmatchpatch"
25
30
26
31
"github.com/dolthub/doltgresql/postgres/parser/parser"
@@ -54,6 +59,8 @@ const TestFooter = ` }
54
59
}
55
60
`
56
61
62
+ const MaxTestCount = 10000
63
+
57
64
// GenerateTestsFromSynopses generates a test file in the output directory for each file in the synopses directory.
58
65
func GenerateTestsFromSynopses (repetitionDisabled ... string ) (err error ) {
59
66
parentFolder , err := GetCommandDocsFolder ()
@@ -64,13 +71,10 @@ func GenerateTestsFromSynopses(repetitionDisabled ...string) (err error) {
64
71
if err != nil {
65
72
return err
66
73
}
74
+ removeComments := regexp .MustCompile (`\/\/[^\r\n]*\r?\n?` )
67
75
68
76
FileLoop:
69
- for i , fileInfo := range fileInfos {
70
- if i != 0 {
71
- //TODO: this runs a single file to prevent writing all of the files, since some are unbelievably large
72
- continue FileLoop
73
- }
77
+ for _ , fileInfo := range fileInfos {
74
78
prefix := strings .ToUpper (
75
79
strings .ReplaceAll (
76
80
strings .ReplaceAll (
@@ -96,7 +100,7 @@ FileLoop:
96
100
continue FileLoop
97
101
}
98
102
scannerString := scanner .String ()
99
- if dataStr != scannerString {
103
+ if removeComments . ReplaceAllString ( dataStr , "" ) != scannerString {
100
104
sb := strings.Builder {}
101
105
dmp := diffmatchpatch .New ()
102
106
diffs := dmp .DiffMain (dataStr , scannerString , true )
@@ -111,6 +115,7 @@ FileLoop:
111
115
} else {
112
116
sb .WriteString (dmp .DiffPrettyText (diffs ))
113
117
}
118
+ fmt .Println (sb .String ())
114
119
err = errors .Join (err , errors .New (sb .String ()))
115
120
continue FileLoop
116
121
}
@@ -126,22 +131,64 @@ FileLoop:
126
131
err = errors .Join (err , nErr )
127
132
continue FileLoop
128
133
}
129
- sb := strings.Builder {}
130
- sb .WriteString (fmt .Sprintf (TestHeader , time .Now ().Year (), strings .ReplaceAll (strings .Title (strings .ToLower (prefix )), " " , "" )))
131
-
132
- result , nErr := GetQueryResult (stmtGen .String ())
134
+ // Not all variables have their definitions set in the synopsis, so we'll handle them here
135
+ unsetVariables , nErr := UnsetVariables (stmtGen )
133
136
if nErr != nil {
134
137
err = errors .Join (err , nErr )
135
138
continue FileLoop
136
139
}
137
- sb .WriteString (result )
138
- for stmtGen .Consume () {
139
- result , nErr = GetQueryResult (stmtGen .String ())
140
+ customVariableDefinitions := make (map [string ]StatementGenerator )
141
+ for _ , unsetVariable := range unsetVariables {
142
+ // Check for a specific definition first
143
+ if prefixVariables , ok := PrefixCustomVariables [prefix ]; ok {
144
+ if variableDefinition , ok := prefixVariables [unsetVariable ]; ok {
145
+ customVariableDefinitions [unsetVariable ] = variableDefinition
146
+ continue
147
+ }
148
+ }
149
+ // Check the global definitions if there isn't a specific definition
150
+ if variableDefinition , ok := GlobalCustomVariables [unsetVariable ]; ok {
151
+ customVariableDefinitions [unsetVariable ] = variableDefinition
152
+ continue
153
+ }
154
+ }
155
+ if nErr = ApplyVariableDefinition (stmtGen , customVariableDefinitions ); nErr != nil {
156
+ err = errors .Join (err , nErr )
157
+ continue FileLoop
158
+ }
159
+ sb := strings.Builder {}
160
+ sb .WriteString (fmt .Sprintf (TestHeader , time .Now ().Year (), strings .ReplaceAll (strings .Title (strings .ToLower (prefix )), " " , "" )))
161
+
162
+ permutations := stmtGen .Permutations ()
163
+ if permutations .Cmp (big .NewInt (MaxTestCount )) <= 0 {
164
+ result , nErr := GetQueryResult (stmtGen .String ())
140
165
if nErr != nil {
141
166
err = errors .Join (err , nErr )
142
167
continue FileLoop
143
168
}
144
169
sb .WriteString (result )
170
+ for stmtGen .Consume () {
171
+ result , nErr = GetQueryResult (stmtGen .String ())
172
+ if nErr != nil {
173
+ err = errors .Join (err , nErr )
174
+ continue FileLoop
175
+ }
176
+ sb .WriteString (result )
177
+ }
178
+ } else {
179
+ randomInts , nErr := GenerateRandomInts (MaxTestCount , permutations )
180
+ if nErr != nil {
181
+ err = errors .Join (err , nErr )
182
+ }
183
+ for _ , randomInt := range randomInts {
184
+ stmtGen .SetConsumeIterations (randomInt )
185
+ result , nErr := GetQueryResult (stmtGen .String ())
186
+ if nErr != nil {
187
+ err = errors .Join (err , nErr )
188
+ continue FileLoop
189
+ }
190
+ sb .WriteString (result )
191
+ }
145
192
}
146
193
147
194
sb .WriteString (TestFooter )
@@ -261,18 +308,49 @@ ForLoop:
261
308
return finalStatementGenerator , nil
262
309
}
263
310
311
+ var postgresVerificationConnection * pgx.Conn
312
+
264
313
// GetQueryResult runs the query against a Postgres server to validate that the query is syntactically valid. It then
265
314
// tests the query against the Postgres parser and Postgres-Vitess AST converter to check the current level of support.
266
315
// It returns a string that may be inserted directly into a test source file (two tabs are prefixed).
267
316
func GetQueryResult (query string ) (string , error ) {
268
- //TODO: verify the query against a Postgres server
317
+ var err error
318
+ ctx := context .Background ()
319
+ if postgresVerificationConnection == nil {
320
+ connectionString := fmt .
Sprintf (
"postgres://postgres:[email protected] :%d/" ,
5432 )
321
+ postgresVerificationConnection , err = pgx .Connect (ctx , connectionString )
322
+ if err != nil {
323
+ return "" , err
324
+ }
325
+ }
326
+ testQuery := fmt .Sprintf ("DO $SYNTAX_CHECK$ BEGIN RETURN; %s; END; $SYNTAX_CHECK$;" , query )
327
+ _ , err = postgresVerificationConnection .Exec (ctx , testQuery )
328
+ if err != nil && strings .Contains (err .Error (), "syntax error" ) {
329
+ // We only care about syntax errors, as statements may rely on internal state, which is not what we're testing
330
+ // There are statements that will not execute inside our DO block due to how Postgres handles some queries, so
331
+ // to confirm that they're syntax errors, we'll run them outside the block. All such queries should be
332
+ // non-destructive, so this should be safe. All other queries will still return a syntax error.
333
+ _ , err = postgresVerificationConnection .Exec (ctx , query )
334
+ // Run a ROLLBACK as some commands may put the connection (not the database) in a bad state
335
+ _ , _ = postgresVerificationConnection .Exec (ctx , "ROLLBACK;" )
336
+ if err != nil && strings .Contains (err .Error (), "syntax error" ) {
337
+ return "" , fmt .Errorf ("%s\n %s" , err , query )
338
+ }
339
+ }
269
340
formattedQuery := strings .ReplaceAll (query , `"` , `\"` )
270
341
statements , err := parser .Parse (query )
271
342
if err != nil || len (statements ) == 0 {
272
343
return fmt .Sprintf ("\t \t Unimplemented(\" %s\" ),\n " , formattedQuery ), nil
273
344
}
274
345
for _ , statement := range statements {
275
- vitessAST , err := ast .Convert (statement )
346
+ vitessAST , err := func () (vitessAST sqlparser.Statement , err error ) {
347
+ defer func () {
348
+ if recoverVal := recover (); recoverVal != nil {
349
+ vitessAST = nil
350
+ }
351
+ }()
352
+ return ast .Convert (statement )
353
+ }()
276
354
if err != nil || vitessAST == nil {
277
355
return fmt .Sprintf ("\t \t Parses(\" %s\" ),\n " , formattedQuery ), nil
278
356
}
0 commit comments