Skip to content

Commit c44b30c

Browse files
authored
Merge pull request #3016 from dolthub/fulghum/update_join
Minor updates to support `UPDATE ... FROM` in Doltgres
2 parents 18900d9 + 5e9e2f3 commit c44b30c

File tree

4 files changed

+52
-57
lines changed

4 files changed

+52
-57
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
122122
if plan.IsEmptyTable(n.Child) {
123123
return n, transform.SameTree, nil
124124
}
125+
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
126+
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
125127
updateDest, err := plan.GetUpdatable(n.Child)
126128
if err != nil {
127129
return nil, transform.SameTree, err

sql/plan/update_join.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ func (u *UpdateJoin) DebugString() string {
5454

5555
// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
5656
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
57+
// TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table.
58+
// Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code
59+
// expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable
60+
// doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
61+
// We should revamp this function so that we can communicate multiple tables being updated.
5762
return &updatableJoinTable{
5863
updaters: u.Updaters,
5964
joinNode: u.Child.(*UpdateSource).Child,

sql/planbuilder/dml.go

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
490490
return
491491
}
492492

493+
// buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
494+
// children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
495+
// don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
496+
// analyzer processing that converts the subquery into a join, and then requires the same logic to
497+
// create an UpdateJoin node under the original Update node.
493498
func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
494499
// TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
495500
// The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
@@ -532,44 +537,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
532537
update.IsProcNested = b.ProcCtx().DbName != ""
533538

534539
var checks []*sql.CheckConstraint
535-
if join, ok := outScope.node.(*plan.JoinNode); ok {
536-
// TODO this doesn't work, a lot of the time the top node
537-
// is a filter. This would have to go before we build the
538-
// filter/accessory nodes. But that errors for a lot of queries.
539-
source := plan.NewUpdateSource(
540-
join,
541-
ignore,
542-
updateExprs,
543-
)
544-
updaters, err := rowUpdatersByTable(b.ctx, source, join)
540+
if hasJoinNode(outScope.node) {
541+
tablesToUpdate, err := getResolvedTablesToUpdate(b.ctx, update.Child, outScope.node)
545542
if err != nil {
546543
b.handleErr(err)
547544
}
548-
updateJoin := plan.NewUpdateJoin(updaters, source)
549-
update.Child = updateJoin
550-
transform.Inspect(update, func(n sql.Node) bool {
551-
// todo maybe this should be later stage
552-
switch n := n.(type) {
553-
case sql.NameableNode:
554-
if _, ok := updaters[n.Name()]; ok {
555-
rt := getResolvedTable(n)
556-
tableScope := inScope.push()
557-
for _, c := range rt.Schema() {
558-
tableScope.addColumn(scopeColumn{
559-
db: rt.SqlDatabase.Name(),
560-
table: strings.ToLower(n.Name()),
561-
tableId: tableScope.tables[strings.ToLower(n.Name())],
562-
col: strings.ToLower(c.Name),
563-
typ: c.Type,
564-
nullable: c.Nullable,
565-
})
566-
}
567-
checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...)
568-
}
569-
default:
545+
546+
for _, rt := range tablesToUpdate {
547+
tableScope := inScope.push()
548+
for _, c := range rt.Schema() {
549+
tableScope.addColumn(scopeColumn{
550+
db: rt.SqlDatabase.Name(),
551+
table: strings.ToLower(rt.Name()),
552+
tableId: tableScope.tables[strings.ToLower(rt.Name())],
553+
col: strings.ToLower(c.Name),
554+
typ: c.Type,
555+
nullable: c.Nullable,
556+
})
570557
}
571-
return true
572-
})
558+
checks = append(checks, b.loadChecksFromTable(tableScope, rt.Table)...)
559+
}
573560
} else {
574561
transform.Inspect(update, func(n sql.Node) bool {
575562
// todo maybe this should be later stage
@@ -588,35 +575,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
588575
return
589576
}
590577

591-
// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
592-
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
593-
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
594-
resolvedTables := getTablesByName(ij)
595-
596-
rowUpdatersByTable := make(map[string]sql.RowUpdater)
597-
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
598-
resolvedTable, ok := resolvedTables[strings.ToLower(tableToBeUpdated)]
599-
if !ok {
600-
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
578+
// hasJoinNode returns true if |node| or any child is a JoinNode.
579+
func hasJoinNode(node sql.Node) bool {
580+
updateJoinFound := false
581+
transform.Inspect(node, func(n sql.Node) bool {
582+
if _, ok := n.(*plan.JoinNode); ok {
583+
updateJoinFound = true
601584
}
585+
return !updateJoinFound
586+
})
587+
return updateJoinFound
588+
}
602589

603-
var table = resolvedTable.UnderlyingTable()
590+
func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
591+
namesOfTablesToBeUpdated := getTablesToBeUpdated(node)
592+
resolvedTablesMap := getTablesByName(ij)
604593

605-
// If there is no UpdatableTable for a table being updated, error out
606-
updatable, ok := table.(sql.UpdatableTable)
607-
if !ok && updatable == nil {
594+
for tableToBeUpdated, _ := range namesOfTablesToBeUpdated {
595+
resolvedTable, ok := resolvedTablesMap[strings.ToLower(tableToBeUpdated)]
596+
if !ok {
608597
return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
609598
}
610599

611-
keyless := sql.IsKeyless(updatable.Schema())
612-
if keyless {
613-
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
614-
}
615-
616-
rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
600+
resolvedTables = append(resolvedTables, resolvedTable)
617601
}
618602

619-
return rowUpdatersByTable, nil
603+
return resolvedTables, nil
620604
}
621605

622606
// getTablesByName takes a node and returns all found resolved tables in a map.

sql/rowexec/update.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,12 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
258258
if errors.Is(err, sql.ErrKeyNotFound) {
259259
cache.Put(hash, struct{}{})
260260

261-
// updateJoin counts matched rows from join output
262-
u.accumulator.handleRowMatched()
261+
// updateJoin counts matched rows from join output, unless a RETURNING clause
262+
// is in use, in which case there will not be an accumulator assigned, since we
263+
// don't need to return the count of updated rows, just the RETURNING expressions.
264+
if u.accumulator != nil {
265+
u.accumulator.handleRowMatched()
266+
}
263267

264268
continue
265269
} else if err != nil {

0 commit comments

Comments
 (0)