|
| 1 | +package io.substrait.isthmus; |
| 2 | + |
| 3 | +import com.github.bsideup.jabel.Desugar; |
| 4 | +import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; |
| 5 | +import io.substrait.extendedexpression.ImmutableExpressionReference; |
| 6 | +import io.substrait.extendedexpression.ImmutableExtendedExpression; |
| 7 | +import io.substrait.extension.SimpleExtension; |
| 8 | +import io.substrait.isthmus.expression.RexExpressionConverter; |
| 9 | +import io.substrait.isthmus.expression.ScalarFunctionConverter; |
| 10 | +import io.substrait.proto.ExtendedExpression; |
| 11 | +import io.substrait.type.NamedStruct; |
| 12 | +import io.substrait.type.Type; |
| 13 | +import java.util.ArrayList; |
| 14 | +import java.util.HashMap; |
| 15 | +import java.util.LinkedHashMap; |
| 16 | +import java.util.List; |
| 17 | +import java.util.Map; |
| 18 | +import org.apache.calcite.jdbc.CalciteSchema; |
| 19 | +import org.apache.calcite.prepare.CalciteCatalogReader; |
| 20 | +import org.apache.calcite.rel.type.RelDataType; |
| 21 | +import org.apache.calcite.rel.type.RelDataTypeField; |
| 22 | +import org.apache.calcite.rex.RexInputRef; |
| 23 | +import org.apache.calcite.rex.RexNode; |
| 24 | +import org.apache.calcite.sql.SqlNode; |
| 25 | +import org.apache.calcite.sql.parser.SqlParseException; |
| 26 | +import org.apache.calcite.sql.parser.SqlParser; |
| 27 | +import org.apache.calcite.sql.validate.SqlValidator; |
| 28 | +import org.apache.calcite.sql2rel.SqlToRelConverter; |
| 29 | +import org.apache.calcite.sql2rel.StandardConvertletTable; |
| 30 | + |
| 31 | +public class SqlExpressionToSubstrait extends SqlConverterBase { |
| 32 | + |
| 33 | + protected final RexExpressionConverter rexConverter; |
| 34 | + |
| 35 | + public SqlExpressionToSubstrait() { |
| 36 | + this(FEATURES_DEFAULT, EXTENSION_COLLECTION); |
| 37 | + } |
| 38 | + |
| 39 | + public SqlExpressionToSubstrait( |
| 40 | + FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { |
| 41 | + super(features); |
| 42 | + ScalarFunctionConverter scalarFunctionConverter = |
| 43 | + new ScalarFunctionConverter(extensions.scalarFunctions(), factory); |
| 44 | + this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); |
| 45 | + } |
| 46 | + |
| 47 | + @Desugar |
| 48 | + private record Result( |
| 49 | + SqlValidator validator, |
| 50 | + CalciteCatalogReader catalogReader, |
| 51 | + Map<String, RelDataType> nameToTypeMap, |
| 52 | + Map<String, RexNode> nameToNodeMap) {} |
| 53 | + |
| 54 | + /** |
| 55 | + * Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression } |
| 56 | + * |
| 57 | + * @param sqlExpression a SQL expression |
| 58 | + * @param createStatements table creation statements defining fields referenced by the expression |
| 59 | + * @return a {@link io.substrait.proto.ExtendedExpression } |
| 60 | + * @throws SqlParseException |
| 61 | + */ |
| 62 | + public ExtendedExpression convert(String sqlExpression, List<String> createStatements) |
| 63 | + throws SqlParseException { |
| 64 | + var result = registerCreateTablesForExtendedExpression(createStatements); |
| 65 | + return executeInnerSQLExpression( |
| 66 | + sqlExpression, |
| 67 | + result.validator(), |
| 68 | + result.catalogReader(), |
| 69 | + result.nameToTypeMap(), |
| 70 | + result.nameToNodeMap()); |
| 71 | + } |
| 72 | + |
| 73 | + private ExtendedExpression executeInnerSQLExpression( |
| 74 | + String sqlExpression, |
| 75 | + SqlValidator validator, |
| 76 | + CalciteCatalogReader catalogReader, |
| 77 | + Map<String, RelDataType> nameToTypeMap, |
| 78 | + Map<String, RexNode> nameToNodeMap) |
| 79 | + throws SqlParseException { |
| 80 | + RexNode rexNode = |
| 81 | + sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); |
| 82 | + NamedStruct namedStruct = toNamedStruct(nameToTypeMap); |
| 83 | + |
| 84 | + ImmutableExpressionReference expressionReference = |
| 85 | + ImmutableExpressionReference.builder() |
| 86 | + .expression(rexNode.accept(this.rexConverter)) |
| 87 | + .addOutputNames("new-column") |
| 88 | + .build(); |
| 89 | + |
| 90 | + List<io.substrait.extendedexpression.ExtendedExpression.ExpressionReference> |
| 91 | + expressionReferences = new ArrayList<>(); |
| 92 | + expressionReferences.add(expressionReference); |
| 93 | + |
| 94 | + ImmutableExtendedExpression.Builder extendedExpression = |
| 95 | + ImmutableExtendedExpression.builder() |
| 96 | + .referredExpressions(expressionReferences) |
| 97 | + .baseSchema(namedStruct); |
| 98 | + |
| 99 | + return new ExtendedExpressionProtoConverter().toProto(extendedExpression.build()); |
| 100 | + } |
| 101 | + |
| 102 | + private RexNode sqlToRexNode( |
| 103 | + String sql, |
| 104 | + SqlValidator validator, |
| 105 | + CalciteCatalogReader catalogReader, |
| 106 | + Map<String, RelDataType> nameToTypeMap, |
| 107 | + Map<String, RexNode> nameToNodeMap) |
| 108 | + throws SqlParseException { |
| 109 | + SqlParser parser = SqlParser.create(sql, parserConfig); |
| 110 | + SqlNode sqlNode = parser.parseExpression(); |
| 111 | + SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); |
| 112 | + SqlToRelConverter converter = |
| 113 | + new SqlToRelConverter( |
| 114 | + null, |
| 115 | + validator, |
| 116 | + catalogReader, |
| 117 | + relOptCluster, |
| 118 | + StandardConvertletTable.INSTANCE, |
| 119 | + converterConfig); |
| 120 | + return converter.convertExpression(validSqlNode, nameToNodeMap); |
| 121 | + } |
| 122 | + |
| 123 | + private Result registerCreateTablesForExtendedExpression(List<String> tables) |
| 124 | + throws SqlParseException { |
| 125 | + Map<String, RelDataType> nameToTypeMap = new LinkedHashMap<>(); |
| 126 | + Map<String, RexNode> nameToNodeMap = new HashMap<>(); |
| 127 | + CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); |
| 128 | + CalciteCatalogReader catalogReader = |
| 129 | + new CalciteCatalogReader(rootSchema, List.of(), factory, config); |
| 130 | + SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); |
| 131 | + if (tables != null) { |
| 132 | + for (String tableDef : tables) { |
| 133 | + List<DefinedTable> tList = parseCreateTable(factory, validator, tableDef); |
| 134 | + for (DefinedTable t : tList) { |
| 135 | + rootSchema.add(t.getName(), t); |
| 136 | + for (RelDataTypeField field : t.getRowType(factory).getFieldList()) { |
| 137 | + nameToTypeMap.merge( // to validate the sql expression tree |
| 138 | + field.getName(), |
| 139 | + field.getType(), |
| 140 | + (v1, v2) -> { |
| 141 | + throw new IllegalArgumentException( |
| 142 | + "There is no support for duplicate column names: " + field.getName()); |
| 143 | + }); |
| 144 | + nameToNodeMap.merge( // to convert sql expression into RexNode |
| 145 | + field.getName(), |
| 146 | + new RexInputRef(field.getIndex(), field.getType()), |
| 147 | + (v1, v2) -> { |
| 148 | + throw new IllegalArgumentException( |
| 149 | + "There is no support for duplicate column names: " + field.getName()); |
| 150 | + }); |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + } else { |
| 155 | + throw new IllegalArgumentException( |
| 156 | + "Information regarding the data and types must be passed."); |
| 157 | + } |
| 158 | + return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); |
| 159 | + } |
| 160 | + |
| 161 | + private NamedStruct toNamedStruct(Map<String, RelDataType> nameToTypeMap) { |
| 162 | + var names = new ArrayList<String>(); |
| 163 | + var types = new ArrayList<Type>(); |
| 164 | + for (Map.Entry<String, RelDataType> entry : nameToTypeMap.entrySet()) { |
| 165 | + String k = entry.getKey(); |
| 166 | + RelDataType v = entry.getValue(); |
| 167 | + names.add(k); |
| 168 | + types.add(TypeConverter.DEFAULT.toSubstrait(v)); |
| 169 | + } |
| 170 | + return NamedStruct.of(names, Type.Struct.builder().fields(types).nullable(false).build()); |
| 171 | + } |
| 172 | +} |
0 commit comments