|
17 | 17 | public class ConsistentPartitionWindowRelRoundtripTest extends TestBase {
|
18 | 18 |
|
19 | 19 | @Test
|
20 |
| - void consistentPartitionWindowRoundtrip() { |
| 20 | + void consistentPartitionWindowRoundtripSingle() { |
21 | 21 | var windowFunctionDeclaration =
|
22 | 22 | defaultExtensionCollection.getWindowFunction(
|
23 | 23 | SimpleExtension.FunctionAnchor.of(
|
@@ -63,5 +63,83 @@ void consistentPartitionWindowRoundtrip() {
|
63 | 63 | io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1);
|
64 | 64 | io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel);
|
65 | 65 | assertEquals(rel1, rel2);
|
| 66 | + |
| 67 | + // Make sure that the record types match I64, I16, I32 and then the I64 from the window |
| 68 | + // function. |
| 69 | + assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64)); |
| 70 | + } |
| 71 | + |
| 72 | + @Test |
| 73 | + void consistentPartitionWindowRoundtripMulti() { |
| 74 | + var windowFunctionLeadDeclaration = |
| 75 | + defaultExtensionCollection.getWindowFunction( |
| 76 | + SimpleExtension.FunctionAnchor.of( |
| 77 | + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); |
| 78 | + var windowFunctionLagDeclaration = |
| 79 | + defaultExtensionCollection.getWindowFunction( |
| 80 | + SimpleExtension.FunctionAnchor.of( |
| 81 | + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); |
| 82 | + Rel input = |
| 83 | + b.namedScan( |
| 84 | + Arrays.asList("test"), |
| 85 | + Arrays.asList("a", "b", "c"), |
| 86 | + Arrays.asList(R.I64, R.I16, R.I32)); |
| 87 | + Rel rel1 = |
| 88 | + ImmutableConsistentPartitionWindow.builder() |
| 89 | + .input(input) |
| 90 | + .windowFunctions( |
| 91 | + Arrays.asList( |
| 92 | + ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() |
| 93 | + .declaration(windowFunctionLeadDeclaration) |
| 94 | + // lead(a) |
| 95 | + .arguments(Arrays.asList(b.fieldReference(input, 0))) |
| 96 | + .options( |
| 97 | + Arrays.asList( |
| 98 | + FunctionOption.builder() |
| 99 | + .name("option") |
| 100 | + .addValues("VALUE1", "VALUE2") |
| 101 | + .build())) |
| 102 | + .outputType(R.I64) |
| 103 | + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) |
| 104 | + .invocation(Expression.AggregationInvocation.ALL) |
| 105 | + .lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED) |
| 106 | + .upperBound(ImmutableWindowBound.Following.CURRENT_ROW) |
| 107 | + .boundsType(Expression.WindowBoundsType.RANGE) |
| 108 | + .build(), |
| 109 | + ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() |
| 110 | + .declaration(windowFunctionLagDeclaration) |
| 111 | + // lag(a) |
| 112 | + .arguments(Arrays.asList(b.fieldReference(input, 0))) |
| 113 | + .options( |
| 114 | + Arrays.asList( |
| 115 | + FunctionOption.builder() |
| 116 | + .name("option") |
| 117 | + .addValues("VALUE1", "VALUE2") |
| 118 | + .build())) |
| 119 | + .outputType(R.I64) |
| 120 | + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) |
| 121 | + .invocation(Expression.AggregationInvocation.ALL) |
| 122 | + .lowerBound(ImmutableWindowBound.Unbounded.UNBOUNDED) |
| 123 | + .upperBound(ImmutableWindowBound.Following.CURRENT_ROW) |
| 124 | + .boundsType(Expression.WindowBoundsType.RANGE) |
| 125 | + .build())) |
| 126 | + // PARTITION BY b |
| 127 | + .partitionExpressions(Arrays.asList(b.fieldReference(input, 1))) |
| 128 | + .sorts( |
| 129 | + Arrays.asList( |
| 130 | + Expression.SortField.builder() |
| 131 | + // SORT BY c |
| 132 | + .expr(b.fieldReference(input, 2)) |
| 133 | + .direction(Expression.SortDirection.ASC_NULLS_FIRST) |
| 134 | + .build())) |
| 135 | + .build(); |
| 136 | + |
| 137 | + io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1); |
| 138 | + io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel); |
| 139 | + assertEquals(rel1, rel2); |
| 140 | + |
| 141 | + // Make sure that the record types match I64, I16, I32 and then the I64 and I64 from the window |
| 142 | + // functions. |
| 143 | + assertEquals(rel2.getRecordType().fields(), Arrays.asList(R.I64, R.I16, R.I32, R.I64, R.I64)); |
66 | 144 | }
|
67 | 145 | }
|
0 commit comments