1
+ // Copyright 2020-2021 Dolthub, Inc.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ package analyzer
16
+
17
+ import (
18
+ pgnode "github.com/dolthub/doltgresql/server/node"
19
+ "github.com/dolthub/go-mysql-server/sql"
20
+ "github.com/dolthub/go-mysql-server/sql/analyzer"
21
+ "github.com/dolthub/go-mysql-server/sql/expression"
22
+ "github.com/dolthub/go-mysql-server/sql/plan"
23
+ "github.com/dolthub/go-mysql-server/sql/transform"
24
+ )
25
+
26
+ // validateColumnDefaults ensures that newly created column defaults from a DDL statement are legal for the type of
27
+ // column, various other business logic checks to match MySQL's logic.
28
+ func ValidateColumnDefaults (ctx * sql.Context , _ * analyzer.Analyzer , n sql.Node , _ * plan.Scope , _ analyzer.RuleSelector , qFlags * sql.QueryFlags ) (sql.Node , transform.TreeIdentity , error ) {
29
+ span , ctx := ctx .Span ("validateColumnDefaults" )
30
+ defer span .End ()
31
+
32
+ return transform .Node (n , func (n sql.Node ) (sql.Node , transform.TreeIdentity , error ) {
33
+ switch node := n .(type ) {
34
+ case * plan.AlterDefaultSet :
35
+ table := getResolvedTable (node )
36
+ sch := table .Schema ()
37
+ index := sch .IndexOfColName (node .ColumnName )
38
+ if index == - 1 {
39
+ return nil , transform .SameTree , sql .ErrColumnNotFound .New (node .ColumnName )
40
+ }
41
+ col := sch [index ]
42
+ err := validateColumnDefault (ctx , col , node .Default )
43
+ if err != nil {
44
+ return node , transform .SameTree , err
45
+ }
46
+
47
+ return node , transform .SameTree , nil
48
+
49
+ case sql.SchemaTarget :
50
+ switch node .(type ) {
51
+ case * plan.AlterPK , * plan.AddColumn , * plan.ModifyColumn , * plan.AlterDefaultDrop , * plan.CreateTable , * plan.DropColumn , * pgnode.CreateTable :
52
+ // DDL nodes must validate any new column defaults, continue to logic below
53
+ default :
54
+ // other node types are not altering the schema and therefore don't need validation of column defaults
55
+ return n , transform .SameTree , nil
56
+ }
57
+
58
+ // There may be multiple DDL nodes in the plan (ALTER TABLE statements can have many clauses), and for each of them
59
+ // we need to count the column indexes in the very hacky way outlined above.
60
+ i := 0
61
+ return transform .NodeExprs (n , func (e sql.Expression ) (sql.Expression , transform.TreeIdentity , error ) {
62
+ eWrapper , ok := e .(* expression.Wrapper )
63
+ if ! ok {
64
+ return e , transform .SameTree , nil
65
+ }
66
+
67
+ defer func () {
68
+ i ++
69
+ }()
70
+
71
+ eVal := eWrapper .Unwrap ()
72
+ if eVal == nil {
73
+ return e , transform .SameTree , nil
74
+ }
75
+ colDefault , ok := eVal .(* sql.ColumnDefaultValue )
76
+ if ! ok {
77
+ return e , transform .SameTree , nil
78
+ }
79
+
80
+ col , err := lookupColumnForTargetSchema (ctx , node , i )
81
+ if err != nil {
82
+ return nil , transform .SameTree , err
83
+ }
84
+
85
+ err = validateColumnDefault (ctx , col , colDefault )
86
+ if err != nil {
87
+ return nil , transform .SameTree , err
88
+ }
89
+
90
+ return e , transform .SameTree , nil
91
+ })
92
+ default :
93
+ return node , transform .SameTree , nil
94
+ }
95
+ })
96
+ }
97
+
98
+ // lookupColumnForTargetSchema looks at the target schema for the specified SchemaTarget node and returns
99
+ // the column based on the specified index. For most node types, this is simply indexing into the target
100
+ // schema but a few types require special handling.
101
+ func lookupColumnForTargetSchema (_ * sql.Context , node sql.SchemaTarget , colIndex int ) (* sql.Column , error ) {
102
+ schema := node .TargetSchema ()
103
+
104
+ switch n := node .(type ) {
105
+ case * plan.ModifyColumn :
106
+ if colIndex < len (schema ) {
107
+ return schema [colIndex ], nil
108
+ } else {
109
+ return n .NewColumn (), nil
110
+ }
111
+ case * plan.AddColumn :
112
+ if colIndex < len (schema ) {
113
+ return schema [colIndex ], nil
114
+ } else {
115
+ return n .Column (), nil
116
+ }
117
+ case * plan.AlterDefaultSet :
118
+ index := schema .IndexOfColName (n .ColumnName )
119
+ if index == - 1 {
120
+ return nil , sql .ErrTableColumnNotFound .New (n .Table , n .ColumnName )
121
+ }
122
+ return schema [index ], nil
123
+ default :
124
+ if colIndex < len (schema ) {
125
+ return schema [colIndex ], nil
126
+ } else {
127
+ // TODO: sql.ErrColumnNotFound would be a better error here, but we need to add all the different node types to
128
+ // the switch to get it
129
+ return nil , expression .ErrIndexOutOfBounds .New (colIndex , len (schema ))
130
+ }
131
+ }
132
+ }
133
+
134
+ // validateColumnDefault validates that the column default expression is valid for the column type and returns an error
135
+ // if not
136
+ func validateColumnDefault (ctx * sql.Context , col * sql.Column , colDefault * sql.ColumnDefaultValue ) error {
137
+ if colDefault == nil {
138
+ return nil
139
+ }
140
+
141
+ var err error
142
+ sql .Inspect (colDefault .Expr , func (e sql.Expression ) bool {
143
+ switch e .(type ) {
144
+ case sql.FunctionExpression , * expression.UnresolvedFunction :
145
+ // TODO: functions must be deterministic to be used in column defaults
146
+ return true
147
+ case * plan.Subquery :
148
+ err = sql .ErrColumnDefaultSubquery .New (col .Name )
149
+ return false
150
+ case * expression.GetField :
151
+ if ! colDefault .IsParenthesized () {
152
+ err = sql .ErrInvalidColumnDefaultValue .New (col .Name )
153
+ return false
154
+ }
155
+ return true
156
+ default :
157
+ return true
158
+ }
159
+ })
160
+
161
+ if err != nil {
162
+ return err
163
+ }
164
+
165
+ // validate type of default expression
166
+ if err = colDefault .CheckType (ctx ); err != nil {
167
+ return err
168
+ }
169
+
170
+ return nil
171
+ }
172
+
173
+ // Finds first ResolvedTable node that is a descendant of the node given
174
+ // This function will not look inside SubqueryAliases
175
+ func getResolvedTable (node sql.Node ) * plan.ResolvedTable {
176
+ var table * plan.ResolvedTable
177
+ transform .Inspect (node , func (n sql.Node ) bool {
178
+ // Inspect is called on all children of a node even if an earlier child's call returns false.
179
+ // We only want the first TableNode match.
180
+ if table != nil {
181
+ return false
182
+ }
183
+ switch nn := n .(type ) {
184
+ case * plan.SubqueryAlias :
185
+ // We should not be matching with ResolvedTables inside SubqueryAliases
186
+ return false
187
+ case * plan.ResolvedTable :
188
+ if ! plan .IsDualTable (nn ) {
189
+ table = nn
190
+ return false
191
+ }
192
+ case * plan.IndexedTableAccess :
193
+ if rt , ok := nn .TableNode .(* plan.ResolvedTable ); ok {
194
+ table = rt
195
+ return false
196
+ }
197
+ }
198
+ return true
199
+ })
200
+ return table
201
+ }
0 commit comments