Skip to content

Commit 14bde38

Browse files
authored
Merge pull request #62 from dolthub/daylon/command-tests
Added parser tests
2 parents 7baab9e + d45f723 commit 14bde38

File tree

241 files changed

+337516
-169
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

241 files changed

+337516
-169
lines changed

testing/generation/command_docs/create_tests.go

+93-15
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
package main
1616

1717
import (
18+
"context"
1819
"errors"
1920
"fmt"
21+
"math/big"
2022
"os"
23+
"regexp"
2124
"strings"
2225
"time"
2326

27+
"github.com/dolthub/vitess/go/vt/sqlparser"
28+
"github.com/jackc/pgx/v5"
2429
"github.com/sergi/go-diff/diffmatchpatch"
2530

2631
"github.com/dolthub/doltgresql/postgres/parser/parser"
@@ -54,6 +59,8 @@ const TestFooter = ` }
5459
}
5560
`
5661

62+
const MaxTestCount = 10000
63+
5764
// GenerateTestsFromSynopses generates a test file in the output directory for each file in the synopses directory.
5865
func GenerateTestsFromSynopses(repetitionDisabled ...string) (err error) {
5966
parentFolder, err := GetCommandDocsFolder()
@@ -64,13 +71,10 @@ func GenerateTestsFromSynopses(repetitionDisabled ...string) (err error) {
6471
if err != nil {
6572
return err
6673
}
74+
removeComments := regexp.MustCompile(`\/\/[^\r\n]*\r?\n?`)
6775

6876
FileLoop:
69-
for i, fileInfo := range fileInfos {
70-
if i != 0 {
71-
//TODO: this runs a single file to prevent writing all of the files, since some are unbelievably large
72-
continue FileLoop
73-
}
77+
for _, fileInfo := range fileInfos {
7478
prefix := strings.ToUpper(
7579
strings.ReplaceAll(
7680
strings.ReplaceAll(
@@ -96,7 +100,7 @@ FileLoop:
96100
continue FileLoop
97101
}
98102
scannerString := scanner.String()
99-
if dataStr != scannerString {
103+
if removeComments.ReplaceAllString(dataStr, "") != scannerString {
100104
sb := strings.Builder{}
101105
dmp := diffmatchpatch.New()
102106
diffs := dmp.DiffMain(dataStr, scannerString, true)
@@ -111,6 +115,7 @@ FileLoop:
111115
} else {
112116
sb.WriteString(dmp.DiffPrettyText(diffs))
113117
}
118+
fmt.Println(sb.String())
114119
err = errors.Join(err, errors.New(sb.String()))
115120
continue FileLoop
116121
}
@@ -126,22 +131,64 @@ FileLoop:
126131
err = errors.Join(err, nErr)
127132
continue FileLoop
128133
}
129-
sb := strings.Builder{}
130-
sb.WriteString(fmt.Sprintf(TestHeader, time.Now().Year(), strings.ReplaceAll(strings.Title(strings.ToLower(prefix)), " ", "")))
131-
132-
result, nErr := GetQueryResult(stmtGen.String())
134+
// Not all variables have their definitions set in the synopsis, so we'll handle them here
135+
unsetVariables, nErr := UnsetVariables(stmtGen)
133136
if nErr != nil {
134137
err = errors.Join(err, nErr)
135138
continue FileLoop
136139
}
137-
sb.WriteString(result)
138-
for stmtGen.Consume() {
139-
result, nErr = GetQueryResult(stmtGen.String())
140+
customVariableDefinitions := make(map[string]StatementGenerator)
141+
for _, unsetVariable := range unsetVariables {
142+
// Check for a specific definition first
143+
if prefixVariables, ok := PrefixCustomVariables[prefix]; ok {
144+
if variableDefinition, ok := prefixVariables[unsetVariable]; ok {
145+
customVariableDefinitions[unsetVariable] = variableDefinition
146+
continue
147+
}
148+
}
149+
// Check the global definitions if there isn't a specific definition
150+
if variableDefinition, ok := GlobalCustomVariables[unsetVariable]; ok {
151+
customVariableDefinitions[unsetVariable] = variableDefinition
152+
continue
153+
}
154+
}
155+
if nErr = ApplyVariableDefinition(stmtGen, customVariableDefinitions); nErr != nil {
156+
err = errors.Join(err, nErr)
157+
continue FileLoop
158+
}
159+
sb := strings.Builder{}
160+
sb.WriteString(fmt.Sprintf(TestHeader, time.Now().Year(), strings.ReplaceAll(strings.Title(strings.ToLower(prefix)), " ", "")))
161+
162+
permutations := stmtGen.Permutations()
163+
if permutations.Cmp(big.NewInt(MaxTestCount)) <= 0 {
164+
result, nErr := GetQueryResult(stmtGen.String())
140165
if nErr != nil {
141166
err = errors.Join(err, nErr)
142167
continue FileLoop
143168
}
144169
sb.WriteString(result)
170+
for stmtGen.Consume() {
171+
result, nErr = GetQueryResult(stmtGen.String())
172+
if nErr != nil {
173+
err = errors.Join(err, nErr)
174+
continue FileLoop
175+
}
176+
sb.WriteString(result)
177+
}
178+
} else {
179+
randomInts, nErr := GenerateRandomInts(MaxTestCount, permutations)
180+
if nErr != nil {
181+
err = errors.Join(err, nErr)
182+
}
183+
for _, randomInt := range randomInts {
184+
stmtGen.SetConsumeIterations(randomInt)
185+
result, nErr := GetQueryResult(stmtGen.String())
186+
if nErr != nil {
187+
err = errors.Join(err, nErr)
188+
continue FileLoop
189+
}
190+
sb.WriteString(result)
191+
}
145192
}
146193

147194
sb.WriteString(TestFooter)
@@ -261,18 +308,49 @@ ForLoop:
261308
return finalStatementGenerator, nil
262309
}
263310

311+
var postgresVerificationConnection *pgx.Conn
312+
264313
// GetQueryResult runs the query against a Postgres server to validate that the query is syntactically valid. It then
265314
// tests the query against the Postgres parser and Postgres-Vitess AST converter to check the current level of support.
266315
// It returns a string that may be inserted directly into a test source file (two tabs are prefixed).
267316
func GetQueryResult(query string) (string, error) {
268-
//TODO: verify the query against a Postgres server
317+
var err error
318+
ctx := context.Background()
319+
if postgresVerificationConnection == nil {
320+
connectionString := fmt.Sprintf("postgres://postgres:[email protected]:%d/", 5432)
321+
postgresVerificationConnection, err = pgx.Connect(ctx, connectionString)
322+
if err != nil {
323+
return "", err
324+
}
325+
}
326+
testQuery := fmt.Sprintf("DO $SYNTAX_CHECK$ BEGIN RETURN; %s; END; $SYNTAX_CHECK$;", query)
327+
_, err = postgresVerificationConnection.Exec(ctx, testQuery)
328+
if err != nil && strings.Contains(err.Error(), "syntax error") {
329+
// We only care about syntax errors, as statements may rely on internal state, which is not what we're testing
330+
// There are statements that will not execute inside our DO block due to how Postgres handles some queries, so
331+
// to confirm that they're syntax errors, we'll run them outside the block. All such queries should be
332+
// non-destructive, so this should be safe. All other queries will still return a syntax error.
333+
_, err = postgresVerificationConnection.Exec(ctx, query)
334+
// Run a ROLLBACK as some commands may put the connection (not the database) in a bad state
335+
_, _ = postgresVerificationConnection.Exec(ctx, "ROLLBACK;")
336+
if err != nil && strings.Contains(err.Error(), "syntax error") {
337+
return "", fmt.Errorf("%s\n%s", err, query)
338+
}
339+
}
269340
formattedQuery := strings.ReplaceAll(query, `"`, `\"`)
270341
statements, err := parser.Parse(query)
271342
if err != nil || len(statements) == 0 {
272343
return fmt.Sprintf("\t\tUnimplemented(\"%s\"),\n", formattedQuery), nil
273344
}
274345
for _, statement := range statements {
275-
vitessAST, err := ast.Convert(statement)
346+
vitessAST, err := func() (vitessAST sqlparser.Statement, err error) {
347+
defer func() {
348+
if recoverVal := recover(); recoverVal != nil {
349+
vitessAST = nil
350+
}
351+
}()
352+
return ast.Convert(statement)
353+
}()
276354
if err != nil || vitessAST == nil {
277355
return fmt.Sprintf("\t\tParses(\"%s\"),\n", formattedQuery), nil
278356
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"strings"
19+
)
20+
21+
// GlobalCustomVariables are variable definitions that are used when a synopsis does not define the definition itself,
22+
// and there isn't a more specific definition in PrefixCustomVariables.
23+
var GlobalCustomVariables = map[string]StatementGenerator{
24+
"access_method_type": customDefinition(`TABLE | INDEX`),
25+
"argmode": customDefinition(`IN | VARIADIC`),
26+
"argtype": customDefinition(`FLOAT8`),
27+
"boolean": customDefinition(`true`),
28+
"cache": customDefinition(`1`),
29+
"code": customDefinition(`'code'`),
30+
"collation": customDefinition(`en_US`),
31+
"column_definition": customDefinition(`v1 INTEGER`),
32+
"connlimit": customDefinition(`-1`),
33+
"cycle_mark_default": customDefinition(`'cycle_mark_default'`),
34+
"cycle_mark_value": customDefinition(`'cycle_mark_value'`),
35+
"delete": customDefinition(`DELETE FROM tablename`),
36+
"dest_encoding": customDefinition(`'UTF8'`),
37+
"domain_constraint": customDefinition(`CONSTRAINT name CHECK (condition)`),
38+
"execution_cost": customDefinition(`10`),
39+
"existing_enum_value": customDefinition(`'1'`),
40+
"filter_value": customDefinition(`'Active'`),
41+
"from_item_recursive": customDefinition(`function_name()`),
42+
"increment": customDefinition(`1`),
43+
"insert": customDefinition(`INSERT INTO tablename VALUES (1)`),
44+
"integer": customDefinition(`1`),
45+
"join_type": customDefinition(`[ INNER ] JOIN | LEFT [ OUTER ] JOIN | RIGHT [ OUTER ] JOIN | FULL [ OUTER ] JOIN`),
46+
"large_object_oid": customDefinition(`99999`),
47+
"loid": customDefinition(`99999`),
48+
"maxvalue": customDefinition(`1`),
49+
"minvalue": customDefinition(`1`),
50+
"neighbor_enum_value": customDefinition(`'1'`),
51+
"new_enum_value": customDefinition(`'1'`),
52+
"numeric_literal": customDefinition(`1`),
53+
"operator": customDefinition(`+`),
54+
"output_expression": customDefinition(`colname`),
55+
"payload": customDefinition(`'payload'`),
56+
"query": customDefinition(`SELECT 1`),
57+
"restart": customDefinition(`0`),
58+
"select": customDefinition(`SELECT 1`),
59+
"sequence_options": customDefinition(`NO MINVALUE`),
60+
"snapshot_id": customDefinition(`'snapshot_id'`),
61+
"source_encoding": customDefinition(`'UTF8'`),
62+
"source_query": customDefinition(`SELECT 1`),
63+
"sql_body": customDefinition(`BEGIN ATOMIC END | RETURN 1`),
64+
"start": customDefinition(`0`),
65+
"storage_parameter": customDefinition(`fillfactor`),
66+
"strategy_number": customDefinition(`3`),
67+
"string_literal": customDefinition(`'str'`),
68+
"sub-SELECT": customDefinition(`SELECT 1`),
69+
"support_number": customDefinition(`3`),
70+
"transaction_id": customDefinition(`'id'`),
71+
"oid": customDefinition(`99999`),
72+
"operator_name": customDefinition(`@@`),
73+
"result_rows": customDefinition(`10`),
74+
"uid": customDefinition(`1`),
75+
"update": customDefinition(`UPDATE tablename SET x = 1`),
76+
"values": customDefinition(`VALUES (1)`),
77+
"with_query": customDefinition(`queryname AS (select)`),
78+
}
79+
80+
var PrefixCustomVariables = map[string]map[string]StatementGenerator{
81+
"ALTER FOREIGN TABLE": {
82+
"index_parameters": customDefinition(`USING INDEX TABLESPACE tablespace_name`),
83+
},
84+
"ALTER OPERATOR": {
85+
"name": customDefinition(`@@`),
86+
},
87+
"ALTER STATISTICS": {
88+
"new_target": customDefinition(`1`),
89+
},
90+
"CREATE OPERATOR": {
91+
"name": customDefinition(`@@`),
92+
},
93+
"CREATE RULE": {
94+
"command": customDefinition(`SELECT 'abc'`),
95+
},
96+
"CREATE SCHEMA": {
97+
"schema_element": customDefinition(`CREATE TABLE tablename()`),
98+
},
99+
"DROP OPERATOR": {
100+
"name": customDefinition(`@@`),
101+
},
102+
"EXPLAIN": {
103+
"statement": customDefinition(`SELECT 1 | INSERT INTO tablename VALUES (1)`),
104+
},
105+
"MERGE": {
106+
"query": customDefinition(`SELECT 1`),
107+
},
108+
"PREPARE": {
109+
"statement": customDefinition(`SELECT 1 | INSERT INTO tablename VALUES (1)`),
110+
},
111+
"SET": {
112+
"value": customDefinition(`1`),
113+
},
114+
}
115+
116+
// customDefinition returns a StatementGenerator for a custom variable definition. The variable definition should follow
117+
// the same layout format as synopses.
118+
func customDefinition(str string) StatementGenerator {
119+
str = strings.TrimSpace(str)
120+
scanner := NewScanner(str)
121+
tokens, err := scanner.Process()
122+
if err != nil {
123+
panic(err)
124+
}
125+
stmtGen, err := ParseTokens(tokens, true)
126+
if err != nil {
127+
panic(err)
128+
}
129+
if stmtGen == nil {
130+
panic("definition did not create a statement generator")
131+
}
132+
return stmtGen
133+
}

testing/generation/command_docs/ints.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@ package main
1616

1717
import (
1818
"crypto/rand"
19+
"math"
1920
"math/big"
2021
"sort"
2122
)
2223

2324
var (
24-
bigIntZero = big.NewInt(0)
25-
bigIntOne = big.NewInt(1)
25+
bigIntZero = big.NewInt(0)
26+
bigIntOne = big.NewInt(1)
27+
bigIntTwo = big.NewInt(2)
28+
bigIntMaxInt64 = big.NewInt(math.MaxInt64)
29+
bigIntMaxUint64 = new(big.Int).Add(new(big.Int).Mul(bigIntMaxInt64, bigIntTwo), bigIntOne)
2630
)
2731

2832
// GenerateRandomInts generates a slice of random integers, with each integer ranging from [0, max). The returned slice
@@ -46,3 +50,15 @@ func GenerateRandomInts(count int64, max *big.Int) (randInts []*big.Int, err err
4650
})
4751
return randInts, nil
4852
}
53+
54+
// GetPercentages converts the slice of numbers to percentages. The max defines the number that would equal 100%. All
55+
// floats will be between [0.0, 100.0], unless the number is not between [0, max].
56+
func GetPercentages(numbers []*big.Int, max *big.Int) []float64 {
57+
maxAsFloat, _ := max.Float64()
58+
percentages := make([]float64, len(numbers))
59+
for i, number := range numbers {
60+
numberAsFloat, _ := number.Float64()
61+
percentages[i] = (numberAsFloat / maxAsFloat) * 100.0
62+
}
63+
return percentages
64+
}

0 commit comments

Comments
 (0)