diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index d393c327f..962864980 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -42,11 +42,47 @@ private class ToSparkType override def visit(expr: Type.Str): DataType = StringType + override def visit(expr: Type.Binary): DataType = BinaryType + override def visit(expr: Type.FixedChar): DataType = StringType override def visit(expr: Type.VarChar): DataType = StringType override def visit(expr: Type.Bool): DataType = BooleanType + + override def visit(expr: Type.PrecisionTimestamp): DataType = { + if (expr.precision() != Util.MICROSECOND_PRECISION) { + throw new UnsupportedOperationException( + s"Unsupported precision for timestamp: ${expr.precision()}") + } + TimestampNTZType + } + override def visit(expr: Type.PrecisionTimestampTZ): DataType = { + if (expr.precision() != Util.MICROSECOND_PRECISION) { + throw new UnsupportedOperationException( + s"Unsupported precision for timestamp: ${expr.precision()}") + } + TimestampType + } + + override def visit(expr: Type.IntervalDay): DataType = { + if (expr.precision() != Util.MICROSECOND_PRECISION) { + throw new UnsupportedOperationException( + s"Unsupported precision for intervalDay: ${expr.precision()}") + } + DayTimeIntervalType.DEFAULT + } + + override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT + + override def visit(expr: Type.ListType): DataType = + ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable()) + + override def visit(expr: Type.Map): DataType = + MapType( + expr.key().accept(this), + expr.value().accept(this), + valueContainsNull = expr.value().nullable()) } class ToSubstraitType { @@ -79,10 +115,12 @@ class ToSubstraitType { case charType: CharType => Some(creator.fixedChar(charType.length)) case varcharType: VarcharType => Some(creator.varChar(varcharType.length)) case StringType => Some(creator.STRING) - case DateType => Some(creator.DATE) - case TimestampType => Some(creator.TIMESTAMP) - case TimestampNTZType => Some(creator.TIMESTAMP_TZ) case BinaryType => Some(creator.BINARY) + case DateType => Some(creator.DATE) + case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION)) + case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION)) + case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION)) + case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR) case ArrayType(elementType, containsNull) => convert(elementType, Seq.empty, containsNull).map(creator.list) case MapType(keyType, valueType, valueContainsNull) => @@ -128,7 +166,7 @@ class ToSubstraitType { ) } - def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = { + def toAttributeSeq(namedStruct: NamedStruct): Seq[AttributeReference] = { namedStruct .struct() .fields() diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 54b472d3b..8bde45116 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -28,8 +28,9 @@ import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.util.DecimalUtil +import io.substrait.utils.Util -import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter} class ToSparkExpression( val scalarFunctionConverter: ToScalarFunction, @@ -52,6 +53,10 @@ class ToSparkExpression( Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.FP32Literal): Literal = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.FP64Literal): Expression = { Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } @@ -68,15 +73,69 @@ class ToSparkExpression( Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.BinaryLiteral): Literal = { + Literal(expr.value().toByteArray, ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.DecimalLiteral): Expression = { val value = expr.value.toByteArray val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.DateLiteral): Expression = { Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = { + Literal( + Util.toMicroseconds(expr.value(), expr.precision()), + ToSubstraitType.convert(expr.getType)) + + } + + override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = { + Literal( + Util.toMicroseconds(expr.value(), expr.precision()), + ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.IntervalDayLiteral): Literal = { + val micros = + (expr.days() * Util.SECONDS_PER_DAY + expr.seconds()) * Util.MICROSECOND_PRECISION + + Util.toMicroseconds(expr.subseconds(), expr.precision()) + Literal(micros, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.IntervalYearLiteral): Literal = { + val months = expr.years() * 12 + expr.months() + Literal(months, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.ListLiteral): Literal = { + val array = expr.values().asScala.map(value => value.accept(this).asInstanceOf[Literal].value) + Literal.create(array, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.EmptyListLiteral): Expression = { + Literal.default(ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.MapLiteral): Literal = { + val map = expr.values().asScala.map { + case (key, value) => + ( + key.accept(this).asInstanceOf[Literal].value, + value.accept(this).asInstanceOf[Literal].value + ) + } + Literal.create(map, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.EmptyMapLiteral): Literal = { + Literal.default(ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.NullLiteral): Expression = { Literal(null, ToSubstraitType.convert(expr.getType)) } @@ -89,6 +148,7 @@ class ToSparkExpression( override def visit(expr: exp.FieldReference): Expression = { withFieldReference(expr)(i => currentOutput(i).clone()) } + override def visit(expr: SExpression.IfThen): Expression = { val branches = expr .ifClauses() diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala index 73362e982..f5bdfc918 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala @@ -19,11 +19,15 @@ package io.substrait.spark.expression import io.substrait.spark.ToSubstraitType import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import io.substrait.expression.{Expression => SExpression} import io.substrait.expression.ExpressionCreator._ +import io.substrait.utils.Util + +import scala.collection.JavaConverters class ToSubstraitLiteral { @@ -34,6 +38,34 @@ class ToSubstraitLiteral { scale: Int): SExpression.Literal = decimal(false, d.toJavaBigDecimal, precision, scale) + private def sparkArray2Substrait( + arrayData: ArrayData, + elementType: DataType, + containsNull: Boolean): SExpression.Literal = { + val elements = arrayData.array.map(any => apply(Literal(any, elementType))) + if (elements.isEmpty) { + return emptyList(false, ToSubstraitType.convert(elementType, nullable = containsNull).get) + } + list(false, JavaConverters.asJavaIterable(elements)) // TODO: handle containsNull + } + + private def sparkMap2Substrait( + mapData: MapData, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): SExpression.Literal = { + val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType))) + val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType))) + if (keys.isEmpty) { + return emptyMap( + false, + ToSubstraitType.convert(keyType, nullable = false).get, + ToSubstraitType.convert(valueType, nullable = valueContainsNull).get) + } + // TODO: handle valueContainsNull + map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap)) + } + val _bool: Boolean => SExpression.Literal = bool(false, _) val _i8: Byte => SExpression.Literal = i8(false, _) val _i16: Short => SExpression.Literal = i16(false, _) @@ -43,7 +75,17 @@ class ToSubstraitLiteral { val _fp64: Double => SExpression.Literal = fp64(false, _) val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait val _date: Int => SExpression.Literal = date(false, _) + val _timestamp: Long => SExpression.Literal = + precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _timestampTz: Long => SExpression.Literal = + precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _intervalDay: Long => SExpression.Literal = (ms: Long) => + intervalDay(false, 0, 0, ms, Util.MICROSECOND_PRECISION) + val _intervalYear: Int => SExpression.Literal = (m: Int) => intervalYear(false, m / 12, m % 12) val _string: String => SExpression.Literal = string(false, _) + val _binary: Array[Byte] => SExpression.Literal = binary(false, _) + val _array: (ArrayData, DataType, Boolean) => SExpression.Literal = sparkArray2Substrait + val _map: (MapData, DataType, DataType, Boolean) => SExpression.Literal = sparkMap2Substrait } private def convertWithValue(literal: Literal): Option[SExpression.Literal] = { @@ -59,7 +101,16 @@ class ToSubstraitLiteral { case Literal(d: Decimal, dataType: DecimalType) => Nonnull._decimal(d, dataType.precision, dataType.scale) case Literal(d: Integer, DateType) => Nonnull._date(d) + case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t) + case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t) + case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d) + case Literal(ym: Int, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym) case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString) + case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b) + case Literal(a: ArrayData, ArrayType(et, containsNull)) => + Nonnull._array(a, et, containsNull) + case Literal(m: MapData, MapType(keyType, valueType, valueContainsNull)) => + Nonnull._map(m, keyType, valueType, valueContainsNull) case _ => null } ) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index cdc54b2ec..7ec4fe2f4 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -252,7 +252,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(emptyScan: relation.EmptyScan): LogicalPlan = { - LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema)) + LocalRelation(ToSubstraitType.toAttributeSeq(emptyScan.getInitialSchema)) } override def visit(namedScan: relation.NamedScan): LogicalPlan = { resolve(UnresolvedRelation(namedScan.getNames.asScala)) match { diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala index 165d59953..cfaf12bfd 100644 --- a/spark/src/main/scala/io/substrait/utils/Util.scala +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -21,6 +21,22 @@ import scala.collection.mutable.ArrayBuffer object Util { + val SECONDS_PER_DAY: Long = 24 * 60 * 60; + val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types + + def toMicroseconds(value: Long, precision: Int): Long = { + // Spark uses microseconds as a Long value for most time things + val factor = MICROSECOND_PRECISION - precision + // Doing this in a way that avoids floating point math + if (factor == 0) { + value + } else if (factor > 0) { + value * math.pow(10, factor).toLong + } else { + value / math.pow(10, -factor).toLong + } + } + /** * Compute the cartesian product for n lists. * diff --git a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala new file mode 100644 index 000000000..61d28201f --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala @@ -0,0 +1,116 @@ +package io.substrait.spark + +import io.substrait.spark.expression.{ToSparkExpression, ToSubstraitLiteral} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.MapData +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, StringType, TimestampNTZType, TimestampType, YearMonthIntervalType} +import org.apache.spark.substrait.SparkTypeUtil +import org.apache.spark.unsafe.types.UTF8String + +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} + +class TypesAndLiteralsSuite extends SparkFunSuite { + + val toSparkExpression = new ToSparkExpression(null, null) + + val types: Seq[DataType] = List( + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + BinaryType, + BooleanType, + DecimalType(10, 2), + TimestampNTZType, + TimestampType, + DayTimeIntervalType.DEFAULT, + YearMonthIntervalType.DEFAULT, + ArrayType(IntegerType, containsNull = false), + ArrayType(IntegerType, containsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + MapType(IntegerType, StringType, valueContainsNull = true) + ) + + types.foreach( + t => { + test(s"test type: $t") { + // Nullability doesn't matter as in Spark it's not a property of the type + val substraitType = ToSubstraitType.convert(t, nullable = true).get + val sparkType = ToSubstraitType.convert(substraitType) + + println("Before: " + t) + println("After: " + sparkType) + println("Substrait: " + substraitType) + + assert(t == sparkType) + } + }) + + val defaultLiterals: Seq[Literal] = types.map(Literal.default) + + val literals: Seq[Literal] = List( + Literal(1), + Literal(1L), + Literal(1.0f), + Literal(1.0), + Literal("1"), + Literal(Array[Byte](1)), + Literal(true), + Literal(BigDecimal("123.4567890")), + Literal(Instant.now()), // Timestamp + Literal(LocalDateTime.now()), // TimestampNTZ + Literal(LocalDate.now()), // Date + Literal(Duration.ofDays(1)), // DayTimeInterval + Literal( + Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4).plusMillis(5) + ), // DayTimeInterval + Literal(Period.ofYears(1)), // YearMonthInterval + Literal(Period.of(1, 2, 0)), // YearMonthInterval, days are ignored + Literal.create(Array(1, 2, 3), ArrayType(IntegerType, containsNull = false)) +// Literal.create(Array(1, null, 3), ArrayType(IntegerType, containsNull = true)) // TODO: handle containsNulls + ) + + (defaultLiterals ++ literals).foreach( + l => { + test(s"test literal: $l (${l.dataType})") { + val substraitLiteral = ToSubstraitLiteral.convert(l).get + val sparkLiteral = substraitLiteral.accept(toSparkExpression).asInstanceOf[Literal] + + println("Before: " + l + " " + l.dataType) + println("After: " + sparkLiteral + " " + sparkLiteral.dataType) + println("Substrait: " + substraitLiteral) + + assert(l.dataType == sparkLiteral.dataType) // makes understanding failures easier + assert(l == sparkLiteral) + } + }) + + test(s"test map literal") { + val l = Literal.create( + Map(1 -> "a", 2 -> "b"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + val substraitLiteral = ToSubstraitLiteral.convert(l).get + val sparkLiteral = substraitLiteral.accept(toSparkExpression).asInstanceOf[Literal] + + println("Before: " + l + " " + l.dataType) + println("After: " + sparkLiteral + " " + sparkLiteral.dataType) + println("Substrait: " + substraitLiteral) + + assert(l.dataType == sparkLiteral.dataType) // makes understanding failures easier + assert(SparkTypeUtil.sameType(l.dataType, sparkLiteral.dataType)) + + // MapData doesn't implement equality so we have to compare the arrays manually + val originalKeys = l.value.asInstanceOf[MapData].keyArray().toIntArray().sorted + val sparkKeys = sparkLiteral.value.asInstanceOf[MapData].keyArray().toIntArray().sorted + assert(originalKeys.sameElements(sparkKeys)) + + val originalValues = l.value.asInstanceOf[MapData].valueArray().toArray[UTF8String](StringType) + val sparkValues = + sparkLiteral.value.asInstanceOf[MapData].valueArray().toArray[UTF8String](StringType) + assert(originalValues.sorted.sameElements(sparkValues.sorted)) + } +}