@@ -490,6 +490,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
490
490
return
491
491
}
492
492
493
+ // buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
494
+ // children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
495
+ // don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
496
+ // analyzer processing that converts the subquery into a join, and then requires the same logic to
497
+ // create an UpdateJoin node under the original Update node.
493
498
func (b * Builder ) buildUpdate (inScope * scope , u * ast.Update ) (outScope * scope ) {
494
499
// TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
495
500
// The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
@@ -532,44 +537,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
532
537
update .IsProcNested = b .ProcCtx ().DbName != ""
533
538
534
539
var checks []* sql.CheckConstraint
535
- if join , ok := outScope .node .(* plan.JoinNode ); ok {
536
- // TODO this doesn't work, a lot of the time the top node
537
- // is a filter. This would have to go before we build the
538
- // filter/accessory nodes. But that errors for a lot of queries.
539
- source := plan .NewUpdateSource (
540
- join ,
541
- ignore ,
542
- updateExprs ,
543
- )
544
- updaters , err := rowUpdatersByTable (b .ctx , source , join )
540
+ if hasJoinNode (outScope .node ) {
541
+ tablesToUpdate , err := getResolvedTablesToUpdate (b .ctx , update .Child , outScope .node )
545
542
if err != nil {
546
543
b .handleErr (err )
547
544
}
548
- updateJoin := plan .NewUpdateJoin (updaters , source )
549
- update .Child = updateJoin
550
- transform .Inspect (update , func (n sql.Node ) bool {
551
- // todo maybe this should be later stage
552
- switch n := n .(type ) {
553
- case sql.NameableNode :
554
- if _ , ok := updaters [n .Name ()]; ok {
555
- rt := getResolvedTable (n )
556
- tableScope := inScope .push ()
557
- for _ , c := range rt .Schema () {
558
- tableScope .addColumn (scopeColumn {
559
- db : rt .SqlDatabase .Name (),
560
- table : strings .ToLower (n .Name ()),
561
- tableId : tableScope .tables [strings .ToLower (n .Name ())],
562
- col : strings .ToLower (c .Name ),
563
- typ : c .Type ,
564
- nullable : c .Nullable ,
565
- })
566
- }
567
- checks = append (checks , b .loadChecksFromTable (tableScope , rt .Table )... )
568
- }
569
- default :
545
+
546
+ for _ , rt := range tablesToUpdate {
547
+ tableScope := inScope .push ()
548
+ for _ , c := range rt .Schema () {
549
+ tableScope .addColumn (scopeColumn {
550
+ db : rt .SqlDatabase .Name (),
551
+ table : strings .ToLower (rt .Name ()),
552
+ tableId : tableScope .tables [strings .ToLower (rt .Name ())],
553
+ col : strings .ToLower (c .Name ),
554
+ typ : c .Type ,
555
+ nullable : c .Nullable ,
556
+ })
570
557
}
571
- return true
572
- })
558
+ checks = append ( checks , b . loadChecksFromTable ( tableScope , rt . Table ) ... )
559
+ }
573
560
} else {
574
561
transform .Inspect (update , func (n sql.Node ) bool {
575
562
// todo maybe this should be later stage
@@ -588,35 +575,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
588
575
return
589
576
}
590
577
591
- // rowUpdatersByTable maps a set of tables to their RowUpdater objects.
592
- func rowUpdatersByTable (ctx * sql.Context , node sql.Node , ij sql.Node ) (map [string ]sql.RowUpdater , error ) {
593
- namesOfTableToBeUpdated := getTablesToBeUpdated (node )
594
- resolvedTables := getTablesByName (ij )
595
-
596
- rowUpdatersByTable := make (map [string ]sql.RowUpdater )
597
- for tableToBeUpdated , _ := range namesOfTableToBeUpdated {
598
- resolvedTable , ok := resolvedTables [strings .ToLower (tableToBeUpdated )]
599
- if ! ok {
600
- return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
578
+ // hasJoinNode returns true if |node| or any child is a JoinNode.
579
+ func hasJoinNode (node sql.Node ) bool {
580
+ updateJoinFound := false
581
+ transform .Inspect (node , func (n sql.Node ) bool {
582
+ if _ , ok := n .(* plan.JoinNode ); ok {
583
+ updateJoinFound = true
601
584
}
585
+ return ! updateJoinFound
586
+ })
587
+ return updateJoinFound
588
+ }
602
589
603
- var table = resolvedTable .UnderlyingTable ()
590
+ func getResolvedTablesToUpdate (_ * sql.Context , node sql.Node , ij sql.Node ) (resolvedTables []* plan.ResolvedTable , err error ) {
591
+ namesOfTablesToBeUpdated := getTablesToBeUpdated (node )
592
+ resolvedTablesMap := getTablesByName (ij )
604
593
605
- // If there is no UpdatableTable for a table being updated, error out
606
- updatable , ok := table .(sql. UpdatableTable )
607
- if ! ok && updatable == nil {
594
+ for tableToBeUpdated , _ := range namesOfTablesToBeUpdated {
595
+ resolvedTable , ok := resolvedTablesMap [ strings . ToLower ( tableToBeUpdated )]
596
+ if ! ok {
608
597
return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
609
598
}
610
599
611
- keyless := sql .IsKeyless (updatable .Schema ())
612
- if keyless {
613
- return nil , sql .ErrUnsupportedFeature .New ("error: keyless tables unsupported for UPDATE JOIN" )
614
- }
615
-
616
- rowUpdatersByTable [tableToBeUpdated ] = updatable .Updater (ctx )
600
+ resolvedTables = append (resolvedTables , resolvedTable )
617
601
}
618
602
619
- return rowUpdatersByTable , nil
603
+ return resolvedTables , nil
620
604
}
621
605
622
606
// getTablesByName takes a node and returns all found resolved tables in a map.
0 commit comments