diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index c9e57fa9a..5d11cbe14 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -65,9 +65,15 @@ private class ToSparkType TimestampType } - override def visit(expr: Type.IntervalDay): DayTimeIntervalType + override def visit(expr: Type.IntervalDay): DataType = { + if (expr.precision() != Util.MICROSECOND_PRECISION) { + throw new UnsupportedOperationException( + s"Unsupported precision for intervalDay: ${expr.precision()}") + } + DayTimeIntervalType.DEFAULT + } - override def visit(expr: Type.IntervalYear): YearMonthIntervalType + override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT override def visit(expr: Type.ListType): DataType = ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable()) @@ -113,7 +119,7 @@ class ToSubstraitType { case DateType => Some(creator.DATE) case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION)) case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION)) - case DayTimeIntervalType.DEFAULT => Some(creator.INTERVAL_DAY) + case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION)) case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR) case ArrayType(elementType, containsNull) => convert(elementType, Seq.empty, containsNull).map(creator.list) diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index e546e20c4..ca69e3012 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -18,18 +18,20 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan + +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.Decimal import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String + import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.util.DecimalUtil import io.substrait.utils.Util -import io.substrait.utils.Util.MICROSECOND_PRECISION -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import io.substrait.utils.Util.SECONDS_PER_DAY import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter} import scala.math.pow @@ -86,11 +88,20 @@ class ToSparkExpression( } override def visit(expr: SExpression.PrecisionTimestampLiteral): Long = { - (expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong + Util.toMicroseconds(expr.value(), expr.precision()) } override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Long = { - (expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong + Util.toMicroseconds(expr.value(), expr.precision()) + } + + override def visit(expr: SExpression.IntervalDayLiteral): Long = { + (expr.days() * SECONDS_PER_DAY + expr.seconds()) * Util.MICROSECOND_PRECISION + Util + .toMicroseconds(expr.subseconds(), expr.precision()) + } + + override def visit(expr: SExpression.IntervalYearLiteral): Long = { + expr.years() * 12 + expr.months() } override def visit(expr: SExpression.ListLiteral): ArrayData = { @@ -99,7 +110,8 @@ class ToSparkExpression( } override def visit(expr: SExpression.MapLiteral): MapData = { - val map = expr.values().asScala.map { case (key, value) => (key.accept(this), value.accept(this)) } + val map = + expr.values().asScala.map { case (key, value) => (key.accept(this), value.accept(this)) } CatalystTypeConverters.convertToCatalyst(map).asInstanceOf[MapData] } diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala index 0c0e730df..9fd0d91ef 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala @@ -38,10 +38,17 @@ class ToSubstraitLiteral { scale: Int): SExpression.Literal = decimal(false, d.toJavaBigDecimal, precision, scale) - private def sparkArray2Substrait(arrayData: ArrayData, elementType: DataType): SExpression.Literal = - list(false, JavaConverters.asJavaIterable(arrayData.array.map(any => apply(Literal(any, elementType))))) + private def sparkArray2Substrait( + arrayData: ArrayData, + elementType: DataType): SExpression.Literal = + list( + false, + JavaConverters.asJavaIterable(arrayData.array.map(any => apply(Literal(any, elementType))))) - private def sparkMap2Substrait(mapData: MapData, keyType: DataType, valueType: DataType): SExpression.Literal = { + private def sparkMap2Substrait( + mapData: MapData, + keyType: DataType, + valueType: DataType): SExpression.Literal = { val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType))) val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType))) map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap)) @@ -56,8 +63,14 @@ class ToSubstraitLiteral { val _fp64: Double => SExpression.Literal = fp64(false, _) val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait val _date: Int => SExpression.Literal = date(false, _) - val _timestamp: Long => SExpression.Literal = precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds - val _timestampTz: Long => SExpression.Literal = precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _timestamp: Long => SExpression.Literal = + precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _timestampTz: Long => SExpression.Literal = + precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _intervalDay: Long => SExpression.Literal = (ms: Long) => + intervalDay(false, 0, 0, ms, Util.MICROSECOND_PRECISION) + val _intervalYear: Long => SExpression.Literal = (m: Long) => + intervalYear(false, (m % 12).toInt, (m / 12).toInt) val _string: String => SExpression.Literal = string(false, _) val _binary: Array[Byte] => SExpression.Literal = binary(false, _) val _array: (ArrayData, DataType) => SExpression.Literal = sparkArray2Substrait @@ -80,10 +93,14 @@ class ToSubstraitLiteral { case Literal(d: Integer, DateType) => Nonnull._date(d) case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t) case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t) + case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d) + case Literal(ym: Long, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym) case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString) case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b) - case Literal(a: ArrayData, ArrayType(et, _)) => Nonnull._array(a, et) // TODO: handle containsNull - case Literal(m: MapData, MapType(keyType, valueType, _)) => Nonnull._map(m, keyType, valueType) // TODO: handle containsNull + case Literal(a: ArrayData, ArrayType(et, _)) => + Nonnull._array(a, et) // TODO: handle containsNull + case Literal(m: MapData, MapType(keyType, valueType, _)) => + Nonnull._map(m, keyType, valueType) // TODO: handle containsNull case _ => null } ) diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala index 8c6b47a85..c5c4eb2f8 100644 --- a/spark/src/main/scala/io/substrait/utils/Util.scala +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -21,8 +21,21 @@ import scala.collection.mutable.ArrayBuffer object Util { + val SECONDS_PER_DAY: Long = 24 * 60 * 60; val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types + def toMicroseconds(value: Long, precision: Int): Long = { + // Spark uses microseconds as a Long value for most time things + val factor = MICROSECOND_PRECISION - precision + if (factor == 0) { + value + } else if (factor > 0) { + value * math.pow(10, factor).toLong + } else { + value / math.pow(10, -factor).toLong + } + } + /** * Compute the cartesian product for n lists. *