From 333989ffa70efc18c6707c3cf670bac826925fe1 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 24 Oct 2023 12:10:01 -0500 Subject: [PATCH 01/34] feat: convert sql expression into proto extended expressions --- .../extension/ExtensionCollector.java | 49 +++++ .../io/substrait/isthmus/SqlToSubstrait.java | 196 ++++++++++++++++++ .../isthmus/ExtendedExpressionTestBase.java | 51 +++++ .../SimpleExtendedExpressionsTest.java | 13 ++ 4 files changed, 309 insertions(+) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index bcdd969d4..714eaec93 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -98,6 +99,54 @@ public void addExtensionsToPlan(Plan.Builder builder) { builder.addAllExtensions(extensionList); } + public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { + var uriPos = new AtomicInteger(1); + var uris = new HashMap(); + + var extensionList = new ArrayList(); + for (var e : funcMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + for (var e : typeMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionType( + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + } + /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 7a850499a..e0f4c0d9b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,22 +1,44 @@ package io.substrait.isthmus; +import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.Expression; +import io.substrait.proto.Expression.ScalarFunction; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.proto.FunctionArgument; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; +import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import io.substrait.type.proto.TypeProtoConverter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.sql2rel.StandardConvertletTable; @@ -48,6 +70,12 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep return executeInner(sql, factory, pair.left, pair.right); } + public ExtendedExpression executeExpression(String expr, List tables) + throws SqlParseException { + var pair = registerCreateTables(tables); + return executeInnerExpression(expr, pair.left, pair.right); + } + // Package protected for testing List sqlToRelNode(String sql, List tables) throws SqlParseException { var pair = registerCreateTables(tables); @@ -91,6 +119,138 @@ private Plan executeInner( return plan.build(); } + private ExtendedExpression executeInnerExpression( + String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + throws SqlParseException { + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + sqlToRexNode(sql, validator, catalogReader) + .forEach( + rexNode -> { + // FIXME! Implement it dynamically for more expression types + ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); + + // FIXME! Get output type dynamically: + // final static Map getTypeCreator = new HashMap<>(){{put("BOOLEAN", + // TypeCreator.of(true).BOOLEAN);}}; + // getTypeCreator.get(rexNode.getType()).accept(...) + io.substrait.proto.Type output = + TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); + + // FIXME! setFunctionReference, addArguments(index: 0, 1) + Expression.Builder expressionBuilder = + Expression.newBuilder() + .setScalarFunction( + ScalarFunction.newBuilder() + .setFunctionReference(1) + .setOutputType(output) + .addArguments( + 0, + FunctionArgument.newBuilder().setValue(result.referenceBuilder())) + .addArguments( + 1, + FunctionArgument.newBuilder() + .setValue(result.expressionBuilderLiteral()))); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + // FIXME! Get schema dynamically + // (as the same for Plan with: + // TypeConverter.DEFAULT.toNamedStruct(rexNode.getType());) + List columnNames = + Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + extendedExpressionBuilder + .addReferredExpr(0, expressionReferenceBuilder) + .setBaseSchema(namedStruct.toProto(new TypeProtoConverter(functionCollector))); + + // Extensions URI FIXME! Populate/create this dynamically + HashMap extensionUris = new HashMap<>(); + extensionUris.put( + "key-001", + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("/functions_comparison.yaml") + .build()); + + // Extensions FIXME! Populate/create this dynamically, maybe use rexNode.getKind() + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("gt:any_any") + .setExtensionUriReference(1)) + .build(); + extensions.add(extensionFunctionLowerThan); + + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + }); + return extendedExpressionBuilder.build(); + } + + static class TraverseRexNode { + static RexInputRef ref = null; + static Expression.Builder referenceBuilder = null; + static Expression.Builder expressionBuilderLiteral = null; + + static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { + + switch (rexNode.getClass().getSimpleName().toUpperCase()) { + case "REXCALL": + for (RexNode rexInternal : ((RexCall) rexNode).operands) { + getRowExpression(rexInternal); + } + ; + break; + case "REXINPUTREF": + ref = (RexInputRef) rexNode; + referenceBuilder = + Expression.newBuilder() + .setSelection( + Expression.FieldReference.newBuilder() + .setDirectReference( + Expression.ReferenceSegment.newBuilder() + .setStructField( + Expression.ReferenceSegment.StructField.newBuilder() + .setField(ref.getIndex())))); + break; + case "REXLITERAL": + RexLiteral literal = (RexLiteral) rexNode; + expressionBuilderLiteral = + Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setI32(literal.getValueAs(Integer.class))); + break; + default: + throw new AssertionError( + "Unsupported type for: " + rexNode.getClass().getSimpleName().toUpperCase()); + } + ResulTraverseRowExpression result = + new ResulTraverseRowExpression(ref, referenceBuilder, expressionBuilderLiteral); + return result; + } + } + + @Desugar + private record ResulTraverseRowExpression( + RexInputRef ref, + Expression.Builder referenceBuilder, + Expression.Builder expressionBuilderLiteral) {} + private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) throws SqlParseException { @@ -107,6 +267,42 @@ private List sqlToRelNode( return roots; } + private List sqlToRexNode( + String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + throws SqlParseException { + SqlParser parser = SqlParser.create(sql, parserConfig); + SqlNode sqlNode = parser.parseExpression(); + Result result = getResult(validator); + SqlNode validSQLNode = + validator.validateParameterizedExpression( + sqlNode, + result.nameToTypeMap()); // FIXME! It may be optional to include this validation + SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); + RexNode rexNode = converter.convertExpression(validSQLNode, result.nameToNodeMap()); + + return Collections.singletonList(rexNode); + } + + private static Result getResult(SqlValidator validator) { + // FIXME! Needs to be created dinamycally, this is for PoC purpose + HashMap nameToNodeMap = new HashMap<>(); + nameToNodeMap.put( + "N_NATIONKEY", + new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + nameToNodeMap.put( + "N_REGIONKEY", + new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + final Map nameToTypeMap = new HashMap<>(); + for (Map.Entry entry : nameToNodeMap.entrySet()) { + nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); + } + Result result = new Result(nameToNodeMap, nameToTypeMap); + return result; + } + + private @Desugar record Result( + HashMap nameToNodeMap, Map nameToTypeMap) {} + @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java new file mode 100644 index 000000000..10f3f57e3 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -0,0 +1,51 @@ +package io.substrait.isthmus; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import io.substrait.proto.ExtendedExpression; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import org.apache.calcite.sql.parser.SqlParseException; + +public class ExtendedExpressionTestBase { + public static String asString(String resource) throws IOException { + return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); + } + + public static List tpchSchemaCreateStatements() throws IOException { + String[] values = asString("tpch/schema.sql").split(";"); + return Arrays.stream(values) + .filter(t -> !t.trim().isBlank()) + .collect(java.util.stream.Collectors.toList()); + } + + protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query) + throws IOException, SqlParseException { + return assertProtoExtendedExpressionRoundrip(query, new SqlToSubstrait()); + } + + protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, SqlToSubstrait s) + throws IOException, SqlParseException { + return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements()); + } + + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + String query, SqlToSubstrait s, List creates) throws SqlParseException { + io.substrait.proto.ExtendedExpression protoExtendedExpression = + s.executeExpression(query, creates); + + try { + String ee = JsonFormat.printer().print(protoExtendedExpression); + System.out.println("Proto Extended Expression: \n" + ee); + + // FIXME! Implement test validation as the same as proto Plan implementation + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + + return protoExtendedExpression; + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java new file mode 100644 index 000000000..bfcea38c9 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -0,0 +1,13 @@ +package io.substrait.isthmus; + +import java.io.IOException; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { + + @Test + public void filter() throws IOException, SqlParseException { + assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 18"); + } +} From f4b6581a6b177654eff632c9c3719fcc83c1b7b3 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 26 Oct 2023 11:58:18 -0500 Subject: [PATCH 02/34] fix: implement nameToNodeMap and nameToTypeMap dyamically instead of hard coded --- build.gradle.kts | 9 ++- isthmus/build.gradle.kts | 2 + .../substrait/isthmus/SqlConverterBase.java | 38 +++++++++- .../io/substrait/isthmus/SqlToSubstrait.java | 46 ++++-------- .../io/substrait/isthmus/SubstraitToSql.java | 4 +- .../ExtendedExpressionIntegrationTest.java | 68 ++++++++++++++++++ .../test/resources/tpch/data/nation.parquet | Bin 0 -> 2319 bytes 7 files changed, 129 insertions(+), 38 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java create mode 100644 isthmus/src/test/resources/tpch/data/nation.parquet diff --git a/build.gradle.kts b/build.gradle.kts index 47a9da29f..3d711fbd5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -33,8 +33,13 @@ val submodulesUpdate by } allprojects { - repositories { mavenCentral() } - + repositories { + mavenCentral() + maven { + name = "github" + url = uri("https://nightlies.apache.org/arrow/java") + } + } tasks.configureEach { val javaToolchains = project.extensions.getByType() useJUnitPlatform() diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index 9941f51de..a3437d076 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -94,6 +94,8 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") + testImplementation("org.apache.arrow:arrow-dataset:14.0.0-SNAPSHOT") + testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0-SNAPSHOT") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index ec83bbc82..40a853539 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -1,11 +1,15 @@ package io.substrait.isthmus; +import com.github.bsideup.jabel.Desugar; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.type.NamedStruct; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Function; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; @@ -22,8 +26,12 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.Schemas; import org.apache.calcite.schema.Table; import org.apache.calcite.schema.impl.AbstractTable; import org.apache.calcite.sql.SqlNode; @@ -86,8 +94,24 @@ protected SqlConverterBase(FeatureBoard features) { EXTENSION_COLLECTION = defaults; } - Pair registerCreateTables(List tables) + /* + HashMap nameToNodeMap = new HashMap<>(); + nameToNodeMap.put( + "N_NATIONKEY", + new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + nameToNodeMap.put( + "N_REGIONKEY", + new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); + final Map nameToTypeMap = new HashMap<>(); + for (Map.Entry entry : nameToNodeMap.entrySet()) { + nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); + } + */ + + Result registerCreateTables(List tables) throws SqlParseException { + Map nameToTypeMap = new HashMap<>(); + Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); CalciteCatalogReader catalogReader = new CalciteCatalogReader(rootSchema, List.of(), factory, config); @@ -97,10 +121,20 @@ Pair registerCreateTables(List table List tList = parseCreateTable(factory, validator, tableDef); for (DefinedTable t : tList) { rootSchema.add(t.getName(), t); + for (RelDataTypeField field : t.type.getFieldList()) { + nameToTypeMap.put(field.getName(), field.getType()); + nameToNodeMap.put(field.getName(), new RexInputRef(field.getIndex(), field.getType())); + } } } } - return Pair.of(validator, catalogReader); + return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); + } + + @Desugar + public record Result(SqlValidator validator, CalciteCatalogReader catalogReader, + Map nameToTypeMap, Map nameToNodeMap) { + } Pair registerCreateTables( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index e0f4c0d9b..c275dcdaa 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -61,8 +61,8 @@ public Plan execute(String sql, Function, NamedStruct> tableLookup) } public Plan execute(String sql, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return executeInner(sql, factory, pair.left, pair.right); + var result = registerCreateTables(tables); + return executeInner(sql, factory, result.validator(), result.catalogReader()); } public Plan execute(String sql, String name, Schema schema) throws SqlParseException { @@ -72,14 +72,15 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep public ExtendedExpression executeExpression(String expr, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return executeInnerExpression(expr, pair.left, pair.right); + var result = registerCreateTables(tables); + return executeInnerExpression(expr, result.validator(), result.catalogReader(), + result.nameToTypeMap(), result.nameToNodeMap()); } // Package protected for testing List sqlToRelNode(String sql, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return sqlToRelNode(sql, pair.left, pair.right); + var result = registerCreateTables(tables); + return sqlToRelNode(sql, result.validator(), result.catalogReader()); } // Package protected for testing @@ -120,11 +121,12 @@ private Plan executeInner( } private ExtendedExpression executeInnerExpression( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + String sql, SqlValidator validator, CalciteCatalogReader catalogReader, + Map nameToTypeMap, Map nameToNodeMap) throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - sqlToRexNode(sql, validator, catalogReader) + sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) .forEach( rexNode -> { // FIXME! Implement it dynamically for more expression types @@ -268,41 +270,21 @@ private List sqlToRelNode( } private List sqlToRexNode( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader) + String sql, SqlValidator validator, CalciteCatalogReader catalogReader, + Map nameToTypeMap, Map nameToNodeMap) throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); - Result result = getResult(validator); SqlNode validSQLNode = validator.validateParameterizedExpression( sqlNode, - result.nameToTypeMap()); // FIXME! It may be optional to include this validation + nameToTypeMap); // FIXME! It may be optional to include this validation SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - RexNode rexNode = converter.convertExpression(validSQLNode, result.nameToNodeMap()); + RexNode rexNode = converter.convertExpression(validSQLNode, nameToNodeMap); return Collections.singletonList(rexNode); } - private static Result getResult(SqlValidator validator) { - // FIXME! Needs to be created dinamycally, this is for PoC purpose - HashMap nameToNodeMap = new HashMap<>(); - nameToNodeMap.put( - "N_NATIONKEY", - new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - nameToNodeMap.put( - "N_REGIONKEY", - new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - final Map nameToTypeMap = new HashMap<>(); - for (Map.Entry entry : nameToNodeMap.entrySet()) { - nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); - } - Result result = new Result(nameToNodeMap, nameToTypeMap); - return result; - } - - private @Desugar record Result( - HashMap nameToNodeMap, Map nameToTypeMap) {} - @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index f402aacd3..5a18cd27e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -22,8 +22,8 @@ public SubstraitToSql() { public RelNode substraitRelToCalciteRel(Rel relRoot, List tables) throws SqlParseException { - var pair = registerCreateTables(tables); - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, pair.right, parserConfig); + var result = registerCreateTables(tables); + return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, result.catalogReader(), parserConfig); } public RelNode substraitRelToCalciteRel( diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java new file mode 100644 index 000000000..4e0b29ec6 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -0,0 +1,68 @@ +package io.substrait.isthmus.integration; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.ibm.icu.impl.ClassLoaderUtil; +import io.substrait.isthmus.ExtendedExpressionTestBase; +import io.substrait.isthmus.SqlToSubstrait; +import io.substrait.proto.ExtendedExpression; +import java.io.IOException; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.Optional; +import org.apache.arrow.dataset.file.FileFormat; +import org.apache.arrow.dataset.file.FileSystemDatasetFactory; +import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.source.Dataset; +import org.apache.arrow.dataset.source.DatasetFactory; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionIntegrationTest { + + @Test + public void projectAndFilterDataset() throws SqlParseException, IOException, URISyntaxException { + URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(getSubstraitExpressionFilter()) + .build(); + try (BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), FileFormat.PARQUET, resource.toURI().toString()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + int count = 0; + while (reader.loadNextBatch()) { + count += reader.getVectorSchemaRoot().getRowCount(); + System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); + } + assertEquals(4, count); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static ByteBuffer getSubstraitExpressionFilter() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + new SqlToSubstrait() + .executeExpression( + "N_NATIONKEY > 20", ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + byte[] extendedExpressions = + Base64.getDecoder() + .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); + ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect(extendedExpressions.length); + substraitExpressionFilter.put(extendedExpressions); + return substraitExpressionFilter; + } +} diff --git a/isthmus/src/test/resources/tpch/data/nation.parquet b/isthmus/src/test/resources/tpch/data/nation.parquet new file mode 100644 index 0000000000000000000000000000000000000000..0189118ce7344297b1ad6f1095a11162ddd960ea GIT binary patch literal 2319 zcmdT`U5p#m6~5OQZ#?W6cEDrC$}Zfk*6!}sJH*az2w_7yyLOWy>!0=7K~71-_4wMJ z-SLcfev)+p51>k*Nc(^qP^Cy!C{zmSLsb>=5Gqv(EfJNdih}yI;-L>nZFxu~S^-Kc z&g@Tq#9QAg&1aoC_xztT=en#`G7^$L!SJM|ERX}z07A=SA%svC!wB5^Wmr;cIgO|l zbxsTKk&kQnYEdm<{kckQ3ETz+sui_rK1Yse#Ur^=A)0tFwp3NC`K7IHoV}|a_clZ1 z5QoLU$F3-cQ<%dt=0rik2+smJEsmvwUIhr+2j>3ri1d5$E_=Ux;O4tq1>Oe&q(j?n zlPN@}4r}?Q*(WW-q9$pwp6wc*3xvmXkGN(Z&S;w&3!nx9Eqjf*r03XO!(}_ix^6p7 z!)g=HCSHdThqZl^I)uW3Z+Wgl8n)R4_NvcZFiuU|S^;}t8K|~vG#Rm5o2or#ZCct1 z1VWVF?6^Iq8{oZ1^%dHN03d8a8@BHe{Nw?{u`NS~&>T2Y2>+fqon~O3Q5O=m+R@43 zG;s}+S)NoDr;!u=UABX!TfSr1K4eJj)65C43ZCJ0@bYlW^jj^%?1w)a(t{xJ>)dG; zPCn6)eam-!YT|gzZrY~dg=dkGJTv6CArSuK*s5>nE9U5N!LeNzD#^LuF7O>%Ax&( z;`=Baeh!&ae_c>P&UFb}(|s3eq0?pCtnZQjkSQW%W@9gO$FMvo5!BvMU!VBR0k+n2 z7-&$p-|cyJ^1ITPh!rkrSNW`RJ%kJ`9P8Nn3hUz=xju7T=qqDA(ts2@%pBP(WWi$> zMW9>z&xiQ+sY`?fjbfH=jJ_kx!Ku0|t&cTq+Xcs--An!|bB&+P_~@`C^B)OgAdz20 z3O|9;(pN&;Xn8Y7W%@X<4L#9S|0VuW?G*l7@9?LF9O&}N(Hg!Ifr*;AGQY8b<=j4*&n&s zy~g>j@@Xi1sP-~FTQY6<}b`dhn!_FBpn`kEwf9xj`?$|R}igVMvOg^ z_)7UvXgy;*Ec_5el0QAU=ZWJ`!4HNl+p`_Za0@@yj%VKFHR=d3%rrWN6nr^myLOOa z(~l%?A^O@squsrv>zkg@GYu%SBCTaVM-hv*@qzW9AcKt|`o_QcLGcx#2zq!Oxo^vh zR|@-}t=Gkk)f>@`L;1`FL4iZ+juV(?Lx&=ned(#he@0y{O^nA-Nqt8;d-W``)Mc*2 z>WSogoW=t5XsC4Efv8`WMj`paFbFt6dl-4Tn*^L-wD+oU7C}hIe`*g%Z!sAUWFf^_6Dh zl&g1Gmr4;Ng_0q8&^l#N&w$UK^^@O0k^sDcfY)V7sQI=C?2UUxc z;lx4;4TC2MCGN2)h2q)x*BG8uWLD;aOGXa%JX|~DcfA24&=)rKxBYLbZgBCRbE%g? zv-0lky)EzFzJa76?B#7ZoAa~yG3FJpk)008`~HV} g0m}3L5GE8n4il7$pz(nmOlWQn{TlAWGW-?#1~YkXEC2ui literal 0 HcmV?d00001 From a79f57d742277d44b5952188d8446b3c1fd09ff6 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 26 Oct 2023 17:29:38 -0500 Subject: [PATCH 03/34] fix: cover support also for project extended expression --- .../substrait/isthmus/SqlConverterBase.java | 14 ++-- .../io/substrait/isthmus/SqlToSubstrait.java | 43 +++++++++--- .../io/substrait/isthmus/SubstraitToSql.java | 3 +- .../ExtendedExpressionIntegrationTest.java | 65 +++++++++++++++++-- 4 files changed, 101 insertions(+), 24 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 40a853539..9448c34ce 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -6,7 +6,6 @@ import io.substrait.type.NamedStruct; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -31,7 +30,6 @@ import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; -import org.apache.calcite.schema.Schemas; import org.apache.calcite.schema.Table; import org.apache.calcite.schema.impl.AbstractTable; import org.apache.calcite.sql.SqlNode; @@ -108,8 +106,7 @@ protected SqlConverterBase(FeatureBoard features) { } */ - Result registerCreateTables(List tables) - throws SqlParseException { + Result registerCreateTables(List tables) throws SqlParseException { Map nameToTypeMap = new HashMap<>(); Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); @@ -132,10 +129,11 @@ Result registerCreateTables(List tables) } @Desugar - public record Result(SqlValidator validator, CalciteCatalogReader catalogReader, - Map nameToTypeMap, Map nameToNodeMap) { - - } + public record Result( + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) {} Pair registerCreateTables( Function, NamedStruct> tableLookup) throws SqlParseException { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index c275dcdaa..baf67514c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -3,6 +3,8 @@ import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; import io.substrait.extension.ExtensionCollector; +import io.substrait.isthmus.expression.RexExpressionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.proto.Expression; import io.substrait.proto.Expression.ScalarFunction; import io.substrait.proto.ExpressionReference; @@ -38,7 +40,6 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.sql2rel.StandardConvertletTable; @@ -46,6 +47,12 @@ /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { + private final ScalarFunctionConverter functionConverter = + new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), factory); + + private final RexExpressionConverter rexExpressionConverter = + new RexExpressionConverter(functionConverter); + public SqlToSubstrait() { this(null); } @@ -73,8 +80,12 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep public ExtendedExpression executeExpression(String expr, List tables) throws SqlParseException { var result = registerCreateTables(tables); - return executeInnerExpression(expr, result.validator(), result.catalogReader(), - result.nameToTypeMap(), result.nameToNodeMap()); + return executeInnerExpression( + expr, + result.validator(), + result.catalogReader(), + result.nameToTypeMap(), + result.nameToNodeMap()); } // Package protected for testing @@ -121,11 +132,15 @@ private Plan executeInner( } private ExtendedExpression executeInnerExpression( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader, - Map nameToTypeMap, Map nameToNodeMap) + String sql, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); + RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) .forEach( rexNode -> { @@ -186,6 +201,11 @@ private ExtendedExpression executeInnerExpression( .setUri("/functions_comparison.yaml") .build()); + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + String declaration = func.declaration().key(); + // Extensions FIXME! Populate/create this dynamically, maybe use rexNode.getKind() ArrayList extensions = new ArrayList<>(); SimpleExtensionDeclaration extensionFunctionLowerThan = @@ -193,7 +213,7 @@ private ExtendedExpression executeInnerExpression( .setExtensionFunction( SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) - .setName("gt:any_any") + .setName(declaration) .setExtensionUriReference(1)) .build(); extensions.add(extensionFunctionLowerThan); @@ -229,6 +249,7 @@ static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { .setStructField( Expression.ReferenceSegment.StructField.newBuilder() .setField(ref.getIndex())))); + break; case "REXLITERAL": RexLiteral literal = (RexLiteral) rexNode; @@ -270,15 +291,17 @@ private List sqlToRelNode( } private List sqlToRexNode( - String sql, SqlValidator validator, CalciteCatalogReader catalogReader, - Map nameToTypeMap, Map nameToNodeMap) + String sql, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); SqlNode validSQLNode = validator.validateParameterizedExpression( - sqlNode, - nameToTypeMap); // FIXME! It may be optional to include this validation + sqlNode, nameToTypeMap); // FIXME! It may be optional to include this validation SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); RexNode rexNode = converter.convertExpression(validSQLNode, nameToNodeMap); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index 5a18cd27e..d43fda1c1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -23,7 +23,8 @@ public SubstraitToSql() { public RelNode substraitRelToCalciteRel(Rel relRoot, List tables) throws SqlParseException { var result = registerCreateTables(tables); - return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, result.catalogReader(), parserConfig); + return SubstraitRelNodeConverter.convert( + relRoot, relOptCluster, result.catalogReader(), parserConfig); } public RelNode substraitRelToCalciteRel( diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 4e0b29ec6..b5cae82eb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -21,6 +21,7 @@ import org.apache.arrow.dataset.source.DatasetFactory; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; @@ -28,17 +29,21 @@ public class ExtendedExpressionIntegrationTest { @Test - public void projectAndFilterDataset() throws SqlParseException, IOException, URISyntaxException { + public void filterDataset() throws SqlParseException, IOException, URISyntaxException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .substraitFilter(getSubstraitExpressionFilter()) + .substraitFilter(getFilterExtendedExpression(sqlExpression)) .build(); try (BufferAllocator allocator = new RootAllocator(); DatasetFactory datasetFactory = new FileSystemDatasetFactory( - allocator, NativeMemoryPool.getDefault(), FileFormat.PARQUET, resource.toURI().toString()); + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + resource.toURI().toString()); Dataset dataset = datasetFactory.finish(); Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { @@ -53,11 +58,47 @@ public void projectAndFilterDataset() throws SqlParseException, IOException, URI } } - private static ByteBuffer getSubstraitExpressionFilter() throws IOException, SqlParseException { + @Test + public void projectDataset() throws SqlParseException, IOException, URISyntaxException { + URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + String sqlExpression = "N_NATIONKEY + 20"; + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(getProjectExtendedExpression(sqlExpression)) + .build(); + try (BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + resource.toURI().toString()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + int count = 0; + int sum = 0; + while (reader.loadNextBatch()) { + count += reader.getVectorSchemaRoot().getRowCount(); + IntVector intVector = (IntVector) reader.getVectorSchemaRoot().getVector(0); + for (int i = 0; i < intVector.getValueCount(); i++) { + sum += intVector.get(i); + } + } + assertEquals(25, count); + assertEquals(24 * 25 / 2 + 20 * count, sum); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static ByteBuffer getFilterExtendedExpression(String sqlExpression) + throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() .executeExpression( - "N_NATIONKEY > 20", ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); @@ -65,4 +106,18 @@ private static ByteBuffer getSubstraitExpressionFilter() throws IOException, Sql substraitExpressionFilter.put(extendedExpressions); return substraitExpressionFilter; } + + private static ByteBuffer getProjectExtendedExpression(String sqlExpression) + throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + new SqlToSubstrait() + .executeExpression( + sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + byte[] extendedExpressions = + Base64.getDecoder() + .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); + ByteBuffer substraitExpressionProject = ByteBuffer.allocateDirect(extendedExpressions.length); + substraitExpressionProject.put(extendedExpressions); + return substraitExpressionProject; + } } From a37be9224dc11ba9b08b1a428ffeccc9143b3d39 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 26 Oct 2023 23:15:54 -0500 Subject: [PATCH 04/34] fix: cover support also for project extended expression --- .editorconfig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.editorconfig b/.editorconfig index 3d674d593..984db0c67 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,7 +10,7 @@ trim_trailing_whitespace = true [*.{yaml,yml}] indent_size = 2 -[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat}] +[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat,**/*.parquet}] charset = unset end_of_line = unset insert_final_newline = unset From 9f6aaf3eb5dce49239c9b15d3d71b5ade7d7f413 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 15 Nov 2023 10:16:52 -0500 Subject: [PATCH 05/34] fix: create schema dynamically --- build.gradle.kts | 8 +----- isthmus/build.gradle.kts | 4 +-- .../substrait/isthmus/SqlConverterBase.java | 7 ++---- .../io/substrait/isthmus/SqlToSubstrait.java | 25 ++++--------------- .../io/substrait/isthmus/TypeConverter.java | 13 ++++++++++ .../ExtendedExpressionIntegrationTest.java | 8 ++++++ 6 files changed, 31 insertions(+), 34 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 3d711fbd5..293163045 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -33,13 +33,7 @@ val submodulesUpdate by } allprojects { - repositories { - mavenCentral() - maven { - name = "github" - url = uri("https://nightlies.apache.org/arrow/java") - } - } + repositories { mavenCentral() } tasks.configureEach { val javaToolchains = project.extensions.getByType() useJUnitPlatform() diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index a3437d076..a5ae3abd2 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -94,8 +94,8 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") - testImplementation("org.apache.arrow:arrow-dataset:14.0.0-SNAPSHOT") - testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0-SNAPSHOT") + testImplementation("org.apache.arrow:arrow-dataset:14.0.0") + testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 9448c34ce..3466cf826 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -5,10 +5,7 @@ import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.type.NamedStruct; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.function.Function; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; @@ -107,7 +104,7 @@ protected SqlConverterBase(FeatureBoard features) { */ Result registerCreateTables(List tables) throws SqlParseException { - Map nameToTypeMap = new HashMap<>(); + Map nameToTypeMap = new LinkedHashMap<>(); Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); CalciteCatalogReader catalogReader = diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index baf67514c..70ff4c087 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -16,11 +16,9 @@ import io.substrait.proto.SimpleExtensionURI; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; -import io.substrait.type.Type; import io.substrait.type.TypeCreator; import io.substrait.type.proto.TypeProtoConverter; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -173,24 +171,7 @@ private ExtendedExpression executeInnerExpression( .setExpression(expressionBuilder) .addOutputNames(result.ref().getName()); - // FIXME! Get schema dynamically - // (as the same for Plan with: - // TypeConverter.DEFAULT.toNamedStruct(rexNode.getType());) - List columnNames = - Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - - extendedExpressionBuilder - .addReferredExpr(0, expressionReferenceBuilder) - .setBaseSchema(namedStruct.toProto(new TypeProtoConverter(functionCollector))); + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); // Extensions URI FIXME! Populate/create this dynamically HashMap extensionUris = new HashMap<>(); @@ -221,6 +202,10 @@ private ExtendedExpression executeInnerExpression( extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); extendedExpressionBuilder.addAllExtensions(extensions); }); + NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); + extendedExpressionBuilder.setBaseSchema( + namedStruct.toProto(new TypeProtoConverter(functionCollector))); + return extendedExpressionBuilder.build(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index ba68d5cfc..73c846d45 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -3,6 +3,7 @@ import static io.substrait.isthmus.SubstraitTypeSystem.DAY_SECOND_INTERVAL; import static io.substrait.isthmus.SubstraitTypeSystem.YEAR_MONTH_INTERVAL; +import com.google.common.collect.Lists; import io.substrait.function.NullableType; import io.substrait.function.TypeExpression; import io.substrait.type.NamedStruct; @@ -11,6 +12,7 @@ import io.substrait.type.TypeVisitor; import java.util.ArrayList; import java.util.List; +import java.util.Map; import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -56,6 +58,17 @@ public NamedStruct toNamedStruct(RelDataType type) { return NamedStruct.of(names, struct); } + public NamedStruct toNamedStruct(Map nameToTypeMap) { + var names = Lists.newArrayList(); + var types = Lists.newArrayList(); + nameToTypeMap.forEach( + (k, v) -> { + names.add(k); + types.add(toSubstrait(v, names)); + }); + return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build()); + } + private Type toSubstrait(RelDataType type, List names) { // Check for user mapped types first as they may re-use SqlTypeNames var userType = userTypeMapper.toSubstrait(type); diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index b5cae82eb..267b2525b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -2,6 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.protobuf.util.JsonFormat; import com.ibm.icu.impl.ClassLoaderUtil; import io.substrait.isthmus.ExtendedExpressionTestBase; import io.substrait.isthmus.SqlToSubstrait; @@ -85,6 +86,7 @@ public void projectDataset() throws SqlParseException, IOException, URISyntaxExc for (int i = 0; i < intVector.getValueCount(); i++) { sum += intVector.get(i); } + System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); } assertEquals(25, count); assertEquals(24 * 25 / 2 + 20 * count, sum); @@ -99,6 +101,9 @@ private static ByteBuffer getFilterExtendedExpression(String sqlExpression) new SqlToSubstrait() .executeExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); @@ -113,6 +118,9 @@ private static ByteBuffer getProjectExtendedExpression(String sqlExpression) new SqlToSubstrait() .executeExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + System.out.println( + "JsonFormat.printer().print(getProjectExtendedExpression): " + + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); From 52b41e341940e7ade458560c4befbd5408de054d Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 16 Nov 2023 16:37:27 -0500 Subject: [PATCH 06/34] fix: set function reference and extensions dinamically --- .pre-commit-config.yaml | 4 +- .../io/substrait/isthmus/SqlToSubstrait.java | 120 +++++++++++------- .../ExtendedExpressionIntegrationTest.java | 8 +- 3 files changed, 77 insertions(+), 55 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 102b99017..a2505b4ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/adrienverge/yamllint.git - rev: v1.26.0 + rev: v1.33.0 hooks: - id: yamllint args: [-c=.yamllint.yaml] - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook - rev: v8.0.0 + rev: v9.9.0 hooks: - id: commitlint stages: [commit-msg] diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 70ff4c087..37d088a3e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -3,6 +3,7 @@ import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.proto.Expression; @@ -18,11 +19,8 @@ import io.substrait.type.NamedStruct; import io.substrait.type.TypeCreator; import io.substrait.type.proto.TypeProtoConverter; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.IOException; +import java.util.*; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; @@ -138,34 +136,33 @@ private ExtendedExpression executeInnerExpression( throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) .forEach( rexNode -> { // FIXME! Implement it dynamically for more expression types - ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); + // ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); + ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - // FIXME! Get output type dynamically: - // final static Map getTypeCreator = new HashMap<>(){{put("BOOLEAN", - // TypeCreator.of(true).BOOLEAN);}}; - // getTypeCreator.get(rexNode.getType()).accept(...) io.substrait.proto.Type output = TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - // FIXME! setFunctionReference, addArguments(index: 0, 1) + List functionArgumentList = new ArrayList<>(); + result + .expressionBuilderMap() + .forEach( + (k, v) -> { + functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); + }); + + ScalarFunction.Builder scalarFunctionBuilder = + ScalarFunction.newBuilder() + .setFunctionReference(1) // rel_01 + .setOutputType(output) + .addAllArguments(functionArgumentList); + Expression.Builder expressionBuilder = - Expression.newBuilder() - .setScalarFunction( - ScalarFunction.newBuilder() - .setFunctionReference(1) - .setOutputType(output) - .addArguments( - 0, - FunctionArgument.newBuilder().setValue(result.referenceBuilder())) - .addArguments( - 1, - FunctionArgument.newBuilder() - .setValue(result.expressionBuilderLiteral()))); + Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); + ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionBuilder) @@ -173,32 +170,53 @@ private ExtendedExpression executeInnerExpression( extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - // Extensions URI FIXME! Populate/create this dynamically - HashMap extensionUris = new HashMap<>(); - extensionUris.put( - "key-001", - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) - .setUri("/functions_comparison.yaml") - .build()); - io.substrait.expression.Expression.ScalarFunctionInvocation func = (io.substrait.expression.Expression.ScalarFunctionInvocation) rexNode.accept(rexExpressionConverter); - String declaration = func.declaration().key(); + String declaration = + func.declaration().key(); // values example: gt:any_any, add:i64_i64 + + // this is not mandatory to be defined; it is working without this definition. It is + // only created here to create a proto message that has the correct semantics + HashMap extensionUris = new HashMap<>(); + SimpleExtensionURI simpleExtensionURI; + try { + simpleExtensionURI = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) // rel_02 + .setUri( + SimpleExtension.loadDefaults().scalarFunctions().stream() + .filter(s -> s.toString().equalsIgnoreCase(declaration)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format( + "Failed to get URI resource for %s.", declaration))) + .uri()) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + extensionUris.put("uri", simpleExtensionURI); - // Extensions FIXME! Populate/create this dynamically, maybe use rexNode.getKind() ArrayList extensions = new ArrayList<>(); SimpleExtensionDeclaration extensionFunctionLowerThan = SimpleExtensionDeclaration.newBuilder() .setExtensionFunction( SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(1) + .setFunctionAnchor( + scalarFunctionBuilder.getFunctionReference()) // rel_01 .setName(declaration) - .setExtensionUriReference(1)) + .setExtensionUriReference( + simpleExtensionURI.getExtensionUriAnchor())) // rel_02 .build(); extensions.add(extensionFunctionLowerThan); + System.out.println( + "extendedExpressionBuilder.getExtensionUrisList(): " + + extendedExpressionBuilder.getExtensionUrisList()); + // adding it for semantic purposes, it is not mandatory or needed extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); extendedExpressionBuilder.addAllExtensions(extensions); }); @@ -209,13 +227,14 @@ private ExtendedExpression executeInnerExpression( return extendedExpressionBuilder.build(); } - static class TraverseRexNode { - static RexInputRef ref = null; - static Expression.Builder referenceBuilder = null; - static Expression.Builder expressionBuilderLiteral = null; - - static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { + class TraverseRexNode { + RexInputRef ref = null; + int control = 0; + Expression.Builder referenceBuilder = null; + Expression.Builder literalBuilder = null; + Map expressionBuilderMap = new LinkedHashMap<>(); + ResulTraverseRowExpression getRowExpression(RexNode rexNode) { switch (rexNode.getClass().getSimpleName().toUpperCase()) { case "REXCALL": for (RexNode rexInternal : ((RexCall) rexNode).operands) { @@ -234,22 +253,24 @@ static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { .setStructField( Expression.ReferenceSegment.StructField.newBuilder() .setField(ref.getIndex())))); - + expressionBuilderMap.put(control, referenceBuilder); + control++; break; case "REXLITERAL": RexLiteral literal = (RexLiteral) rexNode; - expressionBuilderLiteral = + literalBuilder = Expression.newBuilder() .setLiteral( Expression.Literal.newBuilder().setI32(literal.getValueAs(Integer.class))); + expressionBuilderMap.put(control, literalBuilder); + control++; break; default: throw new AssertionError( "Unsupported type for: " + rexNode.getClass().getSimpleName().toUpperCase()); } - ResulTraverseRowExpression result = - new ResulTraverseRowExpression(ref, referenceBuilder, expressionBuilderLiteral); - return result; + return new ResulTraverseRowExpression( + ref, referenceBuilder, literalBuilder, expressionBuilderMap); } } @@ -257,7 +278,8 @@ static ResulTraverseRowExpression getRowExpression(RexNode rexNode) { private record ResulTraverseRowExpression( RexInputRef ref, Expression.Builder referenceBuilder, - Expression.Builder expressionBuilderLiteral) {} + Expression.Builder literalBuilder, + Map expressionBuilderMap) {} private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 267b2525b..77472d714 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -8,7 +8,6 @@ import io.substrait.isthmus.SqlToSubstrait; import io.substrait.proto.ExtendedExpression; import java.io.IOException; -import java.net.URISyntaxException; import java.net.URL; import java.nio.ByteBuffer; import java.util.Base64; @@ -30,7 +29,7 @@ public class ExtendedExpressionIntegrationTest { @Test - public void filterDataset() throws SqlParseException, IOException, URISyntaxException { + public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = @@ -55,14 +54,15 @@ public void filterDataset() throws SqlParseException, IOException, URISyntaxExce } assertEquals(4, count); } catch (Exception e) { + e.printStackTrace(); throw new RuntimeException(e); } } @Test - public void projectDataset() throws SqlParseException, IOException, URISyntaxException { + public void projectDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_NATIONKEY + 20"; + String sqlExpression = "20 + N_NATIONKEY"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) From 3d80d1f8087081fa6fc1f56856a335a0b984b553 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 16 Nov 2023 17:40:03 -0500 Subject: [PATCH 07/34] fix: clean code --- build.gradle.kts | 1 + .../extension/ExtensionCollector.java | 49 ------ isthmus/build.gradle.kts | 5 +- .../substrait/isthmus/SqlConverterBase.java | 22 +-- .../io/substrait/isthmus/SqlToSubstrait.java | 162 ++++++++---------- 5 files changed, 84 insertions(+), 155 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 293163045..47a9da29f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -34,6 +34,7 @@ val submodulesUpdate by allprojects { repositories { mavenCentral() } + tasks.configureEach { val javaToolchains = project.extensions.getByType() useJUnitPlatform() diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 714eaec93..bcdd969d4 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,6 +1,5 @@ package io.substrait.extension; -import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -99,54 +98,6 @@ public void addExtensionsToPlan(Plan.Builder builder) { builder.addAllExtensions(extensionList); } - public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { - var uriPos = new AtomicInteger(1); - var uris = new HashMap(); - - var extensionList = new ArrayList(); - for (var e : funcMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - for (var e : typeMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionType( - SimpleExtensionDeclaration.ExtensionType.newBuilder() - .setTypeAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - } - /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index a5ae3abd2..abf5e412c 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -72,6 +72,7 @@ java { } var CALCITE_VERSION = "1.34.0" +var ARROW_VERSION = "14.0.0" dependencies { implementation(project(":core")) @@ -94,8 +95,8 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") - testImplementation("org.apache.arrow:arrow-dataset:14.0.0") - testImplementation("org.apache.arrow:arrow-memory-netty:14.0.0") + testImplementation("org.apache.arrow:arrow-dataset:${ARROW_VERSION}") + testImplementation("org.apache.arrow:arrow-memory-netty:${ARROW_VERSION}") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 3466cf826..716fdb66e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -89,20 +89,6 @@ protected SqlConverterBase(FeatureBoard features) { EXTENSION_COLLECTION = defaults; } - /* - HashMap nameToNodeMap = new HashMap<>(); - nameToNodeMap.put( - "N_NATIONKEY", - new RexInputRef(0, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - nameToNodeMap.put( - "N_REGIONKEY", - new RexInputRef(1, validator.getTypeFactory().createSqlType(SqlTypeName.BIGINT))); - final Map nameToTypeMap = new HashMap<>(); - for (Map.Entry entry : nameToNodeMap.entrySet()) { - nameToTypeMap.put(entry.getKey(), entry.getValue().getType()); - } - */ - Result registerCreateTables(List tables) throws SqlParseException { Map nameToTypeMap = new LinkedHashMap<>(); Map nameToNodeMap = new HashMap<>(); @@ -116,8 +102,12 @@ Result registerCreateTables(List tables) throws SqlParseException { for (DefinedTable t : tList) { rootSchema.add(t.getName(), t); for (RelDataTypeField field : t.type.getFieldList()) { - nameToTypeMap.put(field.getName(), field.getType()); - nameToNodeMap.put(field.getName(), new RexInputRef(field.getIndex(), field.getType())); + nameToTypeMap.put( + field.getName(), field.getType()); // to validate the sql expression tree + nameToNodeMap.put( + field.getName(), + new RexInputRef( + field.getIndex(), field.getType())); // to convert sql expression into RexNode } } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 37d088a3e..50d18f361 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -136,90 +136,80 @@ private ExtendedExpression executeInnerExpression( throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap) + RexNode rexNode = sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap); + ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); + io.substrait.proto.Type output = + TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); + List functionArgumentList = new ArrayList<>(); + result + .expressionBuilderMap() .forEach( - rexNode -> { - // FIXME! Implement it dynamically for more expression types - // ResulTraverseRowExpression result = TraverseRexNode.getRowExpression(rexNode); - ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - - io.substrait.proto.Type output = - TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - - List functionArgumentList = new ArrayList<>(); - result - .expressionBuilderMap() - .forEach( - (k, v) -> { - functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); - }); - - ScalarFunction.Builder scalarFunctionBuilder = - ScalarFunction.newBuilder() - .setFunctionReference(1) // rel_01 - .setOutputType(output) - .addAllArguments(functionArgumentList); - - Expression.Builder expressionBuilder = - Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - - io.substrait.expression.Expression.ScalarFunctionInvocation func = - (io.substrait.expression.Expression.ScalarFunctionInvocation) - rexNode.accept(rexExpressionConverter); - String declaration = - func.declaration().key(); // values example: gt:any_any, add:i64_i64 - - // this is not mandatory to be defined; it is working without this definition. It is - // only created here to create a proto message that has the correct semantics - HashMap extensionUris = new HashMap<>(); - SimpleExtensionURI simpleExtensionURI; - try { - simpleExtensionURI = - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) // rel_02 - .setUri( - SimpleExtension.loadDefaults().scalarFunctions().stream() - .filter(s -> s.toString().equalsIgnoreCase(declaration)) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format( - "Failed to get URI resource for %s.", declaration))) - .uri()) - .build(); - } catch (IOException e) { - throw new RuntimeException(e); - } - extensionUris.put("uri", simpleExtensionURI); + (k, v) -> { + functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); + }); - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionLowerThan = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor( - scalarFunctionBuilder.getFunctionReference()) // rel_01 - .setName(declaration) - .setExtensionUriReference( - simpleExtensionURI.getExtensionUriAnchor())) // rel_02 - .build(); - extensions.add(extensionFunctionLowerThan); + ScalarFunction.Builder scalarFunctionBuilder = + ScalarFunction.newBuilder() + .setFunctionReference(1) // rel_01 + .setOutputType(output) + .addAllArguments(functionArgumentList); + + Expression.Builder expressionBuilder = + Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 + + // this is not mandatory to be defined; it is working without this definition. It is + // only created here to create a proto message that has the correct semantics + HashMap extensionUris = new HashMap<>(); + SimpleExtensionURI simpleExtensionURI; + try { + simpleExtensionURI = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) // rel_02 + .setUri( + SimpleExtension.loadDefaults().scalarFunctions().stream() + .filter(s -> s.toString().equalsIgnoreCase(declaration)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Failed to get URI resource for %s.", declaration))) + .uri()) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + extensionUris.put("uri", simpleExtensionURI); + + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 + .setName(declaration) + .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 + .build(); + extensions.add(extensionFunctionLowerThan); + + System.out.println( + "extendedExpressionBuilder.getExtensionUrisList(): " + + extendedExpressionBuilder.getExtensionUrisList()); + // adding it for semantic purposes, it is not mandatory or needed + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); - System.out.println( - "extendedExpressionBuilder.getExtensionUrisList(): " - + extendedExpressionBuilder.getExtensionUrisList()); - // adding it for semantic purposes, it is not mandatory or needed - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - }); NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); extendedExpressionBuilder.setBaseSchema( namedStruct.toProto(new TypeProtoConverter(functionCollector))); @@ -297,7 +287,7 @@ private List sqlToRelNode( return roots; } - private List sqlToRexNode( + private RexNode sqlToRexNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader, @@ -306,13 +296,9 @@ private List sqlToRexNode( throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); - SqlNode validSQLNode = - validator.validateParameterizedExpression( - sqlNode, nameToTypeMap); // FIXME! It may be optional to include this validation + SqlNode validSQLNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - RexNode rexNode = converter.convertExpression(validSQLNode, nameToNodeMap); - - return Collections.singletonList(rexNode); + return converter.convertExpression(validSQLNode, nameToNodeMap); } @VisibleForTesting From 5954a626af7aab6045158581d0baa6d17d6e645d Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 16 Nov 2023 18:24:23 -0500 Subject: [PATCH 08/34] fix: clean code --- .../integration/ExtendedExpressionIntegrationTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 77472d714..da34e62f1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -29,7 +29,7 @@ public class ExtendedExpressionIntegrationTest { @Test - public void filterDataset() throws SqlParseException, IOException { + public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = @@ -60,7 +60,7 @@ public void filterDataset() throws SqlParseException, IOException { } @Test - public void projectDataset() throws SqlParseException, IOException { + public void projectDatasetUsingExtendedExpression() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "20 + N_NATIONKEY"; ScanOptions options = From fc33a3233800f91300393e1f001dd7e91aa40f5b Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 17 Nov 2023 17:05:30 -0500 Subject: [PATCH 09/34] fix: rename variables to clean code --- .../io/substrait/isthmus/SqlToSubstrait.java | 22 ++++++++--- .../isthmus/ExtendedExpressionTestBase.java | 2 +- .../SimpleExtendedExpressionsTest.java | 22 ++++++++++- .../ExtendedExpressionIntegrationTest.java | 39 +++++++++++++++++-- 4 files changed, 73 insertions(+), 12 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 50d18f361..fa61c5e81 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -73,11 +73,20 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep return executeInner(sql, factory, pair.left, pair.right); } - public ExtendedExpression executeExpression(String expr, List tables) + /** + * Process to execute an SQL Expression to convert into an Extended expression protobuf message + * + * @param sqlExpression expression defined by the user + * @param tables of names of table needed to consider to load into memory for catalog, schema, + * validate and parse sql + * @return extended expression protobuf message + * @throws SqlParseException + */ + public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) throws SqlParseException { var result = registerCreateTables(tables); - return executeInnerExpression( - expr, + return executeInnerSQLExpression( + sqlExpression, result.validator(), result.catalogReader(), result.nameToTypeMap(), @@ -127,8 +136,8 @@ private Plan executeInner( return plan.build(); } - private ExtendedExpression executeInnerExpression( - String sql, + private ExtendedExpression executeInnerSQLExpression( + String sqlExpression, SqlValidator validator, CalciteCatalogReader catalogReader, Map nameToTypeMap, @@ -136,7 +145,8 @@ private ExtendedExpression executeInnerExpression( throws SqlParseException { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); - RexNode rexNode = sqlToRexNode(sql, validator, catalogReader, nameToTypeMap, nameToNodeMap); + RexNode rexNode = + sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); io.substrait.proto.Type output = TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 10f3f57e3..3bee0b61e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -35,7 +35,7 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String query, SqlToSubstrait s, List creates) throws SqlParseException { io.substrait.proto.ExtendedExpression protoExtendedExpression = - s.executeExpression(query, creates); + s.executeSQLExpression(query, creates); try { String ee = JsonFormat.printer().print(protoExtendedExpression); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index bfcea38c9..8b0248ca4 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import io.substrait.proto.ExtendedExpression; import java.io.IOException; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; @@ -8,6 +9,25 @@ public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { @Test public void filter() throws IOException, SqlParseException { - assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 18"); + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY > 10"); + } + + @Test + public void in() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY IN (10, 20)"); + } + + @Test + public void isNotNull() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY is not null"); + } + + @Test + public void isNull() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY is null"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index da34e62f1..03d5f63ad 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -29,7 +29,7 @@ public class ExtendedExpressionIntegrationTest { @Test - public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { + public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "N_NATIONKEY > 20"; ScanOptions options = @@ -60,7 +60,7 @@ public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOE } @Test - public void projectDatasetUsingExtendedExpression() throws SqlParseException, IOException { + public void projectDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); String sqlExpression = "20 + N_NATIONKEY"; ScanOptions options = @@ -95,11 +95,42 @@ public void projectDatasetUsingExtendedExpression() throws SqlParseException, IO } } + @Test + public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { + URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); + String sqlExpression = "N_NATIONKEY > 20"; + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(getFilterExtendedExpression(sqlExpression)) + .build(); + try (BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + resource.toURI().toString()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + int count = 0; + while (reader.loadNextBatch()) { + count += reader.getVectorSchemaRoot().getRowCount(); + System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); + } + assertEquals(4, count); + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + private static ByteBuffer getFilterExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() - .executeExpression( + .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); System.out.println( "JsonFormat.printer().print(getFilterExtendedExpression): " @@ -116,7 +147,7 @@ private static ByteBuffer getProjectExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() - .executeExpression( + .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); System.out.println( "JsonFormat.printer().print(getProjectExtendedExpression): " From 217f2a0a6160de5e229e7abe9d882786076dcb55 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 23 Nov 2023 16:46:51 -0500 Subject: [PATCH 10/34] fix: from/to pojo/protobuf --- core/build.gradle.kts | 5 + .../expression/ExpressionReference.java | 11 + .../expression/ExtendedExpression.java | 20 ++ .../ExtendedExpressionProtoConverter.java | 210 ++++++++++++++++++ .../extended/expression/MyCliente.java | 75 +++++++ .../ProtoExtendedExpressionConverter.java | 176 +++++++++++++++ .../extension/ExtensionCollector.java | 49 ++++ .../extension/ImmutableExtensionLookup.java | 44 ++++ .../io/substrait/plan/PlanProtoConverter.java | 4 + .../ExtendedExpressionProtoConverterTest.java | 76 +++++++ .../ProtoExtendedExpressionConverterTest.java | 94 ++++++++ .../io/substrait/isthmus/SqlToSubstrait.java | 105 +++++++++ .../isthmus/ExtendedExpressionTestBase.java | 11 +- .../SimpleExtendedExpressionsTest.java | 8 +- .../ExtendedExpressionIntegrationTest.java | 2 +- isthmus/src/test/resources/tpch/schema.sql | 71 ------ 16 files changed, 887 insertions(+), 74 deletions(-) create mode 100644 core/src/main/java/io/substrait/extended/expression/ExpressionReference.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java create mode 100644 core/src/main/java/io/substrait/extended/expression/MyCliente.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 6b8cfac66..c06a00e86 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -85,6 +85,11 @@ dependencies { compileOnly("org.immutables:value-annotations:2.8.8") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") + + implementation("com.google.protobuf:protobuf-java-util:3.17.3") { + exclude("com.google.guava", "guava") + .because("Brings in Guava for Android, which we don't want (and breaks multimaps).") + } } java { diff --git a/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java b/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java new file mode 100644 index 000000000..2214f0438 --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java @@ -0,0 +1,11 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class ExpressionReference { + public abstract Expression getExpression(); + + public abstract String getOutputNames(); +} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java new file mode 100644 index 000000000..e0a5c03f9 --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -0,0 +1,20 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.proto.AdvancedExtension; +import io.substrait.type.NamedStruct; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class ExtendedExpression { + public abstract Map getReferredExpr(); + + public abstract NamedStruct getBaseSchema(); + + public abstract List getExpectedTypeUrls(); + + public abstract Optional getAdvancedExtension(); +} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java new file mode 100644 index 000000000..e4015f3cf --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -0,0 +1,210 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.expression.*; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import io.substrait.type.proto.TypeProtoConverter; +import java.io.IOException; +import java.util.*; + +public class ExtendedExpressionProtoConverter { + public ExtendedExpression toProto( + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojo) { + + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + + final ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, null); + + // convert expression pojo into expression protobuf + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) extendedExpressionPojo.getReferredExpr().get(0)); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder().setExpression(expressionProto).addOutputNames("column-01"); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + extendedExpressionBuilder.setBaseSchema( + extendedExpressionPojo.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); + + + functionCollector.addExtensionsToPlan(extendedExpressionBuilder); + if (extendedExpressionPojo.getAdvancedExtension().isPresent()) { + extendedExpressionBuilder.setAdvancedExtensions( + extendedExpressionPojo.getAdvancedExtension().get()); + } + return extendedExpressionBuilder.build(); + } + + public static void main(String[] args) throws IOException { + SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); + System.out.println( + "defaultExtensionCollection.scalarFunctions(): " + + defaultExtensionCollection.scalarFunctions()); + System.out.println( + "defaultExtensionCollection.windowFunctions(): " + + defaultExtensionCollection.windowFunctions()); + System.out.println( + "defaultExtensionCollection.aggregateFunctions(): " + + defaultExtensionCollection.aggregateFunctions()); + + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183))); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); + + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer().print(proto)); + } + + public static ExtendedExpression createExtendedExpression( + io.substrait.expression.Expression.ScalarFunctionInvocation expr) { + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + + io.substrait.proto.Expression expression = new ExpressionProtoConverter(null, null).visit(expr); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expression.toBuilder()) + .addOutputNames("col-01"); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + + return extendedExpressionBuilder.build(); + } + + public static void createExtendedExpressionManually() { + + Map nameToExpressionMap = new HashMap<>(); + ImmutableExpression.I32Literal build = Expression.I32Literal.builder().value(10).build(); + nameToExpressionMap.put("out_01", build); + + List expressionList = new ArrayList<>(); + expressionList.add(0, null); + + // nation table + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + ExtensionCollector functionCollector = new ExtensionCollector(); + + new ExpressionProtoConverter(new ExtensionCollector(), null); + + FunctionArg functionArg = + new FunctionArg() { + @Override + public R accept( + SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) + throws E { + return null; + } + }; + + // var argVisitor = FunctionArg.toProto(new TypeProtoConverter(new ExtensionCollector()), this); + + /* + + // FIXME! setFunctionReference, addArguments(index: 0, 1) + io.substrait.proto.Expression.Builder expressionBuilder = + io.substrait.proto.Expression.newBuilder() + .setScalarFunction( + io.substrait.proto.Expression.ScalarFunction.newBuilder() + .setFunctionReference(1) + .setOutputType(output) + .addArguments( + 0, + FunctionArgument.newBuilder().setValue(result.referenceBuilder())) + .addArguments( + 1, + FunctionArgument.newBuilder() + .setValue(result.expressionBuilderLiteral()))); + io.substrait.proto.ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + */ + + /* + + io.substrait.extended.expression.ExtendedExpression extendedExpression = new io.substrait.extended.expression.ExtendedExpression() { + @Override + public List getReferredExpr() { + io.substrait.extended.expression.ExpressionReference + + @Override + public NamedStruct getBaseSchema() { + return null; + } + + @Override + public List getExpectedTypeUrls() { + return null; + } + + @Override + public Optional getAdvancedExtension() { + return Optional.empty(); + } + }; + + System.out.println("inicio"); + System.out.println(extendedExpression.getReferredExpr().get(0)); + System.out.println(extendedExpression.getReferredExpr().get(0).getType()); + System.out.println("fin"); + + ExpressionReferenceOrBuilder + + */ + + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/MyCliente.java b/core/src/main/java/io/substrait/extended/expression/MyCliente.java new file mode 100644 index 000000000..0ec8a90ed --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/MyCliente.java @@ -0,0 +1,75 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.extension.SimpleExtension; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.*; + +public class MyCliente { + public static void main(String[] args) throws IOException { + SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183)); + }); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + ImmutableNamedStruct.builder() + .addNames("id") + .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) + .build(); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer() + .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java new file mode 100644 index 000000000..ed1fc377b --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java @@ -0,0 +1,176 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.*; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.NamedStruct; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.proto.ProtoTypeConverter; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +public class ProtoExtendedExpressionConverter { + private ExtensionCollector lookup = new ExtensionCollector(); + private ProtoTypeConverter protoTypeConverter = + new ProtoTypeConverter( + lookup, ImmutableSimpleExtension.ExtensionCollection.builder().build()); + + private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup) { + return new ProtoExpressionConverter( + functionLookup, + this.extensionCollection, + null, + null); + } + + private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup, io.substrait.type.NamedStruct namedStruct) { + return new ProtoExpressionConverter( + functionLookup, + this.extensionCollection, + namedStruct.struct(), + null); + } + + protected final SimpleExtension.ExtensionCollection extensionCollection; + + public ProtoExtendedExpressionConverter() throws IOException { + this(SimpleExtension.loadDefaults()); + } + + public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { + this.extensionCollection = extensionCollection; + } + + protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) { + return new ProtoRelConverter(functionLookup, this.extensionCollection); + } + + public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpressionProto) { + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder().from(extendedExpressionProto).build(); + + + // para struct + NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); + + System.out.println("namedStruct"); + System.out.println(namedStruct); + + ProtoExpressionConverter protoExpressionConverter = + getPprotoExpressionConverter(functionLookup, namedStruct); + + Map indexToExpressionMap = new HashMap<>(); + for (ExpressionReference expressionReference : extendedExpressionProto.getReferredExprList()) { + System.out.println( + "expressionReference.getExpression(): " + expressionReference.getExpression()); + indexToExpressionMap.put( + 0, protoExpressionConverter.from(expressionReference.getExpression())); + } + + // para struct + /* + NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); + + System.out.println("namedStruct"); + System.out.println(namedStruct); + + */ + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .advancedExtension( + Optional.ofNullable( + extendedExpressionProto.hasAdvancedExtensions() + ? extendedExpressionProto.getAdvancedExtensions() + : null)) + .baseSchema(namedStruct); + /* + ProtocolStringList namesList = baseSchema.getNamesList(); + + Type.Struct struct = baseSchema.getStruct(); + Type types = struct.getTypes(0); + System.out.println("types.getDescriptorForType().getName(): " + types.getDescriptorForType().); + + + */ + + /* + System.out.println("namesList: " + namesList); + System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct()); + System.out.println("}}{{{{{{{{{{''------>"); + System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct().getTypes(0)); + + + */ + + /* + ImmutableNamedStruct.builder(). + + // para expression + + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183) + ); + } + ); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + */ + + return builder.build(); + } + + private io.substrait.type.NamedStruct newNamedStruct(NamedStruct namedStruct) { + var struct = namedStruct.getStruct(); + return ImmutableNamedStruct.builder() + .names(namedStruct.getNamesList()) + .struct( + Type.Struct.builder() + .fields( + struct.getTypesList().stream() + .map(protoTypeConverter::from) + .collect(java.util.stream.Collectors.toList())) + .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) + .build()) + .build(); + } +} diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index bcdd969d4..f2a4a6f18 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -98,6 +99,54 @@ public void addExtensionsToPlan(Plan.Builder builder) { builder.addAllExtensions(extensionList); } + public void addExtensionsToPlan(ExtendedExpression.Builder builder) { + var uriPos = new AtomicInteger(1); + var uris = new HashMap(); + + var extensionList = new ArrayList(); + for (var e : funcMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + for (var e : typeMap.forwardMap.entrySet()) { + SimpleExtensionURI uri = + uris.computeIfAbsent( + e.getValue().namespace(), + k -> + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(uriPos.getAndIncrement()) + .setUri(k) + .build()); + var decl = + SimpleExtensionDeclaration.newBuilder() + .setExtensionType( + SimpleExtensionDeclaration.ExtensionType.newBuilder() + .setTypeAnchor(e.getKey()) + .setName(e.getValue().key()) + .setExtensionUriReference(uri.getExtensionUriAnchor())) + .build(); + extensionList.add(decl); + } + + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + } + /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 2cab2cce8..3d600002b 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import java.util.Collections; @@ -73,6 +74,49 @@ public Builder from(Plan p) { return this; } + public Builder from(ExtendedExpression p) { + Map namespaceMap = new HashMap<>(); + for (var extension : p.getExtensionUrisList()) { + namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); + } + + // Add all functions used in plan to the functionMap + for (var extension : p.getExtensionsList()) { + if (!extension.hasExtensionFunction()) { + continue; + } + SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); + int reference = func.getFunctionAnchor(); + String namespace = namespaceMap.get(func.getExtensionUriReference()); + if (namespace == null) { + throw new IllegalStateException( + "Could not find extension URI of " + func.getExtensionUriReference()); + } + String name = func.getName(); + SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name); + functionMap.put(reference, anchor); + } + + // Add all types used in plan to the typeMap + for (var extension : p.getExtensionsList()) { + if (!extension.hasExtensionType()) { + continue; + } + SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); + int reference = type.getTypeAnchor(); + String namespace = namespaceMap.get(type.getExtensionUriReference()); + if (namespace == null) { + throw new IllegalStateException( + "Could not find extension URI of " + type.getExtensionUriReference()); + } + String name = type.getName(); + SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name); + typeMap.put(reference, anchor); + } + + return this; + } + public ImmutableExtensionLookup build() { return new ImmutableExtensionLookup( Collections.unmodifiableMap(functionMap), Collections.unmodifiableMap(typeMap)); diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index 0bdf7d68c..af0f6d69a 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -34,6 +34,10 @@ public Plan toProto(io.substrait.plan.Plan plan) { if (plan.getAdvancedExtension().isPresent()) { builder.setAdvancedExtensions(plan.getAdvancedExtension().get()); } + /* + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + */ return builder.build(); } } diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java new file mode 100644 index 000000000..4004e9de7 --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java @@ -0,0 +1,76 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.*; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionProtoConverterTest extends TestBase { + @Test + public void toProtoTest() throws IOException { + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183)); + }); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + + ImmutableNamedStruct.builder() + .addNames("id") + .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) + .build(); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(namedStruct); + + System.out.println( + "JsonFormat.printer().print(getFilterExtendedExpression): " + + JsonFormat.printer() + .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); + } +} diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java new file mode 100644 index 000000000..1c1fa40b5 --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java @@ -0,0 +1,94 @@ +package io.substrait.extended.expression; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.NamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.*; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ProtoExtendedExpressionConverterTest extends TestBase { + @Test + public void fromTest() throws IOException { + Optional equal = + defaultExtensionCollection.scalarFunctions().stream() + .filter( + s -> { + System.out.println(":>>>>"); + System.out.println(s); + System.out.println(s.uri()); + System.out.println(s.returnType()); + System.out.println(s.description()); + System.out.println("s.name(): " + s.name()); + System.out.println(s.key()); + return s.name().equalsIgnoreCase("add"); + }) + .findFirst() + .map( + declaration -> { + System.out.println("declaration: " + declaration); + System.out.println("declaration.name(): " + declaration.name()); + return ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.I32) + .build(), + ExpressionCreator.i32(false, 183)); + }); + + Map indexToExpressionMap = new HashMap<>(); + indexToExpressionMap.put(0, equal.get()); + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); + /* + List dataTypes = + Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING); + NamedStruct namedStruct = + NamedStruct.of( + columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + */ + ImmutableNamedStruct id = ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.NULLABLE.I32, + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.I32, + TypeCreator.REQUIRED.STRING).build()) + .build(); + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .putAllReferredExpr(indexToExpressionMap) + .baseSchema(id); + + ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); + + System.out.println("=======POJO 01======="); + System.out.println("xxxx: " + builder); + System.out.println("=======PROTO 02======="); + System.out.println("yyyy: " + JsonFormat.printer().print(proto)); + + System.out.println("=======POJO 03======="); + io.substrait.extended.expression.ExtendedExpression from = + new ProtoExtendedExpressionConverter().from(proto); + System.out.println("zzzz: " + from); + System.out.println("11111111"); + + + Assertions.assertEquals(from, builder.build()); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index fa61c5e81..2b35dffef 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -155,6 +155,8 @@ private ExtendedExpression executeInnerSQLExpression( .expressionBuilderMap() .forEach( (k, v) -> { + System.out.println("k->" + k); + System.out.println("v->" + v); functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); }); @@ -224,6 +226,109 @@ private ExtendedExpression executeInnerSQLExpression( extendedExpressionBuilder.setBaseSchema( namedStruct.toProto(new TypeProtoConverter(functionCollector))); + /* + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + */ + + return extendedExpressionBuilder.build(); + } + + private ExtendedExpression executeInnerSQLExpressionPojo( + String sqlExpression, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) + throws SqlParseException { + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + RexNode rexNode = + sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); + ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); + io.substrait.proto.Type output = + TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); + List functionArgumentList = new ArrayList<>(); + result + .expressionBuilderMap() + .forEach( + (k, v) -> { + System.out.println("k->" + k); + System.out.println("v->" + v); + functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); + }); + + ScalarFunction.Builder scalarFunctionBuilder = + ScalarFunction.newBuilder() + .setFunctionReference(1) // rel_01 + .setOutputType(output) + .addAllArguments(functionArgumentList); + + Expression.Builder expressionBuilder = + Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionBuilder) + .addOutputNames(result.ref().getName()); + + extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); + + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 + + // this is not mandatory to be defined; it is working without this definition. It is + // only created here to create a proto message that has the correct semantics + HashMap extensionUris = new HashMap<>(); + SimpleExtensionURI simpleExtensionURI; + try { + simpleExtensionURI = + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) // rel_02 + .setUri( + SimpleExtension.loadDefaults().scalarFunctions().stream() + .filter(s -> s.toString().equalsIgnoreCase(declaration)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Failed to get URI resource for %s.", declaration))) + .uri()) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + extensionUris.put("uri", simpleExtensionURI); + + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 + .setName(declaration) + .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 + .build(); + extensions.add(extensionFunctionLowerThan); + + System.out.println( + "extendedExpressionBuilder.getExtensionUrisList(): " + + extendedExpressionBuilder.getExtensionUrisList()); + // adding it for semantic purposes, it is not mandatory or needed + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + + NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); + extendedExpressionBuilder.setBaseSchema( + namedStruct.toProto(new TypeProtoConverter(functionCollector))); + + /* + builder.addAllExtensionUris(uris.values()); + builder.addAllExtensions(extensionList); + */ + return extendedExpressionBuilder.build(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 3bee0b61e..1c3742b52 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -4,11 +4,14 @@ import com.google.common.io.Resources; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import io.substrait.extended.expression.ExtendedExpressionProtoConverter; +import io.substrait.extended.expression.ProtoExtendedExpressionConverter; import io.substrait.proto.ExtendedExpression; import java.io.IOException; import java.util.Arrays; import java.util.List; import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Assertions; public class ExtendedExpressionTestBase { public static String asString(String resource) throws IOException { @@ -33,7 +36,7 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, } protected ExtendedExpression assertProtoExtendedExpressionRoundrip( - String query, SqlToSubstrait s, List creates) throws SqlParseException { + String query, SqlToSubstrait s, List creates) throws SqlParseException, IOException { io.substrait.proto.ExtendedExpression protoExtendedExpression = s.executeSQLExpression(query, creates); @@ -41,6 +44,12 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String ee = JsonFormat.printer().print(protoExtendedExpression); System.out.println("Proto Extended Expression: \n" + ee); + io.substrait.extended.expression.ExtendedExpression from = + new ProtoExtendedExpressionConverter().from(protoExtendedExpression); + + ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(from); + + Assertions.assertEquals(proto, protoExtendedExpression); // FIXME! Implement test validation as the same as proto Plan implementation } catch (InvalidProtocolBufferException e) { throw new RuntimeException(e); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 8b0248ca4..5a407a936 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -10,7 +10,13 @@ public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { @Test public void filter() throws IOException, SqlParseException { ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("L_ORDERKEY > 10"); + assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 10"); + } + + @Test + public void projection() throws IOException, SqlParseException { + ExtendedExpression extendedExpression = + assertProtoExtendedExpressionRoundrip("L_ORDERKEY + 10"); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 03d5f63ad..0121fae9c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -31,7 +31,7 @@ public class ExtendedExpressionIntegrationTest { @Test public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_NATIONKEY > 20"; + String sqlExpression = "N_REGIONKEY > 20"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) diff --git a/isthmus/src/test/resources/tpch/schema.sql b/isthmus/src/test/resources/tpch/schema.sql index 81f6f927b..b8fb4cfd0 100644 --- a/isthmus/src/test/resources/tpch/schema.sql +++ b/isthmus/src/test/resources/tpch/schema.sql @@ -1,77 +1,6 @@ -CREATE TABLE PART ( - P_PARTKEY BIGINT NOT NULL, - P_NAME VARCHAR(55), - P_MFGR CHAR(25), - P_BRAND CHAR(10), - P_TYPE VARCHAR(25), - P_SIZE INTEGER, - P_CONTAINER CHAR(10), - P_RETAILPRICE DECIMAL, - P_COMMENT VARCHAR(23) -); -CREATE TABLE SUPPLIER ( - S_SUPPKEY BIGINT NOT NULL, - S_NAME CHAR(25), - S_ADDRESS VARCHAR(40), - S_NATIONKEY BIGINT NOT NULL, - S_PHONE CHAR(15), - S_ACCTBAL DECIMAL, - S_COMMENT VARCHAR(101) -); -CREATE TABLE PARTSUPP ( - PS_PARTKEY BIGINT NOT NULL, - PS_SUPPKEY BIGINT NOT NULL, - PS_AVAILQTY INTEGER, - PS_SUPPLYCOST DECIMAL, - PS_COMMENT VARCHAR(199) -); -CREATE TABLE CUSTOMER ( - C_CUSTKEY BIGINT NOT NULL, - C_NAME VARCHAR(25), - C_ADDRESS VARCHAR(40), - C_NATIONKEY BIGINT NOT NULL, - C_PHONE CHAR(15), - C_ACCTBAL DECIMAL, - C_MKTSEGMENT CHAR(10), - C_COMMENT VARCHAR(117) -); -CREATE TABLE ORDERS ( - O_ORDERKEY BIGINT NOT NULL, - O_CUSTKEY BIGINT NOT NULL, - O_ORDERSTATUS CHAR(1), - O_TOTALPRICE DECIMAL, - O_ORDERDATE DATE, - O_ORDERPRIORITY CHAR(15), - O_CLERK CHAR(15), - O_SHIPPRIORITY INTEGER, - O_COMMENT VARCHAR(79) -); -CREATE TABLE LINEITEM ( - L_ORDERKEY BIGINT NOT NULL, - L_PARTKEY BIGINT NOT NULL, - L_SUPPKEY BIGINT NOT NULL, - L_LINENUMBER INTEGER, - L_QUANTITY DECIMAL, - L_EXTENDEDPRICE DECIMAL, - L_DISCOUNT DECIMAL, - L_TAX DECIMAL, - L_RETURNFLAG CHAR(1), - L_LINESTATUS CHAR(1), - L_SHIPDATE DATE, - L_COMMITDATE DATE, - L_RECEIPTDATE DATE, - L_SHIPINSTRUCT CHAR(25), - L_SHIPMODE CHAR(10), - L_COMMENT VARCHAR(44) -); CREATE TABLE NATION ( N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152) ); -CREATE TABLE REGION ( - R_REGIONKEY BIGINT NOT NULL, - R_NAME CHAR(25), - R_COMMENT VARCHAR(152) -); From 75e4f48621a5ed589c484531d5f8d4409aa3b943 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 24 Nov 2023 16:56:10 -0500 Subject: [PATCH 11/34] feat: enable support from/to pojo/protobuf for extended expressions --- .../expression/ExtendedExpression.java | 28 +++++++ .../ExtendedExpressionProtoConverter.java | 50 +++++++++++ .../ProtoExtendedExpressionConverter.java | 84 +++++++++++++++++++ .../extension/ExtensionCollector.java | 25 +++++- .../extension/ImmutableExtensionLookup.java | 13 +-- .../io/substrait/plan/ProtoPlanConverter.java | 5 +- .../ExtendedExpressionProtoConverterTest.java | 73 ++++++++++++++++ .../ProtoExtendedExpressionConverterTest.java | 80 ++++++++++++++++++ 8 files changed, 349 insertions(+), 9 deletions(-) create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java create mode 100644 core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java create mode 100644 core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java new file mode 100644 index 000000000..4f705e82c --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -0,0 +1,28 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.proto.AdvancedExtension; +import io.substrait.type.NamedStruct; +import java.util.List; +import java.util.Optional; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class ExtendedExpression { + public abstract List getReferredExpr(); + + public abstract NamedStruct getBaseSchema(); + + public abstract List getExpectedTypeUrls(); + + // creating simple extensions, such as extensionURIs and extensions, is performed on the fly + + public abstract Optional getAdvancedExtension(); + + @Value.Immutable + public abstract static class ExpressionReference { + public abstract Expression getReferredExpr(); + + public abstract List getOutputNames(); + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java new file mode 100644 index 000000000..6ace03df4 --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -0,0 +1,50 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.proto.TypeProtoConverter; + +/** + * Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link + * ExtendedExpression} + */ +public class ExtendedExpressionProtoConverter { + public ExtendedExpression toProto( + io.substrait.extended.expression.ExtendedExpression extendedExpression) { + + ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtensionCollector functionCollector = new ExtensionCollector(); + + final ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, null); + + for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference + expressionReference : extendedExpression.getReferredExpr()) { + + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr()); + + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionProto) + .addAllOutputNames(expressionReference.getOutputNames()); + + extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder); + } + extendedExpressionBuilder.setBaseSchema( + extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); + + // the process of adding simple extensions, such as extensionURIs and extensions, is handled on + // the fly + functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder); + if (extendedExpression.getAdvancedExtension().isPresent()) { + extendedExpressionBuilder.setAdvancedExtensions( + extendedExpression.getAdvancedExtension().get()); + } + return extendedExpressionBuilder.build(); + } +} diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java new file mode 100644 index 000000000..14c82b5ac --- /dev/null +++ b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java @@ -0,0 +1,84 @@ +package io.substrait.extended.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.*; +import io.substrait.proto.ExpressionReference; +import io.substrait.proto.NamedStruct; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.proto.ProtoTypeConverter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** Converts from {@link io.substrait.proto.ExtendedExpression} to {@link ExtendedExpression} */ +public class ProtoExtendedExpressionConverter { + private final SimpleExtension.ExtensionCollection extensionCollection; + + public ProtoExtendedExpressionConverter() throws IOException { + this(SimpleExtension.loadDefaults()); + } + + public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { + this.extensionCollection = extensionCollection; + } + + private final ProtoTypeConverter protoTypeConverter = + new ProtoTypeConverter( + new ExtensionCollector(), ImmutableSimpleExtension.ExtensionCollection.builder().build()); + + public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) { + // fill in simple extension information through a discovery in the current proto-extended + // expression + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder() + .from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()) + .build(); + + NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto); + + ProtoExpressionConverter protoExpressionConverter = + new ProtoExpressionConverter( + functionLookup, this.extensionCollection, namedStruct.struct(), null); + + List expressionReferences = new ArrayList<>(); + for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { + Expression expressionPojo = + protoExpressionConverter.from(expressionReference.getExpression()); + expressionReferences.add( + ImmutableExpressionReference.builder() + .referredExpr(expressionPojo) + .addAllOutputNames(expressionReference.getOutputNamesList()) + .build()); + } + + ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .advancedExtension( + Optional.ofNullable( + extendedExpression.hasAdvancedExtensions() + ? extendedExpression.getAdvancedExtensions() + : null)) + .baseSchema(namedStruct); + return builder.build(); + } + + private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) { + var struct = namedStruct.getStruct(); + return ImmutableNamedStruct.builder() + .names(namedStruct.getNamesList()) + .struct( + Type.Struct.builder() + .fields( + struct.getTypesList().stream() + .map(protoTypeConverter::from) + .collect(java.util.stream.Collectors.toList())) + .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) + .build()) + .build(); + } +} diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index bcdd969d4..402a8c94c 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,7 @@ package io.substrait.extension; +import com.github.bsideup.jabel.Desugar; +import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; @@ -51,6 +53,20 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { } public void addExtensionsToPlan(Plan.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + private SimpleExtensions getExtensions() { var uriPos = new AtomicInteger(1); var uris = new HashMap(); @@ -93,11 +109,14 @@ public void addExtensionsToPlan(Plan.Builder builder) { .build(); extensionList.add(decl); } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); + return new SimpleExtensions(uris, extensionList); } + @Desugar + private record SimpleExtensions( + HashMap uris, + ArrayList extensionList) {} + /** We don't depend on guava... */ private static class BidiMap { private final Map forwardMap; diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 2cab2cce8..c88bafc1c 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,9 +1,10 @@ package io.substrait.extension; -import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -30,14 +31,16 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from(Plan p) { + public Builder from( + List simpleExtensionURIs, + List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); - for (var extension : p.getExtensionUrisList()) { + for (var extension : simpleExtensionURIs) { namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { continue; } @@ -54,7 +57,7 @@ public Builder from(Plan p) { } // Add all types used in plan to the typeMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionType()) { continue; } diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index be4f4ad9f..7222eb7ed 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -32,7 +32,10 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder() + .from(plan.getExtensionUrisList(), plan.getExtensionsList()) + .build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java new file mode 100644 index 000000000..20079e24f --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java @@ -0,0 +1,73 @@ +package io.substrait.extended.expression; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionProtoConverterTest extends TestBase { + @Test + public void toProtoTest() { + // create predefined POJO extended expression + Optional scalarFunctionExpression = + defaultExtensionCollection.scalarFunctions().stream() + .filter(s -> s.name().equalsIgnoreCase("add")) + .findFirst() + .map( + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); + + List expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.NULLABLE.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + ImmutableExtendedExpression.Builder extendedExpression = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct); + + // convert POJO extended expression into PROTOBUF extended expression + io.substrait.proto.ExtendedExpression proto = + new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); + + assertEquals( + "/functions_arithmetic_decimal.yaml", proto.getExtensionUrisList().get(0).getUri()); + assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); + } +} diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java new file mode 100644 index 000000000..9ab84f274 --- /dev/null +++ b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java @@ -0,0 +1,80 @@ +package io.substrait.extended.expression; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ProtoExtendedExpressionConverterTest extends TestBase { + @Test + public void fromTest() throws IOException { + // create predefined POJO extended expression + Optional scalarFunctionExpression = + defaultExtensionCollection.scalarFunctions().stream() + .filter(s -> s.name().equalsIgnoreCase("add")) + .findFirst() + .map( + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial extended expression + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto extended expression + ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final extended expression + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate extended expression pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } +} From 5adc79fe735d2cf0107a04bdacda32a1c4404279 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 24 Nov 2023 18:53:24 -0500 Subject: [PATCH 12/34] fix: consume core module for proto/pojo conversions --- core/build.gradle.kts | 5 - .../expression/ExpressionReference.java | 11 - .../expression/ExtendedExpression.java | 12 +- .../ExtendedExpressionProtoConverter.java | 206 ++------------ .../extended/expression/MyCliente.java | 75 ----- .../ProtoExtendedExpressionConverter.java | 157 +++------- .../extension/ExtensionCollector.java | 70 ++--- .../extension/ImmutableExtensionLookup.java | 57 +--- .../io/substrait/plan/PlanProtoConverter.java | 4 - .../io/substrait/plan/ProtoPlanConverter.java | 5 +- .../ExtendedExpressionProtoConverterTest.java | 99 ++++--- .../ProtoExtendedExpressionConverterTest.java | 114 ++++---- .../substrait/isthmus/SqlConverterBase.java | 19 +- .../io/substrait/isthmus/SqlToSubstrait.java | 269 ++---------------- .../io/substrait/isthmus/SubstraitToSql.java | 5 +- .../isthmus/ExtendedExpressionTestBase.java | 28 +- .../SimpleExtendedExpressionsTest.java | 2 +- .../ExtendedExpressionIntegrationTest.java | 61 +--- isthmus/src/test/resources/tpch/schema.sql | 71 +++++ 19 files changed, 329 insertions(+), 941 deletions(-) delete mode 100644 core/src/main/java/io/substrait/extended/expression/ExpressionReference.java delete mode 100644 core/src/main/java/io/substrait/extended/expression/MyCliente.java diff --git a/core/build.gradle.kts b/core/build.gradle.kts index c06a00e86..6b8cfac66 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -85,11 +85,6 @@ dependencies { compileOnly("org.immutables:value-annotations:2.8.8") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") - - implementation("com.google.protobuf:protobuf-java-util:3.17.3") { - exclude("com.google.guava", "guava") - .because("Brings in Guava for Android, which we don't want (and breaks multimaps).") - } } java { diff --git a/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java b/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java deleted file mode 100644 index 2214f0438..000000000 --- a/core/src/main/java/io/substrait/extended/expression/ExpressionReference.java +++ /dev/null @@ -1,11 +0,0 @@ -package io.substrait.extended.expression; - -import io.substrait.expression.Expression; -import org.immutables.value.Value; - -@Value.Immutable -public abstract class ExpressionReference { - public abstract Expression getExpression(); - - public abstract String getOutputNames(); -} diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java index e0a5c03f9..4f705e82c 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -4,17 +4,25 @@ import io.substrait.proto.AdvancedExtension; import io.substrait.type.NamedStruct; import java.util.List; -import java.util.Map; import java.util.Optional; import org.immutables.value.Value; @Value.Immutable public abstract class ExtendedExpression { - public abstract Map getReferredExpr(); + public abstract List getReferredExpr(); public abstract NamedStruct getBaseSchema(); public abstract List getExpectedTypeUrls(); + // creating simple extensions, such as extensionURIs and extensions, is performed on the fly + public abstract Optional getAdvancedExtension(); + + @Value.Immutable + public abstract static class ExpressionReference { + public abstract Expression getReferredExpr(); + + public abstract List getOutputNames(); + } } diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java index e4015f3cf..cffdfefd0 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -1,22 +1,19 @@ package io.substrait.extended.expression; -import com.google.protobuf.util.JsonFormat; import io.substrait.expression.*; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; -import io.substrait.extension.SimpleExtension; import io.substrait.proto.ExpressionReference; import io.substrait.proto.ExtendedExpression; -import io.substrait.type.NamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import io.substrait.type.proto.TypeProtoConverter; -import java.io.IOException; -import java.util.*; +/** + * Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link + * io.substrait.proto.ExtendedExpression} + */ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( - io.substrait.extended.expression.ExtendedExpression extendedExpressionPojo) { + io.substrait.extended.expression.ExtendedExpression extendedExpression) { ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); @@ -24,187 +21,30 @@ public ExtendedExpression toProto( final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, null); - // convert expression pojo into expression protobuf - io.substrait.proto.Expression expressionProto = - expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) extendedExpressionPojo.getReferredExpr().get(0)); + for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference + expressionReference : extendedExpression.getReferredExpr()) { - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder().setExpression(expressionProto).addOutputNames("column-01"); + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr()); - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - extendedExpressionBuilder.setBaseSchema( - extendedExpressionPojo.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionProto) + .addAllOutputNames(expressionReference.getOutputNames()); + extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder); + } + extendedExpressionBuilder.setBaseSchema( + extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); - functionCollector.addExtensionsToPlan(extendedExpressionBuilder); - if (extendedExpressionPojo.getAdvancedExtension().isPresent()) { + // the process of adding simple extensions, such as extensionURIs and extensions, is handled on + // the fly + functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder); + if (extendedExpression.getAdvancedExtension().isPresent()) { extendedExpressionBuilder.setAdvancedExtensions( - extendedExpressionPojo.getAdvancedExtension().get()); + extendedExpression.getAdvancedExtension().get()); } return extendedExpressionBuilder.build(); } - - public static void main(String[] args) throws IOException { - SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); - System.out.println( - "defaultExtensionCollection.scalarFunctions(): " - + defaultExtensionCollection.scalarFunctions()); - System.out.println( - "defaultExtensionCollection.windowFunctions(): " - + defaultExtensionCollection.windowFunctions()); - System.out.println( - "defaultExtensionCollection.aggregateFunctions(): " - + defaultExtensionCollection.aggregateFunctions()); - - Optional equal = - defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - return s.name().equalsIgnoreCase("add"); - }) - .findFirst() - .map( - declaration -> - ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183))); - - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(namedStruct); - - ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); - - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer().print(proto)); - } - - public static ExtendedExpression createExtendedExpression( - io.substrait.expression.Expression.ScalarFunctionInvocation expr) { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); - - io.substrait.proto.Expression expression = new ExpressionProtoConverter(null, null).visit(expr); - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expression.toBuilder()) - .addOutputNames("col-01"); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - - return extendedExpressionBuilder.build(); - } - - public static void createExtendedExpressionManually() { - - Map nameToExpressionMap = new HashMap<>(); - ImmutableExpression.I32Literal build = Expression.I32Literal.builder().value(10).build(); - nameToExpressionMap.put("out_01", build); - - List expressionList = new ArrayList<>(); - expressionList.add(0, null); - - // nation table - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - - ExtensionCollector functionCollector = new ExtensionCollector(); - - new ExpressionProtoConverter(new ExtensionCollector(), null); - - FunctionArg functionArg = - new FunctionArg() { - @Override - public R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) - throws E { - return null; - } - }; - - // var argVisitor = FunctionArg.toProto(new TypeProtoConverter(new ExtensionCollector()), this); - - /* - - // FIXME! setFunctionReference, addArguments(index: 0, 1) - io.substrait.proto.Expression.Builder expressionBuilder = - io.substrait.proto.Expression.newBuilder() - .setScalarFunction( - io.substrait.proto.Expression.ScalarFunction.newBuilder() - .setFunctionReference(1) - .setOutputType(output) - .addArguments( - 0, - FunctionArgument.newBuilder().setValue(result.referenceBuilder())) - .addArguments( - 1, - FunctionArgument.newBuilder() - .setValue(result.expressionBuilderLiteral()))); - io.substrait.proto.ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - */ - - /* - - io.substrait.extended.expression.ExtendedExpression extendedExpression = new io.substrait.extended.expression.ExtendedExpression() { - @Override - public List getReferredExpr() { - io.substrait.extended.expression.ExpressionReference - - @Override - public NamedStruct getBaseSchema() { - return null; - } - - @Override - public List getExpectedTypeUrls() { - return null; - } - - @Override - public Optional getAdvancedExtension() { - return Optional.empty(); - } - }; - - System.out.println("inicio"); - System.out.println(extendedExpression.getReferredExpr().get(0)); - System.out.println(extendedExpression.getReferredExpr().get(0).getType()); - System.out.println("fin"); - - ExpressionReferenceOrBuilder - - */ - - } } diff --git a/core/src/main/java/io/substrait/extended/expression/MyCliente.java b/core/src/main/java/io/substrait/extended/expression/MyCliente.java deleted file mode 100644 index 0ec8a90ed..000000000 --- a/core/src/main/java/io/substrait/extended/expression/MyCliente.java +++ /dev/null @@ -1,75 +0,0 @@ -package io.substrait.extended.expression; - -import com.google.protobuf.util.JsonFormat; -import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.ImmutableFieldReference; -import io.substrait.extension.SimpleExtension; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.NamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.io.IOException; -import java.util.*; - -public class MyCliente { - public static void main(String[] args) throws IOException { - SimpleExtension.ExtensionCollection defaultExtensionCollection = SimpleExtension.loadDefaults(); - Optional equal = - defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) - .findFirst() - .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183)); - }); - - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - - ImmutableNamedStruct.builder() - .addNames("id") - .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) - .build(); - - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(namedStruct); - - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer() - .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); - } -} diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java index ed1fc377b..41fc1bef7 100644 --- a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java @@ -5,38 +5,18 @@ import io.substrait.extension.*; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; -import io.substrait.relation.ProtoRelConverter; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; +import java.util.*; +/** + * Converts from {@link io.substrait.proto.ExtendedExpression} to {@link + * io.substrait.extended.expression.ExtendedExpression} + */ public class ProtoExtendedExpressionConverter { - private ExtensionCollector lookup = new ExtensionCollector(); - private ProtoTypeConverter protoTypeConverter = - new ProtoTypeConverter( - lookup, ImmutableSimpleExtension.ExtensionCollection.builder().build()); - - private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup) { - return new ProtoExpressionConverter( - functionLookup, - this.extensionCollection, - null, - null); - } - - private ProtoExpressionConverter getPprotoExpressionConverter(ExtensionLookup functionLookup, io.substrait.type.NamedStruct namedStruct) { - return new ProtoExpressionConverter( - functionLookup, - this.extensionCollection, - namedStruct.struct(), - null); - } - - protected final SimpleExtension.ExtensionCollection extensionCollection; + private final SimpleExtension.ExtensionCollection extensionCollection; public ProtoExtendedExpressionConverter() throws IOException { this(SimpleExtension.loadDefaults()); @@ -46,120 +26,49 @@ public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection exte this.extensionCollection = extensionCollection; } - protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) { - return new ProtoRelConverter(functionLookup, this.extensionCollection); - } + private final ProtoTypeConverter protoTypeConverter = + new ProtoTypeConverter( + new ExtensionCollector(), ImmutableSimpleExtension.ExtensionCollection.builder().build()); - public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpressionProto) { + public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) { + // fill in simple extension information through a discovery in the current proto-extended + // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder().from(extendedExpressionProto).build(); - + ImmutableExtensionLookup.builder() + .from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()) + .build(); - // para struct - NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); - - System.out.println("namedStruct"); - System.out.println(namedStruct); + NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); + io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto); ProtoExpressionConverter protoExpressionConverter = - getPprotoExpressionConverter(functionLookup, namedStruct); - - Map indexToExpressionMap = new HashMap<>(); - for (ExpressionReference expressionReference : extendedExpressionProto.getReferredExprList()) { - System.out.println( - "expressionReference.getExpression(): " + expressionReference.getExpression()); - indexToExpressionMap.put( - 0, protoExpressionConverter.from(expressionReference.getExpression())); + new ProtoExpressionConverter( + functionLookup, this.extensionCollection, namedStruct.struct(), null); + + List expressionReferences = new ArrayList<>(); + for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { + Expression expressionPojo = + protoExpressionConverter.from(expressionReference.getExpression()); + expressionReferences.add( + ImmutableExpressionReference.builder() + .referredExpr(expressionPojo) + .addAllOutputNames(expressionReference.getOutputNamesList()) + .build()); } - // para struct - /* - NamedStruct baseSchema = extendedExpressionProto.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = newNamedStruct(baseSchema); - - System.out.println("namedStruct"); - System.out.println(namedStruct); - - */ - ImmutableExtendedExpression.Builder builder = ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) + .referredExpr(expressionReferences) .advancedExtension( Optional.ofNullable( - extendedExpressionProto.hasAdvancedExtensions() - ? extendedExpressionProto.getAdvancedExtensions() + extendedExpression.hasAdvancedExtensions() + ? extendedExpression.getAdvancedExtensions() : null)) .baseSchema(namedStruct); - /* - ProtocolStringList namesList = baseSchema.getNamesList(); - - Type.Struct struct = baseSchema.getStruct(); - Type types = struct.getTypes(0); - System.out.println("types.getDescriptorForType().getName(): " + types.getDescriptorForType().); - - - */ - - /* - System.out.println("namesList: " + namesList); - System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct()); - System.out.println("}}{{{{{{{{{{''------>"); - System.out.println("baseSchema.getStruct(): " + baseSchema.getStruct().getTypes(0)); - - - */ - - /* - ImmutableNamedStruct.builder(). - - // para expression - - Optional equal = - defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) - .findFirst() - .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183) - ); - } - ); - - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(namedStruct); - - */ - return builder.build(); } - private io.substrait.type.NamedStruct newNamedStruct(NamedStruct namedStruct) { + private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) { var struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() .names(namedStruct.getNamesList()) diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index f2a4a6f18..402a8c94c 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -1,5 +1,6 @@ package io.substrait.extension; +import com.github.bsideup.jabel.Desugar; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; @@ -52,6 +53,20 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { } public void addExtensionsToPlan(Plan.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { + SimpleExtensions simpleExtensions = getExtensions(); + + builder.addAllExtensionUris(simpleExtensions.uris().values()); + builder.addAllExtensions(simpleExtensions.extensionList()); + } + + private SimpleExtensions getExtensions() { var uriPos = new AtomicInteger(1); var uris = new HashMap(); @@ -94,58 +109,13 @@ public void addExtensionsToPlan(Plan.Builder builder) { .build(); extensionList.add(decl); } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); + return new SimpleExtensions(uris, extensionList); } - public void addExtensionsToPlan(ExtendedExpression.Builder builder) { - var uriPos = new AtomicInteger(1); - var uris = new HashMap(); - - var extensionList = new ArrayList(); - for (var e : funcMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - for (var e : typeMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), - k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) - .build()); - var decl = - SimpleExtensionDeclaration.newBuilder() - .setExtensionType( - SimpleExtensionDeclaration.ExtensionType.newBuilder() - .setTypeAnchor(e.getKey()) - .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) - .build(); - extensionList.add(decl); - } - - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - } + @Desugar + private record SimpleExtensions( + HashMap uris, + ArrayList extensionList) {} /** We don't depend on guava... */ private static class BidiMap { diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 3d600002b..c88bafc1c 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,10 +1,10 @@ package io.substrait.extension; -import io.substrait.proto.ExtendedExpression; -import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.SimpleExtensionURI; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -31,14 +31,16 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from(Plan p) { + public Builder from( + List simpleExtensionURIs, + List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); - for (var extension : p.getExtensionUrisList()) { + for (var extension : simpleExtensionURIs) { namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { continue; } @@ -55,50 +57,7 @@ public Builder from(Plan p) { } // Add all types used in plan to the typeMap - for (var extension : p.getExtensionsList()) { - if (!extension.hasExtensionType()) { - continue; - } - SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); - int reference = type.getTypeAnchor(); - String namespace = namespaceMap.get(type.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + type.getExtensionUriReference()); - } - String name = type.getName(); - SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name); - typeMap.put(reference, anchor); - } - - return this; - } - - public Builder from(ExtendedExpression p) { - Map namespaceMap = new HashMap<>(); - for (var extension : p.getExtensionUrisList()) { - namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); - } - - // Add all functions used in plan to the functionMap - for (var extension : p.getExtensionsList()) { - if (!extension.hasExtensionFunction()) { - continue; - } - SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); - int reference = func.getFunctionAnchor(); - String namespace = namespaceMap.get(func.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + func.getExtensionUriReference()); - } - String name = func.getName(); - SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name); - functionMap.put(reference, anchor); - } - - // Add all types used in plan to the typeMap - for (var extension : p.getExtensionsList()) { + for (var extension : simpleExtensionDeclarations) { if (!extension.hasExtensionType()) { continue; } diff --git a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java index af0f6d69a..0bdf7d68c 100644 --- a/core/src/main/java/io/substrait/plan/PlanProtoConverter.java +++ b/core/src/main/java/io/substrait/plan/PlanProtoConverter.java @@ -34,10 +34,6 @@ public Plan toProto(io.substrait.plan.Plan plan) { if (plan.getAdvancedExtension().isPresent()) { builder.setAdvancedExtensions(plan.getAdvancedExtension().get()); } - /* - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - */ return builder.build(); } } diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index be4f4ad9f..7222eb7ed 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -32,7 +32,10 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder() + .from(plan.getExtensionUrisList(), plan.getExtensionsList()) + .build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java index 4004e9de7..20079e24f 100644 --- a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java +++ b/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java @@ -1,76 +1,73 @@ package io.substrait.extended.expression; -import com.google.protobuf.util.JsonFormat; +import static org.junit.jupiter.api.Assertions.assertEquals; + import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; import io.substrait.expression.ImmutableFieldReference; import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; -import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.Test; public class ExtendedExpressionProtoConverterTest extends TestBase { @Test - public void toProtoTest() throws IOException { - Optional equal = + public void toProtoTest() { + // create predefined POJO extended expression + Optional scalarFunctionExpression = defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) + .filter(s -> s.name().equalsIgnoreCase("add")) .findFirst() .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183)); - }); + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); + List expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); - ImmutableNamedStruct.builder() - .addNames("id") - .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.REQUIRED.I32).build()) - .build(); + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.NULLABLE.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); - ImmutableExtendedExpression.Builder builder = + ImmutableExtendedExpression.Builder extendedExpression = ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) + .referredExpr(expressionReferences) .baseSchema(namedStruct); - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer() - .print(new ExtendedExpressionProtoConverter().toProto(builder.build()))); + // convert POJO extended expression into PROTOBUF extended expression + io.substrait.proto.ExtendedExpression proto = + new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); + + assertEquals( + "/functions_arithmetic_decimal.yaml", proto.getExtensionUrisList().get(0).getUri()); + assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); } } diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java index 1c1fa40b5..9ab84f274 100644 --- a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java +++ b/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java @@ -1,6 +1,5 @@ package io.substrait.extended.expression; -import com.google.protobuf.util.JsonFormat; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -8,87 +7,74 @@ import io.substrait.expression.ImmutableFieldReference; import io.substrait.proto.ExtendedExpression; import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; -import java.util.*; - +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class ProtoExtendedExpressionConverterTest extends TestBase { @Test public void fromTest() throws IOException { - Optional equal = + // create predefined POJO extended expression + Optional scalarFunctionExpression = defaultExtensionCollection.scalarFunctions().stream() - .filter( - s -> { - System.out.println(":>>>>"); - System.out.println(s); - System.out.println(s.uri()); - System.out.println(s.returnType()); - System.out.println(s.description()); - System.out.println("s.name(): " + s.name()); - System.out.println(s.key()); - return s.name().equalsIgnoreCase("add"); - }) + .filter(s -> s.name().equalsIgnoreCase("add")) .findFirst() .map( - declaration -> { - System.out.println("declaration: " + declaration); - System.out.println("declaration.name(): " + declaration.name()); - return ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.I32) - .build(), - ExpressionCreator.i32(false, 183)); - }); + declaration -> + ExpressionCreator.scalarFunction( + declaration, + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183))); - Map indexToExpressionMap = new HashMap<>(); - indexToExpressionMap.put(0, equal.get()); - List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"); - /* - List dataTypes = - Arrays.asList( - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING, - TypeCreator.NULLABLE.I32, - TypeCreator.NULLABLE.STRING); - NamedStruct namedStruct = - NamedStruct.of( - columnNames, Type.Struct.builder().fields(dataTypes).nullable(false).build()); - */ - ImmutableNamedStruct id = ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct(Type.Struct.builder().nullable(false).addFields(TypeCreator.NULLABLE.I32, - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.I32, - TypeCreator.REQUIRED.STRING).build()) - .build(); + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .referredExpr(scalarFunctionExpression.get()) + .addOutputNames("new-column") + .build(); - ImmutableExtendedExpression.Builder builder = - ImmutableExtendedExpression.builder() - .putAllReferredExpr(indexToExpressionMap) - .baseSchema(id); + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); - ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(builder.build()); + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); - System.out.println("=======POJO 01======="); - System.out.println("xxxx: " + builder); - System.out.println("=======PROTO 02======="); - System.out.println("yyyy: " + JsonFormat.printer().print(proto)); + // pojo initial extended expression + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct) + .build(); - System.out.println("=======POJO 03======="); - io.substrait.extended.expression.ExtendedExpression from = - new ProtoExtendedExpressionConverter().from(proto); - System.out.println("zzzz: " + from); - System.out.println("11111111"); + // proto extended expression + ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + // pojo final extended expression + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); - Assertions.assertEquals(from, builder.build()); + // validate extended expression pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 716fdb66e..67ae9a8ea 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -89,7 +89,24 @@ protected SqlConverterBase(FeatureBoard features) { EXTENSION_COLLECTION = defaults; } - Result registerCreateTables(List tables) throws SqlParseException { + Pair registerCreateTables(List tables) + throws SqlParseException { + CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); + CalciteCatalogReader catalogReader = + new CalciteCatalogReader(rootSchema, List.of(), factory, config); + SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); + if (tables != null) { + for (String tableDef : tables) { + List tList = parseCreateTable(factory, validator, tableDef); + for (DefinedTable t : tList) { + rootSchema.add(t.getName(), t); + } + } + } + return Pair.of(validator, catalogReader); + } + + Result registerCreateTablesForExtendedExpression(List tables) throws SqlParseException { Map nameToTypeMap = new LinkedHashMap<>(); Map nameToNodeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 2b35dffef..3052898e1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,25 +1,17 @@ package io.substrait.isthmus; -import com.github.bsideup.jabel.Desugar; import com.google.common.annotations.VisibleForTesting; +import io.substrait.extended.expression.ExtendedExpressionProtoConverter; +import io.substrait.extended.expression.ImmutableExpressionReference; +import io.substrait.extended.expression.ImmutableExtendedExpression; import io.substrait.extension.ExtensionCollector; -import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; -import io.substrait.proto.Expression; -import io.substrait.proto.Expression.ScalarFunction; -import io.substrait.proto.ExpressionReference; import io.substrait.proto.ExtendedExpression; -import io.substrait.proto.FunctionArgument; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; -import io.substrait.proto.SimpleExtensionDeclaration; -import io.substrait.proto.SimpleExtensionURI; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; -import io.substrait.type.TypeCreator; -import io.substrait.type.proto.TypeProtoConverter; -import java.io.IOException; import java.util.*; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; @@ -28,9 +20,6 @@ import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexInputRef; -import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.sql.SqlNode; @@ -64,8 +53,8 @@ public Plan execute(String sql, Function, NamedStruct> tableLookup) } public Plan execute(String sql, List tables) throws SqlParseException { - var result = registerCreateTables(tables); - return executeInner(sql, factory, result.validator(), result.catalogReader()); + var pair = registerCreateTables(tables); + return executeInner(sql, factory, pair.left, pair.right); } public Plan execute(String sql, String name, Schema schema) throws SqlParseException { @@ -84,7 +73,7 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep */ public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) throws SqlParseException { - var result = registerCreateTables(tables); + var result = registerCreateTablesForExtendedExpression(tables); return executeInnerSQLExpression( sqlExpression, result.validator(), @@ -95,8 +84,8 @@ public ExtendedExpression executeSQLExpression(String sqlExpression, List sqlToRelNode(String sql, List tables) throws SqlParseException { - var result = registerCreateTables(tables); - return sqlToRelNode(sql, result.validator(), result.catalogReader()); + var pair = registerCreateTables(tables); + return sqlToRelNode(sql, pair.left, pair.right); } // Package protected for testing @@ -143,249 +132,27 @@ private ExtendedExpression executeInnerSQLExpression( Map nameToTypeMap, Map nameToNodeMap) throws SqlParseException { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); - ExtensionCollector functionCollector = new ExtensionCollector(); - RexNode rexNode = - sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - io.substrait.proto.Type output = - TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - List functionArgumentList = new ArrayList<>(); - result - .expressionBuilderMap() - .forEach( - (k, v) -> { - System.out.println("k->" + k); - System.out.println("v->" + v); - functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); - }); - - ScalarFunction.Builder scalarFunctionBuilder = - ScalarFunction.newBuilder() - .setFunctionReference(1) // rel_01 - .setOutputType(output) - .addAllArguments(functionArgumentList); - - Expression.Builder expressionBuilder = - Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - - io.substrait.expression.Expression.ScalarFunctionInvocation func = - (io.substrait.expression.Expression.ScalarFunctionInvocation) - rexNode.accept(rexExpressionConverter); - String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 - - // this is not mandatory to be defined; it is working without this definition. It is - // only created here to create a proto message that has the correct semantics - HashMap extensionUris = new HashMap<>(); - SimpleExtensionURI simpleExtensionURI; - try { - simpleExtensionURI = - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) // rel_02 - .setUri( - SimpleExtension.loadDefaults().scalarFunctions().stream() - .filter(s -> s.toString().equalsIgnoreCase(declaration)) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format("Failed to get URI resource for %s.", declaration))) - .uri()) - .build(); - } catch (IOException e) { - throw new RuntimeException(e); - } - extensionUris.put("uri", simpleExtensionURI); - - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionLowerThan = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 - .setName(declaration) - .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 - .build(); - extensions.add(extensionFunctionLowerThan); - - System.out.println( - "extendedExpressionBuilder.getExtensionUrisList(): " - + extendedExpressionBuilder.getExtensionUrisList()); - // adding it for semantic purposes, it is not mandatory or needed - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - - NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); - extendedExpressionBuilder.setBaseSchema( - namedStruct.toProto(new TypeProtoConverter(functionCollector))); - - /* - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - */ - - return extendedExpressionBuilder.build(); - } - - private ExtendedExpression executeInnerSQLExpressionPojo( - String sqlExpression, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); - ExtensionCollector functionCollector = new ExtensionCollector(); RexNode rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - ResulTraverseRowExpression result = new TraverseRexNode().getRowExpression(rexNode); - io.substrait.proto.Type output = - TypeCreator.NULLABLE.BOOLEAN.accept(new TypeProtoConverter(functionCollector)); - List functionArgumentList = new ArrayList<>(); - result - .expressionBuilderMap() - .forEach( - (k, v) -> { - System.out.println("k->" + k); - System.out.println("v->" + v); - functionArgumentList.add(FunctionArgument.newBuilder().setValue(v).build()); - }); - - ScalarFunction.Builder scalarFunctionBuilder = - ScalarFunction.newBuilder() - .setFunctionReference(1) // rel_01 - .setOutputType(output) - .addAllArguments(functionArgumentList); - - Expression.Builder expressionBuilder = - Expression.newBuilder().setScalarFunction(scalarFunctionBuilder); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionBuilder) - .addOutputNames(result.ref().getName()); - - extendedExpressionBuilder.addReferredExpr(0, expressionReferenceBuilder); - io.substrait.expression.Expression.ScalarFunctionInvocation func = (io.substrait.expression.Expression.ScalarFunctionInvocation) rexNode.accept(rexExpressionConverter); - String declaration = func.declaration().key(); // values example: gt:any_any, add:i64_i64 - - // this is not mandatory to be defined; it is working without this definition. It is - // only created here to create a proto message that has the correct semantics - HashMap extensionUris = new HashMap<>(); - SimpleExtensionURI simpleExtensionURI; - try { - simpleExtensionURI = - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(1) // rel_02 - .setUri( - SimpleExtension.loadDefaults().scalarFunctions().stream() - .filter(s -> s.toString().equalsIgnoreCase(declaration)) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format("Failed to get URI resource for %s.", declaration))) - .uri()) - .build(); - } catch (IOException e) { - throw new RuntimeException(e); - } - extensionUris.put("uri", simpleExtensionURI); - - ArrayList extensions = new ArrayList<>(); - SimpleExtensionDeclaration extensionFunctionLowerThan = - SimpleExtensionDeclaration.newBuilder() - .setExtensionFunction( - SimpleExtensionDeclaration.ExtensionFunction.newBuilder() - .setFunctionAnchor(scalarFunctionBuilder.getFunctionReference()) // rel_01 - .setName(declaration) - .setExtensionUriReference(simpleExtensionURI.getExtensionUriAnchor())) // rel_02 - .build(); - extensions.add(extensionFunctionLowerThan); - - System.out.println( - "extendedExpressionBuilder.getExtensionUrisList(): " - + extendedExpressionBuilder.getExtensionUrisList()); - // adding it for semantic purposes, it is not mandatory or needed - extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); - extendedExpressionBuilder.addAllExtensions(extensions); - NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); - extendedExpressionBuilder.setBaseSchema( - namedStruct.toProto(new TypeProtoConverter(functionCollector))); - - /* - builder.addAllExtensionUris(uris.values()); - builder.addAllExtensions(extensionList); - */ + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder().referredExpr(func).addOutputNames("output").build(); - return extendedExpressionBuilder.build(); - } + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); - class TraverseRexNode { - RexInputRef ref = null; - int control = 0; - Expression.Builder referenceBuilder = null; - Expression.Builder literalBuilder = null; - Map expressionBuilderMap = new LinkedHashMap<>(); + ImmutableExtendedExpression.Builder extendedExpression = + ImmutableExtendedExpression.builder() + .referredExpr(expressionReferences) + .baseSchema(namedStruct); - ResulTraverseRowExpression getRowExpression(RexNode rexNode) { - switch (rexNode.getClass().getSimpleName().toUpperCase()) { - case "REXCALL": - for (RexNode rexInternal : ((RexCall) rexNode).operands) { - getRowExpression(rexInternal); - } - ; - break; - case "REXINPUTREF": - ref = (RexInputRef) rexNode; - referenceBuilder = - Expression.newBuilder() - .setSelection( - Expression.FieldReference.newBuilder() - .setDirectReference( - Expression.ReferenceSegment.newBuilder() - .setStructField( - Expression.ReferenceSegment.StructField.newBuilder() - .setField(ref.getIndex())))); - expressionBuilderMap.put(control, referenceBuilder); - control++; - break; - case "REXLITERAL": - RexLiteral literal = (RexLiteral) rexNode; - literalBuilder = - Expression.newBuilder() - .setLiteral( - Expression.Literal.newBuilder().setI32(literal.getValueAs(Integer.class))); - expressionBuilderMap.put(control, literalBuilder); - control++; - break; - default: - throw new AssertionError( - "Unsupported type for: " + rexNode.getClass().getSimpleName().toUpperCase()); - } - return new ResulTraverseRowExpression( - ref, referenceBuilder, literalBuilder, expressionBuilderMap); - } + return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); } - @Desugar - private record ResulTraverseRowExpression( - RexInputRef ref, - Expression.Builder referenceBuilder, - Expression.Builder literalBuilder, - Map expressionBuilderMap) {} - private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) throws SqlParseException { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index d43fda1c1..f402aacd3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -22,9 +22,8 @@ public SubstraitToSql() { public RelNode substraitRelToCalciteRel(Rel relRoot, List tables) throws SqlParseException { - var result = registerCreateTables(tables); - return SubstraitRelNodeConverter.convert( - relRoot, relOptCluster, result.catalogReader(), parserConfig); + var pair = registerCreateTables(tables); + return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, pair.right, parserConfig); } public RelNode substraitRelToCalciteRel( diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 1c3742b52..d29ad3c07 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -2,8 +2,6 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.util.JsonFormat; import io.substrait.extended.expression.ExtendedExpressionProtoConverter; import io.substrait.extended.expression.ProtoExtendedExpressionConverter; import io.substrait.proto.ExtendedExpression; @@ -37,24 +35,20 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String query, SqlToSubstrait s, List creates) throws SqlParseException, IOException { - io.substrait.proto.ExtendedExpression protoExtendedExpression = - s.executeSQLExpression(query, creates); + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = s.executeSQLExpression(query, creates); - try { - String ee = JsonFormat.printer().print(protoExtendedExpression); - System.out.println("Proto Extended Expression: \n" + ee); + // pojo final extended expression + io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial); - io.substrait.extended.expression.ExtendedExpression from = - new ProtoExtendedExpressionConverter().from(protoExtendedExpression); + // proto final extended expression + ExtendedExpression extendedExpressionProtoFinal = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoFinal); - ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(from); + // round-trip to validate extended expression proto initial equals to final + Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial); - Assertions.assertEquals(proto, protoExtendedExpression); - // FIXME! Implement test validation as the same as proto Plan implementation - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(e); - } - - return protoExtendedExpression; + return extendedExpressionProtoInitial; } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 5a407a936..d08f588d6 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -10,7 +10,7 @@ public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { @Test public void filter() throws IOException, SqlParseException { ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("N_NATIONKEY > 10"); + assertProtoExtendedExpressionRoundrip("L_ORDERKEY > 10"); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 0121fae9c..4badb4a5e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -2,7 +2,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import com.google.protobuf.util.JsonFormat; import com.ibm.icu.impl.ClassLoaderUtil; import io.substrait.isthmus.ExtendedExpressionTestBase; import io.substrait.isthmus.SqlToSubstrait; @@ -21,7 +20,7 @@ import org.apache.arrow.dataset.source.DatasetFactory; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; @@ -31,7 +30,9 @@ public class ExtendedExpressionIntegrationTest { @Test public void filterDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_REGIONKEY > 20"; + // Make sure you pass appropriate data, for example, if you pass N_NATIONKEY > 20 the engine + // creates an i64 but casts it to i32 = 20, causing casting problems. + String sqlExpression = "N_NATIONKEY > 9223372036854771827 - 9223372036854771807"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) @@ -50,11 +51,9 @@ public void filterDataset() throws SqlParseException, IOException { int count = 0; while (reader.loadNextBatch()) { count += reader.getVectorSchemaRoot().getRowCount(); - System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); } assertEquals(4, count); } catch (Exception e) { - e.printStackTrace(); throw new RuntimeException(e); } } @@ -62,7 +61,9 @@ public void filterDataset() throws SqlParseException, IOException { @Test public void projectDataset() throws SqlParseException, IOException { URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "20 + N_NATIONKEY"; + // Make sure you pass appropriate data, for example, if you pass N_NATIONKEY + 20 the engine + // creates an i64 but casts it to i32 = 20, causing casting problems. + String sqlExpression = "N_NATIONKEY + 9888486986"; ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) @@ -79,62 +80,27 @@ public void projectDataset() throws SqlParseException, IOException { Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { int count = 0; - int sum = 0; + Long sum = 0L; while (reader.loadNextBatch()) { count += reader.getVectorSchemaRoot().getRowCount(); - IntVector intVector = (IntVector) reader.getVectorSchemaRoot().getVector(0); - for (int i = 0; i < intVector.getValueCount(); i++) { - sum += intVector.get(i); + BigIntVector bigIntVector = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + for (int i = 0; i < bigIntVector.getValueCount(); i++) { + sum += bigIntVector.get(i); } - System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); } assertEquals(25, count); - assertEquals(24 * 25 / 2 + 20 * count, sum); + assertEquals(24 * 25 / 2 + 9888486986L * count, sum); } catch (Exception e) { throw new RuntimeException(e); } } - @Test - public void filterDatasetUsingExtendedExpression() throws SqlParseException, IOException { - URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - String sqlExpression = "N_NATIONKEY > 20"; - ScanOptions options = - new ScanOptions.Builder(/*batchSize*/ 32768) - .columns(Optional.empty()) - .substraitFilter(getFilterExtendedExpression(sqlExpression)) - .build(); - try (BufferAllocator allocator = new RootAllocator(); - DatasetFactory datasetFactory = - new FileSystemDatasetFactory( - allocator, - NativeMemoryPool.getDefault(), - FileFormat.PARQUET, - resource.toURI().toString()); - Dataset dataset = datasetFactory.finish(); - Scanner scanner = dataset.newScan(options); - ArrowReader reader = scanner.scanBatches()) { - int count = 0; - while (reader.loadNextBatch()) { - count += reader.getVectorSchemaRoot().getRowCount(); - System.out.println(reader.getVectorSchemaRoot().contentToTSVString()); - } - assertEquals(4, count); - } catch (Exception e) { - e.printStackTrace(); - throw new RuntimeException(e); - } - } - private static ByteBuffer getFilterExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); - System.out.println( - "JsonFormat.printer().print(getFilterExtendedExpression): " - + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); @@ -149,9 +115,6 @@ private static ByteBuffer getProjectExtendedExpression(String sqlExpression) new SqlToSubstrait() .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); - System.out.println( - "JsonFormat.printer().print(getProjectExtendedExpression): " - + JsonFormat.printer().print(extendedExpression)); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); diff --git a/isthmus/src/test/resources/tpch/schema.sql b/isthmus/src/test/resources/tpch/schema.sql index b8fb4cfd0..81f6f927b 100644 --- a/isthmus/src/test/resources/tpch/schema.sql +++ b/isthmus/src/test/resources/tpch/schema.sql @@ -1,6 +1,77 @@ +CREATE TABLE PART ( + P_PARTKEY BIGINT NOT NULL, + P_NAME VARCHAR(55), + P_MFGR CHAR(25), + P_BRAND CHAR(10), + P_TYPE VARCHAR(25), + P_SIZE INTEGER, + P_CONTAINER CHAR(10), + P_RETAILPRICE DECIMAL, + P_COMMENT VARCHAR(23) +); +CREATE TABLE SUPPLIER ( + S_SUPPKEY BIGINT NOT NULL, + S_NAME CHAR(25), + S_ADDRESS VARCHAR(40), + S_NATIONKEY BIGINT NOT NULL, + S_PHONE CHAR(15), + S_ACCTBAL DECIMAL, + S_COMMENT VARCHAR(101) +); +CREATE TABLE PARTSUPP ( + PS_PARTKEY BIGINT NOT NULL, + PS_SUPPKEY BIGINT NOT NULL, + PS_AVAILQTY INTEGER, + PS_SUPPLYCOST DECIMAL, + PS_COMMENT VARCHAR(199) +); +CREATE TABLE CUSTOMER ( + C_CUSTKEY BIGINT NOT NULL, + C_NAME VARCHAR(25), + C_ADDRESS VARCHAR(40), + C_NATIONKEY BIGINT NOT NULL, + C_PHONE CHAR(15), + C_ACCTBAL DECIMAL, + C_MKTSEGMENT CHAR(10), + C_COMMENT VARCHAR(117) +); +CREATE TABLE ORDERS ( + O_ORDERKEY BIGINT NOT NULL, + O_CUSTKEY BIGINT NOT NULL, + O_ORDERSTATUS CHAR(1), + O_TOTALPRICE DECIMAL, + O_ORDERDATE DATE, + O_ORDERPRIORITY CHAR(15), + O_CLERK CHAR(15), + O_SHIPPRIORITY INTEGER, + O_COMMENT VARCHAR(79) +); +CREATE TABLE LINEITEM ( + L_ORDERKEY BIGINT NOT NULL, + L_PARTKEY BIGINT NOT NULL, + L_SUPPKEY BIGINT NOT NULL, + L_LINENUMBER INTEGER, + L_QUANTITY DECIMAL, + L_EXTENDEDPRICE DECIMAL, + L_DISCOUNT DECIMAL, + L_TAX DECIMAL, + L_RETURNFLAG CHAR(1), + L_LINESTATUS CHAR(1), + L_SHIPDATE DATE, + L_COMMITDATE DATE, + L_RECEIPTDATE DATE, + L_SHIPINSTRUCT CHAR(25), + L_SHIPMODE CHAR(10), + L_COMMENT VARCHAR(44) +); CREATE TABLE NATION ( N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152) ); +CREATE TABLE REGION ( + R_REGIONKEY BIGINT NOT NULL, + R_NAME CHAR(25), + R_COMMENT VARCHAR(152) +); From 940f70399fa4c4c22c581f855004a1de44463c85 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 24 Nov 2023 19:16:58 -0500 Subject: [PATCH 13/34] fix: clean code redundant method --- .../ExtendedExpressionIntegrationTest.java | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 4badb4a5e..297517ec8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -36,7 +36,7 @@ public void filterDataset() throws SqlParseException, IOException { ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .substraitFilter(getFilterExtendedExpression(sqlExpression)) + .substraitFilter(getExtendedExpression(sqlExpression)) .build(); try (BufferAllocator allocator = new RootAllocator(); DatasetFactory datasetFactory = @@ -67,7 +67,7 @@ public void projectDataset() throws SqlParseException, IOException { ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .substraitProjection(getProjectExtendedExpression(sqlExpression)) + .substraitProjection(getExtendedExpression(sqlExpression)) .build(); try (BufferAllocator allocator = new RootAllocator(); DatasetFactory datasetFactory = @@ -95,7 +95,7 @@ public void projectDataset() throws SqlParseException, IOException { } } - private static ByteBuffer getFilterExtendedExpression(String sqlExpression) + private static ByteBuffer getExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlToSubstrait() @@ -104,22 +104,8 @@ private static ByteBuffer getFilterExtendedExpression(String sqlExpression) byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); - ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect(extendedExpressions.length); - substraitExpressionFilter.put(extendedExpressions); - return substraitExpressionFilter; - } - - private static ByteBuffer getProjectExtendedExpression(String sqlExpression) - throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - new SqlToSubstrait() - .executeSQLExpression( - sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); - byte[] extendedExpressions = - Base64.getDecoder() - .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); - ByteBuffer substraitExpressionProject = ByteBuffer.allocateDirect(extendedExpressions.length); - substraitExpressionProject.put(extendedExpressions); - return substraitExpressionProject; + ByteBuffer substraitExpression = ByteBuffer.allocateDirect(extendedExpressions.length); + substraitExpression.put(extendedExpressions); + return substraitExpression; } } From f817eb0cc0012bb4357cb619fd496c64f2e88ed7 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 29 Nov 2023 10:29:04 -0500 Subject: [PATCH 14/34] fix: apply suggestions from code review Co-authored-by: Dane Pitkin <48041712+danepitkin@users.noreply.github.com> --- .../io/substrait/extended/expression/ExtendedExpression.java | 4 ++-- .../extended/expression/ExtendedExpressionProtoConverter.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java index 4f705e82c..2aee599c2 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java @@ -9,7 +9,7 @@ @Value.Immutable public abstract class ExtendedExpression { - public abstract List getReferredExpr(); + public abstract List getReferredExpressions(); public abstract NamedStruct getBaseSchema(); @@ -21,7 +21,7 @@ public abstract class ExtendedExpression { @Value.Immutable public abstract static class ExpressionReference { - public abstract Expression getReferredExpr(); + public abstract Expression getExpression(); public abstract List getOutputNames(); } diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java index 6ace03df4..f3d8441ae 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java @@ -15,7 +15,7 @@ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( io.substrait.extended.expression.ExtendedExpression extendedExpression) { - ExtendedExpression.Builder extendedExpressionBuilder = ExtendedExpression.newBuilder(); + ExtendedExpression.Builder builder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); final ExpressionProtoConverter expressionProtoConverter = From b1c96bd458e8096c7d8663ab2a0710199f080d7a Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 29 Nov 2023 11:29:03 -0500 Subject: [PATCH 15/34] fix: code review core module --- .../ExtendedExpression.java | 2 +- .../ExtendedExpressionProtoConverter.java | 26 +++++++-------- .../ProtoExtendedExpressionConverter.java | 32 +++++-------------- .../extension/ImmutableExtensionLookup.java | 13 +++++++- .../io/substrait/plan/ProtoPlanConverter.java | 5 +-- .../java/io/substrait/type/NamedStruct.java | 17 ++++++++++ .../ExtendedExpressionProtoConverterTest.java | 6 ++-- .../ProtoExtendedExpressionConverterTest.java | 10 +++--- 8 files changed, 58 insertions(+), 53 deletions(-) rename core/src/main/java/io/substrait/{extended/expression => extendedexpression}/ExtendedExpression.java (94%) rename core/src/main/java/io/substrait/{extended/expression => extendedexpression}/ExtendedExpressionProtoConverter.java (66%) rename core/src/main/java/io/substrait/{extended/expression => extendedexpression}/ProtoExtendedExpressionConverter.java (69%) rename core/src/test/java/io/substrait/{extended/expression => extendedexpression}/ExtendedExpressionProtoConverterTest.java (94%) rename core/src/test/java/io/substrait/{extended/expression => extendedexpression}/ProtoExtendedExpressionConverterTest.java (90%) diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java similarity index 94% rename from core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java rename to core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java index 2aee599c2..de405f9a3 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.proto.AdvancedExtension; diff --git a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java similarity index 66% rename from core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java rename to core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index f3d8441ae..a123e4b9f 100644 --- a/core/src/main/java/io/substrait/extended/expression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.expression.proto.ExpressionProtoConverter; @@ -7,13 +7,10 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.type.proto.TypeProtoConverter; -/** - * Converts from {@link io.substrait.extended.expression.ExtendedExpression} to {@link - * ExtendedExpression} - */ +/** Converts from {@link ExtendedExpression} to {@link ExtendedExpression} */ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( - io.substrait.extended.expression.ExtendedExpression extendedExpression) { + io.substrait.extendedexpression.ExtendedExpression extendedExpression) { ExtendedExpression.Builder builder = ExtendedExpression.newBuilder(); ExtensionCollector functionCollector = new ExtensionCollector(); @@ -21,30 +18,29 @@ public ExtendedExpression toProto( final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, null); - for (io.substrait.extended.expression.ExtendedExpression.ExpressionReference - expressionReference : extendedExpression.getReferredExpr()) { + for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference + expressionReference : extendedExpression.getReferredExpressions()) { io.substrait.proto.Expression expressionProto = expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) expressionReference.getReferredExpr()); + (Expression.ScalarFunctionInvocation) expressionReference.getExpression()); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) .addAllOutputNames(expressionReference.getOutputNames()); - extendedExpressionBuilder.addReferredExpr(expressionReferenceBuilder); + builder.addReferredExpr(expressionReferenceBuilder); } - extendedExpressionBuilder.setBaseSchema( + builder.setBaseSchema( extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); // the process of adding simple extensions, such as extensionURIs and extensions, is handled on // the fly - functionCollector.addExtensionsToExtendedExpression(extendedExpressionBuilder); + functionCollector.addExtensionsToExtendedExpression(builder); if (extendedExpression.getAdvancedExtension().isPresent()) { - extendedExpressionBuilder.setAdvancedExtensions( - extendedExpression.getAdvancedExtension().get()); + builder.setAdvancedExtensions(extendedExpression.getAdvancedExtension().get()); } - return extendedExpressionBuilder.build(); + return builder.build(); } } diff --git a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java similarity index 69% rename from core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java rename to core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 14c82b5ac..3af6ee20d 100644 --- a/core/src/main/java/io/substrait/extended/expression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -1,12 +1,10 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.expression.Expression; import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.*; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; import java.io.IOException; import java.util.ArrayList; @@ -33,12 +31,13 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp // fill in simple extension information through a discovery in the current proto-extended // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder() - .from(extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()) - .build(); + ImmutableExtensionLookup.builder().from(extendedExpression).build(); NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = convertNamedStrutProtoToPojo(baseSchemaProto); + + io.substrait.type.NamedStruct namedStruct = + io.substrait.type.NamedStruct.convertNamedStructProtoToPojo( + baseSchemaProto, protoTypeConverter); ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter( @@ -50,14 +49,14 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp protoExpressionConverter.from(expressionReference.getExpression()); expressionReferences.add( ImmutableExpressionReference.builder() - .referredExpr(expressionPojo) + .expression(expressionPojo) .addAllOutputNames(expressionReference.getOutputNamesList()) .build()); } ImmutableExtendedExpression.Builder builder = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .advancedExtension( Optional.ofNullable( extendedExpression.hasAdvancedExtensions() @@ -66,19 +65,4 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp .baseSchema(namedStruct); return builder.build(); } - - private io.substrait.type.NamedStruct convertNamedStrutProtoToPojo(NamedStruct namedStruct) { - var struct = namedStruct.getStruct(); - return ImmutableNamedStruct.builder() - .names(namedStruct.getNamesList()) - .struct( - Type.Struct.builder() - .fields( - struct.getTypesList().stream() - .map(protoTypeConverter::from) - .collect(java.util.stream.Collectors.toList())) - .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) - .build()) - .build(); - } } diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index c88bafc1c..70034d9b1 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -1,5 +1,7 @@ package io.substrait.extension; +import io.substrait.proto.ExtendedExpression; +import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; import java.util.Collections; @@ -31,7 +33,16 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from( + public Builder from(Plan plan) { + return from(plan.getExtensionUrisList(), plan.getExtensionsList()); + } + + public Builder from(ExtendedExpression extendedExpression) { + return from( + extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()); + } + + private Builder from( List simpleExtensionURIs, List simpleExtensionDeclarations) { Map namespaceMap = new HashMap<>(); diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 7222eb7ed..be4f4ad9f 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -32,10 +32,7 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder() - .from(plan.getExtensionUrisList(), plan.getExtensionsList()) - .build(); + ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 8bf345aa9..11fdd38ad 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -1,5 +1,6 @@ package io.substrait.type; +import io.substrait.type.proto.ProtoTypeConverter; import io.substrait.type.proto.TypeProtoConverter; import java.util.List; import org.immutables.value.Value; @@ -21,4 +22,20 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve .addAllNames(names()) .build(); } + + static io.substrait.type.NamedStruct convertNamedStructProtoToPojo( + io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { + var struct = namedStruct.getStruct(); + return ImmutableNamedStruct.builder() + .names(namedStruct.getNamesList()) + .struct( + Type.Struct.builder() + .fields( + struct.getTypesList().stream() + .map(protoTypeConverter::from) + .collect(java.util.stream.Collectors.toList())) + .nullable(ProtoTypeConverter.isNullable(struct.getNullability())) + .build()) + .build(); + } } diff --git a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java similarity index 94% rename from core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java rename to core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java index 20079e24f..fbe3526eb 100644 --- a/core/src/test/java/io/substrait/extended/expression/ExtendedExpressionProtoConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -36,7 +36,7 @@ public void toProtoTest() { ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .referredExpr(scalarFunctionExpression.get()) + .expression(scalarFunctionExpression.get()) .addOutputNames("new-column") .build(); @@ -59,7 +59,7 @@ public void toProtoTest() { ImmutableExtendedExpression.Builder extendedExpression = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .baseSchema(namedStruct); // convert POJO extended expression into PROTOBUF extended expression diff --git a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java similarity index 90% rename from core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java rename to core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java index 9ab84f274..69a03a90a 100644 --- a/core/src/test/java/io/substrait/extended/expression/ProtoExtendedExpressionConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java @@ -1,4 +1,4 @@ -package io.substrait.extended.expression; +package io.substrait.extendedexpression; import io.substrait.TestBase; import io.substrait.expression.Expression; @@ -37,11 +37,11 @@ public void fromTest() throws IOException { ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .referredExpr(scalarFunctionExpression.get()) + .expression(scalarFunctionExpression.get()) .addOutputNames("new-column") .build(); - List + List expressionReferences = new ArrayList<>(); expressionReferences.add(expressionReference); @@ -62,7 +62,7 @@ public void fromTest() throws IOException { // pojo initial extended expression ImmutableExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() - .referredExpr(expressionReferences) + .referredExpressions(expressionReferences) .baseSchema(namedStruct) .build(); @@ -71,7 +71,7 @@ public void fromTest() throws IOException { new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); // pojo final extended expression - io.substrait.extended.expression.ExtendedExpression extendedExpressionPojoFinal = + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProto); // validate extended expression pojo initial equals to final roundtrip From 3d9b92729445e80d28dfff5b1aab6b5abad07447 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 29 Nov 2023 18:45:12 -0500 Subject: [PATCH 16/34] fix: code review core module testing side --- .../ExtendedExpression.java | 15 +++++++- .../ExtendedExpressionProtoConverter.java | 34 +++++++++++------ .../ProtoExtendedExpressionConverter.java | 23 +++++++---- .../ExtendedExpressionProtoConverterTest.java | 38 ++++++++----------- .../ProtoExtendedExpressionConverterTest.java | 30 +++++++-------- 5 files changed, 83 insertions(+), 57 deletions(-) diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java index de405f9a3..26c1e3803 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.proto.AdvancedExtension; +import io.substrait.proto.AggregateFunction; import io.substrait.type.NamedStruct; import java.util.List; import java.util.Optional; @@ -21,8 +22,20 @@ public abstract class ExtendedExpression { @Value.Immutable public abstract static class ExpressionReference { - public abstract Expression getExpression(); + public abstract ExpressionTypeReference getExpressionType(); public abstract List getOutputNames(); } + + public abstract static class ExpressionTypeReference {} + + @Value.Immutable + public abstract static class ExpressionType extends ExpressionTypeReference { + public abstract Expression getExpression(); + } + + @Value.Immutable + public abstract static class AggregateFunctionType extends ExpressionTypeReference { + public abstract AggregateFunction getMeasure(); + } } diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index a123e4b9f..ac8d08180 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -20,17 +20,29 @@ public ExtendedExpression toProto( for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference expressionReference : extendedExpression.getReferredExpressions()) { - - io.substrait.proto.Expression expressionProto = - expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) expressionReference.getExpression()); - - ExpressionReference.Builder expressionReferenceBuilder = - ExpressionReference.newBuilder() - .setExpression(expressionProto) - .addAllOutputNames(expressionReference.getOutputNames()); - - builder.addReferredExpr(expressionReferenceBuilder); + io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference expressionType = + expressionReference.getExpressionType(); + if (expressionType + instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType) { + io.substrait.proto.Expression expressionProto = + expressionProtoConverter.visit( + (Expression.ScalarFunctionInvocation) + ((io.substrait.extendedexpression.ExtendedExpression.ExpressionType) + expressionType) + .getExpression()); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setExpression(expressionProto) + .addAllOutputNames(expressionReference.getOutputNames()); + builder.addReferredExpr(expressionReferenceBuilder); + } else if (expressionType + instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) { + throw new UnsupportedOperationException( + "Aggregate function types are not supported in conversion to proto Extended Expressions for now"); + } else { + throw new UnsupportedOperationException( + "Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions for now"); + } } builder.setBaseSchema( extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector))); diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 3af6ee20d..14bbf209e 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -45,13 +45,22 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp List expressionReferences = new ArrayList<>(); for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { - Expression expressionPojo = - protoExpressionConverter.from(expressionReference.getExpression()); - expressionReferences.add( - ImmutableExpressionReference.builder() - .expression(expressionPojo) - .addAllOutputNames(expressionReference.getOutputNamesList()) - .build()); + if (expressionReference.getExprTypeCase().getNumber() == 1) { // Expression + Expression expressionPojo = + protoExpressionConverter.from(expressionReference.getExpression()); + expressionReferences.add( + ImmutableExpressionReference.builder() + .expressionType( + ImmutableExpressionType.builder().expression(expressionPojo).build()) + .addAllOutputNames(expressionReference.getOutputNamesList()) + .build()); + } else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction + throw new UnsupportedOperationException( + "Aggregate function types are not supported in conversion from proto Extended Expressions for now"); + } else { + throw new UnsupportedOperationException( + "Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now"); + } } ImmutableExtendedExpression.Builder builder = diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java index fbe3526eb..fa4cd2ac3 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java @@ -3,40 +3,35 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.TestBase; -import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.ImmutableFieldReference; +import io.substrait.expression.*; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.Test; public class ExtendedExpressionProtoConverterTest extends TestBase { + static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; + @Test public void toProtoTest() { // create predefined POJO extended expression - Optional scalarFunctionExpression = - defaultExtensionCollection.scalarFunctions().stream() - .filter(s -> s.name().equalsIgnoreCase("add")) - .findFirst() - .map( - declaration -> - ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183))); + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expression(scalarFunctionExpression.get()) + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) .addOutputNames("new-column") .build(); @@ -66,8 +61,7 @@ public void toProtoTest() { io.substrait.proto.ExtendedExpression proto = new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); - assertEquals( - "/functions_arithmetic_decimal.yaml", proto.getExtensionUrisList().get(0).getUri()); + assertEquals(NAMESPACE, proto.getExtensionUrisList().get(0).getUri()); assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); } } diff --git a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java index 69a03a90a..b0f34b783 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java @@ -12,32 +12,30 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class ProtoExtendedExpressionConverterTest extends TestBase { + static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; + @Test public void fromTest() throws IOException { // create predefined POJO extended expression - Optional scalarFunctionExpression = - defaultExtensionCollection.scalarFunctions().stream() - .filter(s -> s.name().equalsIgnoreCase("add")) - .findFirst() - .map( - declaration -> - ExpressionCreator.scalarFunction( - declaration, - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183))); + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expression(scalarFunctionExpression.get()) + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) .addOutputNames("new-column") .build(); From e7904926ad19d31dd452c8fa7df3da5aa3c07088 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 5 Dec 2023 21:15:30 -0500 Subject: [PATCH 17/34] feat: support aggregation function in extended expression from/to pojo/proto --- .../ExtendedExpressionProtoConverter.java | 14 +- .../ProtoExtendedExpressionConverter.java | 18 +- .../AggregateFunctionProtoController.java | 43 ++++ .../java/io/substrait/type/NamedStruct.java | 2 +- .../ExtendedExpressionProtoConverterTest.java | 67 ----- .../ExtendedExpressionRoundTripTest.java | 229 ++++++++++++++++++ .../ProtoExtendedExpressionConverterTest.java | 78 ------ 7 files changed, 297 insertions(+), 154 deletions(-) create mode 100644 core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java delete mode 100644 core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java create mode 100644 core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java delete mode 100644 core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index ac8d08180..6d1f2efcb 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -3,6 +3,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.AggregateFunction; import io.substrait.proto.ExpressionReference; import io.substrait.proto.ExtendedExpression; import io.substrait.type.proto.TypeProtoConverter; @@ -37,11 +38,18 @@ public ExtendedExpression toProto( builder.addReferredExpr(expressionReferenceBuilder); } else if (expressionType instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) { - throw new UnsupportedOperationException( - "Aggregate function types are not supported in conversion to proto Extended Expressions for now"); + AggregateFunction measure = + ((io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) + expressionType) + .getMeasure(); + ExpressionReference.Builder expressionReferenceBuilder = + ExpressionReference.newBuilder() + .setMeasure(measure.toBuilder()) + .addAllOutputNames(expressionReference.getOutputNames()); + builder.addReferredExpr(expressionReferenceBuilder); } else { throw new UnsupportedOperationException( - "Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions for now"); + "Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions"); } } builder.setBaseSchema( diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 14bbf209e..8daf41cf0 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -2,7 +2,12 @@ import io.substrait.expression.Expression; import io.substrait.expression.proto.ProtoExpressionConverter; -import io.substrait.extension.*; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.ImmutableExtensionLookup; +import io.substrait.extension.ImmutableSimpleExtension; +import io.substrait.extension.SimpleExtension; +import io.substrait.proto.AggregateFunction; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; import io.substrait.type.proto.ProtoTypeConverter; @@ -36,8 +41,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); io.substrait.type.NamedStruct namedStruct = - io.substrait.type.NamedStruct.convertNamedStructProtoToPojo( - baseSchemaProto, protoTypeConverter); + io.substrait.type.NamedStruct.fromProto(baseSchemaProto, protoTypeConverter); ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter( @@ -55,8 +59,12 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp .addAllOutputNames(expressionReference.getOutputNamesList()) .build()); } else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction - throw new UnsupportedOperationException( - "Aggregate function types are not supported in conversion from proto Extended Expressions for now"); + AggregateFunction measure = expressionReference.getMeasure(); + ImmutableExpressionReference.Builder builder = + ImmutableExpressionReference.builder() + .expressionType(ImmutableAggregateFunctionType.builder().measure(measure).build()) + .addAllOutputNames(expressionReference.getOutputNamesList()); + expressionReferences.add(builder.build()); } else { throw new UnsupportedOperationException( "Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now"); diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java new file mode 100644 index 000000000..7904bfa72 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java @@ -0,0 +1,43 @@ +package io.substrait.relation; + +import io.substrait.expression.FunctionArg; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.ExtensionCollector; +import io.substrait.proto.AggregateFunction; +import io.substrait.type.proto.TypeProtoConverter; +import java.util.stream.IntStream; + +/** + * Converts from {@link io.substrait.relation.Aggregate.Measure} to {@link + * io.substrait.proto.AggregateFunction} + */ +public class AggregateFunctionProtoController { + + private final ExpressionProtoConverter exprProtoConverter; + private final TypeProtoConverter typeProtoConverter; + private final ExtensionCollector functionCollector; + + public AggregateFunctionProtoController(ExtensionCollector functionCollector) { + this.functionCollector = functionCollector; + this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null); + this.typeProtoConverter = new TypeProtoConverter(functionCollector); + } + + public AggregateFunction toProto(Aggregate.Measure measure) { + var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); + var args = measure.getFunction().arguments(); + var aggFuncDef = measure.getFunction().declaration(); + + return AggregateFunction.newBuilder() + .setPhase(measure.getFunction().aggregationPhase().toProto()) + .setInvocation(measure.getFunction().invocation().toProto()) + .setOutputType(measure.getFunction().getType().accept(typeProtoConverter)) + .addAllArguments( + IntStream.range(0, args.size()) + .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) + .collect(java.util.stream.Collectors.toList())) + .setFunctionReference( + functionCollector.getFunctionReference(measure.getFunction().declaration())) + .build(); + } +} diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 11fdd38ad..e38a95fb5 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -23,7 +23,7 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve .build(); } - static io.substrait.type.NamedStruct convertNamedStructProtoToPojo( + static io.substrait.type.NamedStruct fromProto( io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { var struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java deleted file mode 100644 index fa4cd2ac3..000000000 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverterTest.java +++ /dev/null @@ -1,67 +0,0 @@ -package io.substrait.extendedexpression; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import io.substrait.TestBase; -import io.substrait.expression.*; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.util.ArrayList; -import java.util.List; -import org.junit.jupiter.api.Test; - -public class ExtendedExpressionProtoConverterTest extends TestBase { - static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; - - @Test - public void toProtoTest() { - // create predefined POJO extended expression - Expression.ScalarFunctionInvocation scalarFunctionInvocation = - b.scalarFn( - NAMESPACE, - "add:dec_dec", - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183)); - - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expressionType( - ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) - .addOutputNames("new-column") - .build(); - - List expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); - - ImmutableNamedStruct namedStruct = - ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct( - Type.Struct.builder() - .nullable(false) - .addFields( - TypeCreator.NULLABLE.decimal(10, 2), - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING) - .build()) - .build(); - - ImmutableExtendedExpression.Builder extendedExpression = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct); - - // convert POJO extended expression into PROTOBUF extended expression - io.substrait.proto.ExtendedExpression proto = - new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); - - assertEquals(NAMESPACE, proto.getExtensionUrisList().get(0).getUri()); - assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName()); - } -} diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java new file mode 100644 index 000000000..1da2d6195 --- /dev/null +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -0,0 +1,229 @@ +package io.substrait.extendedexpression; + +import io.substrait.TestBase; +import io.substrait.expression.*; +import io.substrait.relation.Aggregate; +import io.substrait.relation.AggregateFunctionProtoController; +import io.substrait.relation.ImmutableMeasure; +import io.substrait.type.ImmutableNamedStruct; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ExtendedExpressionRoundTripTest extends TestBase { + static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; + + @Test + public void expressionRoundTrip() throws IOException { + // create predefined POJO extended expression + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial extended expression + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto extended expression + io.substrait.proto.ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final extended expression + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate extended expression pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } + + @Test + public void aggregationRoundTrip() throws IOException { + // create predefined POJO aggregation function + ImmutableMeasure measure = + Aggregate.Measure.builder() + .function( + AggregateFunctionInvocation.builder() + .arguments(Collections.emptyList()) + .declaration(defaultExtensionCollection.aggregateFunctions().get(0)) + .outputType(TypeCreator.of(false).I64) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()) + .build(); + + ImmutableAggregateFunctionType aggregateFunctionType = + ImmutableAggregateFunctionType.builder() + .measure(new AggregateFunctionProtoController(functionCollector).toProto(measure)) + .build(); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .expressionType(aggregateFunctionType) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial aggregation function + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto aggregation function + io.substrait.proto.ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final aggregation function + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate aggregation function pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } + + @Test + public void expressionAndAggregationRoundTrip() throws IOException { + // POJO 01 + // create predefined POJO extended expression + Expression.ScalarFunctionInvocation scalarFunctionInvocation = + b.scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); + + ImmutableExpressionReference expressionReferenceExpression = + ImmutableExpressionReference.builder() + .expressionType( + ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + + // POJO 02 + // create predefined POJO aggregation function + ImmutableMeasure measure = + Aggregate.Measure.builder() + .function( + AggregateFunctionInvocation.builder() + .arguments(Collections.emptyList()) + .declaration(defaultExtensionCollection.aggregateFunctions().get(0)) + .outputType(TypeCreator.of(false).I64) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()) + .build(); + + ImmutableAggregateFunctionType aggregateFunctionType = + ImmutableAggregateFunctionType.builder() + .measure(new AggregateFunctionProtoController(functionCollector).toProto(measure)) + .build(); + + ImmutableExpressionReference expressionReferenceAggregation = + ImmutableExpressionReference.builder() + .expressionType(aggregateFunctionType) + .addOutputNames("new-column") + .build(); + + // adding expression + expressionReferences.add(expressionReferenceExpression); + // adding aggregation function + expressionReferences.add(expressionReferenceAggregation); + + ImmutableNamedStruct namedStruct = + ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + + // pojo initial extended expression + aggregation + ImmutableExtendedExpression extendedExpressionPojoInitial = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct) + .build(); + + // proto extended expression + aggregation + io.substrait.proto.ExtendedExpression extendedExpressionProto = + new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); + + // pojo final extended expression + aggregation + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + + // validate extended expression + aggregation pojo initial equals to final roundtrip + Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + } +} diff --git a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java b/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java deleted file mode 100644 index b0f34b783..000000000 --- a/core/src/test/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverterTest.java +++ /dev/null @@ -1,78 +0,0 @@ -package io.substrait.extendedexpression; - -import io.substrait.TestBase; -import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.ImmutableFieldReference; -import io.substrait.proto.ExtendedExpression; -import io.substrait.type.ImmutableNamedStruct; -import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class ProtoExtendedExpressionConverterTest extends TestBase { - static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; - - @Test - public void fromTest() throws IOException { - // create predefined POJO extended expression - Expression.ScalarFunctionInvocation scalarFunctionInvocation = - b.scalarFn( - NAMESPACE, - "add:dec_dec", - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183)); - - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expressionType( - ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) - .addOutputNames("new-column") - .build(); - - List - expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); - - ImmutableNamedStruct namedStruct = - ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct( - Type.Struct.builder() - .nullable(false) - .addFields( - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING) - .build()) - .build(); - - // pojo initial extended expression - ImmutableExtendedExpression extendedExpressionPojoInitial = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct) - .build(); - - // proto extended expression - ExtendedExpression extendedExpressionProto = - new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); - - // pojo final extended expression - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = - new ProtoExtendedExpressionConverter().from(extendedExpressionProto); - - // validate extended expression pojo initial equals to final roundtrip - Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); - } -} From c26fecd7cf1f6b61e7597e02a440d0af5ae6334b Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 5 Dec 2023 22:08:35 -0500 Subject: [PATCH 18/34] fix: merge from/to proto/pojo + solve comments on the PR --- .../substrait/isthmus/SqlConverterBase.java | 6 +- .../isthmus/SqlExpressionToSubstrait.java | 120 ++++++++++++++++++ .../io/substrait/isthmus/SqlToSubstrait.java | 83 +----------- .../io/substrait/isthmus/TypeConverter.java | 5 +- .../isthmus/ExtendedExpressionTestBase.java | 9 +- .../ExtendedExpressionIntegrationTest.java | 4 +- 6 files changed, 135 insertions(+), 92 deletions(-) create mode 100644 isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 67ae9a8ea..6501316b1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -5,7 +5,11 @@ import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.type.NamedStruct; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.function.Function; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java new file mode 100644 index 000000000..8e31fe30a --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -0,0 +1,120 @@ +package io.substrait.isthmus; + +import com.google.common.annotations.VisibleForTesting; +import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; +import io.substrait.extendedexpression.ImmutableExpressionReference; +import io.substrait.extendedexpression.ImmutableExpressionType; +import io.substrait.extendedexpression.ImmutableExtendedExpression; +import io.substrait.isthmus.expression.RexExpressionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.proto.ExtendedExpression; +import io.substrait.type.NamedStruct; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; + +public class SqlExpressionToSubstrait extends SqlConverterBase { + + public SqlExpressionToSubstrait() { + this(null); + } + + protected SqlExpressionToSubstrait(FeatureBoard features) { + super(features); + } + + private final ScalarFunctionConverter functionConverter = + new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), factory); + + private final RexExpressionConverter rexExpressionConverter = + new RexExpressionConverter(functionConverter); + + /** + * Process to execute an SQL Expression to convert into an Extended expression protobuf message + * + * @param sqlExpression expression defined by the user + * @param tables of names of table needed to consider to load into memory for catalog, schema, + * validate and parse sql + * @return extended expression protobuf message + * @throws SqlParseException + */ + public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) + throws SqlParseException { + var result = registerCreateTablesForExtendedExpression(tables); + return executeInnerSQLExpression( + sqlExpression, + result.validator(), + result.catalogReader(), + result.nameToTypeMap(), + result.nameToNodeMap()); + } + + private ExtendedExpression executeInnerSQLExpression( + String sqlExpression, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) + throws SqlParseException { + RexNode rexNode = + sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); + io.substrait.expression.Expression.ScalarFunctionInvocation func = + (io.substrait.expression.Expression.ScalarFunctionInvocation) + rexNode.accept(rexExpressionConverter); + NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); + + ImmutableExpressionReference expressionReference = + ImmutableExpressionReference.builder() + .expressionType(ImmutableExpressionType.builder().expression(func).build()) + .addOutputNames("new-column") + .build(); + + List + expressionReferences = new ArrayList<>(); + expressionReferences.add(expressionReference); + + ImmutableExtendedExpression.Builder extendedExpression = + ImmutableExtendedExpression.builder() + .referredExpressions(expressionReferences) + .baseSchema(namedStruct); + + return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); + } + + private RexNode sqlToRexNode( + String sql, + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) + throws SqlParseException { + SqlParser parser = SqlParser.create(sql, parserConfig); + SqlNode sqlNode = parser.parseExpression(); + SqlNode validSQLNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); + SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); + return converter.convertExpression(validSQLNode, nameToNodeMap); + } + + @VisibleForTesting + SqlToRelConverter createSqlToRelConverter( + SqlValidator validator, CalciteCatalogReader catalogReader) { + SqlToRelConverter converter = + new SqlToRelConverter( + null, + validator, + catalogReader, + relOptCluster, + StandardConvertletTable.INSTANCE, + converterConfig); + return converter; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 6ee3f11af..7a850499a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,27 +1,18 @@ package io.substrait.isthmus; import com.google.common.annotations.VisibleForTesting; -import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; -import io.substrait.extendedexpression.ImmutableExpressionReference; -import io.substrait.extendedexpression.ImmutableExpressionType; -import io.substrait.extendedexpression.ImmutableExtendedExpression; import io.substrait.extension.ExtensionCollector; -import io.substrait.isthmus.expression.RexExpressionConverter; -import io.substrait.isthmus.expression.ScalarFunctionConverter; -import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.PlanRel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.NamedStruct; -import java.util.*; +import java.util.List; import java.util.function.Function; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelRoot; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; @@ -33,12 +24,6 @@ /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { - private final ScalarFunctionConverter functionConverter = - new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), factory); - - private final RexExpressionConverter rexExpressionConverter = - new RexExpressionConverter(functionConverter); - public SqlToSubstrait() { this(null); } @@ -63,26 +48,6 @@ public Plan execute(String sql, String name, Schema schema) throws SqlParseExcep return executeInner(sql, factory, pair.left, pair.right); } - /** - * Process to execute an SQL Expression to convert into an Extended expression protobuf message - * - * @param sqlExpression expression defined by the user - * @param tables of names of table needed to consider to load into memory for catalog, schema, - * validate and parse sql - * @return extended expression protobuf message - * @throws SqlParseException - */ - public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) - throws SqlParseException { - var result = registerCreateTablesForExtendedExpression(tables); - return executeInnerSQLExpression( - sqlExpression, - result.validator(), - result.catalogReader(), - result.nameToTypeMap(), - result.nameToNodeMap()); - } - // Package protected for testing List sqlToRelNode(String sql, List tables) throws SqlParseException { var pair = registerCreateTables(tables); @@ -126,38 +91,6 @@ private Plan executeInner( return plan.build(); } - private ExtendedExpression executeInnerSQLExpression( - String sqlExpression, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - RexNode rexNode = - sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - io.substrait.expression.Expression.ScalarFunctionInvocation func = - (io.substrait.expression.Expression.ScalarFunctionInvocation) - rexNode.accept(rexExpressionConverter); - NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); - - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expressionType(ImmutableExpressionType.builder().expression(func).build()) - .addOutputNames("new-column") - .build(); - - List - expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); - - ImmutableExtendedExpression.Builder extendedExpression = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct); - - return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); - } - private List sqlToRelNode( String sql, SqlValidator validator, CalciteCatalogReader catalogReader) throws SqlParseException { @@ -174,20 +107,6 @@ private List sqlToRelNode( return roots; } - private RexNode sqlToRexNode( - String sql, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - SqlParser parser = SqlParser.create(sql, parserConfig); - SqlNode sqlNode = parser.parseExpression(); - SqlNode validSQLNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); - SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - return converter.convertExpression(validSQLNode, nameToNodeMap); - } - @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 73c846d45..6b69f3fff 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -3,7 +3,6 @@ import static io.substrait.isthmus.SubstraitTypeSystem.DAY_SECOND_INTERVAL; import static io.substrait.isthmus.SubstraitTypeSystem.YEAR_MONTH_INTERVAL; -import com.google.common.collect.Lists; import io.substrait.function.NullableType; import io.substrait.function.TypeExpression; import io.substrait.type.NamedStruct; @@ -59,8 +58,8 @@ public NamedStruct toNamedStruct(RelDataType type) { } public NamedStruct toNamedStruct(Map nameToTypeMap) { - var names = Lists.newArrayList(); - var types = Lists.newArrayList(); + var names = new ArrayList(); + var types = new ArrayList(); nameToTypeMap.forEach( (k, v) -> { names.add(k); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 052e91c9c..4a9dde8b7 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -25,16 +25,17 @@ public static List tpchSchemaCreateStatements() throws IOException { protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundrip(query, new SqlToSubstrait()); + return assertProtoExtendedExpressionRoundrip(query, new SqlExpressionToSubstrait()); } - protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query, SqlToSubstrait s) - throws IOException, SqlParseException { + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements()); } protected ExtendedExpression assertProtoExtendedExpressionRoundrip( - String query, SqlToSubstrait s, List creates) throws SqlParseException, IOException { + String query, SqlExpressionToSubstrait s, List creates) + throws SqlParseException, IOException { // proto initial extended expression ExtendedExpression extendedExpressionProtoInitial = s.executeSQLExpression(query, creates); diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 297517ec8..0c25c603b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -4,7 +4,7 @@ import com.ibm.icu.impl.ClassLoaderUtil; import io.substrait.isthmus.ExtendedExpressionTestBase; -import io.substrait.isthmus.SqlToSubstrait; +import io.substrait.isthmus.SqlExpressionToSubstrait; import io.substrait.proto.ExtendedExpression; import java.io.IOException; import java.net.URL; @@ -98,7 +98,7 @@ public void projectDataset() throws SqlParseException, IOException { private static ByteBuffer getExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = - new SqlToSubstrait() + new SqlExpressionToSubstrait() .executeSQLExpression( sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); byte[] extendedExpressions = From 0fa69c81a8b669e2976c9b120ef377a2ededa22b Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 6 Dec 2023 18:58:36 -0500 Subject: [PATCH 19/34] fix: code review suggestion Co-authored-by: Dane Pitkin <48041712+danepitkin@users.noreply.github.com> --- .../extendedexpression/ProtoExtendedExpressionConverter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 8daf41cf0..69020cbee 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -67,7 +67,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp expressionReferences.add(builder.build()); } else { throw new UnsupportedOperationException( - "Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now"); + "Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions"); } } From 92d2cc57e46e4b4972e5251477abf6d9baf4300b Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Wed, 6 Dec 2023 17:42:31 -0800 Subject: [PATCH 20/34] refactor: bind instanceof checked variables --- .../ExtendedExpressionProtoConverter.java | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index 6d1f2efcb..e59809a26 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -21,27 +21,21 @@ public ExtendedExpression toProto( for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference expressionReference : extendedExpression.getReferredExpressions()) { - io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference expressionType = - expressionReference.getExpressionType(); - if (expressionType - instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType) { + io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference + expressionTypeReference = expressionReference.getExpressionType(); + if (expressionTypeReference + instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType et) { io.substrait.proto.Expression expressionProto = expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) - ((io.substrait.extendedexpression.ExtendedExpression.ExpressionType) - expressionType) - .getExpression()); + (Expression.ScalarFunctionInvocation) et.getExpression()); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) .addAllOutputNames(expressionReference.getOutputNames()); builder.addReferredExpr(expressionReferenceBuilder); - } else if (expressionType - instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) { - AggregateFunction measure = - ((io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) - expressionType) - .getMeasure(); + } else if (expressionTypeReference + instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType aft) { + AggregateFunction measure = aft.getMeasure(); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setMeasure(measure.toBuilder()) From 379b83fe7406fc10b0524f1a7e281e19b913a017 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 8 Dec 2023 10:58:37 -0500 Subject: [PATCH 21/34] fix: adding Aggregate.Measure POJO instead of Proto --- .../ExtendedExpression.java | 4 +- .../ExtendedExpressionProtoConverter.java | 12 ++-- .../ProtoExtendedExpressionConverter.java | 9 ++- ...a => AggregateFunctionProtoConverter.java} | 4 +- .../ProtoAggregateFunctionConverter.java | 61 +++++++++++++++++++ .../ExtendedExpressionRoundTripTest.java | 17 ++---- 6 files changed, 86 insertions(+), 21 deletions(-) rename core/src/main/java/io/substrait/relation/{AggregateFunctionProtoController.java => AggregateFunctionProtoConverter.java} (92%) create mode 100644 core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java index 26c1e3803..05819bd3d 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java @@ -2,7 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.proto.AdvancedExtension; -import io.substrait.proto.AggregateFunction; +import io.substrait.relation.Aggregate; import io.substrait.type.NamedStruct; import java.util.List; import java.util.Optional; @@ -36,6 +36,6 @@ public abstract static class ExpressionType extends ExpressionTypeReference { @Value.Immutable public abstract static class AggregateFunctionType extends ExpressionTypeReference { - public abstract AggregateFunction getMeasure(); + public abstract Aggregate.Measure getMeasure(); } } diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index e59809a26..434d11c84 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -3,12 +3,15 @@ import io.substrait.expression.Expression; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; -import io.substrait.proto.AggregateFunction; import io.substrait.proto.ExpressionReference; import io.substrait.proto.ExtendedExpression; +import io.substrait.relation.AggregateFunctionProtoConverter; import io.substrait.type.proto.TypeProtoConverter; -/** Converts from {@link ExtendedExpression} to {@link ExtendedExpression} */ +/** + * Converts from {@link io.substrait.extendedexpression.ExtendedExpression} to {@link + * io.substrait.proto.ExtendedExpression} + */ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( io.substrait.extendedexpression.ExtendedExpression extendedExpression) { @@ -35,10 +38,11 @@ public ExtendedExpression toProto( builder.addReferredExpr(expressionReferenceBuilder); } else if (expressionTypeReference instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType aft) { - AggregateFunction measure = aft.getMeasure(); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() - .setMeasure(measure.toBuilder()) + .setMeasure( + new AggregateFunctionProtoConverter(functionCollector) + .toProto(aft.getMeasure())) .addAllOutputNames(expressionReference.getOutputNames()); builder.addReferredExpr(expressionReferenceBuilder); } else { diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 69020cbee..a23c1e5d0 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -7,9 +7,9 @@ import io.substrait.extension.ImmutableExtensionLookup; import io.substrait.extension.ImmutableSimpleExtension; import io.substrait.extension.SimpleExtension; -import io.substrait.proto.AggregateFunction; import io.substrait.proto.ExpressionReference; import io.substrait.proto.NamedStruct; +import io.substrait.relation.ProtoAggregateFunctionConverter; import io.substrait.type.proto.ProtoTypeConverter; import java.io.IOException; import java.util.ArrayList; @@ -48,6 +48,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp functionLookup, this.extensionCollection, namedStruct.struct(), null); List expressionReferences = new ArrayList<>(); + for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { if (expressionReference.getExprTypeCase().getNumber() == 1) { // Expression Expression expressionPojo = @@ -59,7 +60,10 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp .addAllOutputNames(expressionReference.getOutputNamesList()) .build()); } else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction - AggregateFunction measure = expressionReference.getMeasure(); + io.substrait.relation.Aggregate.Measure measure = + new ProtoAggregateFunctionConverter( + functionLookup, extensionCollection, protoExpressionConverter) + .from(expressionReference.getMeasure()); ImmutableExpressionReference.Builder builder = ImmutableExpressionReference.builder() .expressionType(ImmutableAggregateFunctionType.builder().measure(measure).build()) @@ -80,6 +84,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp ? extendedExpression.getAdvancedExtensions() : null)) .baseSchema(namedStruct); + return builder.build(); } } diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java similarity index 92% rename from core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java rename to core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java index 7904bfa72..e92752b3f 100644 --- a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoController.java +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java @@ -11,13 +11,13 @@ * Converts from {@link io.substrait.relation.Aggregate.Measure} to {@link * io.substrait.proto.AggregateFunction} */ -public class AggregateFunctionProtoController { +public class AggregateFunctionProtoConverter { private final ExpressionProtoConverter exprProtoConverter; private final TypeProtoConverter typeProtoConverter; private final ExtensionCollector functionCollector; - public AggregateFunctionProtoController(ExtensionCollector functionCollector) { + public AggregateFunctionProtoConverter(ExtensionCollector functionCollector) { this.functionCollector = functionCollector; this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null); this.typeProtoConverter = new TypeProtoConverter(functionCollector); diff --git a/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java new file mode 100644 index 000000000..34ba0a1e3 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java @@ -0,0 +1,61 @@ +package io.substrait.relation; + +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.Expression; +import io.substrait.expression.FunctionArg; +import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.SimpleExtension; +import io.substrait.type.proto.ProtoTypeConverter; +import java.io.IOException; +import java.util.List; +import java.util.stream.IntStream; + +/** + * Converts from {@link io.substrait.proto.AggregateFunction} to {@link + * io.substrait.relation.Aggregate.Measure} + */ +public class ProtoAggregateFunctionConverter { + private final ExtensionLookup lookup; + private final SimpleExtension.ExtensionCollection extensions; + private final ProtoTypeConverter protoTypeConverter; + private final ProtoExpressionConverter protoExpressionConverter; + + public ProtoAggregateFunctionConverter( + ExtensionLookup lookup, ProtoExpressionConverter protoExpressionConverter) + throws IOException { + this(lookup, SimpleExtension.loadDefaults(), protoExpressionConverter); + } + + public ProtoAggregateFunctionConverter( + ExtensionLookup lookup, + SimpleExtension.ExtensionCollection extensions, + ProtoExpressionConverter protoExpressionConverter) { + this.lookup = lookup; + this.extensions = extensions; + this.protoTypeConverter = new ProtoTypeConverter(lookup, extensions); + this.protoExpressionConverter = protoExpressionConverter; + } + + public io.substrait.relation.Aggregate.Measure from( + io.substrait.proto.AggregateFunction measure) { + FunctionArg.ProtoFrom protoFrom = + new FunctionArg.ProtoFrom(protoExpressionConverter, protoTypeConverter); + SimpleExtension.AggregateFunctionVariant aggregateFunction = + lookup.getAggregateFunction(measure.getFunctionReference(), extensions); + List functionArgs = + IntStream.range(0, measure.getArgumentsCount()) + .mapToObj(i -> protoFrom.convert(aggregateFunction, i, measure.getArguments(i))) + .collect(java.util.stream.Collectors.toList()); + return Aggregate.Measure.builder() + .function( + AggregateFunctionInvocation.builder() + .arguments(functionArgs) + .declaration(aggregateFunction) + .outputType(protoTypeConverter.from(measure.getOutputType())) + .aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase())) + .invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation())) + .build()) + .build(); + } +} diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java index 1da2d6195..be5f523be 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -3,7 +3,6 @@ import io.substrait.TestBase; import io.substrait.expression.*; import io.substrait.relation.Aggregate; -import io.substrait.relation.AggregateFunctionProtoController; import io.substrait.relation.ImmutableMeasure; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.Type; @@ -58,7 +57,7 @@ public void expressionRoundTrip() throws IOException { .build(); // pojo initial extended expression - ImmutableExtendedExpression extendedExpressionPojoInitial = + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() .referredExpressions(expressionReferences) .baseSchema(namedStruct) @@ -79,7 +78,7 @@ public void expressionRoundTrip() throws IOException { @Test public void aggregationRoundTrip() throws IOException { // create predefined POJO aggregation function - ImmutableMeasure measure = + io.substrait.relation.Aggregate.Measure measure = Aggregate.Measure.builder() .function( AggregateFunctionInvocation.builder() @@ -92,9 +91,7 @@ public void aggregationRoundTrip() throws IOException { .build(); ImmutableAggregateFunctionType aggregateFunctionType = - ImmutableAggregateFunctionType.builder() - .measure(new AggregateFunctionProtoController(functionCollector).toProto(measure)) - .build(); + ImmutableAggregateFunctionType.builder().measure(measure).build(); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() @@ -121,7 +118,7 @@ public void aggregationRoundTrip() throws IOException { .build(); // pojo initial aggregation function - ImmutableExtendedExpression extendedExpressionPojoInitial = + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() .referredExpressions(expressionReferences) .baseSchema(namedStruct) @@ -179,9 +176,7 @@ public void expressionAndAggregationRoundTrip() throws IOException { .build(); ImmutableAggregateFunctionType aggregateFunctionType = - ImmutableAggregateFunctionType.builder() - .measure(new AggregateFunctionProtoController(functionCollector).toProto(measure)) - .build(); + ImmutableAggregateFunctionType.builder().measure(measure).build(); ImmutableExpressionReference expressionReferenceAggregation = ImmutableExpressionReference.builder() @@ -209,7 +204,7 @@ public void expressionAndAggregationRoundTrip() throws IOException { .build(); // pojo initial extended expression + aggregation - ImmutableExtendedExpression extendedExpressionPojoInitial = + io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() .referredExpressions(expressionReferences) .baseSchema(namedStruct) From e415785b15c3327029620e18a281e78b58a1a2a3 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 8 Dec 2023 13:48:37 -0500 Subject: [PATCH 22/34] fix: simplify extended expression immutable class --- .../ExtendedExpression.java | 15 +++---- .../ExtendedExpressionProtoConverter.java | 14 +++---- .../ProtoExtendedExpressionConverter.java | 21 +++++----- .../ExtendedExpressionRoundTripTest.java | 39 +++++++------------ 4 files changed, 36 insertions(+), 53 deletions(-) diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java index 05819bd3d..58ab4e001 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpression.java @@ -10,7 +10,7 @@ @Value.Immutable public abstract class ExtendedExpression { - public abstract List getReferredExpressions(); + public abstract List getReferredExpressions(); public abstract NamedStruct getBaseSchema(); @@ -20,22 +20,17 @@ public abstract class ExtendedExpression { public abstract Optional getAdvancedExtension(); - @Value.Immutable - public abstract static class ExpressionReference { - public abstract ExpressionTypeReference getExpressionType(); - - public abstract List getOutputNames(); + public interface ExpressionReferenceBase { + List getOutputNames(); } - public abstract static class ExpressionTypeReference {} - @Value.Immutable - public abstract static class ExpressionType extends ExpressionTypeReference { + public abstract static class ExpressionReference implements ExpressionReferenceBase { public abstract Expression getExpression(); } @Value.Immutable - public abstract static class AggregateFunctionType extends ExpressionTypeReference { + public abstract static class AggregateFunctionReference implements ExpressionReferenceBase { public abstract Aggregate.Measure getMeasure(); } } diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index 434d11c84..083e9328d 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -22,12 +22,10 @@ public ExtendedExpression toProto( final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, null); - for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference + for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReferenceBase expressionReference : extendedExpression.getReferredExpressions()) { - io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference - expressionTypeReference = expressionReference.getExpressionType(); - if (expressionTypeReference - instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType et) { + if (expressionReference + instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et) { io.substrait.proto.Expression expressionProto = expressionProtoConverter.visit( (Expression.ScalarFunctionInvocation) et.getExpression()); @@ -36,8 +34,10 @@ public ExtendedExpression toProto( .setExpression(expressionProto) .addAllOutputNames(expressionReference.getOutputNames()); builder.addReferredExpr(expressionReferenceBuilder); - } else if (expressionTypeReference - instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType aft) { + } else if (expressionReference + instanceof + io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference + aft) { ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setMeasure( diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index a23c1e5d0..67b823c86 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -47,28 +47,29 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp new ProtoExpressionConverter( functionLookup, this.extensionCollection, namedStruct.struct(), null); - List expressionReferences = new ArrayList<>(); + List expressionReferences = new ArrayList<>(); for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { if (expressionReference.getExprTypeCase().getNumber() == 1) { // Expression Expression expressionPojo = protoExpressionConverter.from(expressionReference.getExpression()); - expressionReferences.add( + ImmutableExpressionReference build = ImmutableExpressionReference.builder() - .expressionType( - ImmutableExpressionType.builder().expression(expressionPojo).build()) + .expression(expressionPojo) .addAllOutputNames(expressionReference.getOutputNamesList()) - .build()); + .build(); + expressionReferences.add(build); } else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction io.substrait.relation.Aggregate.Measure measure = new ProtoAggregateFunctionConverter( functionLookup, extensionCollection, protoExpressionConverter) .from(expressionReference.getMeasure()); - ImmutableExpressionReference.Builder builder = - ImmutableExpressionReference.builder() - .expressionType(ImmutableAggregateFunctionType.builder().measure(measure).build()) - .addAllOutputNames(expressionReference.getOutputNamesList()); - expressionReferences.add(builder.build()); + ImmutableAggregateFunctionReference build = + ImmutableAggregateFunctionReference.builder() + .measure(measure) + .addAllOutputNames(expressionReference.getOutputNamesList()) + .build(); + expressionReferences.add(build); } else { throw new UnsupportedOperationException( "Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions"); diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java index be5f523be..f083ad404 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -33,8 +33,7 @@ public void expressionRoundTrip() throws IOException { ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expressionType( - ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) + .expression(scalarFunctionInvocation) .addOutputNames("new-column") .build(); @@ -90,18 +89,11 @@ public void aggregationRoundTrip() throws IOException { .build()) .build(); - ImmutableAggregateFunctionType aggregateFunctionType = - ImmutableAggregateFunctionType.builder().measure(measure).build(); + ImmutableAggregateFunctionReference aggregateFunctionReference = + ImmutableAggregateFunctionReference.builder().measure(measure).build(); - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expressionType(aggregateFunctionType) - .addOutputNames("new-column") - .build(); - - List - expressionReferences = new ArrayList<>(); - expressionReferences.add(expressionReference); + List expressionReferences = new ArrayList<>(); + expressionReferences.add(aggregateFunctionReference); ImmutableNamedStruct namedStruct = ImmutableNamedStruct.builder() @@ -151,15 +143,13 @@ public void expressionAndAggregationRoundTrip() throws IOException { .build(), ExpressionCreator.i32(false, 183)); - ImmutableExpressionReference expressionReferenceExpression = + ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expressionType( - ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build()) + .expression(scalarFunctionInvocation) .addOutputNames("new-column") .build(); - List - expressionReferences = new ArrayList<>(); + List expressionReferences = new ArrayList<>(); // POJO 02 // create predefined POJO aggregation function @@ -175,19 +165,16 @@ public void expressionAndAggregationRoundTrip() throws IOException { .build()) .build(); - ImmutableAggregateFunctionType aggregateFunctionType = - ImmutableAggregateFunctionType.builder().measure(measure).build(); - - ImmutableExpressionReference expressionReferenceAggregation = - ImmutableExpressionReference.builder() - .expressionType(aggregateFunctionType) + ImmutableAggregateFunctionReference aggregateFunctionReference = + ImmutableAggregateFunctionReference.builder() + .measure(measure) .addOutputNames("new-column") .build(); // adding expression - expressionReferences.add(expressionReferenceExpression); + expressionReferences.add(expressionReference); // adding aggregation function - expressionReferences.add(expressionReferenceAggregation); + expressionReferences.add(aggregateFunctionReference); ImmutableNamedStruct namedStruct = ImmutableNamedStruct.builder() From c27dd372f93b6462c98e141fa9f9c176dfbbf906 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 8 Dec 2023 15:56:23 -0500 Subject: [PATCH 23/34] fix: clean code --- .../ExtendedExpressionRoundTripTest.java | 194 ++++++------------ 1 file changed, 64 insertions(+), 130 deletions(-) diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java index f083ad404..14abda47a 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -3,7 +3,6 @@ import io.substrait.TestBase; import io.substrait.expression.*; import io.substrait.relation.Aggregate; -import io.substrait.relation.ImmutableMeasure; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; @@ -20,118 +19,56 @@ public class ExtendedExpressionRoundTripTest extends TestBase { @Test public void expressionRoundTrip() throws IOException { // create predefined POJO extended expression - Expression.ScalarFunctionInvocation scalarFunctionInvocation = - b.scalarFn( - NAMESPACE, - "add:dec_dec", - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183)); + ImmutableExpressionReference expressionReference = getImmutableExpressionReference(); - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expression(scalarFunctionInvocation) - .addOutputNames("new-column") - .build(); - - List - expressionReferences = new ArrayList<>(); + List expressionReferences = new ArrayList<>(); + // adding expression expressionReferences.add(expressionReference); - ImmutableNamedStruct namedStruct = - ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct( - Type.Struct.builder() - .nullable(false) - .addFields( - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING) - .build()) - .build(); - - // pojo initial extended expression - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoInitial = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct) - .build(); - - // proto extended expression - io.substrait.proto.ExtendedExpression extendedExpressionProto = - new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); - - // pojo final extended expression - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = - new ProtoExtendedExpressionConverter().from(extendedExpressionProto); + ImmutableNamedStruct namedStruct = getImmutableNamedStruct(); - // validate extended expression pojo initial equals to final roundtrip - Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + assertExtendedExpressionOperation(expressionReferences, namedStruct); } @Test public void aggregationRoundTrip() throws IOException { // create predefined POJO aggregation function - io.substrait.relation.Aggregate.Measure measure = - Aggregate.Measure.builder() - .function( - AggregateFunctionInvocation.builder() - .arguments(Collections.emptyList()) - .declaration(defaultExtensionCollection.aggregateFunctions().get(0)) - .outputType(TypeCreator.of(false).I64) - .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) - .invocation(Expression.AggregationInvocation.ALL) - .build()) - .build(); - ImmutableAggregateFunctionReference aggregateFunctionReference = - ImmutableAggregateFunctionReference.builder().measure(measure).build(); + getImmutableAggregateFunctionReference(); List expressionReferences = new ArrayList<>(); + // adding aggregation function expressionReferences.add(aggregateFunctionReference); - ImmutableNamedStruct namedStruct = - ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct( - Type.Struct.builder() - .nullable(false) - .addFields( - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING) - .build()) - .build(); + ImmutableNamedStruct namedStruct = getImmutableNamedStruct(); - // pojo initial aggregation function - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoInitial = - ImmutableExtendedExpression.builder() - .referredExpressions(expressionReferences) - .baseSchema(namedStruct) - .build(); - - // proto aggregation function - io.substrait.proto.ExtendedExpression extendedExpressionProto = - new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); - - // pojo final aggregation function - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = - new ProtoExtendedExpressionConverter().from(extendedExpressionProto); - - // validate aggregation function pojo initial equals to final roundtrip - Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); + assertExtendedExpressionOperation(expressionReferences, namedStruct); } @Test public void expressionAndAggregationRoundTrip() throws IOException { // POJO 01 // create predefined POJO extended expression + ImmutableExpressionReference expressionReference = getImmutableExpressionReference(); + + List expressionReferences = new ArrayList<>(); + + // POJO 02 + // create predefined POJO aggregation function + ImmutableAggregateFunctionReference aggregateFunctionReference = + getImmutableAggregateFunctionReference(); + + // adding expression + expressionReferences.add(expressionReference); + // adding aggregation function + expressionReferences.add(aggregateFunctionReference); + + ImmutableNamedStruct namedStruct = getImmutableNamedStruct(); + + assertExtendedExpressionOperation(expressionReferences, namedStruct); + } + + private ImmutableExpressionReference getImmutableExpressionReference() { Expression.ScalarFunctionInvocation scalarFunctionInvocation = b.scalarFn( NAMESPACE, @@ -143,17 +80,14 @@ public void expressionAndAggregationRoundTrip() throws IOException { .build(), ExpressionCreator.i32(false, 183)); - ImmutableExpressionReference expressionReference = - ImmutableExpressionReference.builder() - .expression(scalarFunctionInvocation) - .addOutputNames("new-column") - .build(); - - List expressionReferences = new ArrayList<>(); + return ImmutableExpressionReference.builder() + .expression(scalarFunctionInvocation) + .addOutputNames("new-column") + .build(); + } - // POJO 02 - // create predefined POJO aggregation function - ImmutableMeasure measure = + private static ImmutableAggregateFunctionReference getImmutableAggregateFunctionReference() { + Aggregate.Measure measure = Aggregate.Measure.builder() .function( AggregateFunctionInvocation.builder() @@ -165,47 +99,47 @@ public void expressionAndAggregationRoundTrip() throws IOException { .build()) .build(); - ImmutableAggregateFunctionReference aggregateFunctionReference = - ImmutableAggregateFunctionReference.builder() - .measure(measure) - .addOutputNames("new-column") - .build(); + return ImmutableAggregateFunctionReference.builder() + .measure(measure) + .addOutputNames("new-column") + .build(); + } - // adding expression - expressionReferences.add(expressionReference); - // adding aggregation function - expressionReferences.add(aggregateFunctionReference); + private static ImmutableNamedStruct getImmutableNamedStruct() { + return ImmutableNamedStruct.builder() + .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") + .struct( + Type.Struct.builder() + .nullable(false) + .addFields( + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING, + TypeCreator.REQUIRED.decimal(10, 2), + TypeCreator.REQUIRED.STRING) + .build()) + .build(); + } - ImmutableNamedStruct namedStruct = - ImmutableNamedStruct.builder() - .addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT") - .struct( - Type.Struct.builder() - .nullable(false) - .addFields( - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING, - TypeCreator.REQUIRED.decimal(10, 2), - TypeCreator.REQUIRED.STRING) - .build()) - .build(); + private static void assertExtendedExpressionOperation( + List expressionReferences, + ImmutableNamedStruct namedStruct) + throws IOException { - // pojo initial extended expression + aggregation - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoInitial = + // initial pojo + ExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() .referredExpressions(expressionReferences) .baseSchema(namedStruct) .build(); - // proto extended expression + aggregation + // proto io.substrait.proto.ExtendedExpression extendedExpressionProto = new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); - // pojo final extended expression + aggregation - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + // get pojo from proto + ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProto); - // validate extended expression + aggregation pojo initial equals to final roundtrip Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); } } From a5d81268de3f87dc37e71e3080635c9c2a503eaa Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Sat, 9 Dec 2023 03:49:28 -0500 Subject: [PATCH 24/34] fix: support any kind of expression type on extended expression converter --- .../ExtendedExpressionProtoConverter.java | 5 +- .../ExtendedExpressionRoundTripTest.java | 72 ++++++++----------- 2 files changed, 31 insertions(+), 46 deletions(-) diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index 083e9328d..34deae39c 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -1,6 +1,5 @@ package io.substrait.extendedexpression; -import io.substrait.expression.Expression; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; import io.substrait.proto.ExpressionReference; @@ -13,6 +12,7 @@ * io.substrait.proto.ExtendedExpression} */ public class ExtendedExpressionProtoConverter { + public ExtendedExpression toProto( io.substrait.extendedexpression.ExtendedExpression extendedExpression) { @@ -27,8 +27,7 @@ public ExtendedExpression toProto( if (expressionReference instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et) { io.substrait.proto.Expression expressionProto = - expressionProtoConverter.visit( - (Expression.ScalarFunctionInvocation) et.getExpression()); + et.getExpression().accept(expressionProtoConverter); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java index 14abda47a..1341631c0 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -10,65 +10,51 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.stream.Stream; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public class ExtendedExpressionRoundTripTest extends TestBase { static final String NAMESPACE = "/functions_arithmetic_decimal.yaml"; - @Test - public void expressionRoundTrip() throws IOException { - // create predefined POJO extended expression - ImmutableExpressionReference expressionReference = getImmutableExpressionReference(); + private static Stream expressionReferenceProvider() { + return Stream.of( + Arguments.of(getI32LiteralExpression()), + Arguments.of(getFieldReferenceExpression()), + Arguments.of(getScalarFunctionExpression()), + Arguments.of(getImmutableAggregateFunctionReference())); + } + @ParameterizedTest + @MethodSource("expressionReferenceProvider") + public void testRoundTrip(ImmutableExpressionReference expressionReference) throws IOException { List expressionReferences = new ArrayList<>(); - // adding expression expressionReferences.add(expressionReference); - ImmutableNamedStruct namedStruct = getImmutableNamedStruct(); - assertExtendedExpressionOperation(expressionReferences, namedStruct); } - @Test - public void aggregationRoundTrip() throws IOException { - // create predefined POJO aggregation function - ImmutableAggregateFunctionReference aggregateFunctionReference = - getImmutableAggregateFunctionReference(); - - List expressionReferences = new ArrayList<>(); - // adding aggregation function - expressionReferences.add(aggregateFunctionReference); - - ImmutableNamedStruct namedStruct = getImmutableNamedStruct(); - - assertExtendedExpressionOperation(expressionReferences, namedStruct); + private static ImmutableExpressionReference getI32LiteralExpression() { + return ImmutableExpressionReference.builder() + .expression(ExpressionCreator.i32(false, 76)) + .addOutputNames("new-column") + .build(); } - @Test - public void expressionAndAggregationRoundTrip() throws IOException { - // POJO 01 - // create predefined POJO extended expression - ImmutableExpressionReference expressionReference = getImmutableExpressionReference(); - - List expressionReferences = new ArrayList<>(); - - // POJO 02 - // create predefined POJO aggregation function - ImmutableAggregateFunctionReference aggregateFunctionReference = - getImmutableAggregateFunctionReference(); - - // adding expression - expressionReferences.add(expressionReference); - // adding aggregation function - expressionReferences.add(aggregateFunctionReference); - - ImmutableNamedStruct namedStruct = getImmutableNamedStruct(); - - assertExtendedExpressionOperation(expressionReferences, namedStruct); + private static ImmutableExpressionReference getFieldReferenceExpression() { + return ImmutableExpressionReference.builder() + .expression( + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build()) + .addOutputNames("new-column") + .build(); } - private ImmutableExpressionReference getImmutableExpressionReference() { + private static ImmutableExpressionReference getScalarFunctionExpression() { Expression.ScalarFunctionInvocation scalarFunctionInvocation = b.scalarFn( NAMESPACE, From 50602f2c2a5abd44c06b217c33f9676022a2e206 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Sat, 9 Dec 2023 04:12:51 -0500 Subject: [PATCH 25/34] fix: error scalar function test case --- .../ExtendedExpressionRoundTripTest.java | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java index 1341631c0..1c3693087 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -1,6 +1,7 @@ package io.substrait.extendedexpression; import io.substrait.TestBase; +import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.*; import io.substrait.relation.Aggregate; import io.substrait.type.ImmutableNamedStruct; @@ -29,7 +30,8 @@ private static Stream expressionReferenceProvider() { @ParameterizedTest @MethodSource("expressionReferenceProvider") - public void testRoundTrip(ImmutableExpressionReference expressionReference) throws IOException { + public void testRoundTrip(ExtendedExpression.ExpressionReferenceBase expressionReference) + throws IOException { List expressionReferences = new ArrayList<>(); expressionReferences.add(expressionReference); ImmutableNamedStruct namedStruct = getImmutableNamedStruct(); @@ -56,15 +58,16 @@ private static ImmutableExpressionReference getFieldReferenceExpression() { private static ImmutableExpressionReference getScalarFunctionExpression() { Expression.ScalarFunctionInvocation scalarFunctionInvocation = - b.scalarFn( - NAMESPACE, - "add:dec_dec", - TypeCreator.REQUIRED.BOOLEAN, - ImmutableFieldReference.builder() - .addSegments(FieldReference.StructField.of(0)) - .type(TypeCreator.REQUIRED.decimal(10, 2)) - .build(), - ExpressionCreator.i32(false, 183)); + new SubstraitBuilder(defaultExtensionCollection) + .scalarFn( + NAMESPACE, + "add:dec_dec", + TypeCreator.REQUIRED.BOOLEAN, + ImmutableFieldReference.builder() + .addSegments(FieldReference.StructField.of(0)) + .type(TypeCreator.REQUIRED.decimal(10, 2)) + .build(), + ExpressionCreator.i32(false, 183)); return ImmutableExpressionReference.builder() .expression(scalarFunctionInvocation) From 1c8b8b5c3d14b64c4145993db4dbd3fc482d5a2b Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 12 Dec 2023 17:34:02 -0500 Subject: [PATCH 26/34] fix: support any kind of expression type on extended expression converter --- .../isthmus/SqlExpressionToSubstrait.java | 22 ++++----- .../SimpleExtendedExpressionsTest.java | 45 ++++++++----------- 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 8e31fe30a..6f61a10d7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -3,7 +3,6 @@ import com.google.common.annotations.VisibleForTesting; import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; import io.substrait.extendedexpression.ImmutableExpressionReference; -import io.substrait.extendedexpression.ImmutableExpressionType; import io.substrait.extendedexpression.ImmutableExtendedExpression; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; @@ -67,14 +66,11 @@ private ExtendedExpression executeInnerSQLExpression( throws SqlParseException { RexNode rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - io.substrait.expression.Expression.ScalarFunctionInvocation func = - (io.substrait.expression.Expression.ScalarFunctionInvocation) - rexNode.accept(rexExpressionConverter); NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expressionType(ImmutableExpressionType.builder().expression(func).build()) + .expression(rexNode.accept(rexExpressionConverter)) .addOutputNames("new-column") .build(); @@ -107,14 +103,12 @@ private RexNode sqlToRexNode( @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { - SqlToRelConverter converter = - new SqlToRelConverter( - null, - validator, - catalogReader, - relOptCluster, - StandardConvertletTable.INSTANCE, - converterConfig); - return converter; + return new SqlToRelConverter( + null, + validator, + catalogReader, + relOptCluster, + StandardConvertletTable.INSTANCE, + converterConfig); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index d08f588d6..af74b7bf3 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -2,38 +2,31 @@ import io.substrait.proto.ExtendedExpression; import java.io.IOException; +import java.util.stream.Stream; import org.apache.calcite.sql.parser.SqlParseException; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase { - @Test - public void filter() throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("L_ORDERKEY > 10"); + private static Stream expressionTypeProvider() { + return Stream.of( + Arguments.of("2"), // I32LiteralExpression + Arguments.of("L_ORDERKEY"), // FieldReferenceExpression + Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter + Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection + Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn + Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull + Arguments.of("L_ORDERKEY is null") // ScalarFunctionExpressionIsNull + ); } - @Test - public void projection() throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("L_ORDERKEY + 10"); - } - - @Test - public void in() throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("L_ORDERKEY IN (10, 20)"); - } - - @Test - public void isNotNull() throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("L_ORDERKEY is not null"); - } + @ParameterizedTest + @MethodSource("expressionTypeProvider") + public void testExtendedExpressionsRoundTrip(String sqlExpression) + throws SqlParseException, IOException { - @Test - public void isNull() throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - assertProtoExtendedExpressionRoundrip("L_ORDERKEY is null"); + ExtendedExpression extendedExpression = assertProtoExtendedExpressionRoundrip(sqlExpression); } } From 183dcb6b3b737c0181ffa47787bdca0865e1d864 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Tue, 19 Dec 2023 11:24:13 -0500 Subject: [PATCH 27/34] fix: addressing PR comments --- .../substrait/isthmus/SqlConverterBase.java | 40 --------- .../isthmus/SqlExpressionToSubstrait.java | 82 +++++++++++++++---- .../io/substrait/isthmus/TypeConverter.java | 12 --- .../isthmus/ExtendedExpressionTestBase.java | 2 +- .../ExtendedExpressionIntegrationTest.java | 3 +- 5 files changed, 69 insertions(+), 70 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 6501316b1..ec83bbc82 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -1,15 +1,11 @@ package io.substrait.isthmus; -import com.github.bsideup.jabel.Desugar; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.type.NamedStruct; import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.function.Function; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; @@ -26,10 +22,7 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexInputRef; -import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.Table; import org.apache.calcite.schema.impl.AbstractTable; @@ -110,39 +103,6 @@ Pair registerCreateTables(List table return Pair.of(validator, catalogReader); } - Result registerCreateTablesForExtendedExpression(List tables) throws SqlParseException { - Map nameToTypeMap = new LinkedHashMap<>(); - Map nameToNodeMap = new HashMap<>(); - CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); - CalciteCatalogReader catalogReader = - new CalciteCatalogReader(rootSchema, List.of(), factory, config); - SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); - if (tables != null) { - for (String tableDef : tables) { - List tList = parseCreateTable(factory, validator, tableDef); - for (DefinedTable t : tList) { - rootSchema.add(t.getName(), t); - for (RelDataTypeField field : t.type.getFieldList()) { - nameToTypeMap.put( - field.getName(), field.getType()); // to validate the sql expression tree - nameToNodeMap.put( - field.getName(), - new RexInputRef( - field.getIndex(), field.getType())); // to convert sql expression into RexNode - } - } - } - } - return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); - } - - @Desugar - public record Result( - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) {} - Pair registerCreateTables( Function, NamedStruct> tableLookup) throws SqlParseException { Function, Table> lookup = diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 6f61a10d7..9ee979099 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -1,18 +1,25 @@ package io.substrait.isthmus; -import com.google.common.annotations.VisibleForTesting; +import com.github.bsideup.jabel.Desugar; import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; import io.substrait.extendedexpression.ImmutableExpressionReference; import io.substrait.extendedexpression.ImmutableExtendedExpression; +import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.proto.ExtendedExpression; import io.substrait.type.NamedStruct; +import io.substrait.type.Type; import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.apache.calcite.jdbc.CalciteSchema; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; @@ -23,32 +30,39 @@ public class SqlExpressionToSubstrait extends SqlConverterBase { + protected final RexExpressionConverter rexConverter; + public SqlExpressionToSubstrait() { - this(null); + this(FEATURES_DEFAULT, EXTENSION_COLLECTION); } - protected SqlExpressionToSubstrait(FeatureBoard features) { + public SqlExpressionToSubstrait( + FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { super(features); + ScalarFunctionConverter scalarFunctionConverter = + new ScalarFunctionConverter(extensions.scalarFunctions(), factory); + this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); } - private final ScalarFunctionConverter functionConverter = - new ScalarFunctionConverter(EXTENSION_COLLECTION.scalarFunctions(), factory); - - private final RexExpressionConverter rexExpressionConverter = - new RexExpressionConverter(functionConverter); + @Desugar + private record Result( + SqlValidator validator, + CalciteCatalogReader catalogReader, + Map nameToTypeMap, + Map nameToNodeMap) {} /** * Process to execute an SQL Expression to convert into an Extended expression protobuf message * * @param sqlExpression expression defined by the user - * @param tables of names of table needed to consider to load into memory for catalog, schema, - * validate and parse sql + * @param createStatements of names of table needed to consider to load into memory for catalog, + * schema, validate and parse sql * @return extended expression protobuf message * @throws SqlParseException */ - public ExtendedExpression executeSQLExpression(String sqlExpression, List tables) + public ExtendedExpression convert(String sqlExpression, List createStatements) throws SqlParseException { - var result = registerCreateTablesForExtendedExpression(tables); + var result = registerCreateTablesForExtendedExpression(createStatements); return executeInnerSQLExpression( sqlExpression, result.validator(), @@ -66,11 +80,11 @@ private ExtendedExpression executeInnerSQLExpression( throws SqlParseException { RexNode rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); - NamedStruct namedStruct = TypeConverter.DEFAULT.toNamedStruct(nameToTypeMap); + NamedStruct namedStruct = toNamedStruct(nameToTypeMap); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() - .expression(rexNode.accept(rexExpressionConverter)) + .expression(rexNode.accept(this.rexConverter)) .addOutputNames("new-column") .build(); @@ -100,7 +114,6 @@ private RexNode sqlToRexNode( return converter.convertExpression(validSQLNode, nameToNodeMap); } - @VisibleForTesting SqlToRelConverter createSqlToRelConverter( SqlValidator validator, CalciteCatalogReader catalogReader) { return new SqlToRelConverter( @@ -111,4 +124,43 @@ SqlToRelConverter createSqlToRelConverter( StandardConvertletTable.INSTANCE, converterConfig); } + + private Result registerCreateTablesForExtendedExpression(List tables) + throws SqlParseException { + Map nameToTypeMap = new LinkedHashMap<>(); + Map nameToNodeMap = new HashMap<>(); + CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); + CalciteCatalogReader catalogReader = + new CalciteCatalogReader(rootSchema, List.of(), factory, config); + SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); + if (tables != null) { + for (String tableDef : tables) { + List tList = parseCreateTable(factory, validator, tableDef); + for (DefinedTable t : tList) { + rootSchema.add(t.getName(), t); + for (RelDataTypeField field : t.getRowType(factory).getFieldList()) { + nameToTypeMap.put( + field.getName(), field.getType()); // to validate the sql expression tree + nameToNodeMap.put( + field.getName(), + new RexInputRef( + field.getIndex(), field.getType())); // to convert sql expression into RexNode + } + } + } + } + return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); + } + + private NamedStruct toNamedStruct(Map nameToTypeMap) { + var names = new ArrayList(); + var types = new ArrayList(); + for (Map.Entry entry : nameToTypeMap.entrySet()) { + String k = entry.getKey(); + RelDataType v = entry.getValue(); + names.add(k); + types.add(TypeConverter.DEFAULT.toSubstrait(v)); + } + return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build()); + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 6b69f3fff..ba68d5cfc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -11,7 +11,6 @@ import io.substrait.type.TypeVisitor; import java.util.ArrayList; import java.util.List; -import java.util.Map; import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -57,17 +56,6 @@ public NamedStruct toNamedStruct(RelDataType type) { return NamedStruct.of(names, struct); } - public NamedStruct toNamedStruct(Map nameToTypeMap) { - var names = new ArrayList(); - var types = new ArrayList(); - nameToTypeMap.forEach( - (k, v) -> { - names.add(k); - types.add(toSubstrait(v, names)); - }); - return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build()); - } - private Type toSubstrait(RelDataType type, List names) { // Check for user mapped types first as they may re-use SqlTypeNames var userType = userTypeMapper.toSubstrait(type); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 4a9dde8b7..9c2d0ce5e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -37,7 +37,7 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String query, SqlExpressionToSubstrait s, List creates) throws SqlParseException, IOException { // proto initial extended expression - ExtendedExpression extendedExpressionProtoInitial = s.executeSQLExpression(query, creates); + ExtendedExpression extendedExpressionProtoInitial = s.convert(query, creates); // pojo final extended expression io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java index 0c25c603b..97f504361 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java @@ -99,8 +99,7 @@ private static ByteBuffer getExtendedExpression(String sqlExpression) throws IOException, SqlParseException { ExtendedExpression extendedExpression = new SqlExpressionToSubstrait() - .executeSQLExpression( - sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); + .convert(sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); byte[] extendedExpressions = Base64.getDecoder() .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); From 3835a4f35d41ac31dc9dfcd9d0fb4876602a002e Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Wed, 20 Dec 2023 15:37:14 -0800 Subject: [PATCH 28/34] docs: update SqlExpressionToSubstrait#convert docs --- .../io/substrait/isthmus/SqlExpressionToSubstrait.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 9ee979099..f5b8b6f4e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -52,12 +52,11 @@ private record Result( Map nameToNodeMap) {} /** - * Process to execute an SQL Expression to convert into an Extended expression protobuf message + * Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression } * - * @param sqlExpression expression defined by the user - * @param createStatements of names of table needed to consider to load into memory for catalog, - * schema, validate and parse sql - * @return extended expression protobuf message + * @param sqlExpression a SQL expression + * @param createStatements table creation statements defining fields referenced by the expression + * @return a {@link io.substrait.proto.ExtendedExpression } * @throws SqlParseException */ public ExtendedExpression convert(String sqlExpression, List createStatements) From 6327a176cb42d0cc62b9b6eab4bb06527706314f Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 10 Jan 2024 10:03:07 -0500 Subject: [PATCH 29/34] fix: commit suggestion code Co-authored-by: Vibhatha Lakmal Abeykoon --- .../java/io/substrait/isthmus/SqlExpressionToSubstrait.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index f5b8b6f4e..1a750d0c8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -108,9 +108,9 @@ private RexNode sqlToRexNode( throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); - SqlNode validSQLNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); + SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - return converter.convertExpression(validSQLNode, nameToNodeMap); + return converter.convertExpression(validSqlNode, nameToNodeMap); } SqlToRelConverter createSqlToRelConverter( From 8658558bafa0dbc3ff0a1dca40b2f299c4dde7e3 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 11 Jan 2024 09:41:20 -0500 Subject: [PATCH 30/34] fix: addressing PR comments --- .../isthmus/SqlExpressionToSubstrait.java | 18 +++++++--- .../isthmus/ExtendedExpressionTestBase.java | 21 +++++++++-- .../SimpleExtendedExpressionsTest.java | 18 ++++++++-- .../src/test/resources/tpch/schema_error.sql | 36 +++++++++++++++++++ 4 files changed, 84 insertions(+), 9 deletions(-) create mode 100644 isthmus/src/test/resources/tpch/schema_error.sql diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 1a750d0c8..c850397b6 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -138,12 +138,20 @@ private Result registerCreateTablesForExtendedExpression(List tables) for (DefinedTable t : tList) { rootSchema.add(t.getName(), t); for (RelDataTypeField field : t.getRowType(factory).getFieldList()) { - nameToTypeMap.put( - field.getName(), field.getType()); // to validate the sql expression tree - nameToNodeMap.put( + nameToTypeMap.merge( // to validate the sql expression tree field.getName(), - new RexInputRef( - field.getIndex(), field.getType())); // to convert sql expression into RexNode + field.getType(), + (v1, v2) -> { + throw new IllegalArgumentException( + "There is no support for duplicate column names: " + field.getName()); + }); + nameToNodeMap.merge( // to convert sql expression into RexNode + field.getName(), + new RexInputRef(field.getIndex(), field.getType()), + (v1, v2) -> { + throw new IllegalArgumentException( + "There is no support for duplicate column names: " + field.getName()); + }); } } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 9c2d0ce5e..61ad85e25 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -16,23 +16,40 @@ public static String asString(String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); } - public static List tpchSchemaCreateStatements() throws IOException { - String[] values = asString("tpch/schema.sql").split(";"); + public static List tpchSchemaCreateStatements(String schemaToLoad) throws IOException { + String[] values = asString(schemaToLoad).split(";"); return Arrays.stream(values) .filter(t -> !t.trim().isBlank()) .collect(java.util.stream.Collectors.toList()); } + public static List tpchSchemaCreateStatements() throws IOException { + return tpchSchemaCreateStatements("tpch/schema.sql"); + } + protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query) throws IOException, SqlParseException { return assertProtoExtendedExpressionRoundrip(query, new SqlExpressionToSubstrait()); } + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + String query, String schemaToLoad) throws IOException, SqlParseException { + return assertProtoExtendedExpressionRoundrip( + query, new SqlExpressionToSubstrait(), schemaToLoad); + } + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements()); } + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + String query, SqlExpressionToSubstrait s, String schemaToLoad) + throws IOException, SqlParseException { + return assertProtoExtendedExpressionRoundrip( + query, s, tpchSchemaCreateStatements(schemaToLoad)); + } + protected ExtendedExpression assertProtoExtendedExpressionRoundrip( String query, SqlExpressionToSubstrait s, List creates) throws SqlParseException, IOException { diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index af74b7bf3..f25349629 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -1,6 +1,8 @@ package io.substrait.isthmus; -import io.substrait.proto.ExtendedExpression; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + import java.io.IOException; import java.util.stream.Stream; import org.apache.calcite.sql.parser.SqlParseException; @@ -26,7 +28,19 @@ private static Stream expressionTypeProvider() { @MethodSource("expressionTypeProvider") public void testExtendedExpressionsRoundTrip(String sqlExpression) throws SqlParseException, IOException { + assertProtoExtendedExpressionRoundrip(sqlExpression); + } - ExtendedExpression extendedExpression = assertProtoExtendedExpressionRoundrip(sqlExpression); + @ParameterizedTest + @MethodSource("expressionTypeProvider") + public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sqlExpression) { + IllegalArgumentException illegalArgumentException = + assertThrows( + IllegalArgumentException.class, + () -> assertProtoExtendedExpressionRoundrip(sqlExpression, "tpch/schema_error.sql")); + assertTrue( + illegalArgumentException + .getMessage() + .startsWith("There is no support for duplicate column names")); } } diff --git a/isthmus/src/test/resources/tpch/schema_error.sql b/isthmus/src/test/resources/tpch/schema_error.sql new file mode 100644 index 000000000..a4a137223 --- /dev/null +++ b/isthmus/src/test/resources/tpch/schema_error.sql @@ -0,0 +1,36 @@ +CREATE TABLE LINEITEM ( + L_ORDERKEY BIGINT NOT NULL, + L_PARTKEY BIGINT NOT NULL, + L_SUPPKEY BIGINT NOT NULL, + L_LINENUMBER INTEGER, + L_QUANTITY DECIMAL, + L_EXTENDEDPRICE DECIMAL, + L_DISCOUNT DECIMAL, + L_TAX DECIMAL, + L_RETURNFLAG CHAR(1), + L_LINESTATUS CHAR(1), + L_SHIPDATE DATE, + L_COMMITDATE DATE, + L_RECEIPTDATE DATE, + L_SHIPINSTRUCT CHAR(25), + L_SHIPMODE CHAR(10), + L_COMMENT VARCHAR(44) +); +CREATE TABLE LINEITEM_DUPLICATED ( + L_ORDERKEY BIGINT NOT NULL, + L_PARTKEY BIGINT NOT NULL, + L_SUPPKEY BIGINT NOT NULL, + L_LINENUMBER INTEGER, + L_QUANTITY DECIMAL, + L_EXTENDEDPRICE DECIMAL, + L_DISCOUNT DECIMAL, + L_TAX DECIMAL, + L_RETURNFLAG CHAR(1), + L_LINESTATUS CHAR(1), + L_SHIPDATE DATE, + L_COMMITDATE DATE, + L_RECEIPTDATE DATE, + L_SHIPINSTRUCT CHAR(25), + L_SHIPMODE CHAR(10), + L_COMMENT VARCHAR(44) +); \ No newline at end of file From 3ff586ab1c0fb9b1124a93ceeac70c29c0ba134a Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Thu, 11 Jan 2024 12:45:25 -0500 Subject: [PATCH 31/34] fix: delete integration with arrow project --- isthmus/build.gradle.kts | 3 - .../ExtendedExpressionIntegrationTest.java | 110 ------------------ 2 files changed, 113 deletions(-) delete mode 100644 isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index e9569f1c0..315d7c251 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -72,7 +72,6 @@ java { } var CALCITE_VERSION = "1.34.0" -var ARROW_VERSION = "14.0.0" dependencies { implementation(project(":core")) @@ -95,8 +94,6 @@ dependencies { implementation("org.immutables:value-annotations:2.8.8") annotationProcessor("org.immutables:value:2.8.8") testImplementation("org.apache.calcite:calcite-plus:${CALCITE_VERSION}") - testImplementation("org.apache.arrow:arrow-dataset:${ARROW_VERSION}") - testImplementation("org.apache.arrow:arrow-memory-netty:${ARROW_VERSION}") annotationProcessor("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") compileOnly("com.github.bsideup.jabel:jabel-javac-plugin:0.4.2") } diff --git a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java deleted file mode 100644 index 97f504361..000000000 --- a/isthmus/src/test/java/io/substrait/isthmus/integration/ExtendedExpressionIntegrationTest.java +++ /dev/null @@ -1,110 +0,0 @@ -package io.substrait.isthmus.integration; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import com.ibm.icu.impl.ClassLoaderUtil; -import io.substrait.isthmus.ExtendedExpressionTestBase; -import io.substrait.isthmus.SqlExpressionToSubstrait; -import io.substrait.proto.ExtendedExpression; -import java.io.IOException; -import java.net.URL; -import java.nio.ByteBuffer; -import java.util.Base64; -import java.util.Optional; -import org.apache.arrow.dataset.file.FileFormat; -import org.apache.arrow.dataset.file.FileSystemDatasetFactory; -import org.apache.arrow.dataset.jni.NativeMemoryPool; -import org.apache.arrow.dataset.scanner.ScanOptions; -import org.apache.arrow.dataset.scanner.Scanner; -import org.apache.arrow.dataset.source.Dataset; -import org.apache.arrow.dataset.source.DatasetFactory; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.BigIntVector; -import org.apache.arrow.vector.ipc.ArrowReader; -import org.apache.calcite.sql.parser.SqlParseException; -import org.junit.jupiter.api.Test; - -public class ExtendedExpressionIntegrationTest { - - @Test - public void filterDataset() throws SqlParseException, IOException { - URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - // Make sure you pass appropriate data, for example, if you pass N_NATIONKEY > 20 the engine - // creates an i64 but casts it to i32 = 20, causing casting problems. - String sqlExpression = "N_NATIONKEY > 9223372036854771827 - 9223372036854771807"; - ScanOptions options = - new ScanOptions.Builder(/*batchSize*/ 32768) - .columns(Optional.empty()) - .substraitFilter(getExtendedExpression(sqlExpression)) - .build(); - try (BufferAllocator allocator = new RootAllocator(); - DatasetFactory datasetFactory = - new FileSystemDatasetFactory( - allocator, - NativeMemoryPool.getDefault(), - FileFormat.PARQUET, - resource.toURI().toString()); - Dataset dataset = datasetFactory.finish(); - Scanner scanner = dataset.newScan(options); - ArrowReader reader = scanner.scanBatches()) { - int count = 0; - while (reader.loadNextBatch()) { - count += reader.getVectorSchemaRoot().getRowCount(); - } - assertEquals(4, count); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Test - public void projectDataset() throws SqlParseException, IOException { - URL resource = ClassLoaderUtil.getClassLoader().getResource("./tpch/data/nation.parquet"); - // Make sure you pass appropriate data, for example, if you pass N_NATIONKEY + 20 the engine - // creates an i64 but casts it to i32 = 20, causing casting problems. - String sqlExpression = "N_NATIONKEY + 9888486986"; - ScanOptions options = - new ScanOptions.Builder(/*batchSize*/ 32768) - .columns(Optional.empty()) - .substraitProjection(getExtendedExpression(sqlExpression)) - .build(); - try (BufferAllocator allocator = new RootAllocator(); - DatasetFactory datasetFactory = - new FileSystemDatasetFactory( - allocator, - NativeMemoryPool.getDefault(), - FileFormat.PARQUET, - resource.toURI().toString()); - Dataset dataset = datasetFactory.finish(); - Scanner scanner = dataset.newScan(options); - ArrowReader reader = scanner.scanBatches()) { - int count = 0; - Long sum = 0L; - while (reader.loadNextBatch()) { - count += reader.getVectorSchemaRoot().getRowCount(); - BigIntVector bigIntVector = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); - for (int i = 0; i < bigIntVector.getValueCount(); i++) { - sum += bigIntVector.get(i); - } - } - assertEquals(25, count); - assertEquals(24 * 25 / 2 + 9888486986L * count, sum); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private static ByteBuffer getExtendedExpression(String sqlExpression) - throws IOException, SqlParseException { - ExtendedExpression extendedExpression = - new SqlExpressionToSubstrait() - .convert(sqlExpression, ExtendedExpressionTestBase.tpchSchemaCreateStatements()); - byte[] extendedExpressions = - Base64.getDecoder() - .decode(Base64.getEncoder().encodeToString(extendedExpression.toByteArray())); - ByteBuffer substraitExpression = ByteBuffer.allocateDirect(extendedExpressions.length); - substraitExpression.put(extendedExpressions); - return substraitExpression; - } -} From 6151bca5d8e7ccb78375f2fa59c7798d29aa23cb Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Fri, 12 Jan 2024 07:41:26 -0500 Subject: [PATCH 32/34] fix: apply suggestions from code review Co-authored-by: Vibhatha Lakmal Abeykoon --- .../isthmus/ExtendedExpressionTestBase.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index 61ad85e25..d9dda8437 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -27,30 +27,30 @@ public static List tpchSchemaCreateStatements() throws IOException { return tpchSchemaCreateStatements("tpch/schema.sql"); } - protected ExtendedExpression assertProtoExtendedExpressionRoundrip(String query) + protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundrip(query, new SqlExpressionToSubstrait()); + return assertProtoExtendedExpressionRoundtrip(query, new SqlExpressionToSubstrait()); } - protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( String query, String schemaToLoad) throws IOException, SqlParseException { return assertProtoExtendedExpressionRoundrip( query, new SqlExpressionToSubstrait(), schemaToLoad); } - protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements()); } - protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( String query, SqlExpressionToSubstrait s, String schemaToLoad) throws IOException, SqlParseException { return assertProtoExtendedExpressionRoundrip( query, s, tpchSchemaCreateStatements(schemaToLoad)); } - protected ExtendedExpression assertProtoExtendedExpressionRoundrip( + protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( String query, SqlExpressionToSubstrait s, List creates) throws SqlParseException, IOException { // proto initial extended expression From 96a2f25d50bd1663c43b00f57c215165abb2987e Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Mon, 15 Jan 2024 12:14:38 -0500 Subject: [PATCH 33/34] fix: addressing PR comments --- .../isthmus/SqlExpressionToSubstrait.java | 23 +++++++++---------- .../isthmus/ExtendedExpressionTestBase.java | 6 ++--- .../SimpleExtendedExpressionsTest.java | 4 ++-- .../src/test/resources/tpch/schema_error.sql | 2 +- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index c850397b6..82ec0998f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -109,21 +109,17 @@ private RexNode sqlToRexNode( SqlParser parser = SqlParser.create(sql, parserConfig); SqlNode sqlNode = parser.parseExpression(); SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); - SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); + SqlToRelConverter converter = + new SqlToRelConverter( + null, + validator, + catalogReader, + relOptCluster, + StandardConvertletTable.INSTANCE, + converterConfig); return converter.convertExpression(validSqlNode, nameToNodeMap); } - SqlToRelConverter createSqlToRelConverter( - SqlValidator validator, CalciteCatalogReader catalogReader) { - return new SqlToRelConverter( - null, - validator, - catalogReader, - relOptCluster, - StandardConvertletTable.INSTANCE, - converterConfig); - } - private Result registerCreateTablesForExtendedExpression(List tables) throws SqlParseException { Map nameToTypeMap = new LinkedHashMap<>(); @@ -155,6 +151,9 @@ private Result registerCreateTablesForExtendedExpression(List tables) } } } + } else { + throw new IllegalArgumentException( + "Information regarding the data and types must be passed."); } return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index d9dda8437..d47abcc77 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -34,19 +34,19 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( String query, String schemaToLoad) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundrip( + return assertProtoExtendedExpressionRoundtrip( query, new SqlExpressionToSubstrait(), schemaToLoad); } protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundrip(query, s, tpchSchemaCreateStatements()); + return assertProtoExtendedExpressionRoundtrip(query, s, tpchSchemaCreateStatements()); } protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( String query, SqlExpressionToSubstrait s, String schemaToLoad) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundrip( + return assertProtoExtendedExpressionRoundtrip( query, s, tpchSchemaCreateStatements(schemaToLoad)); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index f25349629..9b6afb457 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -28,7 +28,7 @@ private static Stream expressionTypeProvider() { @MethodSource("expressionTypeProvider") public void testExtendedExpressionsRoundTrip(String sqlExpression) throws SqlParseException, IOException { - assertProtoExtendedExpressionRoundrip(sqlExpression); + assertProtoExtendedExpressionRoundtrip(sqlExpression); } @ParameterizedTest @@ -37,7 +37,7 @@ public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sql IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> assertProtoExtendedExpressionRoundrip(sqlExpression, "tpch/schema_error.sql")); + () -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql")); assertTrue( illegalArgumentException .getMessage() diff --git a/isthmus/src/test/resources/tpch/schema_error.sql b/isthmus/src/test/resources/tpch/schema_error.sql index a4a137223..7b4096098 100644 --- a/isthmus/src/test/resources/tpch/schema_error.sql +++ b/isthmus/src/test/resources/tpch/schema_error.sql @@ -33,4 +33,4 @@ CREATE TABLE LINEITEM_DUPLICATED ( L_SHIPINSTRUCT CHAR(25), L_SHIPMODE CHAR(10), L_COMMENT VARCHAR(44) -); \ No newline at end of file +); From 3b110f384d2a0fb14256367df954c0f413f2b5e1 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 18 Jan 2024 08:54:20 -0800 Subject: [PATCH 34/34] refactor: remove unused nation.parquet data --- .editorconfig | 2 +- .../src/test/resources/tpch/data/nation.parquet | Bin 2319 -> 0 bytes 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 isthmus/src/test/resources/tpch/data/nation.parquet diff --git a/.editorconfig b/.editorconfig index 984db0c67..3d674d593 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,7 +10,7 @@ trim_trailing_whitespace = true [*.{yaml,yml}] indent_size = 2 -[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat,**/*.parquet}] +[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat}] charset = unset end_of_line = unset insert_final_newline = unset diff --git a/isthmus/src/test/resources/tpch/data/nation.parquet b/isthmus/src/test/resources/tpch/data/nation.parquet deleted file mode 100644 index 0189118ce7344297b1ad6f1095a11162ddd960ea..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2319 zcmdT`U5p#m6~5OQZ#?W6cEDrC$}Zfk*6!}sJH*az2w_7yyLOWy>!0=7K~71-_4wMJ z-SLcfev)+p51>k*Nc(^qP^Cy!C{zmSLsb>=5Gqv(EfJNdih}yI;-L>nZFxu~S^-Kc z&g@Tq#9QAg&1aoC_xztT=en#`G7^$L!SJM|ERX}z07A=SA%svC!wB5^Wmr;cIgO|l zbxsTKk&kQnYEdm<{kckQ3ETz+sui_rK1Yse#Ur^=A)0tFwp3NC`K7IHoV}|a_clZ1 z5QoLU$F3-cQ<%dt=0rik2+smJEsmvwUIhr+2j>3ri1d5$E_=Ux;O4tq1>Oe&q(j?n zlPN@}4r}?Q*(WW-q9$pwp6wc*3xvmXkGN(Z&S;w&3!nx9Eqjf*r03XO!(}_ix^6p7 z!)g=HCSHdThqZl^I)uW3Z+Wgl8n)R4_NvcZFiuU|S^;}t8K|~vG#Rm5o2or#ZCct1 z1VWVF?6^Iq8{oZ1^%dHN03d8a8@BHe{Nw?{u`NS~&>T2Y2>+fqon~O3Q5O=m+R@43 zG;s}+S)NoDr;!u=UABX!TfSr1K4eJj)65C43ZCJ0@bYlW^jj^%?1w)a(t{xJ>)dG; zPCn6)eam-!YT|gzZrY~dg=dkGJTv6CArSuK*s5>nE9U5N!LeNzD#^LuF7O>%Ax&( z;`=Baeh!&ae_c>P&UFb}(|s3eq0?pCtnZQjkSQW%W@9gO$FMvo5!BvMU!VBR0k+n2 z7-&$p-|cyJ^1ITPh!rkrSNW`RJ%kJ`9P8Nn3hUz=xju7T=qqDA(ts2@%pBP(WWi$> zMW9>z&xiQ+sY`?fjbfH=jJ_kx!Ku0|t&cTq+Xcs--An!|bB&+P_~@`C^B)OgAdz20 z3O|9;(pN&;Xn8Y7W%@X<4L#9S|0VuW?G*l7@9?LF9O&}N(Hg!Ifr*;AGQY8b<=j4*&n&s zy~g>j@@Xi1sP-~FTQY6<}b`dhn!_FBpn`kEwf9xj`?$|R}igVMvOg^ z_)7UvXgy;*Ec_5el0QAU=ZWJ`!4HNl+p`_Za0@@yj%VKFHR=d3%rrWN6nr^myLOOa z(~l%?A^O@squsrv>zkg@GYu%SBCTaVM-hv*@qzW9AcKt|`o_QcLGcx#2zq!Oxo^vh zR|@-}t=Gkk)f>@`L;1`FL4iZ+juV(?Lx&=ned(#he@0y{O^nA-Nqt8;d-W``)Mc*2 z>WSogoW=t5XsC4Efv8`WMj`paFbFt6dl-4Tn*^L-wD+oU7C}hIe`*g%Z!sAUWFf^_6Dh zl&g1Gmr4;Ng_0q8&^l#N&w$UK^^@O0k^sDcfY)V7sQI=C?2UUxc z;lx4;4TC2MCGN2)h2q)x*BG8uWLD;aOGXa%JX|~DcfA24&=)rKxBYLbZgBCRbE%g? zv-0lky)EzFzJa76?B#7ZoAa~yG3FJpk)008`~HV} g0m}3L5GE8n4il7$pz(nmOlWQn{TlAWGW-?#1~YkXEC2ui