Skip to content

Commit ac100d7

Browse files
committed
Created a custom analyzer rule for validating column defaults since the logic diverges from MySQL's. Added tests and a new related function.
1 parent 1644bb9 commit ac100d7

File tree

7 files changed

+296
-3
lines changed

7 files changed

+296
-3
lines changed

server/analyzer/init.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const (
3737
ruleId_ResolveType // resolveType
3838
ruleId_ReplaceArithmeticExpressions // replaceArithmeticExpressions
3939
ruleId_OptimizeFunctions // optimizeFunctions
40+
ruleId_ValidateColumnDefaults // validateColumnDefaults
4041
)
4142

4243
// Init adds additional rules to the analyzer to handle Doltgres-specific functionality.
@@ -45,13 +46,13 @@ func Init() {
4546
analyzer.Rule{Id: ruleId_ResolveType, Apply: ResolveType},
4647
analyzer.Rule{Id: ruleId_TypeSanitizer, Apply: TypeSanitizer},
4748
analyzer.Rule{Id: ruleId_AddDomainConstraints, Apply: AddDomainConstraints},
48-
getAnalyzerRule(analyzer.OnceBeforeDefault, analyzer.ValidateColumnDefaultsId),
49+
analyzer.Rule{Id: ruleId_ValidateColumnDefaults, Apply: ValidateColumnDefaults},
4950
analyzer.Rule{Id: ruleId_AssignInsertCasts, Apply: AssignInsertCasts},
5051
analyzer.Rule{Id: ruleId_AssignUpdateCasts, Apply: AssignUpdateCasts},
5152
analyzer.Rule{Id: ruleId_ReplaceIndexedTables, Apply: ReplaceIndexedTables},
5253
)
5354

54-
// Column default validation was moved to occur after type sanitization, so we'll remove it from its original place
55+
// We remove the original column default rule, as we have our own implementation
5556
analyzer.OnceBeforeDefault = removeAnalyzerRules(analyzer.OnceBeforeDefault, analyzer.ValidateColumnDefaultsId)
5657

5758
// PostgreSQL doesn't have the concept of prefix lengths, so we add a rule to implicitly add them
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// Copyright 2020-2021 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package analyzer
16+
17+
import (
18+
pgnode "github.com/dolthub/doltgresql/server/node"
19+
"github.com/dolthub/go-mysql-server/sql"
20+
"github.com/dolthub/go-mysql-server/sql/analyzer"
21+
"github.com/dolthub/go-mysql-server/sql/expression"
22+
"github.com/dolthub/go-mysql-server/sql/plan"
23+
"github.com/dolthub/go-mysql-server/sql/transform"
24+
)
25+
26+
// validateColumnDefaults ensures that newly created column defaults from a DDL statement are legal for the type of
27+
// column, various other business logic checks to match MySQL's logic.
28+
func ValidateColumnDefaults(ctx *sql.Context, _ *analyzer.Analyzer, n sql.Node, _ *plan.Scope, _ analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
29+
span, ctx := ctx.Span("validateColumnDefaults")
30+
defer span.End()
31+
32+
return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
33+
switch node := n.(type) {
34+
case *plan.AlterDefaultSet:
35+
table := getResolvedTable(node)
36+
sch := table.Schema()
37+
index := sch.IndexOfColName(node.ColumnName)
38+
if index == -1 {
39+
return nil, transform.SameTree, sql.ErrColumnNotFound.New(node.ColumnName)
40+
}
41+
col := sch[index]
42+
err := validateColumnDefault(ctx, col, node.Default)
43+
if err != nil {
44+
return node, transform.SameTree, err
45+
}
46+
47+
return node, transform.SameTree, nil
48+
49+
case sql.SchemaTarget:
50+
switch node.(type) {
51+
case *plan.AlterPK, *plan.AddColumn, *plan.ModifyColumn, *plan.AlterDefaultDrop, *plan.CreateTable, *plan.DropColumn, *pgnode.CreateTable:
52+
// DDL nodes must validate any new column defaults, continue to logic below
53+
default:
54+
// other node types are not altering the schema and therefore don't need validation of column defaults
55+
return n, transform.SameTree, nil
56+
}
57+
58+
// There may be multiple DDL nodes in the plan (ALTER TABLE statements can have many clauses), and for each of them
59+
// we need to count the column indexes in the very hacky way outlined above.
60+
i := 0
61+
return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
62+
eWrapper, ok := e.(*expression.Wrapper)
63+
if !ok {
64+
return e, transform.SameTree, nil
65+
}
66+
67+
defer func() {
68+
i++
69+
}()
70+
71+
eVal := eWrapper.Unwrap()
72+
if eVal == nil {
73+
return e, transform.SameTree, nil
74+
}
75+
colDefault, ok := eVal.(*sql.ColumnDefaultValue)
76+
if !ok {
77+
return e, transform.SameTree, nil
78+
}
79+
80+
col, err := lookupColumnForTargetSchema(ctx, node, i)
81+
if err != nil {
82+
return nil, transform.SameTree, err
83+
}
84+
85+
err = validateColumnDefault(ctx, col, colDefault)
86+
if err != nil {
87+
return nil, transform.SameTree, err
88+
}
89+
90+
return e, transform.SameTree, nil
91+
})
92+
default:
93+
return node, transform.SameTree, nil
94+
}
95+
})
96+
}
97+
98+
// lookupColumnForTargetSchema looks at the target schema for the specified SchemaTarget node and returns
99+
// the column based on the specified index. For most node types, this is simply indexing into the target
100+
// schema but a few types require special handling.
101+
func lookupColumnForTargetSchema(_ *sql.Context, node sql.SchemaTarget, colIndex int) (*sql.Column, error) {
102+
schema := node.TargetSchema()
103+
104+
switch n := node.(type) {
105+
case *plan.ModifyColumn:
106+
if colIndex < len(schema) {
107+
return schema[colIndex], nil
108+
} else {
109+
return n.NewColumn(), nil
110+
}
111+
case *plan.AddColumn:
112+
if colIndex < len(schema) {
113+
return schema[colIndex], nil
114+
} else {
115+
return n.Column(), nil
116+
}
117+
case *plan.AlterDefaultSet:
118+
index := schema.IndexOfColName(n.ColumnName)
119+
if index == -1 {
120+
return nil, sql.ErrTableColumnNotFound.New(n.Table, n.ColumnName)
121+
}
122+
return schema[index], nil
123+
default:
124+
if colIndex < len(schema) {
125+
return schema[colIndex], nil
126+
} else {
127+
// TODO: sql.ErrColumnNotFound would be a better error here, but we need to add all the different node types to
128+
// the switch to get it
129+
return nil, expression.ErrIndexOutOfBounds.New(colIndex, len(schema))
130+
}
131+
}
132+
}
133+
134+
// validateColumnDefault validates that the column default expression is valid for the column type and returns an error
135+
// if not
136+
func validateColumnDefault(ctx *sql.Context, col *sql.Column, colDefault *sql.ColumnDefaultValue) error {
137+
if colDefault == nil {
138+
return nil
139+
}
140+
141+
var err error
142+
sql.Inspect(colDefault.Expr, func(e sql.Expression) bool {
143+
switch e.(type) {
144+
case sql.FunctionExpression, *expression.UnresolvedFunction:
145+
// TODO: functions must be deterministic to be used in column defaults
146+
return true
147+
case *plan.Subquery:
148+
err = sql.ErrColumnDefaultSubquery.New(col.Name)
149+
return false
150+
case *expression.GetField:
151+
if !colDefault.IsParenthesized() {
152+
err = sql.ErrInvalidColumnDefaultValue.New(col.Name)
153+
return false
154+
}
155+
return true
156+
default:
157+
return true
158+
}
159+
})
160+
161+
if err != nil {
162+
return err
163+
}
164+
165+
// validate type of default expression
166+
if err = colDefault.CheckType(ctx); err != nil {
167+
return err
168+
}
169+
170+
return nil
171+
}
172+
173+
// Finds first ResolvedTable node that is a descendant of the node given
174+
// This function will not look inside SubqueryAliases
175+
func getResolvedTable(node sql.Node) *plan.ResolvedTable {
176+
var table *plan.ResolvedTable
177+
transform.Inspect(node, func(n sql.Node) bool {
178+
// Inspect is called on all children of a node even if an earlier child's call returns false.
179+
// We only want the first TableNode match.
180+
if table != nil {
181+
return false
182+
}
183+
switch nn := n.(type) {
184+
case *plan.SubqueryAlias:
185+
// We should not be matching with ResolvedTables inside SubqueryAliases
186+
return false
187+
case *plan.ResolvedTable:
188+
if !plan.IsDualTable(nn) {
189+
table = nn
190+
return false
191+
}
192+
case *plan.IndexedTableAccess:
193+
if rt, ok := nn.TableNode.(*plan.ResolvedTable); ok {
194+
table = rt
195+
return false
196+
}
197+
}
198+
return true
199+
})
200+
return table
201+
}

server/doltgres_handler.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import (
4343
pgtypes "github.com/dolthub/doltgresql/server/types"
4444
)
4545

46-
var printErrorStackTraces = false
46+
var printErrorStackTraces = true
4747

4848
const PrintErrorStackTracesEnvKey = "DOLTGRES_PRINT_ERROR_STACK_TRACES"
4949

@@ -104,11 +104,17 @@ func (h *DoltgresHandler) ComBind(ctx context.Context, c *mysql.Conn, query stri
104104

105105
bvs, err := h.convertBindParameters(sqlCtx, bindVars.varTypes, bindVars.formatCodes, bindVars.parameters)
106106
if err != nil {
107+
if printErrorStackTraces {
108+
fmt.Printf("unable to convert bind params: %+v\n", err)
109+
}
107110
return nil, nil, err
108111
}
109112

110113
queryPlan, err := h.e.BoundQueryPlan(sqlCtx, query, stmt, bvs)
111114
if err != nil {
115+
if printErrorStackTraces {
116+
fmt.Printf("unable to bind query plan: %+v\n", err)
117+
}
112118
return nil, nil, err
113119
}
114120

server/functions/gen_random_uuid.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package functions
16+
17+
import (
18+
"github.com/dolthub/doltgresql/postgres/parser/uuid"
19+
"github.com/dolthub/doltgresql/server/functions/framework"
20+
pgtypes "github.com/dolthub/doltgresql/server/types"
21+
"github.com/dolthub/go-mysql-server/sql"
22+
)
23+
24+
// initRadians registers the functions to the catalog.
25+
func initGenRandomUuid() {
26+
framework.RegisterFunction(gen_random_uuid)
27+
}
28+
29+
var gen_random_uuid = framework.Function0{
30+
Name: "gen_random_uuid",
31+
Return: pgtypes.Uuid,
32+
Strict: true,
33+
Callable: func(ctx *sql.Context) (any, error) {
34+
return uuid.NewV4()
35+
},
36+
}

server/functions/init.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ func Init() {
102102
initFloor()
103103
initFormatType()
104104
initGcd()
105+
initGenRandomUuid()
105106
initInitcap()
106107
initLcm()
107108
initLeft()

server/node/create_table.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type CreateTable struct {
2929
}
3030

3131
var _ sql.ExecSourceRel = (*CreateTable)(nil)
32+
var _ sql.SchemaTarget = (*CreateTable)(nil)
3233

3334
// NewCreateTable returns a new *CreateTable.
3435
func NewCreateTable(createTable *plan.CreateTable, sequences []*CreateSequence) *CreateTable {
@@ -101,3 +102,18 @@ func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) {
101102
sequences: c.sequences,
102103
}, nil
103104
}
105+
106+
func (c *CreateTable) TargetSchema() sql.Schema {
107+
return c.gmsCreateTable.TargetSchema()
108+
}
109+
110+
func (c CreateTable) WithTargetSchema(schema sql.Schema) (sql.Node, error) {
111+
n, err := c.gmsCreateTable.WithTargetSchema(schema)
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
c.gmsCreateTable = n.(*plan.CreateTable)
117+
118+
return &c, nil
119+
}

testing/go/alter_table_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,5 +428,37 @@ func TestAlterTable(t *testing.T) {
428428
},
429429
},
430430
},
431+
{
432+
Name: "alter table add primary key with timestamp column default values",
433+
SetUpScript: []string{
434+
`CREATE TABLE t1 (
435+
id int NOT NULL,
436+
uid uuid NOT NULL,
437+
created_at timestamp with time zone DEFAULT now() NOT NULL,
438+
updated_at timestamp with time zone DEFAULT now() NOT NULL
439+
);`,
440+
"INSERT INTO t1 (id, uid) VALUES (1, '00000000-0000-0000-0000-000000000001');",
441+
},
442+
Assertions: []ScriptTestAssertion{
443+
{
444+
Query: "ALTER TABLE ONLY public.t1 ADD CONSTRAINT t1_pkey PRIMARY KEY (id);",
445+
},
446+
},
447+
},
448+
{
449+
Name: "alter table add primary key with uuid column default values",
450+
SetUpScript: []string{
451+
`CREATE TABLE t1 (
452+
id int NOT NULL,
453+
uid uuid default gen_random_uuid() NOT NULL
454+
);`,
455+
"INSERT INTO t1 (id) VALUES (1);",
456+
},
457+
Assertions: []ScriptTestAssertion{
458+
{
459+
Query: "ALTER TABLE ONLY public.t1 ADD CONSTRAINT t1_pkey PRIMARY KEY (id);",
460+
},
461+
},
462+
},
431463
})
432464
}

0 commit comments

Comments
 (0)