Skip to content

Commit

Permalink
feat(spark): additional type and literal support (#311)
Browse files Browse the repository at this point in the history
Support for binary, list, map, intervals and precision timestamps literals and types

BREAKING CHANGE: Spark TimestampNTZType is now emitted as Substrait PrecisionTimestamp
BREAKING CHANGE: Spark TimestampType is now emitted as Substrait PrecisionTimestampTZ

feat(core): added support for Expression.EmptyMapLiteral
  • Loading branch information
Blizzara authored Oct 25, 2024
1 parent 75ebac2 commit 513a049
Show file tree
Hide file tree
Showing 15 changed files with 353 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ public OUTPUT visit(Expression.MapLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.EmptyMapLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,25 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract static class EmptyMapLiteral implements Literal {
public abstract Type keyType();

public abstract Type valueType();

public Type getType() {
return Type.withNullability(nullable()).map(keyType(), valueType());
}

public static ImmutableExpression.EmptyMapLiteral.Builder builder() {
return ImmutableExpression.EmptyMapLiteral.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class ListLiteral implements Literal {
public abstract List<Literal> values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ public static Expression.MapLiteral map(
return Expression.MapLiteral.builder().nullable(nullable).putAllValues(values).build();
}

public static Expression.EmptyMapLiteral emptyMap(
boolean nullable, Type keyType, Type valueType) {
return Expression.EmptyMapLiteral.builder()
.keyType(keyType)
.valueType(valueType)
.nullable(nullable)
.build();
}

public static Expression.ListLiteral list(boolean nullable, Expression.Literal... values) {
return Expression.ListLiteral.builder().nullable(nullable).addValues(values).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.MapLiteral expr) throws E;

R visit(Expression.EmptyMapLiteral expr) throws E;

R visit(Expression.ListLiteral expr) throws E;

R visit(Expression.EmptyListLiteral expr) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,21 @@ public Expression visit(io.substrait.expression.Expression.MapLiteral expr) {
});
}

@Override
public Expression visit(io.substrait.expression.Expression.EmptyMapLiteral expr) {
return lit(
bldr -> {
var protoMapType = expr.getType().accept(typeProtoConverter);
bldr.setEmptyMap(protoMapType.getMap())
// For empty maps, the Literal message's own nullable field should be ignored
// in favor of the nullability of the Type.Map in the literal's
// empty_map field. But for safety we set the literal's nullable field
// to match in case any readers either look in the wrong location
// or want to verify that they are consistent.
.setNullable(expr.nullable());
});
}

@Override
public Expression visit(io.substrait.expression.Expression.ListLiteral expr) {
return lit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
literal.getNullable(),
literal.getMap().getKeyValuesList().stream()
.collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue()))));
case EMPTY_MAP -> {
// literal.getNullable() is intentionally ignored in favor of the nullability
// specified in the literal.getEmptyMap() type.
var mapType = protoTypeConverter.fromMap(literal.getEmptyMap());
yield ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value());
}
case UUID -> ExpressionCreator.uuid(literal.getNullable(), literal.getUuid());
case NULL -> ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull()));
case LIST -> ExpressionCreator.list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ public Optional<Expression> visit(Expression.MapLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.EmptyMapLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/type/TypeCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public Type.ListType list(Type type) {
return Type.ListType.builder().nullable(nullable).elementType(type).build();
}

public Type map(Type key, Type value) {
public Type.Map map(Type key, Type value) {
return Type.Map.builder().nullable(nullable).key(key).value(value).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public Type from(io.substrait.proto.Type type) {
.map(this::from)
.collect(java.util.stream.Collectors.toList()));
case LIST -> fromList(type.getList());
case MAP -> n(type.getMap().getNullability())
.map(from(type.getMap().getKey()), from(type.getMap().getValue()));
case MAP -> fromMap(type.getMap());
case USER_DEFINED -> {
var userDefined = type.getUserDefined();
var t = lookup.getType(userDefined.getTypeReference(), extensions);
Expand All @@ -74,6 +73,10 @@ public Type.ListType fromList(io.substrait.proto.Type.List list) {
return n(list.getNullability()).list(from(list.getType()));
}

public Type.Map fromMap(io.substrait.proto.Type.Map map) {
return n(map.getNullability()).map(from(map.getKey()), from(map.getValue()));
}

public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) {
return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,8 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
override def visit(expr: Expression.UserDefinedLiteral): String = {
expr.toString
}

override def visit(expr: Expression.EmptyMapLiteral): String = {
expr.toString
}
}
35 changes: 32 additions & 3 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,38 @@ private class ToSparkType

override def visit(expr: Type.Str): DataType = StringType

override def visit(expr: Type.Binary): DataType = BinaryType

override def visit(expr: Type.FixedChar): DataType = StringType

override def visit(expr: Type.VarChar): DataType = StringType

override def visit(expr: Type.Bool): DataType = BooleanType

override def visit(expr: Type.PrecisionTimestamp): DataType = {
Util.assertMicroseconds(expr.precision())
TimestampNTZType
}
override def visit(expr: Type.PrecisionTimestampTZ): DataType = {
Util.assertMicroseconds(expr.precision())
TimestampType
}

override def visit(expr: Type.IntervalDay): DataType = {
Util.assertMicroseconds(expr.precision())
DayTimeIntervalType.DEFAULT
}

override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())

override def visit(expr: Type.Map): DataType =
MapType(
expr.key().accept(this),
expr.value().accept(this),
valueContainsNull = expr.value().nullable())
}
class ToSubstraitType {

Expand Down Expand Up @@ -81,10 +108,12 @@ class ToSubstraitType {
case charType: CharType => Some(creator.fixedChar(charType.length))
case varcharType: VarcharType => Some(creator.varChar(varcharType.length))
case StringType => Some(creator.STRING)
case DateType => Some(creator.DATE)
case TimestampType => Some(creator.TIMESTAMP)
case TimestampNTZType => Some(creator.TIMESTAMP_TZ)
case BinaryType => Some(creator.BINARY)
case DateType => Some(creator.DATE)
case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION))
case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION))
case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION))
case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR)
case ArrayType(elementType, containsNull) =>
convert(elementType, Seq.empty, containsNull).map(creator.list)
case MapType(keyType, valueType, valueContainsNull) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.DecimalUtil
import io.substrait.utils.Util

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter}

class ToSparkExpression(
val scalarFunctionConverter: ToScalarFunction,
Expand Down Expand Up @@ -61,6 +62,10 @@ class ToSparkExpression(
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.FP32Literal): Literal = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.FP64Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}
Expand All @@ -77,15 +82,71 @@ class ToSparkExpression(
Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.BinaryLiteral): Literal = {
Literal(expr.value().toByteArray, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.DecimalLiteral): Expression = {
val value = expr.value.toByteArray
val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16)
Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.DateLiteral): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = {
// Spark timestamps are stored as a microseconds Long
Util.assertMicroseconds(expr.precision())
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = {
// Spark timestamps are stored as a microseconds Long
Util.assertMicroseconds(expr.precision())
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalDayLiteral): Literal = {
Util.assertMicroseconds(expr.precision())
// Spark uses a single microseconds Long as the "physical" type for DayTimeInterval
val micros =
(expr.days() * Util.SECONDS_PER_DAY + expr.seconds()) * Util.MICROS_PER_SECOND +
expr.subseconds()
Literal(micros, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalYearLiteral): Literal = {
// Spark uses a single months Int as the "physical" type for YearMonthInterval
val months = expr.years() * 12 + expr.months()
Literal(months, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.ListLiteral): Literal = {
val array = expr.values().asScala.map(value => value.accept(this).asInstanceOf[Literal].value)
Literal.create(array, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.EmptyListLiteral): Expression = {
Literal.default(ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.MapLiteral): Literal = {
val map = expr.values().asScala.map {
case (key, value) =>
(
key.accept(this).asInstanceOf[Literal].value,
value.accept(this).asInstanceOf[Literal].value
)
}
Literal.create(map, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.EmptyMapLiteral): Literal = {
Literal.default(ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.NullLiteral): Expression = {
Literal(null, ToSubstraitType.convert(expr.getType))
}
Expand All @@ -98,6 +159,7 @@ class ToSparkExpression(
override def visit(expr: exp.FieldReference): Expression = {
withFieldReference(expr)(i => currentOutput(i).clone())
}

override def visit(expr: SExpression.IfThen): Expression = {
val branches = expr
.ifClauses()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ package io.substrait.spark.expression
import io.substrait.spark.ToSubstraitType

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.expression.{Expression => SExpression}
import io.substrait.expression.ExpressionCreator._
import io.substrait.utils.Util

import scala.collection.JavaConverters

class ToSubstraitLiteral {

Expand All @@ -34,6 +38,34 @@ class ToSubstraitLiteral {
scale: Int): SExpression.Literal =
decimal(false, d.toJavaBigDecimal, precision, scale)

private def sparkArray2Substrait(
arrayData: ArrayData,
elementType: DataType,
containsNull: Boolean): SExpression.Literal = {
val elements = arrayData.array.map(any => apply(Literal(any, elementType)))
if (elements.isEmpty) {
return emptyList(false, ToSubstraitType.convert(elementType, nullable = containsNull).get)
}
list(false, JavaConverters.asJavaIterable(elements)) // TODO: handle containsNull
}

private def sparkMap2Substrait(
mapData: MapData,
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean): SExpression.Literal = {
val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType)))
val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType)))
if (keys.isEmpty) {
return emptyMap(
false,
ToSubstraitType.convert(keyType, nullable = false).get,
ToSubstraitType.convert(valueType, nullable = valueContainsNull).get)
}
// TODO: handle valueContainsNull
map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap))
}

val _bool: Boolean => SExpression.Literal = bool(false, _)
val _i8: Byte => SExpression.Literal = i8(false, _)
val _i16: Short => SExpression.Literal = i16(false, _)
Expand All @@ -43,7 +75,21 @@ class ToSubstraitLiteral {
val _fp64: Double => SExpression.Literal = fp64(false, _)
val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait
val _date: Int => SExpression.Literal = date(false, _)
val _timestamp: Long => SExpression.Literal =
precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestampTz: Long => SExpression.Literal =
precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _intervalDay: Long => SExpression.Literal = (ms: Long) => {
val days = (ms / Util.MICROS_PER_SECOND / Util.SECONDS_PER_DAY).toInt
val seconds = (ms / Util.MICROS_PER_SECOND % Util.SECONDS_PER_DAY).toInt
val micros = ms % Util.MICROS_PER_SECOND
intervalDay(false, days, seconds, micros, Util.MICROSECOND_PRECISION)
}
val _intervalYear: Int => SExpression.Literal = (m: Int) => intervalYear(false, m / 12, m % 12)
val _string: String => SExpression.Literal = string(false, _)
val _binary: Array[Byte] => SExpression.Literal = binary(false, _)
val _array: (ArrayData, DataType, Boolean) => SExpression.Literal = sparkArray2Substrait
val _map: (MapData, DataType, DataType, Boolean) => SExpression.Literal = sparkMap2Substrait
}

private def convertWithValue(literal: Literal): Option[SExpression.Literal] = {
Expand All @@ -59,7 +105,16 @@ class ToSubstraitLiteral {
case Literal(d: Decimal, dataType: DecimalType) =>
Nonnull._decimal(d, dataType.precision, dataType.scale)
case Literal(d: Integer, DateType) => Nonnull._date(d)
case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t)
case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t)
case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d)
case Literal(ym: Int, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym)
case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString)
case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b)
case Literal(a: ArrayData, ArrayType(et, containsNull)) =>
Nonnull._array(a, et, containsNull)
case Literal(m: MapData, MapType(keyType, valueType, valueContainsNull)) =>
Nonnull._map(m, keyType, valueType, valueContainsNull)
case _ => null
}
)
Expand Down
Loading

0 comments on commit 513a049

Please sign in to comment.