Skip to content

Commit

Permalink
feat(spark): add support for more types and literals (binary, list, a…
Browse files Browse the repository at this point in the history
…rray, intervals, timestamps)
  • Loading branch information
Blizzara committed Oct 23, 2024
1 parent ac0b7d1 commit bccdaf0
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 8 deletions.
44 changes: 40 additions & 4 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,41 @@ private class ToSparkType

override def visit(expr: Type.Str): DataType = StringType

override def visit(expr: Type.Binary): DataType = BinaryType

override def visit(expr: Type.FixedChar): DataType = StringType

override def visit(expr: Type.VarChar): DataType = StringType

override def visit(expr: Type.Bool): DataType = BooleanType

override def visit(expr: Type.PrecisionTimestamp): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for timestamp: ${expr.precision()}")
}
TimestampNTZType
}
override def visit(expr: Type.PrecisionTimestampTZ): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for timestamp: ${expr.precision()}")
}
TimestampType
}

override def visit(expr: Type.IntervalDay): DayTimeIntervalType

override def visit(expr: Type.IntervalYear): YearMonthIntervalType

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())

override def visit(expr: Type.Map): DataType =
MapType(
expr.key().accept(this),
expr.value().accept(this),
valueContainsNull = expr.value().nullable())
}
class ToSubstraitType {

Expand Down Expand Up @@ -79,10 +109,12 @@ class ToSubstraitType {
case charType: CharType => Some(creator.fixedChar(charType.length))
case varcharType: VarcharType => Some(creator.varChar(varcharType.length))
case StringType => Some(creator.STRING)
case DateType => Some(creator.DATE)
case TimestampType => Some(creator.TIMESTAMP)
case TimestampNTZType => Some(creator.TIMESTAMP_TZ)
case BinaryType => Some(creator.BINARY)
case DateType => Some(creator.DATE)
case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION))
case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION))
case DayTimeIntervalType.DEFAULT => Some(creator.INTERVAL_DAY)
case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR)
case ArrayType(elementType, containsNull) =>
convert(elementType, Seq.empty, containsNull).map(creator.list)
case MapType(keyType, valueType, valueContainsNull) =>
Expand All @@ -91,6 +123,10 @@ class ToSubstraitType {
keyT =>
convert(valueType, Seq.empty, valueContainsNull)
.map(valueT => creator.map(keyT, valueT)))
case StructType(fields) =>
Util
.seqToOption(fields.map(f => convert(f.dataType, f.nullable)))
.map(l => creator.struct(JavaConverters.asJavaIterable(l)))
case _ =>
None
}
Expand Down Expand Up @@ -128,7 +164,7 @@ class ToSubstraitType {
)
}

def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = {
def toAttributeSeq(namedStruct: NamedStruct): Seq[AttributeReference] = {
namedStruct
.struct()
.fields()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@ package io.substrait.spark.expression

import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.DecimalUtil
import io.substrait.utils.Util
import io.substrait.utils.Util.MICROSECOND_PRECISION
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter}
import scala.math.pow

class ToSparkExpression(
val scalarFunctionConverter: ToScalarFunction,
Expand Down Expand Up @@ -68,15 +71,38 @@ class ToSparkExpression(
Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.BinaryLiteral): Array[Byte] = {
expr.value().toByteArray
}

override def visit(expr: SExpression.DecimalLiteral): Expression = {
val value = expr.value.toByteArray
val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16)
Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.DateLiteral): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.PrecisionTimestampLiteral): Long = {
(expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong
}

override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Long = {
(expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong
}

override def visit(expr: SExpression.ListLiteral): ArrayData = {
val array = expr.values().asScala.map(value => value.accept(this))
CatalystTypeConverters.convertToCatalyst(array).asInstanceOf[ArrayData]
}

override def visit(expr: SExpression.MapLiteral): MapData = {
val map = expr.values().asScala.map { case (key, value) => (key.accept(this), value.accept(this)) }
CatalystTypeConverters.convertToCatalyst(map).asInstanceOf[MapData]
}

override def visit(expr: SExpression.NullLiteral): Expression = {
Literal(null, ToSubstraitType.convert(expr.getType))
}
Expand All @@ -89,6 +115,7 @@ class ToSparkExpression(
override def visit(expr: exp.FieldReference): Expression = {
withFieldReference(expr)(i => currentOutput(i).clone())
}

override def visit(expr: SExpression.IfThen): Expression = {
val branches = expr
.ifClauses()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ package io.substrait.spark.expression
import io.substrait.spark.ToSubstraitType

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.expression.{Expression => SExpression}
import io.substrait.expression.ExpressionCreator._
import io.substrait.utils.Util

import scala.collection.JavaConverters

class ToSubstraitLiteral {

Expand All @@ -34,6 +38,15 @@ class ToSubstraitLiteral {
scale: Int): SExpression.Literal =
decimal(false, d.toJavaBigDecimal, precision, scale)

private def sparkArray2Substrait(arrayData: ArrayData, elementType: DataType): SExpression.Literal =
list(false, JavaConverters.asJavaIterable(arrayData.array.map(any => apply(Literal(any, elementType)))))

private def sparkMap2Substrait(mapData: MapData, keyType: DataType, valueType: DataType): SExpression.Literal = {
val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType)))
val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType)))
map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap))
}

val _bool: Boolean => SExpression.Literal = bool(false, _)
val _i8: Byte => SExpression.Literal = i8(false, _)
val _i16: Short => SExpression.Literal = i16(false, _)
Expand All @@ -43,7 +56,13 @@ class ToSubstraitLiteral {
val _fp64: Double => SExpression.Literal = fp64(false, _)
val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait
val _date: Int => SExpression.Literal = date(false, _)
val _timestamp: Long => SExpression.Literal = precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestampTz: Long => SExpression.Literal = precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _string: String => SExpression.Literal = string(false, _)
val _binary: Array[Byte] => SExpression.Literal = binary(false, _)
val _array: (ArrayData, DataType) => SExpression.Literal = sparkArray2Substrait
val _map: (MapData, DataType, DataType) => SExpression.Literal = sparkMap2Substrait

}

private def convertWithValue(literal: Literal): Option[SExpression.Literal] = {
Expand All @@ -59,7 +78,12 @@ class ToSubstraitLiteral {
case Literal(d: Decimal, dataType: DecimalType) =>
Nonnull._decimal(d, dataType.precision, dataType.scale)
case Literal(d: Integer, DateType) => Nonnull._date(d)
case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t)
case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t)
case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString)
case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b)
case Literal(a: ArrayData, ArrayType(et, _)) => Nonnull._array(a, et) // TODO: handle containsNull
case Literal(m: MapData, MapType(keyType, valueType, _)) => Nonnull._map(m, keyType, valueType) // TODO: handle containsNull
case _ => null
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}

override def visit(emptyScan: relation.EmptyScan): LogicalPlan = {
LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema))
LocalRelation(ToSubstraitType.toAttributeSeq(emptyScan.getInitialSchema))
}
override def visit(namedScan: relation.NamedScan): LogicalPlan = {
resolve(UnresolvedRelation(namedScan.getNames.asScala)) match {
Expand Down
2 changes: 2 additions & 0 deletions spark/src/main/scala/io/substrait/utils/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer

object Util {

val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types

/**
* Compute the cartesian product for n lists.
*
Expand Down

0 comments on commit bccdaf0

Please sign in to comment.