Skip to content

Commit 60575b3

Browse files
authored
fix(core): wrong type derivation for ConsistentPartitionWindow (#286)
The output of a ConsistentPartitionWindow consists of: * all input columns * all window expressions The deriveRecordType() has been updated to reflect this
1 parent a0ca17b commit 60575b3

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ protected Type.Struct deriveRecordType() {
2929
.struct(
3030
Stream.concat(
3131
initial.fields().stream(),
32-
getPartitionExpressions().stream().map(Expression::getType)));
32+
getWindowFunctions().stream().map(WindowRelFunctionInvocation::outputType)));
3333
}
3434

3535
@Override

core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java

+79-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
public class ConsistentPartitionWindowRelRoundtripTest extends TestBase {
1818

1919
@Test
20-
void consistentPartitionWindowRoundtrip() {
20+
void consistentPartitionWindowRoundtripSingle() {
2121
var windowFunctionDeclaration =
2222
defaultExtensionCollection.getWindowFunction(
2323
SimpleExtension.FunctionAnchor.of(
@@ -63,5 +63,83 @@ void consistentPartitionWindowRoundtrip() {
6363
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1);
6464
io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel);
6565
assertEquals(rel1, rel2);
66+
67+
// Make sure that the record types match I64, I16, I32 and then the I64 from the window
68+
// function.
69+
assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64));
70+
}
71+
72+
@Test
73+
void consistentPartitionWindowRoundtripMulti() {
74+
var windowFunctionLeadDeclaration =
75+
defaultExtensionCollection.getWindowFunction(
76+
SimpleExtension.FunctionAnchor.of(
77+
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any"));
78+
var windowFunctionLagDeclaration =
79+
defaultExtensionCollection.getWindowFunction(
80+
SimpleExtension.FunctionAnchor.of(
81+
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any"));
82+
Rel input =
83+
b.namedScan(
84+
Arrays.asList("test"),
85+
Arrays.asList("a", "b", "c"),
86+
Arrays.asList(R.I64, R.I16, R.I32));
87+
Rel rel1 =
88+
ImmutableConsistentPartitionWindow.builder()
89+
.input(input)
90+
.windowFunctions(
91+
Arrays.asList(
92+
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
93+
.declaration(windowFunctionLeadDeclaration)
94+
// lead(a)
95+
.arguments(Arrays.asList(b.fieldReference(input, 0)))
96+
.options(
97+
Arrays.asList(
98+
FunctionOption.builder()
99+
.name("option")
100+
.addValues("VALUE1", "VALUE2")
101+
.build()))
102+
.outputType(R.I64)
103+
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
104+
.invocation(Expression.AggregationInvocation.ALL)
105+
.lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED)
106+
.upperBound(ImmutableWindowBound.Following.CURRENT_ROW)
107+
.boundsType(Expression.WindowBoundsType.RANGE)
108+
.build(),
109+
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
110+
.declaration(windowFunctionLagDeclaration)
111+
// lag(a)
112+
.arguments(Arrays.asList(b.fieldReference(input, 0)))
113+
.options(
114+
Arrays.asList(
115+
FunctionOption.builder()
116+
.name("option")
117+
.addValues("VALUE1", "VALUE2")
118+
.build()))
119+
.outputType(R.I64)
120+
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
121+
.invocation(Expression.AggregationInvocation.ALL)
122+
.lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED)
123+
.upperBound(ImmutableWindowBound.Following.CURRENT_ROW)
124+
.boundsType(Expression.WindowBoundsType.RANGE)
125+
.build()))
126+
// PARTITION BY b
127+
.partitionExpressions(Arrays.asList(b.fieldReference(input, 1)))
128+
.sorts(
129+
Arrays.asList(
130+
Expression.SortField.builder()
131+
// SORT BY c
132+
.expr(b.fieldReference(input, 2))
133+
.direction(Expression.SortDirection.ASC_NULLS_FIRST)
134+
.build()))
135+
.build();
136+
137+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1);
138+
io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel);
139+
assertEquals(rel1, rel2);
140+
141+
// Make sure that the record types match I64, I16, I32 and then the I64 and I64 from the window
142+
// functions.
143+
assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64, R.I64));
66144
}
67145
}

0 commit comments

Comments
 (0)