Skip to content

Commit

Permalink
feat(spark): add support for more types and literals (binary, list, a…
Browse files Browse the repository at this point in the history
…rray, intervals, timestamps)
  • Loading branch information
Blizzara committed Oct 24, 2024
1 parent 45d9387 commit 5ae5dbf
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 7 deletions.
50 changes: 46 additions & 4 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,47 @@ private class ToSparkType

override def visit(expr: Type.Str): DataType = StringType

override def visit(expr: Type.Binary): DataType = BinaryType

override def visit(expr: Type.FixedChar): DataType = StringType

override def visit(expr: Type.VarChar): DataType = StringType

override def visit(expr: Type.Bool): DataType = BooleanType

override def visit(expr: Type.PrecisionTimestamp): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for timestamp: ${expr.precision()}")
}
TimestampNTZType
}
override def visit(expr: Type.PrecisionTimestampTZ): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for timestamp: ${expr.precision()}")
}
TimestampType
}

override def visit(expr: Type.IntervalDay): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for intervalDay: ${expr.precision()}")
}
DayTimeIntervalType.DEFAULT
}

override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())

override def visit(expr: Type.Map): DataType =
MapType(
expr.key().accept(this),
expr.value().accept(this),
valueContainsNull = expr.value().nullable())
}
class ToSubstraitType {

Expand Down Expand Up @@ -79,10 +115,12 @@ class ToSubstraitType {
case charType: CharType => Some(creator.fixedChar(charType.length))
case varcharType: VarcharType => Some(creator.varChar(varcharType.length))
case StringType => Some(creator.STRING)
case DateType => Some(creator.DATE)
case TimestampType => Some(creator.TIMESTAMP)
case TimestampNTZType => Some(creator.TIMESTAMP_TZ)
case BinaryType => Some(creator.BINARY)
case DateType => Some(creator.DATE)
case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION))
case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION))
case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION))
case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR)
case ArrayType(elementType, containsNull) =>
convert(elementType, Seq.empty, containsNull).map(creator.list)
case MapType(keyType, valueType, valueContainsNull) =>
Expand All @@ -91,6 +129,10 @@ class ToSubstraitType {
keyT =>
convert(valueType, Seq.empty, valueContainsNull)
.map(valueT => creator.map(keyT, valueT)))
case StructType(fields) =>
Util
.seqToOption(fields.map(f => convert(f.dataType, f.nullable)))
.map(l => creator.struct(JavaConverters.asJavaIterable(l)))
case _ =>
None
}
Expand Down Expand Up @@ -128,7 +170,7 @@ class ToSubstraitType {
)
}

def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = {
def toAttributeSeq(namedStruct: NamedStruct): Seq[AttributeReference] = {
namedStruct
.struct()
.fields()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package io.substrait.spark.expression
import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -28,8 +30,10 @@ import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.DecimalUtil
import io.substrait.utils.Util
import io.substrait.utils.Util.SECONDS_PER_DAY

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter}

class ToSparkExpression(
val scalarFunctionConverter: ToScalarFunction,
Expand All @@ -42,7 +46,7 @@ class ToSparkExpression(
Literal.TrueLiteral
} else {
Literal.FalseLiteral
}
}
}
override def visit(expr: SExpression.I32Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
Expand All @@ -52,6 +56,10 @@ class ToSparkExpression(
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.FP32Literal): Literal = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.FP64Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}
Expand All @@ -68,15 +76,58 @@ class ToSparkExpression(
Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.BinaryLiteral): Literal = {
Literal(expr.value().toByteArray, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.DecimalLiteral): Expression = {
val value = expr.value.toByteArray
val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16)
Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.DateLiteral): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = {
Literal(Util.toMicroseconds(expr.value(), expr.precision()), ToSubstraitType.convert(expr.getType))

}

override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = {
Literal(Util.toMicroseconds(expr.value(), expr.precision()), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalDayLiteral): Literal = {
val micros = (expr.days() * SECONDS_PER_DAY + expr.seconds()) * Util.MICROSECOND_PRECISION + Util
.toMicroseconds(expr.subseconds(), expr.precision())
Literal(micros, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalYearLiteral): Literal = {
val months =expr.years() * 12 + expr.months()
Literal(months, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.ListLiteral): Literal = {
val array = expr.values().asScala.map(value => value.accept(this).asInstanceOf[Literal].value)
Literal.create(array, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.EmptyListLiteral): Expression = {
Literal.default(ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.MapLiteral): Literal = {
val map = expr.values().asScala.map { case (key, value) => (key.accept(this).asInstanceOf[Literal].value, value.accept(this).asInstanceOf[Literal].value) }
Literal.create(map, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.EmptyMapLiteral): Literal = {
Literal.default(ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.NullLiteral): Expression = {
Literal(null, ToSubstraitType.convert(expr.getType))
}
Expand All @@ -89,6 +140,7 @@ class ToSparkExpression(
override def visit(expr: exp.FieldReference): Expression = {
withFieldReference(expr)(i => currentOutput(i).clone())
}

override def visit(expr: SExpression.IfThen): Expression = {
val branches = expr
.ifClauses()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ package io.substrait.spark.expression
import io.substrait.spark.ToSubstraitType

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.expression.{Expression => SExpression}
import io.substrait.expression.ExpressionCreator._
import io.substrait.utils.Util

import scala.collection.JavaConverters

class ToSubstraitLiteral {

Expand All @@ -34,6 +38,33 @@ class ToSubstraitLiteral {
scale: Int): SExpression.Literal =
decimal(false, d.toJavaBigDecimal, precision, scale)

private def sparkArray2Substrait(
arrayData: ArrayData,
elementType: DataType,
containsNull: Boolean): SExpression.Literal = {
val elements = arrayData.array.map(any => apply(Literal(any, elementType)))
if (elements.isEmpty) {
return emptyList(false, ToSubstraitType.convert(elementType, nullable = containsNull).get)
}
list(false, JavaConverters.asJavaIterable(elements)) // TODO: handle containsNull
}

private def sparkMap2Substrait(
mapData: MapData,
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean): SExpression.Literal = {
val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType)))
val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType)))
if (keys.isEmpty) {
return emptyMap(
false,
ToSubstraitType.convert(keyType, nullable = false).get,
ToSubstraitType.convert(valueType, nullable = valueContainsNull).get)
}
map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap)) // TODO: handle valueContainsNull
}

val _bool: Boolean => SExpression.Literal = bool(false, _)
val _i8: Byte => SExpression.Literal = i8(false, _)
val _i16: Short => SExpression.Literal = i16(false, _)
Expand All @@ -43,7 +74,17 @@ class ToSubstraitLiteral {
val _fp64: Double => SExpression.Literal = fp64(false, _)
val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait
val _date: Int => SExpression.Literal = date(false, _)
val _timestamp: Long => SExpression.Literal =
precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestampTz: Long => SExpression.Literal =
precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _intervalDay: Long => SExpression.Literal = (ms: Long) =>
intervalDay(false, 0, 0, ms, Util.MICROSECOND_PRECISION)
val _intervalYear: Int => SExpression.Literal = (m: Int) => intervalYear(false, m % 12, m / 12)
val _string: String => SExpression.Literal = string(false, _)
val _binary: Array[Byte] => SExpression.Literal = binary(false, _)
val _array: (ArrayData, DataType, Boolean) => SExpression.Literal = sparkArray2Substrait
val _map: (MapData, DataType, DataType, Boolean) => SExpression.Literal = sparkMap2Substrait
}

private def convertWithValue(literal: Literal): Option[SExpression.Literal] = {
Expand All @@ -59,7 +100,16 @@ class ToSubstraitLiteral {
case Literal(d: Decimal, dataType: DecimalType) =>
Nonnull._decimal(d, dataType.precision, dataType.scale)
case Literal(d: Integer, DateType) => Nonnull._date(d)
case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t)
case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t)
case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d)
case Literal(ym: Int, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym)
case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString)
case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b)
case Literal(a: ArrayData, ArrayType(et, containsNull)) =>
Nonnull._array(a, et, containsNull)
case Literal(m: MapData, MapType(keyType, valueType, valueContainsNull)) =>
Nonnull._map(m, keyType, valueType, valueContainsNull)
case _ => null
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}

override def visit(emptyScan: relation.EmptyScan): LogicalPlan = {
LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema))
LocalRelation(ToSubstraitType.toAttributeSeq(emptyScan.getInitialSchema))
}
override def visit(namedScan: relation.NamedScan): LogicalPlan = {
resolve(UnresolvedRelation(namedScan.getNames.asScala)) match {
Expand Down
15 changes: 15 additions & 0 deletions spark/src/main/scala/io/substrait/utils/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@ import scala.collection.mutable.ArrayBuffer

object Util {

val SECONDS_PER_DAY: Long = 24 * 60 * 60;
val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types

def toMicroseconds(value: Long, precision: Int): Long = {
// Spark uses microseconds as a Long value for most time things
val factor = MICROSECOND_PRECISION - precision
if (factor == 0) {
value
} else if (factor > 0) {
value * math.pow(10, factor).toLong
} else {
value / math.pow(10, -factor).toLong
}
}

/**
* Compute the cartesian product for n lists.
*
Expand Down
Loading

0 comments on commit 5ae5dbf

Please sign in to comment.