Skip to content

Commit 750220e

Browse files
davisusanibardanepitkinvbaruavibhatha
authored
feat: enable conversion of SQL expressions to Substrait ExtendedExpressions (#191)
Introduces the SqlExpressionToSubstrait class for converting SQL expression to Substrait --------- Co-authored-by: Dane Pitkin <[email protected]> Co-authored-by: Victor Barua <[email protected]> Co-authored-by: Vibhatha Lakmal Abeykoon <[email protected]>
1 parent 733815d commit 750220e

File tree

5 files changed

+328
-2
lines changed

5 files changed

+328
-2
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
repos:
22
- repo: https://github.com/adrienverge/yamllint.git
3-
rev: v1.26.0
3+
rev: v1.33.0
44
hooks:
55
- id: yamllint
66
args: [-c=.yamllint.yaml]
77
- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
8-
rev: v8.0.0
8+
rev: v9.9.0
99
hooks:
1010
- id: commitlint
1111
stages: [commit-msg]
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package io.substrait.isthmus;
2+
3+
import com.github.bsideup.jabel.Desugar;
4+
import io.substrait.extendedexpression.ExtendedExpressionProtoConverter;
5+
import io.substrait.extendedexpression.ImmutableExpressionReference;
6+
import io.substrait.extendedexpression.ImmutableExtendedExpression;
7+
import io.substrait.extension.SimpleExtension;
8+
import io.substrait.isthmus.expression.RexExpressionConverter;
9+
import io.substrait.isthmus.expression.ScalarFunctionConverter;
10+
import io.substrait.proto.ExtendedExpression;
11+
import io.substrait.type.NamedStruct;
12+
import io.substrait.type.Type;
13+
import java.util.ArrayList;
14+
import java.util.HashMap;
15+
import java.util.LinkedHashMap;
16+
import java.util.List;
17+
import java.util.Map;
18+
import org.apache.calcite.jdbc.CalciteSchema;
19+
import org.apache.calcite.prepare.CalciteCatalogReader;
20+
import org.apache.calcite.rel.type.RelDataType;
21+
import org.apache.calcite.rel.type.RelDataTypeField;
22+
import org.apache.calcite.rex.RexInputRef;
23+
import org.apache.calcite.rex.RexNode;
24+
import org.apache.calcite.sql.SqlNode;
25+
import org.apache.calcite.sql.parser.SqlParseException;
26+
import org.apache.calcite.sql.parser.SqlParser;
27+
import org.apache.calcite.sql.validate.SqlValidator;
28+
import org.apache.calcite.sql2rel.SqlToRelConverter;
29+
import org.apache.calcite.sql2rel.StandardConvertletTable;
30+
31+
public class SqlExpressionToSubstrait extends SqlConverterBase {
32+
33+
protected final RexExpressionConverter rexConverter;
34+
35+
public SqlExpressionToSubstrait() {
36+
this(FEATURES_DEFAULT, EXTENSION_COLLECTION);
37+
}
38+
39+
public SqlExpressionToSubstrait(
40+
FeatureBoard features, SimpleExtension.ExtensionCollection extensions) {
41+
super(features);
42+
ScalarFunctionConverter scalarFunctionConverter =
43+
new ScalarFunctionConverter(extensions.scalarFunctions(), factory);
44+
this.rexConverter = new RexExpressionConverter(scalarFunctionConverter);
45+
}
46+
47+
@Desugar
48+
private record Result(
49+
SqlValidator validator,
50+
CalciteCatalogReader catalogReader,
51+
Map<String, RelDataType> nameToTypeMap,
52+
Map<String, RexNode> nameToNodeMap) {}
53+
54+
/**
55+
* Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression }
56+
*
57+
* @param sqlExpression a SQL expression
58+
* @param createStatements table creation statements defining fields referenced by the expression
59+
* @return a {@link io.substrait.proto.ExtendedExpression }
60+
* @throws SqlParseException
61+
*/
62+
public ExtendedExpression convert(String sqlExpression, List<String> createStatements)
63+
throws SqlParseException {
64+
var result = registerCreateTablesForExtendedExpression(createStatements);
65+
return executeInnerSQLExpression(
66+
sqlExpression,
67+
result.validator(),
68+
result.catalogReader(),
69+
result.nameToTypeMap(),
70+
result.nameToNodeMap());
71+
}
72+
73+
private ExtendedExpression executeInnerSQLExpression(
74+
String sqlExpression,
75+
SqlValidator validator,
76+
CalciteCatalogReader catalogReader,
77+
Map<String, RelDataType> nameToTypeMap,
78+
Map<String, RexNode> nameToNodeMap)
79+
throws SqlParseException {
80+
RexNode rexNode =
81+
sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap);
82+
NamedStruct namedStruct = toNamedStruct(nameToTypeMap);
83+
84+
ImmutableExpressionReference expressionReference =
85+
ImmutableExpressionReference.builder()
86+
.expression(rexNode.accept(this.rexConverter))
87+
.addOutputNames("new-column")
88+
.build();
89+
90+
List<io.substrait.extendedexpression.ExtendedExpression.ExpressionReference>
91+
expressionReferences = new ArrayList<>();
92+
expressionReferences.add(expressionReference);
93+
94+
ImmutableExtendedExpression.Builder extendedExpression =
95+
ImmutableExtendedExpression.builder()
96+
.referredExpressions(expressionReferences)
97+
.baseSchema(namedStruct);
98+
99+
return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build());
100+
}
101+
102+
private RexNode sqlToRexNode(
103+
String sql,
104+
SqlValidator validator,
105+
CalciteCatalogReader catalogReader,
106+
Map<String, RelDataType> nameToTypeMap,
107+
Map<String, RexNode> nameToNodeMap)
108+
throws SqlParseException {
109+
SqlParser parser = SqlParser.create(sql, parserConfig);
110+
SqlNode sqlNode = parser.parseExpression();
111+
SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap);
112+
SqlToRelConverter converter =
113+
new SqlToRelConverter(
114+
null,
115+
validator,
116+
catalogReader,
117+
relOptCluster,
118+
StandardConvertletTable.INSTANCE,
119+
converterConfig);
120+
return converter.convertExpression(validSqlNode, nameToNodeMap);
121+
}
122+
123+
private Result registerCreateTablesForExtendedExpression(List<String> tables)
124+
throws SqlParseException {
125+
Map<String, RelDataType> nameToTypeMap = new LinkedHashMap<>();
126+
Map<String, RexNode> nameToNodeMap = new HashMap<>();
127+
CalciteSchema rootSchema = CalciteSchema.createRootSchema(false);
128+
CalciteCatalogReader catalogReader =
129+
new CalciteCatalogReader(rootSchema, List.of(), factory, config);
130+
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
131+
if (tables != null) {
132+
for (String tableDef : tables) {
133+
List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef);
134+
for (DefinedTable t : tList) {
135+
rootSchema.add(t.getName(), t);
136+
for (RelDataTypeField field : t.getRowType(factory).getFieldList()) {
137+
nameToTypeMap.merge( // to validate the sql expression tree
138+
field.getName(),
139+
field.getType(),
140+
(v1, v2) -> {
141+
throw new IllegalArgumentException(
142+
"There is no support for duplicate column names: " + field.getName());
143+
});
144+
nameToNodeMap.merge( // to convert sql expression into RexNode
145+
field.getName(),
146+
new RexInputRef(field.getIndex(), field.getType()),
147+
(v1, v2) -> {
148+
throw new IllegalArgumentException(
149+
"There is no support for duplicate column names: " + field.getName());
150+
});
151+
}
152+
}
153+
}
154+
} else {
155+
throw new IllegalArgumentException(
156+
"Information regarding the data and types must be passed.");
157+
}
158+
return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap);
159+
}
160+
161+
private NamedStruct toNamedStruct(Map<String, RelDataType> nameToTypeMap) {
162+
var names = new ArrayList<String>();
163+
var types = new ArrayList<Type>();
164+
for (Map.Entry<String, RelDataType> entry : nameToTypeMap.entrySet()) {
165+
String k = entry.getKey();
166+
RelDataType v = entry.getValue();
167+
names.add(k);
168+
types.add(TypeConverter.DEFAULT.toSubstrait(v));
169+
}
170+
return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build());
171+
}
172+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package io.substrait.isthmus;
2+
3+
import com.google.common.base.Charsets;
4+
import com.google.common.io.Resources;
5+
import io.substrait.extendedexpression.ExtendedExpressionProtoConverter;
6+
import io.substrait.extendedexpression.ProtoExtendedExpressionConverter;
7+
import io.substrait.proto.ExtendedExpression;
8+
import java.io.IOException;
9+
import java.util.Arrays;
10+
import java.util.List;
11+
import org.apache.calcite.sql.parser.SqlParseException;
12+
import org.junit.jupiter.api.Assertions;
13+
14+
public class ExtendedExpressionTestBase {
15+
public static String asString(String resource) throws IOException {
16+
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
17+
}
18+
19+
public static List<String> tpchSchemaCreateStatements(String schemaToLoad) throws IOException {
20+
String[] values = asString(schemaToLoad).split(";");
21+
return Arrays.stream(values)
22+
.filter(t -> !t.trim().isBlank())
23+
.collect(java.util.stream.Collectors.toList());
24+
}
25+
26+
public static List<String> tpchSchemaCreateStatements() throws IOException {
27+
return tpchSchemaCreateStatements("tpch/schema.sql");
28+
}
29+
30+
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query)
31+
throws IOException, SqlParseException {
32+
return assertProtoExtendedExpressionRoundtrip(query, new SqlExpressionToSubstrait());
33+
}
34+
35+
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
36+
String query, String schemaToLoad) throws IOException, SqlParseException {
37+
return assertProtoExtendedExpressionRoundtrip(
38+
query, new SqlExpressionToSubstrait(), schemaToLoad);
39+
}
40+
41+
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
42+
String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException {
43+
return assertProtoExtendedExpressionRoundtrip(query, s, tpchSchemaCreateStatements());
44+
}
45+
46+
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
47+
String query, SqlExpressionToSubstrait s, String schemaToLoad)
48+
throws IOException, SqlParseException {
49+
return assertProtoExtendedExpressionRoundtrip(
50+
query, s, tpchSchemaCreateStatements(schemaToLoad));
51+
}
52+
53+
protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
54+
String query, SqlExpressionToSubstrait s, List<String> creates)
55+
throws SqlParseException, IOException {
56+
// proto initial extended expression
57+
ExtendedExpression extendedExpressionProtoInitial = s.convert(query, creates);
58+
59+
// pojo final extended expression
60+
io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal =
61+
new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial);
62+
63+
// proto final extended expression
64+
ExtendedExpression extendedExpressionProtoFinal =
65+
new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoFinal);
66+
67+
// round-trip to validate extended expression proto initial equals to final
68+
Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial);
69+
70+
return extendedExpressionProtoInitial;
71+
}
72+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package io.substrait.isthmus;
2+
3+
import static org.junit.jupiter.api.Assertions.assertThrows;
4+
import static org.junit.jupiter.api.Assertions.assertTrue;
5+
6+
import java.io.IOException;
7+
import java.util.stream.Stream;
8+
import org.apache.calcite.sql.parser.SqlParseException;
9+
import org.junit.jupiter.params.ParameterizedTest;
10+
import org.junit.jupiter.params.provider.Arguments;
11+
import org.junit.jupiter.params.provider.MethodSource;
12+
13+
public class SimpleExtendedExpressionsTest extends ExtendedExpressionTestBase {
14+
15+
private static Stream<Arguments> expressionTypeProvider() {
16+
return Stream.of(
17+
Arguments.of("2"), // I32LiteralExpression
18+
Arguments.of("L_ORDERKEY"), // FieldReferenceExpression
19+
Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter
20+
Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection
21+
Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn
22+
Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull
23+
Arguments.of("L_ORDERKEY is null") // ScalarFunctionExpressionIsNull
24+
);
25+
}
26+
27+
@ParameterizedTest
28+
@MethodSource("expressionTypeProvider")
29+
public void testExtendedExpressionsRoundTrip(String sqlExpression)
30+
throws SqlParseException, IOException {
31+
assertProtoExtendedExpressionRoundtrip(sqlExpression);
32+
}
33+
34+
@ParameterizedTest
35+
@MethodSource("expressionTypeProvider")
36+
public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sqlExpression) {
37+
IllegalArgumentException illegalArgumentException =
38+
assertThrows(
39+
IllegalArgumentException.class,
40+
() -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql"));
41+
assertTrue(
42+
illegalArgumentException
43+
.getMessage()
44+
.startsWith("There is no support for duplicate column names"));
45+
}
46+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
CREATE TABLE LINEITEM (
2+
L_ORDERKEY BIGINT NOT NULL,
3+
L_PARTKEY BIGINT NOT NULL,
4+
L_SUPPKEY BIGINT NOT NULL,
5+
L_LINENUMBER INTEGER,
6+
L_QUANTITY DECIMAL,
7+
L_EXTENDEDPRICE DECIMAL,
8+
L_DISCOUNT DECIMAL,
9+
L_TAX DECIMAL,
10+
L_RETURNFLAG CHAR(1),
11+
L_LINESTATUS CHAR(1),
12+
L_SHIPDATE DATE,
13+
L_COMMITDATE DATE,
14+
L_RECEIPTDATE DATE,
15+
L_SHIPINSTRUCT CHAR(25),
16+
L_SHIPMODE CHAR(10),
17+
L_COMMENT VARCHAR(44)
18+
);
19+
CREATE TABLE LINEITEM_DUPLICATED (
20+
L_ORDERKEY BIGINT NOT NULL,
21+
L_PARTKEY BIGINT NOT NULL,
22+
L_SUPPKEY BIGINT NOT NULL,
23+
L_LINENUMBER INTEGER,
24+
L_QUANTITY DECIMAL,
25+
L_EXTENDEDPRICE DECIMAL,
26+
L_DISCOUNT DECIMAL,
27+
L_TAX DECIMAL,
28+
L_RETURNFLAG CHAR(1),
29+
L_LINESTATUS CHAR(1),
30+
L_SHIPDATE DATE,
31+
L_COMMITDATE DATE,
32+
L_RECEIPTDATE DATE,
33+
L_SHIPINSTRUCT CHAR(25),
34+
L_SHIPMODE CHAR(10),
35+
L_COMMENT VARCHAR(44)
36+
);

0 commit comments

Comments
 (0)