Skip to content

Commit 82cb729

Browse files
feat(spark): add Window support
To support the OVER clause in SQL Signed-off-by: Andrew Coleman <[email protected]>
1 parent b8ccd8b commit 82cb729

12 files changed

+358
-23
lines changed

spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,19 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
152152
})
153153
}
154154

155+
override def visit(window: ConsistentPartitionWindow): String = {
156+
withBuilder(window, 10)(
157+
builder => {
158+
builder
159+
.append("functions=")
160+
.append(window.getWindowFunctions)
161+
.append("partitions=")
162+
.append(window.getPartitionExpressions)
163+
.append("sorts=")
164+
.append(window.getSorts)
165+
})
166+
}
167+
155168
override def visit(localFiles: LocalFiles): String = {
156169
withBuilder(localFiles, 10)(
157170
builder => {

spark/src/main/scala/io/substrait/spark/SparkExtension.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package io.substrait.spark
1818

19-
import io.substrait.spark.expression.ToAggregateFunction
19+
import io.substrait.spark.expression.{ToAggregateFunction, ToWindowFunction}
2020

2121
import io.substrait.extension.SimpleExtension
2222

@@ -43,4 +43,8 @@ object SparkExtension {
4343

4444
val toAggregateFunction: ToAggregateFunction = ToAggregateFunction(
4545
JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.aggregateFunctions()))
46+
47+
val toWindowFunction: ToWindowFunction = ToWindowFunction(
48+
JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.windowFunctions())
49+
)
4650
}

spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import io.substrait.spark.ToSubstraitType
2121
import org.apache.spark.internal.Logging
2222
import org.apache.spark.sql.catalyst.SQLConfHelper
2323
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
24-
import org.apache.spark.sql.catalyst.expressions.Expression
24+
import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression}
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2626
import org.apache.spark.sql.types.DataType
2727

@@ -238,7 +238,6 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
238238
val parent: FunctionConverter[F, T]) {
239239

240240
def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = {
241-
242241
val opTypes = operands.map(_.getType)
243242
val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable)
244243
val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE))
@@ -250,17 +249,23 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
250249
.map(name + ":" + _)
251250
.find(k => directMap.contains(k))
252251

253-
if (directMatchKey.isDefined) {
252+
if (operands.isEmpty) {
253+
val variant = directMap(name + ":")
254+
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
255+
Option(parent.generateBinding(expression, variant, operands, outputType))
256+
} else if (directMatchKey.isDefined) {
254257
val variant = directMap(directMatchKey.get)
255258
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
256259
val funcArgs: Seq[FunctionArg] = operands
257260
Option(parent.generateBinding(expression, variant, funcArgs, outputType))
258261
} else if (singularInputType.isDefined) {
259-
val types = expression match {
260-
case agg: AggregateExpression => agg.aggregateFunction.children.map(_.dataType)
261-
case other => other.children.map(_.dataType)
262+
val children = expression match {
263+
case agg: AggregateExpression => agg.aggregateFunction.children
264+
case win: WindowExpression => win.windowFunction.children
265+
case other => other.children
262266
}
263-
val nullable = expression.children.exists(e => e.nullable)
267+
val types = children.map(_.dataType)
268+
val nullable = children.exists(e => e.nullable)
264269
FunctionFinder
265270
.leastRestrictive(types)
266271
.flatMap(
@@ -298,6 +303,4 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
298303
}
299304
})
300305
}
301-
302-
def allowedArgCount(count: Int): Boolean = true
303306
}

spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,22 @@ class FunctionMappings {
8080
s[HyperLogLogPlusPlus]("approx_count_distinct")
8181
)
8282

83+
val WINDOW_SIGS: Seq[Sig] = Seq(
84+
s[RowNumber]("row_number"),
85+
s[Rank]("rank"),
86+
s[DenseRank]("dense_rank"),
87+
s[PercentRank]("percent_rank"),
88+
s[CumeDist]("cume_dist"),
89+
s[NTile]("ntile"),
90+
s[Lead]("lead"),
91+
s[Lag]("lag"),
92+
s[NthValue]("nth_value")
93+
)
94+
8395
lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap
8496
lazy val aggregate_functions_map: Map[Class[_], Sig] =
8597
AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap
98+
lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap
8699
}
87100

88101
object FunctionMappings extends FunctionMappings

spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ abstract class ToAggregateFunction(functions: Seq[SimpleExtension.AggregateFunct
5353
expression: AggregateExpression,
5454
operands: Seq[SExpression]): Option[AggregateFunctionInvocation] = {
5555
Option(signatures.get(expression.aggregateFunction.getClass))
56-
.filter(m => m.allowedArgCount(2))
5756
.flatMap(m => m.attemptMatch(expression, operands))
5857
}
5958

spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVar
4242

4343
def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = {
4444
Option(signatures.get(expression.getClass))
45-
.filter(m => m.allowedArgCount(2))
4645
.flatMap(m => m.attemptMatch(expression, operands))
4746
}
4847
}
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package io.substrait.spark.expression
18+
19+
import io.substrait.spark.expression.ToWindowFunction.fromSpark
20+
21+
import org.apache.spark.sql.catalyst.expressions.{CurrentRow, Expression, FrameType, Literal, OffsetWindowFunction, RangeFrame, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnspecifiedFrame, WindowExpression, WindowFrame, WindowSpecDefinition}
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23+
import org.apache.spark.sql.types.{IntegerType, LongType}
24+
25+
import io.substrait.`type`.Type
26+
import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg, WindowBound}
27+
import io.substrait.expression.Expression.WindowBoundsType
28+
import io.substrait.expression.WindowBound.{CURRENT_ROW, UNBOUNDED, WindowBoundVisitor}
29+
import io.substrait.extension.SimpleExtension
30+
import io.substrait.relation.ConsistentPartitionWindow.WindowRelFunctionInvocation
31+
32+
import scala.collection.JavaConverters
33+
34+
abstract class ToWindowFunction(functions: Seq[SimpleExtension.WindowFunctionVariant])
35+
extends FunctionConverter[SimpleExtension.WindowFunctionVariant, WindowRelFunctionInvocation](
36+
functions) {
37+
38+
override def generateBinding(
39+
sparkExp: Expression,
40+
function: SimpleExtension.WindowFunctionVariant,
41+
arguments: Seq[FunctionArg],
42+
outputType: Type): WindowRelFunctionInvocation = {
43+
44+
val (frameType, lower, upper) = sparkExp match {
45+
case WindowExpression(_: OffsetWindowFunction, _) =>
46+
(WindowBoundsType.ROWS, UNBOUNDED, CURRENT_ROW)
47+
case WindowExpression(
48+
_,
49+
WindowSpecDefinition(_, _, SpecifiedWindowFrame(frameType, lower, upper))) =>
50+
(fromSpark(frameType), fromSpark(lower), fromSpark(upper))
51+
case WindowExpression(_, WindowSpecDefinition(_, orderSpec, UnspecifiedFrame)) =>
52+
if (orderSpec.isEmpty) {
53+
(WindowBoundsType.ROWS, UNBOUNDED, UNBOUNDED)
54+
} else {
55+
(WindowBoundsType.RANGE, UNBOUNDED, CURRENT_ROW)
56+
}
57+
58+
case _ => throw new UnsupportedOperationException(s"Unsupported window expression: $sparkExp")
59+
}
60+
61+
ExpressionCreator.windowRelFunction(
62+
function,
63+
outputType,
64+
SExpression.AggregationPhase.INITIAL_TO_RESULT, // use defaults...
65+
SExpression.AggregationInvocation.ALL, // Spark doesn't define these
66+
frameType,
67+
lower,
68+
upper,
69+
JavaConverters.asJavaIterable(arguments)
70+
)
71+
}
72+
73+
def convert(
74+
expression: WindowExpression,
75+
operands: Seq[SExpression]): Option[WindowRelFunctionInvocation] = {
76+
val cls = expression.windowFunction match {
77+
case agg: AggregateExpression => agg.aggregateFunction.getClass
78+
case other => other.getClass
79+
}
80+
81+
Option(signatures.get(cls))
82+
.flatMap(m => m.attemptMatch(expression, operands))
83+
}
84+
85+
def apply(
86+
expression: WindowExpression,
87+
operands: Seq[SExpression]): WindowRelFunctionInvocation = {
88+
convert(expression, operands).getOrElse(throw new UnsupportedOperationException(
89+
s"Unable to find binding for call ${expression.windowFunction} -- $operands -- $expression"))
90+
}
91+
}
92+
93+
object ToWindowFunction {
94+
def fromSpark(frameType: FrameType): WindowBoundsType = frameType match {
95+
case RowFrame => WindowBoundsType.ROWS
96+
case RangeFrame => WindowBoundsType.RANGE
97+
case other => throw new UnsupportedOperationException(s"Unsupported bounds type: $other.")
98+
}
99+
100+
def fromSpark(bound: Expression): WindowBound = bound match {
101+
case UnboundedPreceding => WindowBound.UNBOUNDED
102+
case UnboundedFollowing => WindowBound.UNBOUNDED
103+
case CurrentRow => WindowBound.CURRENT_ROW
104+
case e: Literal =>
105+
e.dataType match {
106+
case IntegerType | LongType =>
107+
val offset = e.eval().asInstanceOf[Int]
108+
if (offset < 0) WindowBound.Preceding.of(-offset)
109+
else if (offset == 0) WindowBound.CURRENT_ROW
110+
else WindowBound.Following.of(offset)
111+
}
112+
case _ => throw new UnsupportedOperationException(s"Unexpected bound: $bound")
113+
}
114+
115+
def toSparkFrame(
116+
boundsType: WindowBoundsType,
117+
lowerBound: WindowBound,
118+
upperBound: WindowBound): WindowFrame = {
119+
val frameType = boundsType match {
120+
case WindowBoundsType.ROWS => RowFrame
121+
case WindowBoundsType.RANGE => RangeFrame
122+
case WindowBoundsType.UNSPECIFIED => return UnspecifiedFrame
123+
}
124+
SpecifiedWindowFrame(
125+
frameType,
126+
toSparkBound(lowerBound, isLower = true),
127+
toSparkBound(upperBound, isLower = false))
128+
}
129+
130+
private def toSparkBound(bound: WindowBound, isLower: Boolean): Expression = {
131+
bound.accept(new WindowBoundVisitor[Expression, Exception] {
132+
133+
override def visit(preceding: WindowBound.Preceding): Expression =
134+
Literal(-preceding.offset().intValue())
135+
136+
override def visit(following: WindowBound.Following): Expression =
137+
Literal(following.offset().intValue())
138+
139+
override def visit(currentRow: WindowBound.CurrentRow): Expression = CurrentRow
140+
141+
override def visit(unbounded: WindowBound.Unbounded): Expression =
142+
if (isLower) UnboundedPreceding else UnboundedFollowing
143+
})
144+
}
145+
146+
def apply(functions: Seq[SimpleExtension.WindowFunctionVariant]): ToWindowFunction = {
147+
new ToWindowFunction(functions) {
148+
override def getSigs: Seq[Sig] =
149+
FunctionMappings.WINDOW_SIGS ++ FunctionMappings.AGGREGATE_SIGS
150+
}
151+
}
152+
153+
}

spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,56 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
119119
}
120120
}
121121

122+
override def visit(window: relation.ConsistentPartitionWindow): LogicalPlan = {
123+
val child = window.getInput.accept(this)
124+
withChild(child) {
125+
val partitions = window.getPartitionExpressions.asScala
126+
.map(expr => expr.accept(expressionConverter))
127+
val sortOrders = window.getSorts.asScala.map(toSortOrder)
128+
val windowExpressions = window.getWindowFunctions.asScala
129+
.map(
130+
func => {
131+
val arguments = func.arguments().asScala.zipWithIndex.map {
132+
case (arg, i) =>
133+
arg.accept(func.declaration(), i, expressionConverter)
134+
}
135+
val windowFunction = SparkExtension.toWindowFunction
136+
.getSparkExpressionFromSubstraitFunc(func.declaration.key, func.outputType)
137+
.map(sig => sig.makeCall(arguments))
138+
.map {
139+
case win: WindowFunction => win
140+
case agg: AggregateFunction =>
141+
AggregateExpression(
142+
agg,
143+
ToAggregateFunction.toSpark(func.aggregationPhase()),
144+
ToAggregateFunction.toSpark(func.invocation()),
145+
None)
146+
}
147+
.getOrElse({
148+
val msg = String.format(
149+
"Unable to convert Window function %s(%s).",
150+
func.declaration.name,
151+
func.arguments.asScala
152+
.map {
153+
case ea: exp.EnumArg => ea.value.toString
154+
case e: SExpression => e.getType.accept(new StringTypeVisitor)
155+
case t: Type => t.accept(new StringTypeVisitor)
156+
case a => throw new IllegalStateException("Unexpected value: " + a)
157+
}
158+
.mkString(", ")
159+
)
160+
throw new IllegalArgumentException(msg)
161+
})
162+
val frame =
163+
ToWindowFunction.toSparkFrame(func.boundsType(), func.lowerBound(), func.upperBound())
164+
val spec = WindowSpecDefinition(partitions, sortOrders, frame)
165+
WindowExpression(windowFunction, spec)
166+
})
167+
.map(toNamedExpression)
168+
Window(windowExpressions, partitions, sortOrders, child)
169+
}
170+
}
171+
122172
override def visit(join: relation.Join): LogicalPlan = {
123173
val left = join.getLeft.accept(this)
124174
val right = join.getRight.accept(this)
@@ -162,6 +212,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
162212
}
163213
SortOrder(expression, direction, nullOrdering, Seq.empty)
164214
}
215+
165216
override def visit(fetch: relation.Fetch): LogicalPlan = {
166217
val child = fetch.getInput.accept(this)
167218
val limit = fetch.getCount.getAsLong.intValue()
@@ -180,6 +231,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
180231
Offset(toLiteral(offset), child)
181232
}
182233
}
234+
183235
override def visit(sort: relation.Sort): LogicalPlan = {
184236
val child = sort.getInput.accept(this)
185237
withChild(child) {

0 commit comments

Comments
 (0)