Skip to content

Commit a0ca17b

Browse files
authored
fix(core): output nullability of IfThen depends on all possible outputs
For an IfThen expression, the output should be nullable if any of the possible outputs is nullable
1 parent 32f45fe commit a0ca17b

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,13 @@ abstract static class IfThen implements Expression {
584584
public abstract Expression elseClause();
585585

586586
public Type getType() {
587-
return elseClause().getType();
587+
Type elseType = elseClause().getType();
588+
589+
// If any of the clauses are nullable, the whole expression is also nullable.
590+
if (ifClauses().stream().anyMatch(clause -> clause.then().getType().nullable())) {
591+
return TypeCreator.asNullable(elseType);
592+
}
593+
return elseType;
588594
}
589595

590596
public static ImmutableExpression.IfThen.Builder builder() {
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package io.substrait.type.proto;
2+
3+
import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
import static org.junit.jupiter.api.Assertions.assertFalse;
6+
import static org.junit.jupiter.api.Assertions.assertTrue;
7+
8+
import io.substrait.TestBase;
9+
import io.substrait.expression.Expression;
10+
import io.substrait.expression.ExpressionCreator;
11+
import io.substrait.expression.proto.ExpressionProtoConverter;
12+
import io.substrait.expression.proto.ProtoExpressionConverter;
13+
import java.util.Arrays;
14+
import org.junit.jupiter.api.Test;
15+
16+
public class IfThenRoundtripTest extends TestBase {
17+
18+
@Test
19+
void ifThenNotNullable() {
20+
final Expression.IfThen ifRel =
21+
b.ifThen(
22+
Arrays.asList(
23+
b.ifClause(ExpressionCreator.bool(false, false), ExpressionCreator.i64(false, 1))),
24+
ExpressionCreator.i64(false, 2));
25+
assertFalse(ifRel.getType().nullable());
26+
27+
var to = new ExpressionProtoConverter(null, null);
28+
var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter);
29+
assertEquals(ifRel, from.from(ifRel.accept(to)));
30+
}
31+
32+
@Test
33+
void ifThenNullable() {
34+
final Expression.IfThen ifRel =
35+
b.ifThen(
36+
Arrays.asList(
37+
b.ifClause(ExpressionCreator.bool(true, false), ExpressionCreator.i64(true, 1))),
38+
ExpressionCreator.i64(false, 2));
39+
assertTrue(ifRel.getType().nullable());
40+
41+
var to = new ExpressionProtoConverter(null, null);
42+
var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter);
43+
assertEquals(ifRel, from.from(ifRel.accept(to)));
44+
}
45+
}

0 commit comments

Comments
 (0)