Skip to content

Commit 9938575

Browse files
feat: add MakeDecimal support to spark module
The Spark query optimiser injects an internal function (MakeDecimal) when numeric literals appear in a query. This commit adds support for this, which drastically improves the pass rate for the TPC-DS test suite. Signed-off-by: Andrew Coleman <[email protected]>
1 parent f22b3d0 commit 9938575

File tree

4 files changed

+49
-21
lines changed

4 files changed

+49
-21
lines changed

spark/src/main/resources/spark.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ scalar_functions:
3232
- args:
3333
- value: DECIMAL<P,S>
3434
return: i64
35+
-
36+
name: make_decimal
37+
description: >-
38+
Return the Decimal value of an unscaled Long.
39+
Note: this expression is internal and created only by the optimizer,
40+
impls:
41+
- args:
42+
- value: i64
43+
return: DECIMAL<P,S>

spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class FunctionMappings {
5858
s[Year]("year"),
5959

6060
// internal
61+
s[MakeDecimal]("make_decimal"),
6162
s[UnscaledValue]("unscaled")
6263
)
6364

spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ package io.substrait.spark.expression
1818

1919
import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, ToSubstraitType}
2020
import io.substrait.spark.logical.ToLogicalPlan
21-
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery}
22-
import org.apache.spark.sql.types.{Decimal, NullType}
21+
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
22+
import org.apache.spark.sql.types.Decimal
2323
import org.apache.spark.unsafe.types.UTF8String
2424
import io.substrait.`type`.{StringTypeVisitor, Type}
2525
import io.substrait.{expression => exp}
@@ -131,23 +131,32 @@ class ToSparkExpression(
131131
arg.accept(expr.declaration(), i, this)
132132
}
133133

134-
scalarFunctionConverter
135-
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
136-
.flatMap(sig => Option(sig.makeCall(args)))
137-
.getOrElse({
138-
val msg = String.format(
139-
"Unable to convert scalar function %s(%s).",
140-
expr.declaration.name,
141-
expr.arguments.asScala
142-
.map {
143-
case ea: exp.EnumArg => ea.value.toString
144-
case e: SExpression => e.getType.accept(new StringTypeVisitor)
145-
case t: Type => t.accept(new StringTypeVisitor)
146-
case a => throw new IllegalStateException("Unexpected value: " + a)
147-
}
148-
.mkString(", ")
149-
)
150-
throw new IllegalArgumentException(msg)
151-
})
134+
expr.declaration.name match {
135+
case "make_decimal" => expr.outputType match {
136+
// Need special case handing of this internal function (not nice, I know).
137+
// Because the precision and scale arguments are extracted from the output type,
138+
// we can't use the generic scalar function conversion mechanism here.
139+
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
140+
case _ => throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
141+
}
142+
case _ => scalarFunctionConverter
143+
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
144+
.flatMap(sig => Option(sig.makeCall(args)))
145+
.getOrElse({
146+
val msg = String.format(
147+
"Unable to convert scalar function %s(%s).",
148+
expr.declaration.name,
149+
expr.arguments.asScala
150+
.map {
151+
case ea: exp.EnumArg => ea.value.toString
152+
case e: SExpression => e.getType.accept(new StringTypeVisitor)
153+
case t: Type => t.accept(new StringTypeVisitor)
154+
case a => throw new IllegalStateException("Unexpected value: " + a)
155+
}
156+
.mkString(", ")
157+
)
158+
throw new IllegalArgumentException(msg)
159+
})
160+
}
152161
}
153162
}

spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,16 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
3232
}
3333

3434
// "q9" failed in spark 3.3
35-
val successfulSQL: Set[String] = Set("q4", "q7", "q18", "q22", "q26", "q28", "q29", "q37", "q41", "q48", "q50", "q62", "q69", "q82", "q85", "q88", "q90", "q93", "q96", "q97", "q99")
35+
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7",
36+
"q11", "q13", "q15", "q16", "q18", "q19",
37+
"q22", "q25", "q26", "q28", "q29",
38+
"q30", "q31", "q32", "q37",
39+
"q41", "q42", "q43", "q46", "q48",
40+
"q50", "q52", "q55", "q58", "q59",
41+
"q61", "q62", "q65", "q68", "q69",
42+
"q79",
43+
"q81", "q82", "q85", "q88",
44+
"q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q99")
3645

3746
tpcdsQueries.foreach {
3847
q =>

0 commit comments

Comments
 (0)