Skip to content

Commit 960f44e

Browse files
authored
fix(query): fix incorrect order of group by items with CTE or subquery (#18692)
* fix(query): fix incorrect order of group by items with CTE or subquery * fix(query): fix incorrect order of group by items with CTE or subquery
1 parent 846b83c commit 960f44e

File tree

4 files changed

+89
-16
lines changed

4 files changed

+89
-16
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,5 @@ benchmark/clickbench/results
8080
# lychee
8181
.lycheecache
8282

83-
tests/nox/cache
83+
# tmp
84+
tmp

src/query/expression/src/aggregate/aggregate_hashtable.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use std::sync::atomic::Ordering;
1818
use std::sync::Arc;
1919

2020
use bumpalo::Bump;
21+
use databend_common_exception::ErrorCode;
2122
use databend_common_exception::Result;
2223

2324
use super::partitioned_payload::PartitionedPayload;
@@ -176,6 +177,20 @@ impl AggregateHashTable {
176177
agg_states: ProjectedBlock,
177178
row_count: usize,
178179
) -> Result<usize> {
180+
#[cfg(debug_assertions)]
181+
{
182+
for (i, group_column) in group_columns.iter().enumerate() {
183+
if group_column.data_type() != self.payload.group_types[i] {
184+
return Err(ErrorCode::UnknownException(format!(
185+
"group_column type not match in index {}, expect: {:?}, actual: {:?}",
186+
i,
187+
self.payload.group_types[i],
188+
group_column.data_type()
189+
)));
190+
}
191+
}
192+
}
193+
179194
state.row_count = row_count;
180195
group_hash_columns(group_columns, &mut state.group_hashes);
181196

src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_hierarchical_grouping_sets.rs

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,16 @@ const ID: RuleID = RuleID::HierarchicalGroupingSetsToUnion;
6666
pub struct RuleHierarchicalGroupingSetsToUnion {
6767
id: RuleID,
6868
matchers: Vec<Matcher>,
69+
cte_channel_size: usize,
6970
}
7071

7172
impl RuleHierarchicalGroupingSetsToUnion {
72-
pub fn new(_ctx: Arc<OptimizerContext>) -> Self {
73+
pub fn new(ctx: Arc<OptimizerContext>) -> Self {
74+
let cte_channel_size = ctx
75+
.get_table_ctx()
76+
.get_settings()
77+
.get_grouping_sets_channel_size()
78+
.unwrap();
7379
Self {
7480
id: ID,
7581
matchers: vec![Matcher::MatchOp {
@@ -79,21 +85,36 @@ impl RuleHierarchicalGroupingSetsToUnion {
7985
children: vec![Matcher::Leaf],
8086
}],
8187
}],
88+
cte_channel_size: cte_channel_size as usize,
8289
}
8390
}
8491

8592
/// Analyzes grouping sets to build a true hierarchical dependency DAG
86-
fn build_hierarchy_dag(&self, grouping_sets: &[Vec<IndexType>]) -> HierarchyDAG {
93+
fn build_hierarchy_dag(
94+
&self,
95+
grouping_sets: &[Vec<IndexType>],
96+
agg: &Aggregate,
97+
) -> HierarchyDAG {
8798
let mut levels: Vec<GroupingLevel> = grouping_sets
8899
.iter()
89100
.enumerate()
90-
.map(|(idx, set)| GroupingLevel {
91-
set_index: idx,
92-
columns: set.clone(),
93-
direct_children: Vec::new(),
94-
possible_parents: Vec::new(),
95-
chosen_parent: None,
96-
level: set.len(),
101+
.map(|(idx, set)| {
102+
// Sort columns according to their order in group_items for consistent schema ordering
103+
let mut sorted_columns = set.clone();
104+
sorted_columns.sort_by_key(|&col_idx| {
105+
agg.group_items
106+
.iter()
107+
.position(|item| item.index == col_idx)
108+
.unwrap_or(usize::MAX) // Put unknown columns at the end
109+
});
110+
GroupingLevel {
111+
set_index: idx,
112+
columns: sorted_columns,
113+
direct_children: Vec::new(),
114+
possible_parents: Vec::new(),
115+
chosen_parent: None,
116+
level: set.len(),
117+
}
97118
})
98119
.collect();
99120

@@ -406,7 +427,7 @@ impl RuleHierarchicalGroupingSetsToUnion {
406427
cte_name: cte_name.to_string(),
407428
cte_output_columns: None,
408429
ref_count: 1,
409-
channel_size: None,
430+
channel_size: Some(self.cte_channel_size),
410431
}
411432
.into(),
412433
),
@@ -457,7 +478,7 @@ impl RuleHierarchicalGroupingSetsToUnion {
457478
cte_name: cte_name.to_string(),
458479
cte_output_columns: None,
459480
ref_count: 1,
460-
channel_size: None,
481+
channel_size: Some(self.cte_channel_size),
461482
}
462483
.into(),
463484
),
@@ -497,7 +518,7 @@ impl RuleHierarchicalGroupingSetsToUnion {
497518
cte_name: cte_name.to_string(),
498519
cte_output_columns: None,
499520
ref_count: 1,
500-
channel_size: None,
521+
channel_size: Some(self.cte_channel_size),
501522
}
502523
.into(),
503524
),
@@ -632,10 +653,11 @@ impl RuleHierarchicalGroupingSetsToUnion {
632653
// Create parent CTE consumer
633654
let parent_output_columns: Vec<IndexType> = {
634655
let mut output_cols = Vec::new();
635-
// Then: aggregate function output columns
656+
// First: aggregate function output columns
636657
for agg_item in &agg.aggregate_functions {
637658
output_cols.push(agg_item.index);
638659
}
660+
// Then: parent level columns (already sorted from build_hierarchy_dag)
639661
for &col_idx in &parent_level.columns {
640662
output_cols.push(col_idx);
641663
}
@@ -666,7 +688,7 @@ impl RuleHierarchicalGroupingSetsToUnion {
666688
cte_name: cte_name.to_string(),
667689
cte_output_columns: None,
668690
ref_count: 1,
669-
channel_size: None,
691+
channel_size: Some(self.cte_channel_size),
670692
}
671693
.into(),
672694
),
@@ -850,7 +872,7 @@ impl Rule for RuleHierarchicalGroupingSetsToUnion {
850872
}
851873

852874
// Build hierarchy DAG
853-
let hierarchy = self.build_hierarchy_dag(&grouping_sets.sets);
875+
let hierarchy = self.build_hierarchy_dag(&grouping_sets.sets, &agg);
854876
// Check if we have meaningful hierarchical structure
855877
let hierarchical_levels = hierarchy
856878
.levels

tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,41 @@ a A 1 4 NULL A
261261
a B 1 5 NULL B
262262
a A 1 5 NULL NULL
263263

264+
265+
## group with CTE
266+
query ?TTT
267+
WITH cte_0 AS
268+
(
269+
SELECT try_cast(number % 10 AS String) AS media_source,
270+
try_cast(number % 12 AS String) AS site_name,
271+
try_cast(number AS Float64) AS bi_cost,
272+
try_cast((today() + ( number % 11)) AS Date) AS created_at
273+
FROM numbers(1000)
274+
),
275+
cte_1 AS
276+
(
277+
SELECT sum(bi_cost) AS bi_cost_agg,
278+
media_source,
279+
created_at,
280+
site_name
281+
FROM cte_0
282+
GROUP BY cube (media_source, created_at, site_name)
283+
HAVING 1 = 1)SELECT *
284+
FROM cte_1 order by bi_cost_agg desc LIMIT 10;
285+
---
286+
----
287+
499500.0 NULL NULL NULL
288+
50400.0 9 NULL NULL
289+
50300.0 8 NULL NULL
290+
50200.0 7 NULL NULL
291+
50100.0 6 NULL NULL
292+
50000.0 5 NULL NULL
293+
49900.0 4 NULL NULL
294+
49800.0 3 NULL NULL
295+
49700.0 2 NULL NULL
296+
49600.0 1 NULL NULL
297+
298+
264299
statement ok
265300
drop table t all;
266301

0 commit comments

Comments
 (0)