Skip to content

Commit 1291abc

Browse files
committed
feat: Support PERCENTILE_CONT planning
1 parent dcf3e4a commit 1291abc

25 files changed

+557
-91
lines changed

Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-cli/Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true }
4444
ordered-float = "2.10"
4545
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "a03d4eef5640e05dddf99fc2357ad6d58b5337cb", features = ["arrow"], optional = true }
4646
pyo3 = { version = "0.16", optional = true }
47-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
47+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" }

datafusion/core/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7"
7979
pyo3 = { version = "0.16", optional = true }
8080
rand = "0.8"
8181
smallvec = { version = "1.6", features = ["union"] }
82-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
82+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" }
8383
tempfile = "3"
8484
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] }
8585
tokio-stream = "0.1"

datafusion/core/src/logical_plan/expr_rewriter.rs

+13-5
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,19 @@ impl ExprRewritable for Expr {
252252
args,
253253
fun,
254254
distinct,
255-
} => Expr::AggregateFunction {
256-
args: rewrite_vec(args, rewriter)?,
257-
fun,
258-
distinct,
259-
},
255+
within_group,
256+
} => {
257+
let within_group = match within_group {
258+
Some(within_group) => Some(rewrite_vec(within_group, rewriter)?),
259+
None => None,
260+
};
261+
Expr::AggregateFunction {
262+
args: rewrite_vec(args, rewriter)?,
263+
fun,
264+
distinct,
265+
within_group,
266+
}
267+
}
260268
Expr::GroupingSet(grouping_set) => match grouping_set {
261269
GroupingSet::Rollup(exprs) => {
262270
Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))

datafusion/core/src/logical_plan/expr_schema.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,23 @@ impl ExprSchemable for Expr {
9292
.collect::<Result<Vec<_>>>()?;
9393
window_function::return_type(fun, &data_types)
9494
}
95-
Expr::AggregateFunction { fun, args, .. } => {
95+
Expr::AggregateFunction {
96+
fun,
97+
args,
98+
within_group,
99+
..
100+
} => {
96101
let data_types = args
97102
.iter()
98103
.map(|e| e.get_type(schema))
99104
.collect::<Result<Vec<_>>>()?;
100-
aggregate_function::return_type(fun, &data_types)
105+
let within_group = within_group
106+
.as_ref()
107+
.unwrap_or(&vec![])
108+
.iter()
109+
.map(|e| e.get_type(schema))
110+
.collect::<Result<Vec<_>>>()?;
111+
aggregate_function::return_type(fun, &data_types, &within_group)
101112
}
102113
Expr::AggregateUDF { fun, args, .. } => {
103114
let data_types = args

datafusion/core/src/logical_plan/expr_visitor.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,24 @@ impl ExprVisitable for Expr {
179179
Expr::ScalarFunction { args, .. }
180180
| Expr::ScalarUDF { args, .. }
181181
| Expr::TableUDF { args, .. }
182-
| Expr::AggregateFunction { args, .. }
183182
| Expr::AggregateUDF { args, .. } => args
184183
.iter()
185184
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
185+
Expr::AggregateFunction {
186+
args, within_group, ..
187+
} => {
188+
let visitor = args
189+
.iter()
190+
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
191+
let visitor = if let Some(within_group) = within_group.as_ref() {
192+
within_group
193+
.iter()
194+
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?
195+
} else {
196+
visitor
197+
};
198+
Ok(visitor)
199+
}
186200
Expr::WindowFunction {
187201
args,
188202
partition_by,

datafusion/core/src/optimizer/single_distinct_to_groupby.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
8080
fun: fun.clone(),
8181
args: vec![col(SINGLE_DISTINCT_ALIAS)],
8282
distinct: false,
83+
within_group: None,
8384
}
8485
}
8586
_ => agg_expr.clone(),
@@ -168,13 +169,21 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> bool {
168169
.iter()
169170
.filter(|expr| {
170171
let mut is_distinct = false;
171-
if let Expr::AggregateFunction { distinct, args, .. } = expr {
172+
let mut is_within_group = false;
173+
if let Expr::AggregateFunction {
174+
distinct,
175+
args,
176+
within_group,
177+
..
178+
} = expr
179+
{
172180
is_distinct = *distinct;
181+
is_within_group = within_group.is_some();
173182
args.iter().for_each(|expr| {
174183
fields_set.insert(expr.name(input.schema()).unwrap());
175184
})
176185
}
177-
is_distinct
186+
is_distinct && !is_within_group
178187
})
179188
.count()
180189
== aggr_expr.len()
@@ -314,6 +323,7 @@ mod tests {
314323
fun: aggregates::AggregateFunction::Max,
315324
distinct: true,
316325
args: vec![col("b")],
326+
within_group: None,
317327
},
318328
],
319329
)?

datafusion/core/src/optimizer/utils.rs

+26-6
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,14 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
339339
Expr::ScalarFunction { args, .. }
340340
| Expr::ScalarUDF { args, .. }
341341
| Expr::TableUDF { args, .. }
342-
| Expr::AggregateFunction { args, .. }
343342
| Expr::AggregateUDF { args, .. } => Ok(args.clone()),
343+
Expr::AggregateFunction {
344+
args, within_group, ..
345+
} => Ok(args
346+
.iter()
347+
.chain(within_group.as_ref().unwrap_or(&vec![]))
348+
.cloned()
349+
.collect()),
344350
Expr::GroupingSet(grouping_set) => match grouping_set {
345351
GroupingSet::Rollup(exprs) => Ok(exprs.clone()),
346352
GroupingSet::Cube(exprs) => Ok(exprs.clone()),
@@ -517,11 +523,25 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
517523
})
518524
}
519525
}
520-
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
521-
fun: fun.clone(),
522-
args: expressions.to_vec(),
523-
distinct: *distinct,
524-
}),
526+
Expr::AggregateFunction {
527+
fun,
528+
distinct,
529+
args,
530+
..
531+
} => {
532+
let args_limit = args.len();
533+
let within_group = if expressions.len() > args_limit {
534+
Some(expressions[args_limit..].to_vec())
535+
} else {
536+
None
537+
};
538+
Ok(Expr::AggregateFunction {
539+
fun: fun.clone(),
540+
args: expressions[..args_limit].to_vec(),
541+
distinct: *distinct,
542+
within_group,
543+
})
544+
}
525545
Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF {
526546
fun: fun.clone(),
527547
args: expressions.to_vec(),

0 commit comments

Comments
 (0)