Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added parser tests #62

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 93 additions & 15 deletions testing/generation/command_docs/create_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
package main

import (
"context"
"errors"
"fmt"
"math/big"
"os"
"regexp"
"strings"
"time"

"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/jackc/pgx/v5"
"github.com/sergi/go-diff/diffmatchpatch"

"github.com/dolthub/doltgresql/postgres/parser/parser"
Expand Down Expand Up @@ -54,6 +59,8 @@ const TestFooter = ` }
}
`

const MaxTestCount = 10000

// GenerateTestsFromSynopses generates a test file in the output directory for each file in the synopses directory.
func GenerateTestsFromSynopses(repetitionDisabled ...string) (err error) {
parentFolder, err := GetCommandDocsFolder()
Expand All @@ -64,13 +71,10 @@ func GenerateTestsFromSynopses(repetitionDisabled ...string) (err error) {
if err != nil {
return err
}
removeComments := regexp.MustCompile(`\/\/[^\r\n]*\r?\n?`)

FileLoop:
for i, fileInfo := range fileInfos {
if i != 0 {
//TODO: this runs a single file to prevent writing all of the files, since some are unbelievably large
continue FileLoop
}
for _, fileInfo := range fileInfos {
prefix := strings.ToUpper(
strings.ReplaceAll(
strings.ReplaceAll(
Expand All @@ -96,7 +100,7 @@ FileLoop:
continue FileLoop
}
scannerString := scanner.String()
if dataStr != scannerString {
if removeComments.ReplaceAllString(dataStr, "") != scannerString {
sb := strings.Builder{}
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(dataStr, scannerString, true)
Expand All @@ -111,6 +115,7 @@ FileLoop:
} else {
sb.WriteString(dmp.DiffPrettyText(diffs))
}
fmt.Println(sb.String())
err = errors.Join(err, errors.New(sb.String()))
continue FileLoop
}
Expand All @@ -126,22 +131,64 @@ FileLoop:
err = errors.Join(err, nErr)
continue FileLoop
}
sb := strings.Builder{}
sb.WriteString(fmt.Sprintf(TestHeader, time.Now().Year(), strings.ReplaceAll(strings.Title(strings.ToLower(prefix)), " ", "")))

result, nErr := GetQueryResult(stmtGen.String())
// Not all variables have their definitions set in the synopsis, so we'll handle them here
unsetVariables, nErr := UnsetVariables(stmtGen)
if nErr != nil {
err = errors.Join(err, nErr)
continue FileLoop
}
sb.WriteString(result)
for stmtGen.Consume() {
result, nErr = GetQueryResult(stmtGen.String())
customVariableDefinitions := make(map[string]StatementGenerator)
for _, unsetVariable := range unsetVariables {
// Check for a specific definition first
if prefixVariables, ok := PrefixCustomVariables[prefix]; ok {
if variableDefinition, ok := prefixVariables[unsetVariable]; ok {
customVariableDefinitions[unsetVariable] = variableDefinition
continue
}
}
// Check the global definitions if there isn't a specific definition
if variableDefinition, ok := GlobalCustomVariables[unsetVariable]; ok {
customVariableDefinitions[unsetVariable] = variableDefinition
continue
}
}
if nErr = ApplyVariableDefinition(stmtGen, customVariableDefinitions); nErr != nil {
err = errors.Join(err, nErr)
continue FileLoop
}
sb := strings.Builder{}
sb.WriteString(fmt.Sprintf(TestHeader, time.Now().Year(), strings.ReplaceAll(strings.Title(strings.ToLower(prefix)), " ", "")))

permutations := stmtGen.Permutations()
if permutations.Cmp(big.NewInt(MaxTestCount)) <= 0 {
result, nErr := GetQueryResult(stmtGen.String())
if nErr != nil {
err = errors.Join(err, nErr)
continue FileLoop
}
sb.WriteString(result)
for stmtGen.Consume() {
result, nErr = GetQueryResult(stmtGen.String())
if nErr != nil {
err = errors.Join(err, nErr)
continue FileLoop
}
sb.WriteString(result)
}
} else {
randomInts, nErr := GenerateRandomInts(MaxTestCount, permutations)
if nErr != nil {
err = errors.Join(err, nErr)
}
for _, randomInt := range randomInts {
stmtGen.SetConsumeIterations(randomInt)
result, nErr := GetQueryResult(stmtGen.String())
if nErr != nil {
err = errors.Join(err, nErr)
continue FileLoop
}
sb.WriteString(result)
}
}

sb.WriteString(TestFooter)
Expand Down Expand Up @@ -261,18 +308,49 @@ ForLoop:
return finalStatementGenerator, nil
}

var postgresVerificationConnection *pgx.Conn

// GetQueryResult runs the query against a Postgres server to validate that the query is syntactically valid. It then
// tests the query against the Postgres parser and Postgres-Vitess AST converter to check the current level of support.
// It returns a string that may be inserted directly into a test source file (two tabs are prefixed).
func GetQueryResult(query string) (string, error) {
//TODO: verify the query against a Postgres server
var err error
ctx := context.Background()
if postgresVerificationConnection == nil {
connectionString := fmt.Sprintf("postgres://postgres:[email protected]:%d/", 5432)
postgresVerificationConnection, err = pgx.Connect(ctx, connectionString)
if err != nil {
return "", err
}
}
testQuery := fmt.Sprintf("DO $SYNTAX_CHECK$ BEGIN RETURN; %s; END; $SYNTAX_CHECK$;", query)
_, err = postgresVerificationConnection.Exec(ctx, testQuery)
if err != nil && strings.Contains(err.Error(), "syntax error") {
// We only care about syntax errors, as statements may rely on internal state, which is not what we're testing
// There are statements that will not execute inside our DO block due to how Postgres handles some queries, so
// to confirm that they're syntax errors, we'll run them outside the block. All such queries should be
// non-destructive, so this should be safe. All other queries will still return a syntax error.
_, err = postgresVerificationConnection.Exec(ctx, query)
// Run a ROLLBACK as some commands may put the connection (not the database) in a bad state
_, _ = postgresVerificationConnection.Exec(ctx, "ROLLBACK;")
if err != nil && strings.Contains(err.Error(), "syntax error") {
return "", fmt.Errorf("%s\n%s", err, query)
}
}
formattedQuery := strings.ReplaceAll(query, `"`, `\"`)
statements, err := parser.Parse(query)
if err != nil || len(statements) == 0 {
return fmt.Sprintf("\t\tUnimplemented(\"%s\"),\n", formattedQuery), nil
}
for _, statement := range statements {
vitessAST, err := ast.Convert(statement)
vitessAST, err := func() (vitessAST sqlparser.Statement, err error) {
defer func() {
if recoverVal := recover(); recoverVal != nil {
vitessAST = nil
}
}()
return ast.Convert(statement)
}()
if err != nil || vitessAST == nil {
return fmt.Sprintf("\t\tParses(\"%s\"),\n", formattedQuery), nil
}
Expand Down
133 changes: 133 additions & 0 deletions testing/generation/command_docs/custom_variables.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"strings"
)

// GlobalCustomVariables are variable definitions that are used when a synopsis does not define the definition itself,
// and there isn't a more specific definition in PrefixCustomVariables.
var GlobalCustomVariables = map[string]StatementGenerator{
"access_method_type": customDefinition(`TABLE | INDEX`),
"argmode": customDefinition(`IN | VARIADIC`),
"argtype": customDefinition(`FLOAT8`),
"boolean": customDefinition(`true`),
"cache": customDefinition(`1`),
"code": customDefinition(`'code'`),
"collation": customDefinition(`en_US`),
"column_definition": customDefinition(`v1 INTEGER`),
"connlimit": customDefinition(`-1`),
"cycle_mark_default": customDefinition(`'cycle_mark_default'`),
"cycle_mark_value": customDefinition(`'cycle_mark_value'`),
"delete": customDefinition(`DELETE FROM tablename`),
"dest_encoding": customDefinition(`'UTF8'`),
"domain_constraint": customDefinition(`CONSTRAINT name CHECK (condition)`),
"execution_cost": customDefinition(`10`),
"existing_enum_value": customDefinition(`'1'`),
"filter_value": customDefinition(`'Active'`),
"from_item_recursive": customDefinition(`function_name()`),
"increment": customDefinition(`1`),
"insert": customDefinition(`INSERT INTO tablename VALUES (1)`),
"integer": customDefinition(`1`),
"join_type": customDefinition(`[ INNER ] JOIN | LEFT [ OUTER ] JOIN | RIGHT [ OUTER ] JOIN | FULL [ OUTER ] JOIN`),
"large_object_oid": customDefinition(`99999`),
"loid": customDefinition(`99999`),
"maxvalue": customDefinition(`1`),
"minvalue": customDefinition(`1`),
"neighbor_enum_value": customDefinition(`'1'`),
"new_enum_value": customDefinition(`'1'`),
"numeric_literal": customDefinition(`1`),
"operator": customDefinition(`+`),
"output_expression": customDefinition(`colname`),
"payload": customDefinition(`'payload'`),
"query": customDefinition(`SELECT 1`),
"restart": customDefinition(`0`),
"select": customDefinition(`SELECT 1`),
"sequence_options": customDefinition(`NO MINVALUE`),
"snapshot_id": customDefinition(`'snapshot_id'`),
"source_encoding": customDefinition(`'UTF8'`),
"source_query": customDefinition(`SELECT 1`),
"sql_body": customDefinition(`BEGIN ATOMIC END | RETURN 1`),
"start": customDefinition(`0`),
"storage_parameter": customDefinition(`fillfactor`),
"strategy_number": customDefinition(`3`),
"string_literal": customDefinition(`'str'`),
"sub-SELECT": customDefinition(`SELECT 1`),
"support_number": customDefinition(`3`),
"transaction_id": customDefinition(`'id'`),
"oid": customDefinition(`99999`),
"operator_name": customDefinition(`@@`),
"result_rows": customDefinition(`10`),
"uid": customDefinition(`1`),
"update": customDefinition(`UPDATE tablename SET x = 1`),
"values": customDefinition(`VALUES (1)`),
"with_query": customDefinition(`queryname AS (select)`),
}

var PrefixCustomVariables = map[string]map[string]StatementGenerator{
"ALTER FOREIGN TABLE": {
"index_parameters": customDefinition(`USING INDEX TABLESPACE tablespace_name`),
},
"ALTER OPERATOR": {
"name": customDefinition(`@@`),
},
"ALTER STATISTICS": {
"new_target": customDefinition(`1`),
},
"CREATE OPERATOR": {
"name": customDefinition(`@@`),
},
"CREATE RULE": {
"command": customDefinition(`SELECT 'abc'`),
},
"CREATE SCHEMA": {
"schema_element": customDefinition(`CREATE TABLE tablename()`),
},
"DROP OPERATOR": {
"name": customDefinition(`@@`),
},
"EXPLAIN": {
"statement": customDefinition(`SELECT 1 | INSERT INTO tablename VALUES (1)`),
},
"MERGE": {
"query": customDefinition(`SELECT 1`),
},
"PREPARE": {
"statement": customDefinition(`SELECT 1 | INSERT INTO tablename VALUES (1)`),
},
"SET": {
"value": customDefinition(`1`),
},
}

// customDefinition returns a StatementGenerator for a custom variable definition. The variable definition should follow
// the same layout format as synopses.
func customDefinition(str string) StatementGenerator {
str = strings.TrimSpace(str)
scanner := NewScanner(str)
tokens, err := scanner.Process()
if err != nil {
panic(err)
}
stmtGen, err := ParseTokens(tokens, true)
if err != nil {
panic(err)
}
if stmtGen == nil {
panic("definition did not create a statement generator")
}
return stmtGen
}
20 changes: 18 additions & 2 deletions testing/generation/command_docs/ints.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ package main

import (
"crypto/rand"
"math"
"math/big"
"sort"
)

var (
bigIntZero = big.NewInt(0)
bigIntOne = big.NewInt(1)
bigIntZero = big.NewInt(0)
bigIntOne = big.NewInt(1)
bigIntTwo = big.NewInt(2)
bigIntMaxInt64 = big.NewInt(math.MaxInt64)
bigIntMaxUint64 = new(big.Int).Add(new(big.Int).Mul(bigIntMaxInt64, bigIntTwo), bigIntOne)
)

// GenerateRandomInts generates a slice of random integers, with each integer ranging from [0, max). The returned slice
Expand All @@ -46,3 +50,15 @@ func GenerateRandomInts(count int64, max *big.Int) (randInts []*big.Int, err err
})
return randInts, nil
}

// GetPercentages converts the slice of numbers to percentages. The max defines the number that would equal 100%. All
// floats will be between [0.0, 100.0], unless the number is not between [0, max].
func GetPercentages(numbers []*big.Int, max *big.Int) []float64 {
maxAsFloat, _ := max.Float64()
percentages := make([]float64, len(numbers))
for i, number := range numbers {
numberAsFloat, _ := number.Float64()
percentages[i] = (numberAsFloat / maxAsFloat) * 100.0
}
return percentages
}
Loading