@@ -18,8 +18,8 @@ package io.substrait.spark.expression
18
18
19
19
import io .substrait .spark .{DefaultExpressionVisitor , HasOutputStack , ToSubstraitType }
20
20
import io .substrait .spark .logical .ToLogicalPlan
21
- import org .apache .spark .sql .catalyst .expressions .{CaseWhen , Cast , Expression , In , Literal , NamedExpression , ScalarSubquery }
22
- import org .apache .spark .sql .types .{ Decimal , NullType }
21
+ import org .apache .spark .sql .catalyst .expressions .{CaseWhen , Cast , Expression , In , Literal , MakeDecimal , NamedExpression , ScalarSubquery }
22
+ import org .apache .spark .sql .types .Decimal
23
23
import org .apache .spark .unsafe .types .UTF8String
24
24
import io .substrait .`type` .{StringTypeVisitor , Type }
25
25
import io .substrait .{expression => exp }
@@ -131,23 +131,32 @@ class ToSparkExpression(
131
131
arg.accept(expr.declaration(), i, this )
132
132
}
133
133
134
- scalarFunctionConverter
135
- .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
136
- .flatMap(sig => Option (sig.makeCall(args)))
137
- .getOrElse({
138
- val msg = String .format(
139
- " Unable to convert scalar function %s(%s)." ,
140
- expr.declaration.name,
141
- expr.arguments.asScala
142
- .map {
143
- case ea : exp.EnumArg => ea.value.toString
144
- case e : SExpression => e.getType.accept(new StringTypeVisitor )
145
- case t : Type => t.accept(new StringTypeVisitor )
146
- case a => throw new IllegalStateException (" Unexpected value: " + a)
147
- }
148
- .mkString(" , " )
149
- )
150
- throw new IllegalArgumentException (msg)
151
- })
134
+ expr.declaration.name match {
135
+ case " make_decimal" => expr.outputType match {
136
+ // Need special case handing of this internal function (not nice, I know).
137
+ // Because the precision and scale arguments are extracted from the output type,
138
+ // we can't use the generic scalar function conversion mechanism here.
139
+ case d : Type .Decimal => MakeDecimal (args.head, d.precision, d.scale)
140
+ case _ => throw new IllegalArgumentException (" Output type of MakeDecimal must be a decimal type" )
141
+ }
142
+ case _ => scalarFunctionConverter
143
+ .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
144
+ .flatMap(sig => Option (sig.makeCall(args)))
145
+ .getOrElse({
146
+ val msg = String .format(
147
+ " Unable to convert scalar function %s(%s)." ,
148
+ expr.declaration.name,
149
+ expr.arguments.asScala
150
+ .map {
151
+ case ea : exp.EnumArg => ea.value.toString
152
+ case e : SExpression => e.getType.accept(new StringTypeVisitor )
153
+ case t : Type => t.accept(new StringTypeVisitor )
154
+ case a => throw new IllegalStateException (" Unexpected value: " + a)
155
+ }
156
+ .mkString(" , " )
157
+ )
158
+ throw new IllegalArgumentException (msg)
159
+ })
160
+ }
152
161
}
153
162
}
0 commit comments