Skip to content

Commit 6bb46ac

Browse files
feat(spark): add some numeric function mappings (#317)
Signed-off-by: Andrew Coleman <[email protected]>
1 parent 7a9ac66 commit 6bb46ac

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ class FunctionMappings {
3939
s[Subtract]("subtract"),
4040
s[Multiply]("multiply"),
4141
s[Divide]("divide"),
42+
s[Abs]("abs"),
43+
s[Remainder]("modulus"),
44+
s[Pow]("power"),
45+
s[Exp]("exp"),
46+
s[Sqrt]("sqrt"),
47+
s[Sin]("sin"),
48+
s[Cos]("cos"),
49+
s[Tan]("tan"),
50+
s[Asin]("asin"),
51+
s[Acos]("acos"),
52+
s[Atan]("atan"),
53+
s[Atan2]("atan2"),
54+
s[Sinh]("sinh"),
55+
s[Cosh]("cosh"),
56+
s[Tanh]("tanh"),
57+
s[Asinh]("asinh"),
58+
s[Acosh]("acosh"),
59+
s[Atanh]("atanh"),
60+
s[Log]("ln"),
61+
s[Log10]("log10"),
4262
s[And]("and"),
4363
s[Or]("or"),
4464
s[Not]("not"),
@@ -77,7 +97,8 @@ class FunctionMappings {
7797
s[Min]("min"),
7898
s[Max]("max"),
7999
s[First]("any_value"),
80-
s[HyperLogLogPlusPlus]("approx_count_distinct")
100+
s[HyperLogLogPlusPlus]("approx_count_distinct"),
101+
s[StddevSamp]("std_dev")
81102
)
82103

83104
val WINDOW_SIGS: Seq[Sig] = Seq(
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.substrait.spark
2+
3+
import org.apache.spark.SparkFunSuite
4+
import org.apache.spark.sql.test.SharedSparkSession
5+
6+
class NumericSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase {
7+
8+
override def beforeAll(): Unit = {
9+
super.beforeAll()
10+
sparkContext.setLogLevel("WARN")
11+
}
12+
13+
test("basic") {
14+
assertSqlSubstraitRelRoundTrip(
15+
"select sqrt(abs(num)), mod(num, 2) from (values (-5), (7.4)) as table(num)"
16+
)
17+
}
18+
19+
test("exponentials") {
20+
assertSqlSubstraitRelRoundTrip(
21+
"select power(num, 3), exp(num), ln(num), log10(num) from (values (5), (17)) as table(num)"
22+
)
23+
}
24+
25+
test("trig") {
26+
assertSqlSubstraitRelRoundTrip(
27+
"select sin(num), cos(num), tan(num) from (values (30), (90)) as table(num)"
28+
)
29+
assertSqlSubstraitRelRoundTrip(
30+
"select asin(num), acos(num), atan(num) from (values (0.5), (-0.5)) as table(num)"
31+
)
32+
assertSqlSubstraitRelRoundTrip(
33+
"select sinh(num), cosh(num), tanh(num) from (values (30), (90)) as table(num)"
34+
)
35+
assertSqlSubstraitRelRoundTrip(
36+
"select asinh(num), acosh(num), atanh(num) from (values (0.5), (-0.5)) as table(num)"
37+
)
38+
}
39+
40+
}

spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,18 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
3232
}
3333

3434
// spotless:off
35-
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q5", "q7", "q8",
36-
"q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19",
37-
"q20", "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29",
38-
"q30", "q31", "q32", "q33", "q36", "q37", "q38",
39-
"q40", "q41", "q42", "q43", "q44", "q46", "q48", "q49",
40-
"q50", "q52", "q54", "q55", "q56", "q58", "q59",
41-
"q60", "q61", "q62", "q65", "q66", "q67", "q68", "q69",
42-
"q70", "q71", "q73", "q76", "q77", "q79",
43-
"q80", "q81", "q82", "q85", "q86", "q87", "q88",
44-
"q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99")
35+
val failingSQL: Set[String] = Set(
36+
"q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713
37+
"q9", // requires implementation of named_struct()
38+
"q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type)
39+
"q51", "q83", "q84", // TBD
40+
"q72" //requires implementation of date_add()
41+
)
4542
// spotless:on
4643

4744
tpcdsQueries.foreach {
4845
q =>
49-
if (runAllQueriesIncludeFailed || successfulSQL.contains(q)) {
46+
if (runAllQueriesIncludeFailed || !failingSQL.contains(q)) {
5047
test(s"check simplified (tpcds-v1.4/$q)") {
5148
testQuery("tpcds", q)
5249
}

0 commit comments

Comments
 (0)