Skip to content

Commit c7a2896

Browse files
fix(spark): casting date/time requires timezone
When casting a date type to a string, Spark requires that a timezone is specified, otherwise it will not resolve the logical plan. The timezone is ignored for non-date/time values. Signed-off-by: Andrew Coleman <[email protected]>
1 parent 6bb46ac commit c7a2896

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtens
2020
import io.substrait.spark.logical.ToLogicalPlan
2121

2222
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
23-
import org.apache.spark.sql.types.Decimal
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types.{DateType, Decimal}
2425
import org.apache.spark.substrait.SparkTypeUtil
2526
import org.apache.spark.unsafe.types.UTF8String
2627

@@ -153,7 +154,12 @@ class ToSparkExpression(
153154

154155
override def visit(expr: SExpression.Cast): Expression = {
155156
val childExp = expr.input().accept(this)
156-
Cast(childExp, ToSubstraitType.convert(expr.getType))
157+
val tt = ToSubstraitType.convert(expr.getType)
158+
val tz = childExp.dataType match {
159+
case DateType => Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE))
160+
case _ => None
161+
}
162+
Cast(childExp, tt, tz)
157163
}
158164

159165
override def visit(expr: exp.FieldReference): Expression = {
@@ -197,6 +203,7 @@ class ToSparkExpression(
197203
val list = expr.options().asScala.map(e => e.accept(this))
198204
In(value, list)
199205
}
206+
200207
override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = {
201208
val eArgs = expr.arguments().asScala
202209
val args = eArgs.zipWithIndex.map {

spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
3636
"q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713
3737
"q9", // requires implementation of named_struct()
3838
"q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type)
39-
"q51", "q83", "q84", // TBD
39+
"q51", "q84", // TBD
4040
"q72" //requires implementation of date_add()
4141
)
4242
// spotless:on

0 commit comments

Comments
 (0)