From 00854d552a0d168b330d851e20a2e6a2117391fa Mon Sep 17 00:00:00 2001 From: Bruno Volpato Date: Tue, 27 Feb 2024 13:20:40 -0500 Subject: [PATCH] fix(isthmus): fix rel converter for sort when slot is wrapped --- .../isthmus/SubstraitRelNodeConverter.java | 13 +++- .../isthmus/calcite/RexVisitorFinder.java | 67 +++++++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 isthmus/src/main/java/io/substrait/isthmus/calcite/RexVisitorFinder.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 8ff98e6f6..034e3417a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -4,6 +4,7 @@ import io.substrait.expression.Expression; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.RexVisitorFinder; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; @@ -358,8 +359,16 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField) { var expression = sortField.expr(); var rex = expression.accept(expressionRexConverter); var sortDirection = sortField.direction(); - RexSlot rexSlot = (RexSlot) rex; - int fieldIndex = rexSlot.getIndex(); + + RexSlot slot = + new RexVisitorFinder<>(RexSlot.class) + .findUnique(rex) + .orElseThrow( + () -> + new RuntimeException( + String.format( + "No slot found in sort field, expression type: %s", rex.getKind()))); + int fieldIndex = slot.getIndex(); var fieldDirection = RelFieldCollation.Direction.ASCENDING; var nullDirection = RelFieldCollation.NullDirection.UNSPECIFIED; switch (sortDirection) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/RexVisitorFinder.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/RexVisitorFinder.java new file mode 100644 index 000000000..21d893ec5 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/RexVisitorFinder.java @@ -0,0 +1,67 @@ +package io.substrait.isthmus.calcite; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexVisitorImpl; + +/** + * Visitor that finds all instances of a given class in a RexNode tree. + * + * @param Class type to find instances. + */ +public class RexVisitorFinder extends RexVisitorImpl { + final List found; + final Class findClass; + + public RexVisitorFinder(Class findClass) { + super(true); + this.found = new ArrayList<>(); + this.findClass = findClass; + } + + @Override + public void visitEach(Iterable expressions) { + for (RexNode expr : expressions) { + if (findClass.isInstance(expr)) { + found.add(findClass.cast(expr)); + } + } + super.visitEach(expressions); + } + + /** + * Find all instances of the class in the given call. + * + * @param call The call to search + * @return List of instances of the class + */ + public List find(RexNode call) { + if (findClass.isInstance(call)) { + found.add(findClass.cast(call)); + } + call.accept(this); + return found; + } + + /** + * Find a unique instance of the class in the given call. + * + *

Throws an exception if more than one instance is found. + * + * @param call The call to search + * @return Optional of the instance of the class + */ + public Optional findUnique(RexNode call) { + this.find(call); + + if (this.found.isEmpty()) { + return Optional.empty(); + } + if (this.found.size() > 1) { + throw new IllegalStateException("Found more than one instance of " + findClass); + } + return Optional.of(this.found.get(0)); + } +}