Skip to content

Commit 5ae5dbf

Browse files
committed
feat(spark): add support for more types and literals (binary, list, array, intervals, timestamps)
1 parent 45d9387 commit 5ae5dbf

File tree

6 files changed

+274
-7
lines changed

6 files changed

+274
-7
lines changed

spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,47 @@ private class ToSparkType
4242

4343
override def visit(expr: Type.Str): DataType = StringType
4444

45+
override def visit(expr: Type.Binary): DataType = BinaryType
46+
4547
override def visit(expr: Type.FixedChar): DataType = StringType
4648

4749
override def visit(expr: Type.VarChar): DataType = StringType
4850

4951
override def visit(expr: Type.Bool): DataType = BooleanType
52+
53+
override def visit(expr: Type.PrecisionTimestamp): DataType = {
54+
if (expr.precision() != Util.MICROSECOND_PRECISION) {
55+
throw new UnsupportedOperationException(
56+
s"Unsupported precision for timestamp: ${expr.precision()}")
57+
}
58+
TimestampNTZType
59+
}
60+
override def visit(expr: Type.PrecisionTimestampTZ): DataType = {
61+
if (expr.precision() != Util.MICROSECOND_PRECISION) {
62+
throw new UnsupportedOperationException(
63+
s"Unsupported precision for timestamp: ${expr.precision()}")
64+
}
65+
TimestampType
66+
}
67+
68+
override def visit(expr: Type.IntervalDay): DataType = {
69+
if (expr.precision() != Util.MICROSECOND_PRECISION) {
70+
throw new UnsupportedOperationException(
71+
s"Unsupported precision for intervalDay: ${expr.precision()}")
72+
}
73+
DayTimeIntervalType.DEFAULT
74+
}
75+
76+
override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT
77+
78+
override def visit(expr: Type.ListType): DataType =
79+
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())
80+
81+
override def visit(expr: Type.Map): DataType =
82+
MapType(
83+
expr.key().accept(this),
84+
expr.value().accept(this),
85+
valueContainsNull = expr.value().nullable())
5086
}
5187
class ToSubstraitType {
5288

@@ -79,10 +115,12 @@ class ToSubstraitType {
79115
case charType: CharType => Some(creator.fixedChar(charType.length))
80116
case varcharType: VarcharType => Some(creator.varChar(varcharType.length))
81117
case StringType => Some(creator.STRING)
82-
case DateType => Some(creator.DATE)
83-
case TimestampType => Some(creator.TIMESTAMP)
84-
case TimestampNTZType => Some(creator.TIMESTAMP_TZ)
85118
case BinaryType => Some(creator.BINARY)
119+
case DateType => Some(creator.DATE)
120+
case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION))
121+
case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION))
122+
case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION))
123+
case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR)
86124
case ArrayType(elementType, containsNull) =>
87125
convert(elementType, Seq.empty, containsNull).map(creator.list)
88126
case MapType(keyType, valueType, valueContainsNull) =>
@@ -91,6 +129,10 @@ class ToSubstraitType {
91129
keyT =>
92130
convert(valueType, Seq.empty, valueContainsNull)
93131
.map(valueT => creator.map(keyT, valueT)))
132+
case StructType(fields) =>
133+
Util
134+
.seqToOption(fields.map(f => convert(f.dataType, f.nullable)))
135+
.map(l => creator.struct(JavaConverters.asJavaIterable(l)))
94136
case _ =>
95137
None
96138
}
@@ -128,7 +170,7 @@ class ToSubstraitType {
128170
)
129171
}
130172

131-
def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = {
173+
def toAttributeSeq(namedStruct: NamedStruct): Seq[AttributeReference] = {
132174
namedStruct
133175
.struct()
134176
.fields()

spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package io.substrait.spark.expression
1919
import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
2020
import io.substrait.spark.logical.ToLogicalPlan
2121

22+
import org.apache.spark.sql.catalyst.CatalystTypeConverters
2223
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
24+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
2325
import org.apache.spark.sql.types.Decimal
2426
import org.apache.spark.substrait.SparkTypeUtil
2527
import org.apache.spark.unsafe.types.UTF8String
@@ -28,8 +30,10 @@ import io.substrait.`type`.{StringTypeVisitor, Type}
2830
import io.substrait.{expression => exp}
2931
import io.substrait.expression.{Expression => SExpression}
3032
import io.substrait.util.DecimalUtil
33+
import io.substrait.utils.Util
34+
import io.substrait.utils.Util.SECONDS_PER_DAY
3135

32-
import scala.collection.JavaConverters.asScalaBufferConverter
36+
import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter}
3337

3438
class ToSparkExpression(
3539
val scalarFunctionConverter: ToScalarFunction,
@@ -42,7 +46,7 @@ class ToSparkExpression(
4246
Literal.TrueLiteral
4347
} else {
4448
Literal.FalseLiteral
45-
}
49+
}
4650
}
4751
override def visit(expr: SExpression.I32Literal): Expression = {
4852
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
@@ -52,6 +56,10 @@ class ToSparkExpression(
5256
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
5357
}
5458

59+
override def visit(expr: SExpression.FP32Literal): Literal = {
60+
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
61+
}
62+
5563
override def visit(expr: SExpression.FP64Literal): Expression = {
5664
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
5765
}
@@ -68,15 +76,58 @@ class ToSparkExpression(
6876
Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType))
6977
}
7078

79+
override def visit(expr: SExpression.BinaryLiteral): Literal = {
80+
Literal(expr.value().toByteArray, ToSubstraitType.convert(expr.getType))
81+
}
82+
7183
override def visit(expr: SExpression.DecimalLiteral): Expression = {
7284
val value = expr.value.toByteArray
7385
val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16)
7486
Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType))
7587
}
88+
7689
override def visit(expr: SExpression.DateLiteral): Expression = {
7790
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
7891
}
7992

93+
override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = {
94+
Literal(Util.toMicroseconds(expr.value(), expr.precision()), ToSubstraitType.convert(expr.getType))
95+
96+
}
97+
98+
override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = {
99+
Literal(Util.toMicroseconds(expr.value(), expr.precision()), ToSubstraitType.convert(expr.getType))
100+
}
101+
102+
override def visit(expr: SExpression.IntervalDayLiteral): Literal = {
103+
val micros = (expr.days() * SECONDS_PER_DAY + expr.seconds()) * Util.MICROSECOND_PRECISION + Util
104+
.toMicroseconds(expr.subseconds(), expr.precision())
105+
Literal(micros, ToSubstraitType.convert(expr.getType))
106+
}
107+
108+
override def visit(expr: SExpression.IntervalYearLiteral): Literal = {
109+
val months =expr.years() * 12 + expr.months()
110+
Literal(months, ToSubstraitType.convert(expr.getType))
111+
}
112+
113+
override def visit(expr: SExpression.ListLiteral): Literal = {
114+
val array = expr.values().asScala.map(value => value.accept(this).asInstanceOf[Literal].value)
115+
Literal.create(array, ToSubstraitType.convert(expr.getType))
116+
}
117+
118+
override def visit(expr: SExpression.EmptyListLiteral): Expression = {
119+
Literal.default(ToSubstraitType.convert(expr.getType))
120+
}
121+
122+
override def visit(expr: SExpression.MapLiteral): Literal = {
123+
val map = expr.values().asScala.map { case (key, value) => (key.accept(this).asInstanceOf[Literal].value, value.accept(this).asInstanceOf[Literal].value) }
124+
Literal.create(map, ToSubstraitType.convert(expr.getType))
125+
}
126+
127+
override def visit(expr: SExpression.EmptyMapLiteral): Literal = {
128+
Literal.default(ToSubstraitType.convert(expr.getType))
129+
}
130+
80131
override def visit(expr: SExpression.NullLiteral): Expression = {
81132
Literal(null, ToSubstraitType.convert(expr.getType))
82133
}
@@ -89,6 +140,7 @@ class ToSparkExpression(
89140
override def visit(expr: exp.FieldReference): Expression = {
90141
withFieldReference(expr)(i => currentOutput(i).clone())
91142
}
143+
92144
override def visit(expr: SExpression.IfThen): Expression = {
93145
val branches = expr
94146
.ifClauses()

spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@ package io.substrait.spark.expression
1919
import io.substrait.spark.ToSubstraitType
2020

2121
import org.apache.spark.sql.catalyst.expressions.Literal
22+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
2223
import org.apache.spark.sql.types._
2324
import org.apache.spark.unsafe.types.UTF8String
2425

2526
import io.substrait.expression.{Expression => SExpression}
2627
import io.substrait.expression.ExpressionCreator._
28+
import io.substrait.utils.Util
29+
30+
import scala.collection.JavaConverters
2731

2832
class ToSubstraitLiteral {
2933

@@ -34,6 +38,33 @@ class ToSubstraitLiteral {
3438
scale: Int): SExpression.Literal =
3539
decimal(false, d.toJavaBigDecimal, precision, scale)
3640

41+
private def sparkArray2Substrait(
42+
arrayData: ArrayData,
43+
elementType: DataType,
44+
containsNull: Boolean): SExpression.Literal = {
45+
val elements = arrayData.array.map(any => apply(Literal(any, elementType)))
46+
if (elements.isEmpty) {
47+
return emptyList(false, ToSubstraitType.convert(elementType, nullable = containsNull).get)
48+
}
49+
list(false, JavaConverters.asJavaIterable(elements)) // TODO: handle containsNull
50+
}
51+
52+
private def sparkMap2Substrait(
53+
mapData: MapData,
54+
keyType: DataType,
55+
valueType: DataType,
56+
valueContainsNull: Boolean): SExpression.Literal = {
57+
val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType)))
58+
val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType)))
59+
if (keys.isEmpty) {
60+
return emptyMap(
61+
false,
62+
ToSubstraitType.convert(keyType, nullable = false).get,
63+
ToSubstraitType.convert(valueType, nullable = valueContainsNull).get)
64+
}
65+
map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap)) // TODO: handle valueContainsNull
66+
}
67+
3768
val _bool: Boolean => SExpression.Literal = bool(false, _)
3869
val _i8: Byte => SExpression.Literal = i8(false, _)
3970
val _i16: Short => SExpression.Literal = i16(false, _)
@@ -43,7 +74,17 @@ class ToSubstraitLiteral {
4374
val _fp64: Double => SExpression.Literal = fp64(false, _)
4475
val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait
4576
val _date: Int => SExpression.Literal = date(false, _)
77+
val _timestamp: Long => SExpression.Literal =
78+
precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
79+
val _timestampTz: Long => SExpression.Literal =
80+
precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
81+
val _intervalDay: Long => SExpression.Literal = (ms: Long) =>
82+
intervalDay(false, 0, 0, ms, Util.MICROSECOND_PRECISION)
83+
val _intervalYear: Int => SExpression.Literal = (m: Int) => intervalYear(false, m % 12, m / 12)
4684
val _string: String => SExpression.Literal = string(false, _)
85+
val _binary: Array[Byte] => SExpression.Literal = binary(false, _)
86+
val _array: (ArrayData, DataType, Boolean) => SExpression.Literal = sparkArray2Substrait
87+
val _map: (MapData, DataType, DataType, Boolean) => SExpression.Literal = sparkMap2Substrait
4788
}
4889

4990
private def convertWithValue(literal: Literal): Option[SExpression.Literal] = {
@@ -59,7 +100,16 @@ class ToSubstraitLiteral {
59100
case Literal(d: Decimal, dataType: DecimalType) =>
60101
Nonnull._decimal(d, dataType.precision, dataType.scale)
61102
case Literal(d: Integer, DateType) => Nonnull._date(d)
103+
case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t)
104+
case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t)
105+
case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d)
106+
case Literal(ym: Int, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym)
62107
case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString)
108+
case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b)
109+
case Literal(a: ArrayData, ArrayType(et, containsNull)) =>
110+
Nonnull._array(a, et, containsNull)
111+
case Literal(m: MapData, MapType(keyType, valueType, valueContainsNull)) =>
112+
Nonnull._map(m, keyType, valueType, valueContainsNull)
63113
case _ => null
64114
}
65115
)

spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
252252
}
253253

254254
override def visit(emptyScan: relation.EmptyScan): LogicalPlan = {
255-
LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema))
255+
LocalRelation(ToSubstraitType.toAttributeSeq(emptyScan.getInitialSchema))
256256
}
257257
override def visit(namedScan: relation.NamedScan): LogicalPlan = {
258258
resolve(UnresolvedRelation(namedScan.getNames.asScala)) match {

spark/src/main/scala/io/substrait/utils/Util.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@ import scala.collection.mutable.ArrayBuffer
2121

2222
object Util {
2323

24+
val SECONDS_PER_DAY: Long = 24 * 60 * 60;
25+
val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types
26+
27+
def toMicroseconds(value: Long, precision: Int): Long = {
28+
// Spark uses microseconds as a Long value for most time things
29+
val factor = MICROSECOND_PRECISION - precision
30+
if (factor == 0) {
31+
value
32+
} else if (factor > 0) {
33+
value * math.pow(10, factor).toLong
34+
} else {
35+
value / math.pow(10, -factor).toLong
36+
}
37+
}
38+
2439
/**
2540
* Compute the cartesian product for n lists.
2641
*

0 commit comments

Comments
 (0)