Skip to content

Commit fc9dc40

Browse files
committed
Add regression test for #1889
1 parent 01dac86 commit fc9dc40

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

enginetest/join_planning_tests.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
type JoinPlanTest struct {
3232
q string
3333
types []plan.JoinType
34+
indexes []string
3435
exp []sql.Row
3536
order []string
3637
skipOld bool
@@ -918,6 +919,50 @@ join uv d on d.u = c.x`,
918919
},
919920
},
920921
},
922+
{
923+
// This is a regression test for https://github.com/dolthub/go-mysql-server/pull/1889.
924+
// We should always prefer a more specific index over a less specific index for lookups.
925+
name: "lookup join multiple indexes",
926+
setup: []string{
927+
"create table lhs (a int, b int, c int);",
928+
"create table rhs (a int, b int, c int, d int, index a_idx(a), index abcd_idx(a,b,c,d));",
929+
"insert into lhs values (0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1);",
930+
"insert into rhs values " +
931+
"(0, 0, 0, 0)," +
932+
"(0, 0, 0, 1)," +
933+
"(0, 0, 1, 0)," +
934+
"(0, 0, 1, 1)," +
935+
"(0, 1, 0, 0)," +
936+
"(0, 1, 0, 1)," +
937+
"(0, 1, 1, 0)," +
938+
"(0, 1, 1, 1)," +
939+
"(1, 0, 0, 0)," +
940+
"(1, 0, 0, 1)," +
941+
"(1, 0, 1, 0)," +
942+
"(1, 0, 1, 1)," +
943+
"(1, 1, 0, 0)," +
944+
"(1, 1, 0, 1)," +
945+
"(1, 1, 1, 0)," +
946+
"(1, 1, 1, 1);",
947+
},
948+
tests: []JoinPlanTest{
949+
{
950+
q: "select rhs.* from lhs left join rhs on lhs.a = rhs.a and lhs.b = rhs.b and lhs.c = rhs.c",
951+
types: []plan.JoinType{plan.JoinTypeLeftOuterLookup},
952+
indexes: []string{"abcd_idx"},
953+
exp: []sql.Row{
954+
{0, 0, 0, 0},
955+
{0, 0, 0, 1},
956+
{0, 0, 1, 0},
957+
{0, 0, 1, 1},
958+
{0, 1, 1, 0},
959+
{0, 1, 1, 1},
960+
{1, 1, 1, 0},
961+
{1, 1, 1, 1},
962+
},
963+
},
964+
},
965+
},
921966
}
922967

923968
func TestJoinPlanning(t *testing.T, harness Harness) {
@@ -930,6 +975,9 @@ func TestJoinPlanning(t *testing.T, harness Harness) {
930975
if tt.types != nil {
931976
evalJoinTypeTest(t, harness, e, tt)
932977
}
978+
if tt.indexes != nil {
979+
evalIndexTest(t, harness, e, tt)
980+
}
933981
if tt.exp != nil {
934982
evalJoinCorrectness(t, harness, e, tt.q, tt.q, tt.exp, tt.skipOld)
935983
}
@@ -966,6 +1014,31 @@ func evalJoinTypeTest(t *testing.T, harness Harness, e *sqle.Engine, tt JoinPlan
9661014
})
9671015
}
9681016

1017+
func evalIndexTest(t *testing.T, harness Harness, e *sqle.Engine, tt JoinPlanTest) {
1018+
t.Run(tt.q+" join indexes", func(t *testing.T) {
1019+
if tt.skipOld {
1020+
t.Skip()
1021+
}
1022+
1023+
ctx := NewContext(harness)
1024+
ctx = ctx.WithQuery(tt.q)
1025+
1026+
a, err := e.AnalyzeQuery(ctx, tt.q)
1027+
require.NoError(t, err)
1028+
1029+
idxs := collectIndexes(a)
1030+
var exp []string
1031+
for _, i := range tt.indexes {
1032+
exp = append(exp, i)
1033+
}
1034+
var cmp []string
1035+
for _, i := range idxs {
1036+
cmp = append(cmp, i.ID())
1037+
}
1038+
require.Equal(t, exp, cmp, fmt.Sprintf("unexpected plan:\n%s", sql.DebugString(a)))
1039+
})
1040+
}
1041+
9691042
func evalJoinCorrectness(t *testing.T, harness Harness, e *sqle.Engine, name, q string, exp []sql.Row, skipOld bool) {
9701043
t.Run(name, func(t *testing.T) {
9711044
if vh, ok := harness.(VersionedHarness); (ok && vh.Version() == sql.VersionStable && skipOld) || (!ok && skipOld) {
@@ -1018,6 +1091,35 @@ func collectJoinTypes(n sql.Node) []plan.JoinType {
10181091
return types
10191092
}
10201093

1094+
func collectIndexes(n sql.Node) []sql.Index {
1095+
var indexes []sql.Index
1096+
transform.Inspect(n, func(n sql.Node) bool {
1097+
if n == nil {
1098+
return true
1099+
}
1100+
access, ok := n.(*plan.IndexedTableAccess)
1101+
if ok {
1102+
indexes = append(indexes, access.Index())
1103+
return true
1104+
}
1105+
1106+
if ex, ok := n.(sql.Expressioner); ok {
1107+
for _, e := range ex.Expressions() {
1108+
transform.InspectExpr(e, func(e sql.Expression) bool {
1109+
sq, ok := e.(*plan.Subquery)
1110+
if !ok {
1111+
return false
1112+
}
1113+
indexes = append(indexes, collectIndexes(sq.Query)...)
1114+
return false
1115+
})
1116+
}
1117+
}
1118+
return true
1119+
})
1120+
return indexes
1121+
}
1122+
10211123
func evalJoinOrder(t *testing.T, harness Harness, e *sqle.Engine, q string, exp []string, skipOld bool) {
10221124
t.Run(q+" join order", func(t *testing.T) {
10231125
if vh, ok := harness.(VersionedHarness); (ok && vh.Version() == sql.VersionStable && skipOld) || (!ok && skipOld) {
@@ -1064,6 +1166,9 @@ func TestJoinPlanningPrepared(t *testing.T, harness Harness) {
10641166
if tt.types != nil {
10651167
evalJoinTypeTestPrepared(t, harness, e, tt, tt.skipOld)
10661168
}
1169+
if tt.indexes != nil {
1170+
evalJoinIndexTestPrepared(t, harness, e, tt, tt.skipOld)
1171+
}
10671172
if tt.exp != nil {
10681173
evalJoinCorrectnessPrepared(t, harness, e, tt.q, tt.q, tt.exp, tt.skipOld)
10691174
}
@@ -1115,6 +1220,46 @@ func evalJoinTypeTestPrepared(t *testing.T, harness Harness, e *sqle.Engine, tt
11151220
})
11161221
}
11171222

1223+
func evalJoinIndexTestPrepared(t *testing.T, harness Harness, e *sqle.Engine, tt JoinPlanTest, skipOld bool) {
1224+
t.Run(tt.q+" join indexes", func(t *testing.T) {
1225+
if vh, ok := harness.(VersionedHarness); (ok && vh.Version() == sql.VersionStable && skipOld) || (!ok && skipOld) {
1226+
t.Skip()
1227+
}
1228+
1229+
ctx := NewContext(harness)
1230+
ctx = ctx.WithQuery(tt.q)
1231+
1232+
bindings, err := injectBindVarsAndPrepare(t, ctx, e, tt.q)
1233+
require.NoError(t, err)
1234+
1235+
p, ok := e.PreparedDataCache.GetCachedStmt(ctx.Session.ID(), tt.q)
1236+
require.True(t, ok, "prepared statement not found")
1237+
1238+
if len(bindings) > 0 {
1239+
var usedBindings map[string]bool
1240+
p, usedBindings, err = plan.ApplyBindings(p, bindings)
1241+
require.NoError(t, err)
1242+
for binding := range bindings {
1243+
require.True(t, usedBindings[binding], "unused binding %s", binding)
1244+
}
1245+
}
1246+
1247+
a, _, err := e.Analyzer.AnalyzePrepared(ctx, p, nil)
1248+
require.NoError(t, err)
1249+
1250+
idxs := collectIndexes(a)
1251+
var exp []string
1252+
for _, i := range tt.indexes {
1253+
exp = append(exp, i)
1254+
}
1255+
var cmp []string
1256+
for _, i := range idxs {
1257+
cmp = append(cmp, i.ID())
1258+
}
1259+
require.Equal(t, exp, cmp, fmt.Sprintf("unexpected plan:\n%s", sql.DebugString(a)))
1260+
})
1261+
}
1262+
11181263
func evalJoinCorrectnessPrepared(t *testing.T, harness Harness, e *sqle.Engine, name, q string, exp []sql.Row, skipOld bool) {
11191264
t.Run(q, func(t *testing.T) {
11201265
if vh, ok := harness.(VersionedHarness); (ok && vh.Version() == sql.VersionStable && skipOld) || (!ok && skipOld) {

0 commit comments

Comments
 (0)