diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index 962864980..1fd76204a 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -29,6 +29,8 @@ import scala.collection.JavaConverters.asScalaBufferConverter private class ToSparkType extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") { + override def visit(expr: Type.I8): DataType = ByteType + override def visit(expr: Type.I16): DataType = ShortType override def visit(expr: Type.I32): DataType = IntegerType override def visit(expr: Type.I64): DataType = LongType diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 8bde45116..c315bda8c 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -45,6 +45,15 @@ class ToSparkExpression( Literal.FalseLiteral } } + + override def visit(expr: SExpression.I8Literal): Expression = { + Literal(expr.value().asInstanceOf[Byte], ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.I16Literal): Expression = { + Literal(expr.value().asInstanceOf[Short], ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.I32Literal): Expression = { Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } diff --git a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala index 61d28201f..5246c0069 100644 --- a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala +++ b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala @@ -5,7 +5,7 @@ import io.substrait.spark.expression.{ToSparkExpression, ToSubstraitLiteral} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.MapData -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, StringType, TimestampNTZType, TimestampType, YearMonthIntervalType} +import org.apache.spark.sql.types._ import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String @@ -16,6 +16,8 @@ class TypesAndLiteralsSuite extends SparkFunSuite { val toSparkExpression = new ToSparkExpression(null, null) val types: Seq[DataType] = List( + ByteType, + ShortType, IntegerType, LongType, FloatType, @@ -52,6 +54,8 @@ class TypesAndLiteralsSuite extends SparkFunSuite { val defaultLiterals: Seq[Literal] = types.map(Literal.default) val literals: Seq[Literal] = List( + Literal(1.toByte), + Literal(1.toShort), Literal(1), Literal(1L), Literal(1.0f),