Skip to content

Commit 67ff12c

Browse files
fix(spark): incorrect deriveRecordType() for Expand
In the Expand relation, the record type was being calculated incorrectly, leading to errors when round-tripping to protobuf and back. Signed-off-by: Andrew Coleman <[email protected]>
1 parent e3139c6 commit 67ff12c

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

core/src/main/java/io/substrait/relation/Expand.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ public abstract class Expand extends SingleInputRel {
1717
@Override
1818
public Type.Struct deriveRecordType() {
1919
Type.Struct initial = getInput().getRecordType();
20-
return TypeCreator.of(initial.nullable())
21-
.struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64)));
20+
var fields =
21+
getFields().isEmpty()
22+
? initial.fields().stream()
23+
: Stream.concat(initial.fields().stream(), getFields().get(0).getTypes());
24+
return TypeCreator.of(initial.nullable()).struct(fields);
2225
}
2326

2427
@Override
@@ -31,15 +34,15 @@ public static ImmutableExpand.Builder builder() {
3134
}
3235

3336
public interface ExpandField {
34-
Type getType();
37+
Stream<Type> getTypes();
3538
}
3639

3740
@Value.Immutable
3841
public abstract static class ConsistentField implements ExpandField {
3942
public abstract Expression getExpression();
4043

41-
public Type getType() {
42-
return getExpression().getType();
44+
public Stream<Type> getTypes() {
45+
return Stream.of(getExpression().getType());
4346
}
4447

4548
public static ImmutableExpand.ConsistentField.Builder builder() {
@@ -51,8 +54,8 @@ public static ImmutableExpand.ConsistentField.Builder builder() {
5154
public abstract static class SwitchingField implements ExpandField {
5255
public abstract List<Expression> getDuplicates();
5356

54-
public Type getType() {
55-
return getDuplicates().get(0).getType();
57+
public Stream<Type> getTypes() {
58+
return getDuplicates().stream().map(Expression::getType);
5659
}
5760

5861
public static ImmutableExpand.SwitchingField.Builder builder() {

spark/src/main/scala/io/substrait/spark/SparkExtension.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ object SparkExtension {
3434
private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection =
3535
SimpleExtension.loadDefaults()
3636

37+
val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls)
38+
3739
lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = {
3840
val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]()
3941
ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala)

spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import io.substrait.debug.TreePrinter
2626
import io.substrait.extension.ExtensionCollector
2727
import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter}
2828
import io.substrait.proto
29-
import io.substrait.relation.RelProtoConverter
29+
import io.substrait.relation.{ProtoRelConverter, RelProtoConverter}
3030
import org.scalactic.Equality
3131
import org.scalactic.source.Position
3232
import org.scalatest.Succeeded
@@ -93,6 +93,10 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
9393
require(logicalPlan2.resolved);
9494
val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2)
9595

96+
val extensionCollector = new ExtensionCollector;
97+
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
98+
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)
99+
96100
pojoRel2.shouldEqualPlainly(pojoRel)
97101
logicalPlan2
98102
}

0 commit comments

Comments
 (0)