Skip to content

Commit 222fa00

Browse files
committed
get rid of nested switch in apply_foreign_keys
1 parent 0146174 commit 222fa00

File tree

2 files changed

+28
-33
lines changed

2 files changed

+28
-33
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,14 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
122122
if plan.IsEmptyTable(n.Child) {
123123
return n, transform.SameTree, nil
124124
}
125-
updateDest, err := plan.GetUpdatable(n.Child)
126-
if err != nil {
127-
return nil, transform.SameTree, err
128-
}
129-
switch updateDest.(type) {
130-
case *plan.UpdatableJoinTable:
131-
updateTargets := updateDest.(*plan.UpdatableJoinTable).UpdateTargets
125+
if n.IsJoin {
126+
uj := n.Child.(*plan.UpdateJoin)
127+
updateTargets := uj.UpdateTargets
132128
fkHandlerMap := make(map[string]sql.Node, len(updateTargets))
133129
for tableName, updateTarget := range updateTargets {
134130
fkHandlerMap[tableName] = updateTarget
135-
updateDest, err := plan.GetUpdatable(updateTarget)
136-
if err != nil {
137-
return nil, transform.SameTree, err
138-
}
139131
fkHandler, err :=
140-
getForeignKeyHandlerFromUpdateDestination(updateDest, ctx, a, cache, fkChain, updateTarget)
132+
getForeignKeyHandlerFromUpdateTarget(updateTarget, ctx, a, cache, fkChain)
141133
if err != nil {
142134
return nil, transform.SameTree, err
143135
}
@@ -147,20 +139,19 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
147139
fkHandlerMap[tableName] = fkHandler
148140
}
149141
}
150-
uj := plan.NewUpdateJoin(fkHandlerMap, n.Child.(*plan.UpdateJoin).Child)
142+
uj = plan.NewUpdateJoin(fkHandlerMap, uj.Child)
151143
nn, err := n.WithChildren(uj)
152144
return nn, transform.NewTree, err
153-
default:
154-
fkHandler, err := getForeignKeyHandlerFromUpdateDestination(updateDest, ctx, a, cache, fkChain, n.Child)
155-
if err != nil {
156-
return nil, transform.SameTree, err
157-
}
158-
if fkHandler == nil {
159-
return n, transform.SameTree, nil
160-
}
161-
nn, err := n.WithChildren(fkHandler)
162-
return nn, transform.NewTree, err
163145
}
146+
fkHandler, err := getForeignKeyHandlerFromUpdateTarget(n.Child, ctx, a, cache, fkChain)
147+
if err != nil {
148+
return nil, transform.SameTree, err
149+
}
150+
if fkHandler == nil {
151+
return n, transform.SameTree, nil
152+
}
153+
nn, err := n.WithChildren(fkHandler)
154+
return nn, transform.NewTree, err
164155
case *plan.DeleteFrom:
165156
if plan.IsEmptyTable(n.Child) {
166157
return n, transform.SameTree, nil
@@ -457,10 +448,14 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
457448
return fkEditor, nil
458449
}
459450

460-
// getForeignKeyHandlerFromUpdateDestination creates a ForeignKeyHandler from a given UpdatableTable. It's used in
461-
// applying foreign keys to Update nodes
462-
func getForeignKeyHandlerFromUpdateDestination(updateDest sql.UpdatableTable, ctx *sql.Context, a *Analyzer,
463-
cache *foreignKeyCache, fkChain foreignKeyChain, originalNode sql.Node) (*plan.ForeignKeyHandler, error) {
451+
// getForeignKeyHandlerFromUpdateTarget creates a ForeignKeyHandler from a given update target Node. It is used for
452+
// applying foreign key constrains to Update nodes
453+
func getForeignKeyHandlerFromUpdateTarget(updateTarget sql.Node, ctx *sql.Context, a *Analyzer,
454+
cache *foreignKeyCache, fkChain foreignKeyChain) (*plan.ForeignKeyHandler, error) {
455+
updateDest, err := plan.GetUpdatable(updateTarget)
456+
if err != nil {
457+
return nil, err
458+
}
464459
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
465460
if !ok {
466461
return nil, nil
@@ -477,7 +472,7 @@ func getForeignKeyHandlerFromUpdateDestination(updateDest sql.UpdatableTable, ct
477472
return &plan.ForeignKeyHandler{
478473
Table: fkTbl,
479474
Sch: updateDest.Schema(),
480-
OriginalNode: originalNode,
475+
OriginalNode: updateTarget,
481476
Editor: fkEditor,
482477
AllUpdaters: fkChain.GetUpdaters(),
483478
}, nil

sql/plan/update_join.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ import (
2121
)
2222

2323
type UpdateJoin struct {
24-
updateTargets map[string]sql.Node
24+
UpdateTargets map[string]sql.Node
2525
UnaryNode
2626
}
2727

2828
// NewUpdateJoin returns a new *UpdateJoin node.
2929
func NewUpdateJoin(updateTargets map[string]sql.Node, child sql.Node) *UpdateJoin {
3030
return &UpdateJoin{
31-
updateTargets: updateTargets,
31+
UpdateTargets: updateTargets,
3232
UnaryNode: UnaryNode{Child: child},
3333
}
3434
}
@@ -55,7 +55,7 @@ func (u *UpdateJoin) DebugString() string {
5555
// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
5656
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
5757
return &UpdatableJoinTable{
58-
UpdateTargets: u.updateTargets,
58+
UpdateTargets: u.UpdateTargets,
5959
joinNode: u.Child.(*UpdateSource).Child,
6060
}
6161
}
@@ -66,7 +66,7 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
6666
return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
6767
}
6868

69-
return NewUpdateJoin(u.updateTargets, children[0]), nil
69+
return NewUpdateJoin(u.UpdateTargets, children[0]), nil
7070
}
7171

7272
func (u *UpdateJoin) IsReadOnly() bool {
@@ -79,7 +79,7 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll
7979
}
8080

8181
func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) (map[string]sql.RowUpdater, error) {
82-
return getUpdaters(u.updateTargets, ctx)
82+
return getUpdaters(u.UpdateTargets, ctx)
8383
}
8484

8585
func getUpdaters(updateTargets map[string]sql.Node, ctx *sql.Context) (map[string]sql.RowUpdater, error) {

0 commit comments

Comments
 (0)