Skip to content

Commit cd45932

Browse files
authored
Merge pull request #3032 from dolthub/angela/ifnull
Generalize types in `CASE`, `IF`, and `IFNULL`
2 parents c1e5dc8 + a8173e2 commit cd45932

File tree

11 files changed

+299
-110
lines changed

11 files changed

+299
-110
lines changed

enginetest/queries/integration_plans.go

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enginetest/queries/queries.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6092,7 +6092,7 @@ SELECT * FROM cte WHERE d = 2;`,
60926092
Query: `SELECT if(123 = 123, NULL, NULL = 1)`,
60936093
Expected: []sql.Row{{nil}},
60946094
ExpectedColumns: []*sql.Column{
6095-
{Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Int64}, // TODO: this should be getting coerced to bool
6095+
{Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Boolean},
60966096
},
60976097
},
60986098
{

enginetest/queries/script_queries.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8712,6 +8712,26 @@ where
87128712
},
87138713
},
87148714
},
8715+
{
8716+
Name: "tinyint column does not restrict IF or IFNULL output",
8717+
// https://github.com/dolthub/dolt/issues/9321
8718+
SetUpScript: []string{
8719+
"create table t0 (c0 tinyint);",
8720+
"insert into t0 values (null);",
8721+
},
8722+
Assertions: []ScriptTestAssertion{
8723+
{
8724+
Query: "select ifnull(t0.c0, 128) as ref0 from t0",
8725+
Expected: []sql.Row{
8726+
{128},
8727+
},
8728+
},
8729+
{
8730+
Query: "select if(t0.c0 = 1, t0.c0, 128) as ref0 from t0",
8731+
Expected: []sql.Row{{128}},
8732+
},
8733+
},
8734+
},
87158735
}
87168736

87178737
var SpatialScriptTests = []ScriptTest{

server/handler_test.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ func TestHandlerOutput(t *testing.T) {
212212
})
213213
require.NoError(t, err)
214214
require.Equal(t, 1, len(result.Rows))
215-
require.Equal(t, sqltypes.Int64, result.Rows[0][0].Type())
215+
require.Equal(t, sqltypes.Int16, result.Rows[0][0].Type())
216216
require.Equal(t, []byte("456"), result.Rows[0][0].ToBytes())
217217
})
218218
}
@@ -471,7 +471,8 @@ func TestHandlerComPrepareExecute(t *testing.T) {
471471
},
472472
},
473473
schema: []*query.Field{
474-
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
474+
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32,
475+
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
475476
},
476477
expected: []sql.Row{
477478
{0}, {1}, {2}, {3}, {4},
@@ -550,7 +551,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) {
550551
},
551552
},
552553
schema: []*query.Field{
553-
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
554+
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32,
555+
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
554556
},
555557
expected: []sql.Row{
556558
{0}, {1}, {2}, {3}, {4},
@@ -567,7 +569,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) {
567569
BindVars: nil,
568570
},
569571
schema: []*query.Field{
570-
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
572+
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16,
573+
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
571574
},
572575
expected: []sql.Row{
573576
{1000},
@@ -584,7 +587,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) {
584587
BindVars: nil,
585588
},
586589
schema: []*query.Field{
587-
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
590+
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16,
591+
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
588592
},
589593
expected: []sql.Row{
590594
{-129},

sql/expression/case.go

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -43,71 +43,14 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression
4343
return &Case{expr, branches, elseExpr}
4444
}
4545

46-
// From the description of operator typing here:
47-
// https://dev.mysql.com/doc/refman/8.0/en/flow-control-functions.html#operator_case
48-
func combinedCaseBranchType(left, right sql.Type) sql.Type {
49-
if left == types.Null {
50-
return right
51-
}
52-
if right == types.Null {
53-
return left
54-
}
55-
56-
// Our current implementation of StringType.Convert(enum), does not match MySQL's behavior.
57-
// So, we make sure to return Enums in this particular case.
58-
// More details: https://github.com/dolthub/dolt/issues/8598
59-
if types.IsEnum(left) && types.IsEnum(right) {
60-
return right
61-
}
62-
if types.IsSet(left) && types.IsSet(right) {
63-
return right
64-
}
65-
if types.IsTextOnly(left) && types.IsTextOnly(right) {
66-
return types.LongText
67-
}
68-
if types.IsTextBlob(left) && types.IsTextBlob(right) {
69-
return types.LongBlob
70-
}
71-
if types.IsTime(left) && types.IsTime(right) {
72-
if left == right {
73-
return left
74-
}
75-
return types.DatetimeMaxPrecision
76-
}
77-
if types.IsNumber(left) && types.IsNumber(right) {
78-
if left == types.Float64 || right == types.Float64 {
79-
return types.Float64
80-
}
81-
if left == types.Float32 || right == types.Float32 {
82-
return types.Float32
83-
}
84-
if types.IsDecimal(left) || types.IsDecimal(right) {
85-
return types.MustCreateDecimalType(65, 10)
86-
}
87-
if left == types.Uint64 && types.IsSigned(right) ||
88-
right == types.Uint64 && types.IsSigned(left) {
89-
return types.MustCreateDecimalType(65, 10)
90-
}
91-
if !types.IsSigned(left) && !types.IsSigned(right) {
92-
return types.Uint64
93-
} else {
94-
return types.Int64
95-
}
96-
}
97-
if types.IsJSON(left) && types.IsJSON(right) {
98-
return types.JSON
99-
}
100-
return types.LongText
101-
}
102-
10346
// Type implements the sql.Expression interface.
10447
func (c *Case) Type() sql.Type {
10548
curr := types.Null
10649
for _, b := range c.Branches {
107-
curr = combinedCaseBranchType(curr, b.Value.Type())
50+
curr = types.GeneralizeTypes(curr, b.Value.Type())
10851
}
10952
if c.Else != nil {
110-
curr = combinedCaseBranchType(curr, c.Else.Type())
53+
curr = types.GeneralizeTypes(curr, c.Else.Type())
11154
}
11255
return curr
11356
}

sql/expression/case_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ func TestCaseType(t *testing.T) {
161161
}
162162
}
163163

164-
decimalType := types.MustCreateDecimalType(65, 10)
165-
164+
decimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)
165+
uint64DecimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 0)
166166
testCases := []struct {
167167
name string
168168
c *Case
@@ -175,13 +175,13 @@ func TestCaseType(t *testing.T) {
175175
},
176176
{
177177
"unsigned promoted and unsigned",
178-
caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint32)),
178+
caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint64)),
179179
types.Uint64,
180180
},
181181
{
182182
"signed promoted and signed",
183183
caseExpr(NewLiteral(int8(0), types.Int8), NewLiteral(int32(1), types.Int32)),
184-
types.Int64,
184+
types.Int32,
185185
},
186186
{
187187
"int and float to float",
@@ -216,7 +216,7 @@ func TestCaseType(t *testing.T) {
216216
{
217217
"uint64 and int8 to decimal",
218218
caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral(int8(0), types.Int8)),
219-
decimalType,
219+
uint64DecimalType,
220220
},
221221
{
222222
"int and text to text",

sql/expression/function/if.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,15 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
8989
return nil, err
9090
}
9191
}
92-
eval, _, err = f.Type().Convert(ctx, eval)
92+
if ret, _, err := f.Type().Convert(ctx, eval); err == nil {
93+
return ret, nil
94+
}
9395
return eval, err
9496
}
9597

9698
// Type implements the Expression interface.
9799
func (f *If) Type() sql.Type {
98-
// if either type is string type, this should be a string type, regardless need to promote
99-
typ1 := f.ifTrue.Type()
100-
typ2 := f.ifFalse.Type()
101-
if types.IsText(typ1) || types.IsText(typ2) {
102-
return types.Text
103-
}
104-
105-
if typ1 == types.Null {
106-
return typ2.Promote()
107-
}
108-
return typ1.Promote()
100+
return types.GeneralizeTypes(f.ifTrue.Type(), f.ifFalse.Type())
109101
}
110102

111103
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/expression/function/ifnull.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,32 @@ func (f *IfNull) Description() string {
5252

5353
// Eval implements the Expression interface.
5454
func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
55+
t := f.Type()
56+
5557
left, err := f.LeftChild.Eval(ctx, row)
5658
if err != nil {
5759
return nil, err
5860
}
5961
if left != nil {
60-
return left, nil
62+
if ret, _, err := t.Convert(ctx, left); err == nil {
63+
return ret, nil
64+
}
65+
return left, err
6166
}
6267

6368
right, err := f.RightChild.Eval(ctx, row)
6469
if err != nil {
6570
return nil, err
6671
}
67-
return right, nil
72+
if ret, _, err := t.Convert(ctx, right); err == nil {
73+
return ret, nil
74+
}
75+
return right, err
6876
}
6977

7078
// Type implements the Expression interface.
7179
func (f *IfNull) Type() sql.Type {
72-
if types.IsNull(f.LeftChild) {
73-
if types.IsNull(f.RightChild) {
74-
return types.Null
75-
}
76-
return f.RightChild.Type()
77-
}
78-
return f.LeftChild.Type()
80+
return types.GeneralizeTypes(f.LeftChild.Type(), f.RightChild.Type())
7981
}
8082

8183
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/expression/function/ifnull_test.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,28 @@ import (
2626

2727
func TestIfNull(t *testing.T) {
2828
testCases := []struct {
29-
expression interface{}
30-
value interface{}
31-
expected interface{}
29+
expression interface{}
30+
expressionType sql.Type
31+
value interface{}
32+
valueType sql.Type
33+
expected interface{}
34+
expectedType sql.Type
3235
}{
33-
{"foo", "bar", "foo"},
34-
{"foo", "foo", "foo"},
35-
{nil, "foo", "foo"},
36-
{"foo", nil, "foo"},
37-
{nil, nil, nil},
38-
{"", nil, ""},
36+
{"foo", types.LongText, "bar", types.LongText, "foo", types.LongText},
37+
{"foo", types.LongText, "foo", types.LongText, "foo", types.LongText},
38+
{nil, types.LongText, "foo", types.LongText, "foo", types.LongText},
39+
{"foo", types.LongText, nil, types.LongText, "foo", types.LongText},
40+
{nil, types.LongText, nil, types.LongText, nil, types.LongText},
41+
{"", types.LongText, nil, types.LongText, "", types.LongText},
42+
{nil, types.Int8, 128, types.Int64, int64(128), types.Int64},
3943
}
4044

41-
f := NewIfNull(
42-
expression.NewGetField(0, types.LongText, "expression", true),
43-
expression.NewGetField(1, types.LongText, "value", true),
44-
)
45-
require.Equal(t, types.LongText, f.Type())
46-
4745
for _, tc := range testCases {
46+
f := NewIfNull(
47+
expression.NewGetField(0, tc.expressionType, "expression", true),
48+
expression.NewGetField(1, tc.valueType, "value", true),
49+
)
50+
require.Equal(t, tc.expectedType, f.Type())
4851
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value))
4952
require.NoError(t, err)
5053
require.Equal(t, tc.expected, v)

0 commit comments

Comments
 (0)