diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 35e41787a..7b7c70824 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -375,6 +375,13 @@ public Expression singleOrList(Expression condition, Expression... options) { return SingleOrList.builder().condition(condition).addOptions(options).build(); } + public Expression.InPredicate inPredicate(Rel haystack, Expression... needles) { + return Expression.InPredicate.builder() + .addAllNeedles(Arrays.asList(needles)) + .haystack(haystack) + .build(); + } + public List sortFields(Rel input, int... indexes) { return Arrays.stream(indexes) .mapToObj( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 036e30d84..fb31f8d22 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -110,6 +110,7 @@ public SubstraitRelNodeConverter( this.scalarFunctionConverter = scalarFunctionConverter; this.aggregateFunctionConverter = aggregateFunctionConverter; this.expressionRexConverter = expressionRexConverter; + this.expressionRexConverter.setRelNodeConverter(this); } public static RelNode convert( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index b8ef20c29..2a121abaf 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -10,6 +10,7 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.WindowBound; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.SubstraitRelNodeConverter; import io.substrait.isthmus.TypeConverter; import io.substrait.type.StringTypeVisitor; import io.substrait.type.Type; @@ -22,12 +23,14 @@ import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.calcite.avatica.util.ByteString; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexWindowBound; import org.apache.calcite.rex.RexWindowBounds; import org.apache.calcite.sql.SqlAggFunction; @@ -50,6 +53,7 @@ public class ExpressionRexConverter extends AbstractExpressionVisitor needles = + expr.needles().stream().map(e -> e.accept(this)).collect(Collectors.toList()); + RelNode rel = expr.haystack().accept(relNodeConverter); + return RexSubQuery.in(rel, ImmutableList.copyOf(needles)); + } + static class ToRexWindowBound implements WindowBound.WindowBoundVisitor { diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java index c2ee6e08c..58a39f613 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java @@ -57,6 +57,13 @@ public void mapLiteral() throws IOException, SqlParseException { assertFullRoundTrip("select MAP[1, 'hello'] from ORDERS"); } + @Test + public void inPredicate() throws IOException, SqlParseException { + assertFullRoundTrip( + "select L_PARTKEY from LINEITEM where L_PARTKEY in " + + "(SELECT L_SUPPKEY from LINEITEM where L_SUPPKEY < L_ORDERKEY)"); + } + @Test public void singleOrList() { Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10));