Skip to content

Commit b2eef4f

Browse files
author
James Cor
committed
works for this case
1 parent dc2d5d6 commit b2eef4f

File tree

5 files changed

+63
-13
lines changed

5 files changed

+63
-13
lines changed

enginetest/memory_engine_test.go

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,37 @@ func TestSingleQueryPrepared(t *testing.T) {
202202

203203
// Convenience test for debugging a single query. Unskip and set to the desired query.
204204
func TestSingleScript(t *testing.T) {
205-
t.Skip()
205+
//t.Skip()
206206
var scripts = []queries.ScriptTest{
207207
{
208-
Name: "AS OF propagates to nested CALLs",
209-
SetUpScript: []string{},
208+
Dialect: "mysql",
209+
Name: "UPDATE join – multiple tables, with trigger",
210+
SetUpScript: []string{
211+
"create table customers (id int primary key, name text, tier text)",
212+
"create table orders (id int primary key, customer_id int, status text)",
213+
"create table trigger_log (msg text)",
214+
`CREATE TRIGGER after_orders_update after update on orders for each row
215+
begin
216+
insert into trigger_log (msg) values(
217+
concat('Order ', OLD.id, ' status changed from ', OLD.status, ' to ', NEW.status));
218+
end;`,
219+
`Create trigger after_customers_update after update on customers for each row
220+
begin
221+
insert into trigger_log (msg) values(
222+
concat('Customer ', OLD.id, ' tier changed from ', OLD.tier, ' to ', NEW.tier));
223+
end;`,
224+
"insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');",
225+
"insert into orders values (101, 1, 'pending'), (102, 2, 'pending');",
226+
"update customers c join orders o on c.id = o.customer_id set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
227+
},
210228
Assertions: []queries.ScriptTestAssertion{
211229
{
212-
Query: "create procedure create_proc() create table t (i int primary key, j int);",
213-
Expected: []sql.Row{
214-
{types.NewOkResult(0)},
215-
},
216-
},
217-
{
218-
Query: "call create_proc()",
230+
Query: "SELECT * FROM trigger_log order by msg;",
219231
Expected: []sql.Row{
220-
{types.NewOkResult(0)},
232+
{"Customer 1 tier changed from silver to platinum"},
233+
{"Customer 2 tier changed from gold to platinum"},
234+
{"Order 101 status changed from pending to shipped"},
235+
{"Order 102 status changed from pending to shipped"},
221236
},
222237
},
223238
},
@@ -232,8 +247,8 @@ func TestSingleScript(t *testing.T) {
232247
panic(err)
233248
}
234249

235-
//engine.EngineAnalyzer().Debug = true
236-
//engine.EngineAnalyzer().Verbose = true
250+
engine.EngineAnalyzer().Debug = true
251+
engine.EngineAnalyzer().Verbose = true
237252

238253
enginetest.TestScriptWithEngine(t, engine, harness, test)
239254
}

sql/analyzer/aliases.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ func getTableAliases(n sql.Node, scope *plan.Scope) (TableAliases, error) {
167167
var recScope *plan.Scope
168168
if !scope.IsEmpty() {
169169
recScope = recScope.WithMemos(scope.Memos)
170+
recScope.InUpdateJoin = scope.InUpdateJoin
170171
}
171172

172173
aliasFn = func(node sql.Node) bool {
@@ -179,6 +180,9 @@ func getTableAliases(n sql.Node, scope *plan.Scope) (TableAliases, error) {
179180
case *plan.RecursiveCte:
180181
case sql.NameableNode:
181182
analysisErr = passAliases.addUnqualified(at.Name(), t)
183+
if scope != nil && scope.InUpdateJoin {
184+
analysisErr = nil
185+
}
182186
case *plan.UnresolvedTable:
183187
panic("Table not resolved")
184188
default:

sql/analyzer/fix_exec_indexes.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,11 @@ func fixExprToScope(e sql.Expression, scopes ...*idxScope) sql.Expression {
735735
// don't have the destination schema, and column references in default values are determined in the build phase)
736736

737737
idx, _ := newScope.getIdxId(e.Id(), e.String())
738+
739+
if e.String() == "old.id" {
740+
print()
741+
}
742+
738743
if idx >= 0 {
739744
return e.WithIndex(idx), transform.NewTree, nil
740745
}

sql/analyzer/triggers.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,33 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
487487
plan.NewSubqueryAlias("new", "", updateSrc.Child),
488488
),
489489
)
490+
491+
updateTargets := n.(*plan.Update).Child.(*plan.UpdateJoin).UpdateTargets
492+
if proj, isProj := updateSrc.Child.(*plan.Project); isProj {
493+
oldExprs := make([]sql.Expression, len(proj.Expressions()))
494+
newExprs := make([]sql.Expression, len(proj.Expressions()))
495+
for i, expr := range proj.Expressions() {
496+
if gf, isGf := expr.(*expression.GetField); isGf {
497+
if tbl, ok := updateTargets[gf.Table()]; ok {
498+
if tbl.(*plan.ResolvedTable).Name() == trigger.Table.(*plan.ResolvedTable).Name() {
499+
oldExprs[i] = gf.WithTable("old")
500+
newExprs[i] = gf.WithTable("new")
501+
continue
502+
}
503+
}
504+
}
505+
oldExprs[i] = expr
506+
newExprs[i] = expr
507+
}
508+
scopeNode.Child = plan.NewCrossJoin(
509+
plan.NewProject(oldExprs, proj.Child),
510+
plan.NewProject(newExprs, proj.Child),
511+
)
512+
}
490513
}
491514
// Triggers are wrapped in prepend nodes, which means that the parent scope is included
492515
s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache())
516+
s.InUpdateJoin = true
493517
triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, DefaultRuleSelector, qFlags)
494518
case sqlparser.DeleteStr:
495519
scopeNode := plan.NewProject(

sql/plan/scope.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type Scope struct {
4646
JoinTrees []string
4747

4848
inInsertSource bool
49+
InUpdateJoin bool
4950
}
5051

5152
func (s *Scope) IsEmpty() bool {
@@ -79,6 +80,7 @@ func (s *Scope) NewScope(node sql.Node) *Scope {
7980
recursionDepth: s.recursionDepth + 1,
8081
Procedures: s.Procedures,
8182
joinSiblings: s.joinSiblings,
83+
InUpdateJoin: s.InUpdateJoin,
8284
}
8385
}
8486

0 commit comments

Comments
 (0)