Skip to content

Commit 9022ac3

Browse files
committed
feat: Support PERCENTILE_CONT planning
1 parent bba28d6 commit 9022ac3

File tree

13 files changed

+298
-12
lines changed

13 files changed

+298
-12
lines changed

Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-cli/Cargo.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true }
4444
ordered-float = "2.10"
4545
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "8fd2aa80114d5c0d4e6a0c370729507a4424e7b3", features = ["arrow"], optional = true }
4646
pyo3 = { version = "0.16", optional = true }
47-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
47+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" }

datafusion/core/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7"
7979
pyo3 = { version = "0.16", optional = true }
8080
rand = "0.8"
8181
smallvec = { version = "1.6", features = ["union"] }
82-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
82+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" }
8383
tempfile = "3"
8484
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] }
8585
tokio-stream = "0.1"

datafusion/core/src/physical_plan/aggregates.rs

+8
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ pub fn create_aggregate_expr(
239239
.to_string(),
240240
));
241241
}
242+
(AggregateFunction::PercentileCont, _) => {
243+
Arc::new(expressions::PercentileCont::new(
244+
// Pass in the desired percentile expr
245+
name,
246+
coerced_phy_exprs,
247+
return_type,
248+
)?)
249+
}
242250
(AggregateFunction::ApproxMedian, false) => {
243251
Arc::new(expressions::ApproxMedian::new(
244252
coerced_phy_exprs[0].clone(),

datafusion/core/src/sql/planner.rs

+46-5
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ use datafusion_expr::expr::GroupingSet;
5656
use sqlparser::ast::{
5757
ArrayAgg, BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr,
5858
Fetch, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
59-
ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator,
59+
ObjectName, Offset as SQLOffset, PercentileCont, Query, Select, SelectItem, SetExpr, SetOperator,
6060
ShowStatementFilter, TableFactor, TableWithJoins, TrimWhereField, UnaryOperator,
6161
Value, Values as SQLValues,
6262
};
@@ -1440,22 +1440,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
14401440

14411441
let order_by_rex = order_by
14421442
.into_iter()
1443-
.map(|e| self.order_by_to_sort_expr(e, plan.schema()))
1443+
.map(|e| self.order_by_to_sort_expr(e, plan.schema(), true))
14441444
.collect::<Result<Vec<_>>>()?;
14451445

14461446
LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build()
14471447
}
14481448

14491449
/// convert sql OrderByExpr to Expr::Sort
1450-
fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema) -> Result<Expr> {
1450+
fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema, parse_indexes: bool) -> Result<Expr> {
14511451
let OrderByExpr {
14521452
asc,
14531453
expr,
14541454
nulls_first,
14551455
} = e;
14561456

14571457
let expr = match expr {
1458-
SQLExpr::Value(Value::Number(v, _)) => {
1458+
SQLExpr::Value(Value::Number(v, _)) if parse_indexes => {
14591459
let field_index = v
14601460
.parse::<usize>()
14611461
.map_err(|err| DataFusionError::Plan(err.to_string()))?;
@@ -2313,7 +2313,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
23132313
let order_by = window
23142314
.order_by
23152315
.into_iter()
2316-
.map(|e| self.order_by_to_sort_expr(e, schema))
2316+
.map(|e| self.order_by_to_sort_expr(e, schema, true))
23172317
.collect::<Result<Vec<_>>>()?;
23182318
let window_frame = window
23192319
.window_frame
@@ -2441,6 +2441,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
24412441

24422442
SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema),
24432443

2444+
SQLExpr::PercentileCont(percentile_cont) => self.parse_percentile_cont(percentile_cont, schema),
2445+
24442446
_ => Err(DataFusionError::NotImplemented(format!(
24452447
"Unsupported ast node {:?} in sqltorel",
24462448
sql
@@ -2494,6 +2496,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
24942496
})
24952497
}
24962498

2499+
fn parse_percentile_cont(
2500+
&self,
2501+
percentile_cont: PercentileCont,
2502+
input_schema: &DFSchema,
2503+
) -> Result<Expr> {
2504+
let PercentileCont {
2505+
expr,
2506+
within_group,
2507+
} = percentile_cont;
2508+
2509+
// Some dialects have special syntax for percentile_cont. DataFusion only supports it like a function.
2510+
let expr = self.sql_expr_to_logical_expr(*expr, input_schema)?;
2511+
let (order_by_expr, asc, nulls_first) = match self.order_by_to_sort_expr(*within_group, input_schema, false)? {
2512+
Expr::Sort { expr, asc, nulls_first } => (expr, asc, nulls_first),
2513+
_ => return Err(DataFusionError::Internal("PercentileCont expected Sort expression in ORDER BY".to_string())),
2514+
};
2515+
let asc_expr = Expr::Literal(ScalarValue::Boolean(Some(asc)));
2516+
let nulls_first_expr = Expr::Literal(ScalarValue::Boolean(Some(nulls_first)));
2517+
2518+
let args = vec![expr, *order_by_expr, asc_expr, nulls_first_expr];
2519+
// next, aggregate built-ins
2520+
let fun = aggregates::AggregateFunction::PercentileCont;
2521+
2522+
Ok(Expr::AggregateFunction {
2523+
fun,
2524+
distinct: false,
2525+
args,
2526+
})
2527+
}
2528+
24972529
fn function_args_to_expr(
24982530
&self,
24992531
args: Vec<FunctionArg>,
@@ -4133,6 +4165,15 @@ mod tests {
41334165
quick_test(sql, expected);
41344166
}
41354167

4168+
#[test]
4169+
fn select_percentile_cont() {
4170+
let sql = "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY age) FROM person";
4171+
let expected = "Projection: #PERCENTILECONT(Float64(0.5),person.age,Boolean(true),Boolean(false))\
4172+
\n Aggregate: groupBy=[[]], aggr=[[PERCENTILECONT(Float64(0.5), #person.age, Boolean(true), Boolean(false))]]\
4173+
\n TableScan: person projection=None";
4174+
quick_test(sql, expected);
4175+
}
4176+
41364177
#[test]
41374178
fn select_scalar_func() {
41384179
let sql = "SELECT sqrt(age) FROM person";

datafusion/expr/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ path = "src/lib.rs"
3838
ahash = { version = "0.7", default-features = false }
3939
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "8fd2aa80114d5c0d4e6a0c370729507a4424e7b3", features = ["prettyprint"] }
4040
datafusion-common = { path = "../common", version = "7.0.0" }
41-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
41+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" }

datafusion/expr/src/aggregate_function.rs

+39
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ pub enum AggregateFunction {
8484
ApproxPercentileCont,
8585
/// Approximate continuous percentile function with weight
8686
ApproxPercentileContWithWeight,
87+
/// Continuous percentile function
88+
PercentileCont,
8789
/// ApproxMedian
8890
ApproxMedian,
8991
/// BoolAnd
@@ -124,6 +126,7 @@ impl FromStr for AggregateFunction {
124126
"approx_percentile_cont_with_weight" => {
125127
AggregateFunction::ApproxPercentileContWithWeight
126128
}
129+
"percentile_cont" => AggregateFunction::PercentileCont,
127130
"approx_median" => AggregateFunction::ApproxMedian,
128131
"bool_and" => AggregateFunction::BoolAnd,
129132
"bool_or" => AggregateFunction::BoolOr,
@@ -178,6 +181,7 @@ pub fn return_type(
178181
AggregateFunction::ApproxPercentileContWithWeight => {
179182
Ok(coerced_data_types[0].clone())
180183
}
184+
AggregateFunction::PercentileCont => Ok(coerced_data_types[1].clone()),
181185
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
182186
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => Ok(DataType::Boolean),
183187
}
@@ -324,6 +328,33 @@ pub fn coerce_types(
324328
}
325329
Ok(input_types.to_vec())
326330
}
331+
AggregateFunction::PercentileCont => {
332+
if !matches!(input_types[0], DataType::Float64) {
333+
return Err(DataFusionError::Plan(format!(
334+
"The percentile argument for {:?} must be Float64, not {:?}.",
335+
agg_fun, input_types[0]
336+
)));
337+
}
338+
if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
339+
return Err(DataFusionError::Plan(format!(
340+
"The function {:?} does not support inputs of type {:?}.",
341+
agg_fun, input_types[1]
342+
)));
343+
}
344+
if !matches!(input_types[2], DataType::Boolean) {
345+
return Err(DataFusionError::Plan(format!(
346+
"The asc argument for {:?} must be Boolean, not {:?}.",
347+
agg_fun, input_types[2]
348+
)));
349+
}
350+
if !matches!(input_types[3], DataType::Boolean) {
351+
return Err(DataFusionError::Plan(format!(
352+
"The nulls_first argument for {:?} must be Boolean, not {:?}.",
353+
agg_fun, input_types[3]
354+
)));
355+
}
356+
Ok(input_types.to_vec())
357+
}
327358
AggregateFunction::ApproxMedian => {
328359
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
329360
return Err(DataFusionError::Plan(format!(
@@ -395,6 +426,14 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
395426
.collect(),
396427
Volatility::Immutable,
397428
),
429+
AggregateFunction::PercentileCont => Signature::one_of(
430+
// Accept a float64 percentile paired with any numeric value, plus bool values
431+
NUMERICS
432+
.iter()
433+
.map(|t| TypeSignature::Exact(vec![DataType::Float64, t.clone(), DataType::Boolean, DataType::Boolean]))
434+
.collect(),
435+
Volatility::Immutable,
436+
),
398437
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
399438
Signature::exact(vec![DataType::Boolean], Volatility::Immutable)
400439
}

datafusion/physical-expr/src/expressions/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod not;
4949
mod nth_value;
5050
mod nullif;
5151
mod outer_column;
52+
mod percentile_cont;
5253
mod rank;
5354
mod row_number;
5455
mod stats;
@@ -95,6 +96,7 @@ pub use not::{not, NotExpr};
9596
pub use nth_value::NthValue;
9697
pub use nullif::nullif_func;
9798
pub use outer_column::OuterColumn;
99+
pub use percentile_cont::PercentileCont;
98100
pub use rank::{dense_rank, percent_rank, rank};
99101
pub use row_number::RowNumber;
100102
pub use stats::StatsType;

0 commit comments

Comments
 (0)