From bccdaf0940425f9d5415069a44645f46a131f6cd Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 23 Oct 2024 19:03:53 +0200 Subject: [PATCH] feat(spark): add support for more types and literals (binary, list, array, intervals, timestamps) --- .../io/substrait/spark/ToSubstraitType.scala | 44 +++++++++++++++++-- .../spark/expression/ToSparkExpression.scala | 33 ++++++++++++-- .../spark/expression/ToSubstraitLiteral.scala | 24 ++++++++++ .../spark/logical/ToLogicalPlan.scala | 2 +- .../main/scala/io/substrait/utils/Util.scala | 2 + 5 files changed, 97 insertions(+), 8 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index d393c327f..c9e57fa9a 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -42,11 +42,41 @@ private class ToSparkType override def visit(expr: Type.Str): DataType = StringType + override def visit(expr: Type.Binary): DataType = BinaryType + override def visit(expr: Type.FixedChar): DataType = StringType override def visit(expr: Type.VarChar): DataType = StringType override def visit(expr: Type.Bool): DataType = BooleanType + + override def visit(expr: Type.PrecisionTimestamp): DataType = { + if (expr.precision() != Util.MICROSECOND_PRECISION) { + throw new UnsupportedOperationException( + s"Unsupported precision for timestamp: ${expr.precision()}") + } + TimestampNTZType + } + override def visit(expr: Type.PrecisionTimestampTZ): DataType = { + if (expr.precision() != Util.MICROSECOND_PRECISION) { + throw new UnsupportedOperationException( + s"Unsupported precision for timestamp: ${expr.precision()}") + } + TimestampType + } + + override def visit(expr: Type.IntervalDay): DayTimeIntervalType + + override def visit(expr: Type.IntervalYear): YearMonthIntervalType + + override def visit(expr: Type.ListType): DataType = + ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable()) + + override def visit(expr: Type.Map): DataType = + MapType( + expr.key().accept(this), + expr.value().accept(this), + valueContainsNull = expr.value().nullable()) } class ToSubstraitType { @@ -79,10 +109,12 @@ class ToSubstraitType { case charType: CharType => Some(creator.fixedChar(charType.length)) case varcharType: VarcharType => Some(creator.varChar(varcharType.length)) case StringType => Some(creator.STRING) - case DateType => Some(creator.DATE) - case TimestampType => Some(creator.TIMESTAMP) - case TimestampNTZType => Some(creator.TIMESTAMP_TZ) case BinaryType => Some(creator.BINARY) + case DateType => Some(creator.DATE) + case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION)) + case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION)) + case DayTimeIntervalType.DEFAULT => Some(creator.INTERVAL_DAY) + case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR) case ArrayType(elementType, containsNull) => convert(elementType, Seq.empty, containsNull).map(creator.list) case MapType(keyType, valueType, valueContainsNull) => @@ -91,6 +123,10 @@ class ToSubstraitType { keyT => convert(valueType, Seq.empty, valueContainsNull) .map(valueT => creator.map(keyT, valueT))) + case StructType(fields) => + Util + .seqToOption(fields.map(f => convert(f.dataType, f.nullable))) + .map(l => creator.struct(JavaConverters.asJavaIterable(l))) case _ => None } @@ -128,7 +164,7 @@ class ToSubstraitType { ) } - def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = { + def toAttributeSeq(namedStruct: NamedStruct): Seq[AttributeReference] = { namedStruct .struct() .fields() diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 54b472d3b..e546e20c4 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -18,18 +18,21 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan - import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} import org.apache.spark.sql.types.Decimal import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String - import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.util.DecimalUtil +import io.substrait.utils.Util +import io.substrait.utils.Util.MICROSECOND_PRECISION +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter} +import scala.math.pow class ToSparkExpression( val scalarFunctionConverter: ToScalarFunction, @@ -68,15 +71,38 @@ class ToSparkExpression( Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.BinaryLiteral): Array[Byte] = { + expr.value().toByteArray + } + override def visit(expr: SExpression.DecimalLiteral): Expression = { val value = expr.value.toByteArray val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.DateLiteral): Expression = { Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.PrecisionTimestampLiteral): Long = { + (expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong + } + + override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Long = { + (expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong + } + + override def visit(expr: SExpression.ListLiteral): ArrayData = { + val array = expr.values().asScala.map(value => value.accept(this)) + CatalystTypeConverters.convertToCatalyst(array).asInstanceOf[ArrayData] + } + + override def visit(expr: SExpression.MapLiteral): MapData = { + val map = expr.values().asScala.map { case (key, value) => (key.accept(this), value.accept(this)) } + CatalystTypeConverters.convertToCatalyst(map).asInstanceOf[MapData] + } + override def visit(expr: SExpression.NullLiteral): Expression = { Literal(null, ToSubstraitType.convert(expr.getType)) } @@ -89,6 +115,7 @@ class ToSparkExpression( override def visit(expr: exp.FieldReference): Expression = { withFieldReference(expr)(i => currentOutput(i).clone()) } + override def visit(expr: SExpression.IfThen): Expression = { val branches = expr .ifClauses() diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala index 73362e982..0c0e730df 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala @@ -19,11 +19,15 @@ package io.substrait.spark.expression import io.substrait.spark.ToSubstraitType import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import io.substrait.expression.{Expression => SExpression} import io.substrait.expression.ExpressionCreator._ +import io.substrait.utils.Util + +import scala.collection.JavaConverters class ToSubstraitLiteral { @@ -34,6 +38,15 @@ class ToSubstraitLiteral { scale: Int): SExpression.Literal = decimal(false, d.toJavaBigDecimal, precision, scale) + private def sparkArray2Substrait(arrayData: ArrayData, elementType: DataType): SExpression.Literal = + list(false, JavaConverters.asJavaIterable(arrayData.array.map(any => apply(Literal(any, elementType))))) + + private def sparkMap2Substrait(mapData: MapData, keyType: DataType, valueType: DataType): SExpression.Literal = { + val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType))) + val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType))) + map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap)) + } + val _bool: Boolean => SExpression.Literal = bool(false, _) val _i8: Byte => SExpression.Literal = i8(false, _) val _i16: Short => SExpression.Literal = i16(false, _) @@ -43,7 +56,13 @@ class ToSubstraitLiteral { val _fp64: Double => SExpression.Literal = fp64(false, _) val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait val _date: Int => SExpression.Literal = date(false, _) + val _timestamp: Long => SExpression.Literal = precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _timestampTz: Long => SExpression.Literal = precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds val _string: String => SExpression.Literal = string(false, _) + val _binary: Array[Byte] => SExpression.Literal = binary(false, _) + val _array: (ArrayData, DataType) => SExpression.Literal = sparkArray2Substrait + val _map: (MapData, DataType, DataType) => SExpression.Literal = sparkMap2Substrait + } private def convertWithValue(literal: Literal): Option[SExpression.Literal] = { @@ -59,7 +78,12 @@ class ToSubstraitLiteral { case Literal(d: Decimal, dataType: DecimalType) => Nonnull._decimal(d, dataType.precision, dataType.scale) case Literal(d: Integer, DateType) => Nonnull._date(d) + case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t) + case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t) case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString) + case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b) + case Literal(a: ArrayData, ArrayType(et, _)) => Nonnull._array(a, et) // TODO: handle containsNull + case Literal(m: MapData, MapType(keyType, valueType, _)) => Nonnull._map(m, keyType, valueType) // TODO: handle containsNull case _ => null } ) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index cdc54b2ec..7ec4fe2f4 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -252,7 +252,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(emptyScan: relation.EmptyScan): LogicalPlan = { - LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema)) + LocalRelation(ToSubstraitType.toAttributeSeq(emptyScan.getInitialSchema)) } override def visit(namedScan: relation.NamedScan): LogicalPlan = { resolve(UnresolvedRelation(namedScan.getNames.asScala)) match { diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala index 165d59953..8c6b47a85 100644 --- a/spark/src/main/scala/io/substrait/utils/Util.scala +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer object Util { + val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types + /** * Compute the cartesian product for n lists. *