diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 5ae4679af..c4d95238d 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -348,6 +348,8 @@ pub enum Expr { ListAgg(ListAgg), /// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)` ArrayAgg(ArrayAgg), + /// The `WITHIN GROUP` expr `... WITHIN GROUP (ORDER BY ...)` + WithinGroup(WithinGroup), /// The `GROUPING SETS` expr. GroupingSets(Vec>), /// The `CUBE` expr. @@ -549,6 +551,7 @@ impl fmt::Display for Expr { Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s), Expr::ListAgg(listagg) => write!(f, "{}", listagg), Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg), + Expr::WithinGroup(withingroup) => write!(f, "{}", withingroup), Expr::GroupingSets(sets) => { write!(f, "GROUPING SETS (")?; let mut sep = ""; @@ -2420,7 +2423,6 @@ pub struct ListAgg { pub expr: Box, pub separator: Option>, pub on_overflow: Option, - pub within_group: Vec, } impl fmt::Display for ListAgg { @@ -2438,13 +2440,6 @@ impl fmt::Display for ListAgg { write!(f, "{}", on_overflow)?; } write!(f, ")")?; - if !self.within_group.is_empty() { - write!( - f, - " WITHIN GROUP (ORDER BY {})", - display_comma_separated(&self.within_group) - )?; - } Ok(()) } } @@ -2494,7 +2489,6 @@ pub struct ArrayAgg { pub expr: Box, pub order_by: Option>, pub limit: Option>, - pub within_group: bool, // order by is used inside a within group or not } impl fmt::Display for ArrayAgg { @@ -2505,20 +2499,33 @@ impl fmt::Display for ArrayAgg { if self.distinct { "DISTINCT " } else { "" }, self.expr )?; - if !self.within_group { - if let Some(order_by) = &self.order_by { - write!(f, " ORDER BY {}", order_by)?; - } - if let Some(limit) = &self.limit { - write!(f, " LIMIT {}", limit)?; - } + if let Some(order_by) = &self.order_by { + write!(f, " ORDER BY {}", order_by)?; } - write!(f, ")")?; - if self.within_group { - if let Some(order_by) = &self.order_by { - write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?; - } + if let Some(limit) = &self.limit { + write!(f, " LIMIT {}", limit)?; } + write!(f, ")")?; + Ok(()) + } +} + +/// A `WITHIN GROUP` invocation ` WITHIN GROUP (ORDER BY )` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct WithinGroup { + pub expr: Box, + pub order_by: Vec, +} + +impl fmt::Display for WithinGroup { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} WITHIN GROUP (ORDER BY {})", + self.expr, + display_comma_separated(&self.order_by), + )?; Ok(()) } } diff --git a/src/parser.rs b/src/parser.rs index 753c6a11d..2204b7821 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -627,14 +627,33 @@ impl<'a> Parser<'a> { None }; - Ok(Expr::Function(Function { + let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) { + self.expect_token(&Token::LParen)?; + self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?; + let order_by_expr = self.parse_comma_separated(Parser::parse_order_by_expr)?; + self.expect_token(&Token::RParen)?; + Some(order_by_expr) + } else { + None + }; + + let function = Expr::Function(Function { name, args, over, distinct, special: false, approximate: false, - })) + }); + + Ok(if let Some(within_group) = within_group { + Expr::WithinGroup(WithinGroup { + expr: Box::new(function), + order_by: within_group, + }) + } else { + function + }) } pub fn parse_time_functions(&mut self, name: ObjectName) -> Result { @@ -995,17 +1014,24 @@ impl<'a> Parser<'a> { self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?; let order_by_expr = self.parse_comma_separated(Parser::parse_order_by_expr)?; self.expect_token(&Token::RParen)?; - order_by_expr + Some(order_by_expr) } else { - vec![] + None }; - Ok(Expr::ListAgg(ListAgg { + let list_agg = Expr::ListAgg(ListAgg { distinct, expr, separator, on_overflow, - within_group, - })) + }); + Ok(if let Some(within_group) = within_group { + Expr::WithinGroup(WithinGroup { + expr: Box::new(list_agg), + order_by: within_group, + }) + } else { + list_agg + }) } pub fn parse_array_agg_expr(&mut self) -> Result { @@ -1031,7 +1057,6 @@ impl<'a> Parser<'a> { expr, order_by, limit, - within_group: false, })); } // Snowflake defines ORDERY BY in within group instead of inside the function like @@ -1042,18 +1067,25 @@ impl<'a> Parser<'a> { self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?; let order_by_expr = self.parse_order_by_expr()?; self.expect_token(&Token::RParen)?; - Some(Box::new(order_by_expr)) + Some(order_by_expr) } else { None }; - Ok(Expr::ArrayAgg(ArrayAgg { + let array_agg = Expr::ArrayAgg(ArrayAgg { distinct, expr, - order_by: within_group, + order_by: None, limit: None, - within_group: true, - })) + }); + Ok(if let Some(within_group) = within_group { + Expr::WithinGroup(WithinGroup { + expr: Box::new(array_agg), + order_by: vec![within_group], + }) + } else { + array_agg + }) } // This function parses date/time fields for both the EXTRACT function-like diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 96071f1f2..f0b90a09d 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1702,14 +1702,16 @@ fn parse_listagg() { }, ]; assert_eq!( - &Expr::ListAgg(ListAgg { - distinct: true, - expr, - separator: Some(Box::new(Expr::Value(Value::SingleQuotedString( - ", ".to_string() - )))), - on_overflow, - within_group + &Expr::WithinGroup(WithinGroup { + expr: Box::new(Expr::ListAgg(ListAgg { + distinct: true, + expr, + separator: Some(Box::new(Expr::Value(Value::SingleQuotedString( + ", ".to_string() + )))), + on_overflow, + })), + order_by: within_group }), expr_from_projection(only(&select.projection)) ); @@ -1736,6 +1738,41 @@ fn parse_array_agg_func() { } } +#[test] +fn parse_within_group() { + let sql = "SELECT PERCENTILE_CONT(0.0) WITHIN GROUP (ORDER BY name ASC NULLS FIRST)"; + let select = verified_only_select(sql); + + #[cfg(feature = "bigdecimal")] + let value = bigdecimal::BigDecimal::from(0); + #[cfg(not(feature = "bigdecimal"))] + let value = "0.0".to_string(); + let expr = Expr::Value(Value::Number(value, false)); + let function = Expr::Function(Function { + name: ObjectName(vec![Ident::new("PERCENTILE_CONT")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))], + over: None, + distinct: false, + special: false, + approximate: false, + }); + let within_group = vec![OrderByExpr { + expr: Expr::Identifier(Ident { + value: "name".to_string(), + quote_style: None, + }), + asc: Some(true), + nulls_first: Some(true), + }]; + assert_eq!( + &Expr::WithinGroup(WithinGroup { + expr: Box::new(function), + order_by: within_group + }), + expr_from_projection(only(&select.projection)) + ); +} + #[test] fn parse_create_table() { let sql = "CREATE TABLE uk_cities (\