From 237179f9f170c30f34ed09a1810db19e3725a5bc Mon Sep 17 00:00:00 2001 From: Dane Pitkin <48041712+danepitkin@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:57:48 -0500 Subject: [PATCH] feat: add MergeJoinRel (#201) --- .../io/substrait/dsl/SubstraitBuilder.java | 67 ++++++++++----- .../relation/AbstractRelVisitor.java | 16 ++-- .../substrait/relation/ProtoRelConverter.java | 43 +++++++++- .../relation/RelCopyOnWriteVisitor.java | 59 +++++++++---- .../substrait/relation/RelProtoConverter.java | 55 +++++++++--- .../io/substrait/relation/RelVisitor.java | 7 +- .../relation/physical/MergeJoin.java | 85 +++++++++++++++++++ .../type/proto/ExtensionRoundtripTest.java | 21 +++++ .../type/proto/JoinRoundtripTest.java | 12 +++ 9 files changed, 304 insertions(+), 61 deletions(-) create mode 100644 core/src/main/java/io/substrait/relation/physical/MergeJoin.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 4ebf178b7..35e41787a 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -28,6 +28,7 @@ import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.ImmutableType; import io.substrait.type.NamedStruct; @@ -201,27 +202,32 @@ public HashJoin hashJoin( .build(); } - public NamedScan namedScan( - Iterable tableName, Iterable columnNames, Iterable types) { - return namedScan(tableName, columnNames, types, Optional.empty()); - } - - public NamedScan namedScan( - Iterable tableName, - Iterable columnNames, - Iterable types, - Rel.Remap remap) { - return namedScan(tableName, columnNames, types, Optional.of(remap)); + public MergeJoin mergeJoin( + List leftKeys, + List rightKeys, + MergeJoin.JoinType joinType, + Rel left, + Rel right) { + return mergeJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right); } - private NamedScan namedScan( - Iterable tableName, - Iterable columnNames, - Iterable types, - Optional remap) { - var struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); - var namedStruct = NamedStruct.of(columnNames, struct); - return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); + public MergeJoin mergeJoin( + List leftKeys, + List rightKeys, + MergeJoin.JoinType joinType, + Optional remap, + Rel left, + Rel right) { + return MergeJoin.builder() + .left(left) + .right(right) + .leftKeys( + this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray())) + .rightKeys( + this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray())) + .joinType(joinType) + .remap(remap) + .build(); } public NestedLoopJoin nestedLoopJoin( @@ -248,6 +254,29 @@ private NestedLoopJoin nestedLoopJoin( .build(); } + public NamedScan namedScan( + Iterable tableName, Iterable columnNames, Iterable types) { + return namedScan(tableName, columnNames, types, Optional.empty()); + } + + public NamedScan namedScan( + Iterable tableName, + Iterable columnNames, + Iterable types, + Rel.Remap remap) { + return namedScan(tableName, columnNames, types, Optional.of(remap)); + } + + private NamedScan namedScan( + Iterable tableName, + Iterable columnNames, + Iterable types, + Optional remap) { + var struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); + var namedStruct = NamedStruct.of(columnNames, struct); + return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); + } + public Project project(Function> expressionsFn, Rel input) { return project(expressionsFn, Optional.empty(), input); } diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 52a70bf33..b04b90348 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; public abstract class AbstractRelVisitor @@ -32,11 +33,6 @@ public OUTPUT visit(Join join) throws EXCEPTION { return visitFallback(join); } - @Override - public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { - return visitFallback(nestedLoopJoin); - } - @Override public OUTPUT visit(Set set) throws EXCEPTION { return visitFallback(set); @@ -96,4 +92,14 @@ public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION { public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION { return visitFallback(hashJoin); } + + @Override + public OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION { + return visitFallback(mergeJoin); + } + + @Override + public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { + return visitFallback(nestedLoopJoin); + } } diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 05c595419..c72d47c45 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -17,6 +17,7 @@ import io.substrait.proto.FilterRel; import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.MergeJoinRel; import io.substrait.proto.NestedLoopJoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; @@ -28,6 +29,7 @@ import io.substrait.relation.files.ImmutableFileFormat; import io.substrait.relation.files.ImmutableFileOrFiles; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.NamedStruct; @@ -79,9 +81,6 @@ public Rel from(io.substrait.proto.Rel rel) { case JOIN -> { return newJoin(rel.getJoin()); } - case NESTED_LOOP_JOIN -> { - return newNestedLoopJoin(rel.getNestedLoopJoin()); - } case SET -> { return newSet(rel.getSet()); } @@ -103,6 +102,12 @@ public Rel from(io.substrait.proto.Rel rel) { case HASH_JOIN -> { return newHashJoin(rel.getHashJoin()); } + case MERGE_JOIN -> { + return newMergeJoin(rel.getMergeJoin()); + } + case NESTED_LOOP_JOIN -> { + return newNestedLoopJoin(rel.getNestedLoopJoin()); + } default -> { throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType); } @@ -537,6 +542,38 @@ private Rel newHashJoin(HashJoinRel rel) { return builder.build(); } + private Rel newMergeJoin(MergeJoinRel rel) { + Rel left = from(rel.getLeft()); + Rel right = from(rel.getRight()); + var leftKeys = rel.getLeftKeysList(); + var rightKeys = rel.getRightKeysList(); + + Type.Struct leftStruct = left.getRecordType(); + Type.Struct rightStruct = right.getRecordType(); + Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); + var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); + var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + var builder = + MergeJoin.builder() + .left(left) + .right(right) + .leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList())) + .rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList())) + .joinType(MergeJoin.JoinType.fromProto(rel.getType())) + .postJoinFilter( + Optional.ofNullable( + rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null)); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(advancedExtension(rel.getAdvancedExtension())); + } + return builder.build(); + } + private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index defcfe820..f0a7e19a9 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -9,6 +9,7 @@ import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; import java.util.List; import java.util.Optional; @@ -156,24 +157,6 @@ public Optional visit(Join join) throws EXCEPTION { .build()); } - @Override - public Optional visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { - var left = nestedLoopJoin.getLeft().accept(this); - var right = nestedLoopJoin.getRight().accept(this); - var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor()); - - if (allEmpty(left, right, condition)) { - return Optional.empty(); - } - return Optional.of( - NestedLoopJoin.builder() - .from(nestedLoopJoin) - .left(left.orElse(nestedLoopJoin.getLeft())) - .right(right.orElse(nestedLoopJoin.getRight())) - .condition(condition.orElse(nestedLoopJoin.getCondition())) - .build()); - } - @Override public Optional visit(Set set) throws EXCEPTION { return transformList(set.getInputs(), t -> t.accept(this)) @@ -319,6 +302,46 @@ public Optional visit(HashJoin hashJoin) throws EXCEPTION { .build()); } + @Override + public Optional visit(MergeJoin mergeJoin) throws EXCEPTION { + var left = mergeJoin.getLeft().accept(this); + var right = mergeJoin.getRight().accept(this); + var leftKeys = transformList(mergeJoin.getLeftKeys(), this::visitFieldReference); + var rightKeys = transformList(mergeJoin.getRightKeys(), this::visitFieldReference); + var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter()); + + if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { + return Optional.empty(); + } + return Optional.of( + MergeJoin.builder() + .from(mergeJoin) + .left(left.orElse(mergeJoin.getLeft())) + .right(right.orElse(mergeJoin.getRight())) + .leftKeys(leftKeys.orElse(mergeJoin.getLeftKeys())) + .rightKeys(rightKeys.orElse(mergeJoin.getRightKeys())) + .postJoinFilter(or(postFilter, mergeJoin::getPostJoinFilter)) + .build()); + } + + @Override + public Optional visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { + var left = nestedLoopJoin.getLeft().accept(this); + var right = nestedLoopJoin.getRight().accept(this); + var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor()); + + if (allEmpty(left, right, condition)) { + return Optional.empty(); + } + return Optional.of( + NestedLoopJoin.builder() + .from(nestedLoopJoin) + .left(left.orElse(nestedLoopJoin.getLeft())) + .right(right.orElse(nestedLoopJoin.getRight())) + .condition(condition.orElse(nestedLoopJoin.getCondition())) + .build()); + } + // utilities protected Optional> visitExprList(List exprs) throws EXCEPTION { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 2ab4c0527..9ac73eb3f 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -15,6 +15,7 @@ import io.substrait.proto.FilterRel; import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.MergeJoinRel; import io.substrait.proto.NestedLoopJoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; @@ -25,6 +26,7 @@ import io.substrait.proto.SortRel; import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; import java.util.Collection; @@ -181,20 +183,6 @@ public Rel visit(Join join) throws RuntimeException { return Rel.newBuilder().setJoin(builder).build(); } - @Override - public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { - var builder = - NestedLoopJoinRel.newBuilder() - .setCommon(common(nestedLoopJoin)) - .setLeft(toProto(nestedLoopJoin.getLeft())) - .setRight(toProto(nestedLoopJoin.getRight())) - .setExpression(toProto(nestedLoopJoin.getCondition())) - .setType(nestedLoopJoin.getJoinType().toProto()); - - nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); - return Rel.newBuilder().setNestedLoopJoin(builder).build(); - } - @Override public Rel visit(Set set) throws RuntimeException { var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); @@ -280,6 +268,45 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException { return Rel.newBuilder().setHashJoin(builder).build(); } + @Override + public Rel visit(MergeJoin mergeJoin) throws RuntimeException { + var builder = + MergeJoinRel.newBuilder() + .setCommon(common(mergeJoin)) + .setLeft(toProto(mergeJoin.getLeft())) + .setRight(toProto(mergeJoin.getRight())) + .setType(mergeJoin.getJoinType().toProto()); + + List leftKeys = mergeJoin.getLeftKeys(); + List rightKeys = mergeJoin.getRightKeys(); + + if (leftKeys.size() != rightKeys.size()) { + throw new RuntimeException("Number of left and right keys must be equal."); + } + + builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList())); + builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList())); + + mergeJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); + + mergeJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); + return Rel.newBuilder().setMergeJoin(builder).build(); + } + + @Override + public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { + var builder = + NestedLoopJoinRel.newBuilder() + .setCommon(common(nestedLoopJoin)) + .setLeft(toProto(nestedLoopJoin.getLeft())) + .setRight(toProto(nestedLoopJoin.getRight())) + .setExpression(toProto(nestedLoopJoin.getCondition())) + .setType(nestedLoopJoin.getJoinType().toProto()); + + nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); + return Rel.newBuilder().setNestedLoopJoin(builder).build(); + } + @Override public Rel visit(Project project) throws RuntimeException { var builder = diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index 38b70816c..e901986fd 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; public interface RelVisitor { @@ -14,8 +15,6 @@ public interface RelVisitor { OUTPUT visit(Join join) throws EXCEPTION; - OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION; - OUTPUT visit(Set set) throws EXCEPTION; OUTPUT visit(NamedScan namedScan) throws EXCEPTION; @@ -39,4 +38,8 @@ public interface RelVisitor { OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION; OUTPUT visit(HashJoin hashJoin) throws EXCEPTION; + + OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION; + + OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION; } diff --git a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java new file mode 100644 index 000000000..5435a4c2a --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java @@ -0,0 +1,85 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.proto.MergeJoinRel; +import io.substrait.relation.BiRel; +import io.substrait.relation.HasExtension; +import io.substrait.relation.RelVisitor; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class MergeJoin extends BiRel implements HasExtension { + + public abstract List getLeftKeys(); + + public abstract List getRightKeys(); + + public abstract JoinType getJoinType(); + + public abstract Optional getPostJoinFilter(); + + public static enum JoinType { + UNKNOWN(MergeJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), + INNER(MergeJoinRel.JoinType.JOIN_TYPE_INNER), + OUTER(MergeJoinRel.JoinType.JOIN_TYPE_OUTER), + LEFT(MergeJoinRel.JoinType.JOIN_TYPE_LEFT), + RIGHT(MergeJoinRel.JoinType.JOIN_TYPE_RIGHT), + LEFT_SEMI(MergeJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI), + RIGHT_SEMI(MergeJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI), + LEFT_ANTI(MergeJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI), + RIGHT_ANTI(MergeJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI); + + private MergeJoinRel.JoinType proto; + + JoinType(MergeJoinRel.JoinType proto) { + this.proto = proto; + } + + public static JoinType fromProto(MergeJoinRel.JoinType proto) { + for (var v : values()) { + if (v.proto == proto) { + return v; + } + } + throw new IllegalArgumentException("Unknown type: " + proto); + } + + public MergeJoinRel.JoinType toProto() { + return proto; + } + } + + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + } + + @Override + public O accept(RelVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableMergeJoin.Builder builder() { + return ImmutableMergeJoin.builder(); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index b076f03c1..33081c0df 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -23,6 +23,7 @@ import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.relation.utils.StringHolder; import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; @@ -187,6 +188,26 @@ void hashJoin() { verifyRoundTrip(relWithoutKeys); } + @Test + void mergeJoin() { + // with empty keys + List leftEmptyKeys = Collections.emptyList(); + List rightEmptyKeys = Collections.emptyList(); + Rel relWithoutKeys = + MergeJoin.builder() + .from( + b.mergeJoin( + leftEmptyKeys, + rightEmptyKeys, + MergeJoin.JoinType.INNER, + commonTable, + commonTable)) + .commonExtension(commonExtension) + .extension(relExtension) + .build(); + verifyRoundTrip(relWithoutKeys); + } + @Test void nestedLoopJoin() { Rel rel = diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 8ae8a7da5..d9f116cb3 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -3,6 +3,7 @@ import io.substrait.TestBase; import io.substrait.relation.Rel; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; import java.util.Arrays; import java.util.List; @@ -33,6 +34,17 @@ void hashJoin() { verifyRoundTrip(relWithoutKeys); } + @Test + void mergeJoin() { + List leftKeys = Arrays.asList(0, 1); + List rightKeys = Arrays.asList(2, 0); + Rel relWithoutKeys = + MergeJoin.builder() + .from(b.mergeJoin(leftKeys, rightKeys, MergeJoin.JoinType.INNER, leftTable, rightTable)) + .build(); + verifyRoundTrip(relWithoutKeys); + } + @Test void nestedLoopJoin() { List inputRels = Arrays.asList(leftTable, rightTable);