From 03c7079ef1cf061058153e89cec2f5cfe8af22a6 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 09:07:46 +0800 Subject: [PATCH 01/17] Expose group hash feedback from GroupValues --- .../physical-expr-common/src/binary_map.rs | 72 +++++++++++--- .../src/binary_view_map.rs | 66 ++++++++++--- .../benches/dictionary_group_values.rs | 22 ++++- .../physical-plan/benches/multi_group_by.rs | 3 +- .../aggregates/aggregate_hash_table/common.rs | 10 ++ .../aggregate_hash_table/final_table.rs | 9 +- .../aggregate_hash_table/partial_table.rs | 20 ++-- .../src/aggregates/group_values/mod.rs | 14 ++- .../group_values/multi_group_by/mod.rs | 70 +++++++------ .../src/aggregates/group_values/row.rs | 60 ++++++------ .../group_values/single_group_by/boolean.rs | 19 +++- .../group_values/single_group_by/bytes.rs | 39 ++++++-- .../single_group_by/bytes_view.rs | 33 +++++-- .../group_values/single_group_by/primitive.rs | 97 +++++++++++++------ .../physical-plan/src/aggregates/row_hash.rs | 32 +++++- .../physical-plan/src/recursive_query.rs | 10 +- 16 files changed, 416 insertions(+), 160 deletions(-) diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index ad184d6500d56..40a2c14269c16 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -306,8 +306,62 @@ where values.data_type(), DataType::Binary | DataType::LargeBinary )); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(values.len(), 0); + create_hashes([values], &self.random_state, &mut self.hashes_buffer) + .unwrap(); + let hashes = std::mem::take(&mut self.hashes_buffer); self.insert_if_new_inner::>( values, + &hashes, + make_payload_fn, + observe_payload_fn, + ); + self.hashes_buffer = hashes + } + OutputType::Utf8 => { + assert!(matches!( + values.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(values.len(), 0); + create_hashes([values], &self.random_state, &mut self.hashes_buffer) + .unwrap(); + let hashes = std::mem::take(&mut self.hashes_buffer); + self.insert_if_new_inner::>( + values, + &hashes, + make_payload_fn, + observe_payload_fn, + ); + self.hashes_buffer = hashes + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + /// Inserts each value from `values` into the map using precomputed hashes. + pub fn insert_if_new_with_hashes( + &mut self, + values: &ArrayRef, + hashes: &[u64], + make_payload_fn: MP, + observe_payload_fn: OP, + ) where + MP: FnMut(Option<&[u8]>) -> V, + OP: FnMut(V), + { + assert_eq!(values.len(), hashes.len()); + match self.output_type { + OutputType::Binary => { + assert!(matches!( + values.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.insert_if_new_inner::>( + values, + hashes, make_payload_fn, observe_payload_fn, ) @@ -319,6 +373,7 @@ where )); self.insert_if_new_inner::>( values, + hashes, make_payload_fn, observe_payload_fn, ) @@ -338,6 +393,7 @@ where fn insert_if_new_inner( &mut self, values: &ArrayRef, + hashes: &[u64], mut make_payload_fn: MP, mut observe_payload_fn: OP, ) where @@ -345,22 +401,10 @@ where OP: FnMut(V), B: ByteArrayType, { - // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes([values], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); - - // step 2: insert each value into the set, if not already present let values = values.as_bytes::(); + assert_eq!(values.len(), hashes.len()); - // Ensure lengths are equivalent - assert_eq!(values.len(), batch_hashes.len()); - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + for (value, &hash) in values.iter().zip(hashes.iter()) { // handle null value let Some(value) = value else { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 9d4b556393a24..6d901e8753743 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -216,8 +216,56 @@ where match self.output_type { OutputType::BinaryView => { assert!(matches!(values.data_type(), DataType::BinaryView)); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(values.len(), 0); + create_hashes([values], &self.random_state, &mut self.hashes_buffer) + .unwrap(); + let hashes = std::mem::take(&mut self.hashes_buffer); self.insert_if_new_inner::( values, + &hashes, + make_payload_fn, + observe_payload_fn, + ); + self.hashes_buffer = hashes + } + OutputType::Utf8View => { + assert!(matches!(values.data_type(), DataType::Utf8View)); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(values.len(), 0); + create_hashes([values], &self.random_state, &mut self.hashes_buffer) + .unwrap(); + let hashes = std::mem::take(&mut self.hashes_buffer); + self.insert_if_new_inner::( + values, + &hashes, + make_payload_fn, + observe_payload_fn, + ); + self.hashes_buffer = hashes + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), + }; + } + + /// Inserts each value from `values` into the map using precomputed hashes. + pub fn insert_if_new_with_hashes( + &mut self, + values: &ArrayRef, + hashes: &[u64], + make_payload_fn: MP, + observe_payload_fn: OP, + ) where + MP: FnMut(Option<&[u8]>) -> V, + OP: FnMut(V), + { + assert_eq!(values.len(), hashes.len()); + match self.output_type { + OutputType::BinaryView => { + assert!(matches!(values.data_type(), DataType::BinaryView)); + self.insert_if_new_inner::( + values, + hashes, make_payload_fn, observe_payload_fn, ) @@ -226,6 +274,7 @@ where assert!(matches!(values.data_type(), DataType::Utf8View)); self.insert_if_new_inner::( values, + hashes, make_payload_fn, observe_payload_fn, ) @@ -245,6 +294,7 @@ where fn insert_if_new_inner( &mut self, values: &ArrayRef, + hashes: &[u64], mut make_payload_fn: MP, mut observe_payload_fn: OP, ) where @@ -252,27 +302,15 @@ where OP: FnMut(V), B: ByteViewType, { - // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes([values], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); - - // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); + assert_eq!(values.len(), hashes.len()); // Get raw views buffer for direct comparison let input_views = values.views(); - // Ensure lengths are equivalent - assert_eq!(values.len(), self.hashes_buffer.len()); - for i in 0..values.len() { let view_u128 = input_views[i]; - let hash = self.hashes_buffer[i]; + let hash = hashes[i]; // handle null value via validity bitmap check if values.is_null(i) { diff --git a/datafusion/physical-plan/benches/dictionary_group_values.rs b/datafusion/physical-plan/benches/dictionary_group_values.rs index ded52aebd1100..f853105e2aa71 100644 --- a/datafusion/physical-plan/benches/dictionary_group_values.rs +++ b/datafusion/physical-plan/benches/dictionary_group_values.rs @@ -109,10 +109,17 @@ fn bench_intern_emit(c: &mut Criterion) { new_group_values(schema.clone(), &GroupOrdering::None) .unwrap(), Vec::::with_capacity(size), + Vec::::with_capacity(size), ) }, - |(gv, groups)| { - gv.intern(std::slice::from_ref(&array), groups).unwrap(); + |(gv, groups, hashes)| { + gv.intern( + std::slice::from_ref(&array), + groups, + hashes, + &mut vec![], + ) + .unwrap(); black_box(&*groups); black_box(gv.emit(EmitTo::All).unwrap()); }, @@ -154,11 +161,18 @@ fn bench_repeated_intern_emit(c: &mut Criterion) { new_group_values(schema.clone(), &GroupOrdering::None) .unwrap(), Vec::::with_capacity(size), + Vec::::with_capacity(size), ) }, - |(gv, groups)| { + |(gv, groups, hashes)| { for arr in &batches { - gv.intern(std::slice::from_ref(arr), groups).unwrap(); + gv.intern( + std::slice::from_ref(arr), + groups, + hashes, + &mut vec![], + ) + .unwrap(); black_box(&*groups); } black_box(gv.emit(EmitTo::All).unwrap()); diff --git a/datafusion/physical-plan/benches/multi_group_by.rs b/datafusion/physical-plan/benches/multi_group_by.rs index 92d0448775599..71769d68b83ab 100644 --- a/datafusion/physical-plan/benches/multi_group_by.rs +++ b/datafusion/physical-plan/benches/multi_group_by.rs @@ -95,9 +95,10 @@ fn bench_intern( batches: &[Vec], groups: &mut Vec, ) { + let mut hashes = vec![]; for batch in batches { groups.clear(); - gv.intern(batch, groups).unwrap(); + gv.intern(batch, groups, &mut hashes, &mut vec![]).unwrap(); } black_box(&*groups); } diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 719fbe93e5416..3bda4623adf3b 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -134,6 +134,8 @@ impl AggregateHashTable { group_by: Arc::clone(&agg.group_by), group_values, batch_group_indices: Default::default(), + batch_hashes: Default::default(), + new_group_rows: Default::default(), accumulators, }), _mode: PhantomData, @@ -181,6 +183,8 @@ impl AggregateHashTable { acc + state.group_values.size() + state.batch_group_indices.allocated_size() + + state.batch_hashes.allocated_size() + + state.new_group_rows.allocated_size() } AggregateHashTableState::OutputtingMaterializedFinal(output) => { output.memory_size() @@ -292,6 +296,12 @@ pub(super) struct AggregateHashTableBuffer { /// accumulator to update that group's aggregate state. pub(super) batch_group_indices: Vec, + /// Hash for each row in the current input batch. + pub(super) batch_hashes: Vec, + + /// Input rows that created new groups in the current input batch. + pub(super) new_group_rows: Vec, + /// One item per aggregate expression. /// /// Example: `COUNT(x), SUM(y)` creates two items. Each item owns the input diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs index c3e4f831c4bbf..ccba78d15ba16 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs @@ -122,9 +122,12 @@ impl AggregateHashTable { let timer = self.group_by_metrics.aggregation_time.timer(); for group_values in &evaluated_batch.grouping_set_args { - state - .group_values - .intern(group_values, &mut state.batch_group_indices)?; + state.group_values.intern( + group_values, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; let group_indices = &state.batch_group_indices; let total_num_groups = state.group_values.len(); diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 9d226aa28b35f..ed18b8ba8ec2a 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -131,6 +131,8 @@ impl AggregateHashTable { group_by: Arc::clone(&state.group_by), group_values, batch_group_indices: Default::default(), + batch_hashes: Default::default(), + new_group_rows: Default::default(), accumulators, }), _mode: PhantomData, @@ -146,9 +148,12 @@ impl AggregateHashTable { let _timer = self.group_by_metrics.aggregation_time.timer(); for group_values in &evaluated_batch.grouping_set_args { - state - .group_values - .intern(group_values, &mut state.batch_group_indices)?; + state.group_values.intern( + group_values, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; let group_indices = &state.batch_group_indices; let total_num_groups = state.group_values.len(); @@ -216,9 +221,12 @@ impl AggregateHashTable { .collect(); cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); - state - .group_values - .intern(&cols, &mut state.batch_group_indices)?; + state.group_values.intern( + &cols, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; any_interned = true; } diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ee253e5d7afdd..a49ad87676505 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -91,13 +91,21 @@ pub trait GroupValues: Send { /// Calculates the group id for each input row of `cols`, assigning new /// group ids as necessary. /// - /// When the function returns, `groups` must contain the group id for each - /// row in `cols`. + /// When the function returns, `groups` must contain the group id for each + /// row in `cols`, and `hashes` must contain the hash for each row in `cols`. + /// `new_group_rows` is filled with the input row index that first created + /// each new group in this call. /// /// If a row has the same value as a previous row, the same group id is /// assigned. If a row has a new value, the next available group id is /// assigned. - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()>; /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index f275d777c3279..abfe70b53a7c6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -344,23 +344,20 @@ impl GroupValuesColumn { /// /// `Group indices` order are against with their input order, and this will lead to error /// in `streaming aggregation`. - fn scalarized_intern( + fn scalarized_intern_impl( &mut self, cols: &[ArrayRef], + hashes: &[u64], groups: &mut Vec, + new_group_rows: &mut Vec, ) -> Result<()> { - let n_rows = cols[0].len(); + debug_assert_eq!(hashes.len(), cols.first().map_or(0, |array| array.len())); // tracks to which group each of the input rows belongs groups.clear(); + new_group_rows.clear(); - // 1.1 Calculate the group keys for the group values - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, batch_hashes)?; - - for (row, &target_hash) in batch_hashes.iter().enumerate() { + for (row, &target_hash) in hashes.iter().enumerate() { let entry = self .map .find_mut(target_hash, |(exist_hash, group_idx_view)| { @@ -425,6 +422,7 @@ impl GroupValuesColumn { |(hash, _group_index)| *hash, &mut self.map_size, ); + new_group_rows.push(row); group_idx } }; @@ -445,21 +443,20 @@ impl GroupValuesColumn { /// /// The vectorized approach can offer higher performance for avoiding row by row /// downcast for `cols` and being able to implement even more optimizations(like simd). - fn vectorized_intern( + fn vectorized_intern_impl( &mut self, cols: &[ArrayRef], + hashes: &[u64], groups: &mut Vec, + new_group_rows: &mut Vec, ) -> Result<()> { let n_rows = cols[0].len(); + debug_assert_eq!(hashes.len(), n_rows); // tracks to which group each of the input rows belongs groups.clear(); groups.resize(n_rows, usize::MAX); - - let mut batch_hashes = mem::take(&mut self.hashes_buffer); - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, &mut batch_hashes)?; + new_group_rows.clear(); // General steps for one round `vectorized equal_to & append`: // 1. Collect vectorized context by checking hash values of `cols` in `map`, @@ -482,7 +479,7 @@ impl GroupValuesColumn { // // 1. Collect vectorized context by checking hash values of `cols` in `map` - self.collect_vectorized_process_context(&batch_hashes, groups); + self.collect_vectorized_process_context(hashes, groups, new_group_rows); // 2. Perform `vectorized_append` self.vectorized_append(cols)?; @@ -492,9 +489,7 @@ impl GroupValuesColumn { // 4. Perform scalarized inter for remaining rows // (about remaining rows, can see comments for `remaining_row_indices`) - self.scalarized_intern_remaining(cols, &batch_hashes, groups)?; - - self.hashes_buffer = batch_hashes; + self.scalarized_intern_remaining(cols, hashes, groups, new_group_rows)?; Ok(()) } @@ -514,8 +509,9 @@ impl GroupValuesColumn { /// Otherwise get all group indices from `group_index_lists`, and add them. fn collect_vectorized_process_context( &mut self, - batch_hashes: &[u64], + hashes: &[u64], groups: &mut [usize], + new_group_rows: &mut Vec, ) { self.vectorized_operation_buffers.append_row_indices.clear(); self.vectorized_operation_buffers @@ -525,7 +521,7 @@ impl GroupValuesColumn { .equal_to_group_indices .clear(); - for (row, &target_hash) in batch_hashes.iter().enumerate() { + for (row, &target_hash) in hashes.iter().enumerate() { let entry = self .map .find(target_hash, |(exist_hash, _)| target_hash == *exist_hash); @@ -553,6 +549,7 @@ impl GroupValuesColumn { // Set group index to row in `groups` groups[row] = current_group_idx; + new_group_rows.push(row); continue; }; @@ -741,8 +738,9 @@ impl GroupValuesColumn { fn scalarized_intern_remaining( &mut self, cols: &[ArrayRef], - batch_hashes: &[u64], + hashes: &[u64], groups: &mut [usize], + new_group_rows: &mut Vec, ) -> Result<()> { if self .vectorized_operation_buffers @@ -755,7 +753,7 @@ impl GroupValuesColumn { let mut map = mem::take(&mut self.map); for &row in &self.vectorized_operation_buffers.remaining_row_indices { - let target_hash = batch_hashes[row]; + let target_hash = hashes[row]; let entry = map.find_mut(target_hash, |(exist_hash, _)| { // Somewhat surprisingly, this closure can be called even if the // hash doesn't match, so check the hash first with an integer @@ -816,6 +814,7 @@ impl GroupValuesColumn { } groups[row] = group_idx; + new_group_rows.push(row); } self.map = map; @@ -1078,14 +1077,21 @@ fn make_group_column(field: &Field) -> Result> { } impl GroupValues for GroupValuesColumn { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - // `try_new` and the reset points in `emit` / `clear_shrink` keep - // `self.group_values` populated with one builder per schema field, - // so no lazy initialization is needed here. + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { + let n_rows = cols[0].len(); + hashes.clear(); + hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, hashes)?; if !STREAMING { - self.vectorized_intern(cols, groups) + self.vectorized_intern_impl(cols, hashes, groups, new_group_rows) } else { - self.scalarized_intern(cols, groups) + self.scalarized_intern_impl(cols, hashes, groups, new_group_rows) } } @@ -1878,7 +1884,11 @@ mod tests { fn load_to_group_values(&self, group_values: &mut impl GroupValues) { for batch in self.test_batches.iter() { - group_values.intern(batch, &mut vec![]).unwrap(); + let mut groups = vec![]; + let mut hashes = vec![]; + group_values + .intern(batch, &mut groups, &mut hashes, &mut vec![]) + .unwrap(); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 4976a098ecee5..5344da39cf8bb 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -113,65 +113,48 @@ impl GroupValuesRows { random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } -} -impl GroupValues for GroupValuesRows { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - // Normalize -0.0 → +0.0 so RowConverter (IEEE 754 totalOrder) and - // primitive hashing both group ±0 together. No-op for non-float - // columns. - let normalized_cols: Vec = - cols.iter().map(normalize_float_zero).collect(); - let cols = normalized_cols.as_slice(); + fn intern_impl( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { + hashes.clear(); + hashes.resize(cols.first().map_or(0, |array| array.len()), 0); + create_hashes(cols, &self.random_state, hashes)?; - // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; group_rows.clear(); self.row_converter.append(group_rows, cols)?; - let n_rows = group_rows.num_rows(); + debug_assert_eq!(hashes.len(), group_rows.num_rows()); let mut group_values = match self.group_values.take() { Some(group_values) => group_values, None => self.row_converter.empty_rows(0, 0), }; - // tracks to which group each of the input rows belongs groups.clear(); + new_group_rows.clear(); - // 1.1 Calculate the group keys for the group values - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, batch_hashes)?; - - for (row, &target_hash) in batch_hashes.iter().enumerate() { + for (row, &target_hash) in hashes.iter().enumerate() { let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| { - // Somewhat surprisingly, this closure can be called even if the - // hash doesn't match, so check the hash first with an integer - // comparison first avoid the more expensive comparison with - // group value. https://github.com/apache/datafusion/pull/11718 target_hash == *exist_hash - // verify that the group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) && group_rows.row(row) == group_values.row(*group_idx) }); let group_idx = match entry { - // Existing group_index for this group value Some((_hash, group_idx)) => *group_idx, - // 1.2 Need to create new entry for the group None => { - // Add new entry to aggr_state and save newly created index let group_idx = group_values.num_rows(); group_values.push(group_rows.row(row)); - - // for hasher function, use precomputed hash value self.map.insert_accounted( (target_hash, group_idx), |(hash, _group_index)| *hash, &mut self.map_size, ); + new_group_rows.push(row); group_idx } }; @@ -182,6 +165,21 @@ impl GroupValues for GroupValuesRows { Ok(()) } +} + +impl GroupValues for GroupValuesRows { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { + let normalized_cols: Vec = + cols.iter().map(normalize_float_zero).collect(); + let cols = normalized_cols.as_slice(); + self.intern_impl(cols, groups, hashes, new_group_rows) + } fn size(&self) -> usize { let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index e993c0c53d199..fed93cc2f67de 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -16,6 +16,7 @@ // under the License. use crate::aggregates::group_values::GroupValues; +use datafusion_common::hash_utils::create_hashes; use arrow::array::{ ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, NullBufferBuilder, @@ -42,11 +43,22 @@ impl GroupValuesBoolean { } impl GroupValues for GroupValuesBoolean { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { + hashes.clear(); + hashes.resize(cols.first().map_or(0, |array| array.len()), 0); + create_hashes(cols, &crate::aggregates::AGGREGATION_HASH_SEED, hashes)?; + let array = cols[0].as_boolean(); groups.clear(); + new_group_rows.clear(); - for value in array.iter() { + for (row, value) in array.iter().enumerate() { let index = match value { Some(false) => { if let Some(index) = self.false_group { @@ -54,6 +66,7 @@ impl GroupValues for GroupValuesBoolean { } else { let index = self.len(); self.false_group = Some(index); + new_group_rows.push(row); index } } @@ -63,6 +76,7 @@ impl GroupValues for GroupValuesBoolean { } else { let index = self.len(); self.true_group = Some(index); + new_group_rows.push(row); index } } @@ -72,6 +86,7 @@ impl GroupValues for GroupValuesBoolean { } else { let index = self.len(); self.null_group = Some(index); + new_group_rows.push(row); index } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index b881a51b25474..5b0d88512dfcd 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -45,29 +45,44 @@ impl GroupValuesBytes { } impl GroupValues for GroupValuesBytes { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { assert_eq!(cols.len(), 1); - - // look up / add entries in the table let arr = &cols[0]; + hashes.clear(); + hashes.resize(arr.len(), 0); + datafusion_common::hash_utils::create_hashes( + cols, + &crate::aggregates::AGGREGATION_HASH_SEED, + hashes, + )?; + groups.clear(); - self.map.insert_if_new( + new_group_rows.clear(); + let mut next_new_group = self.num_groups; + self.map.insert_if_new_with_hashes( arr, - // called for each new group + hashes, |_value| { - // assign new group index on each insert let group_idx = self.num_groups; self.num_groups += 1; group_idx }, - // called for each group |group_idx| { + if group_idx == next_new_group { + new_group_rows.push(groups.len()); + next_new_group += 1; + } groups.push(group_idx); }, ); - // ensure we assigned a group to for each row assert_eq!(groups.len(), arr.len()); Ok(()) } @@ -108,7 +123,13 @@ impl GroupValues for GroupValuesBytes { self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + let mut hashes = vec![]; + self.intern( + &[remaining_group_values], + &mut group_indexes, + &mut hashes, + &mut vec![], + )?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 7a56f7c52c11a..8d4df1d347f87 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -47,29 +47,40 @@ impl GroupValues for GroupValuesBytesView { &mut self, cols: &[ArrayRef], groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); - - // look up / add entries in the table let arr = &cols[0]; + hashes.clear(); + hashes.resize(arr.len(), 0); + datafusion_common::hash_utils::create_hashes( + cols, + &crate::aggregates::AGGREGATION_HASH_SEED, + hashes, + )?; + groups.clear(); - self.map.insert_if_new( + new_group_rows.clear(); + let mut next_new_group = self.num_groups; + self.map.insert_if_new_with_hashes( arr, - // called for each new group + hashes, |_value| { - // assign new group index on each insert let group_idx = self.num_groups; self.num_groups += 1; group_idx }, - // called for each group |group_idx| { + if group_idx == next_new_group { + new_group_rows.push(groups.len()); + next_new_group += 1; + } groups.push(group_idx); }, ); - // ensure we assigned a group to for each row assert_eq!(groups.len(), arr.len()); Ok(()) } @@ -110,7 +121,13 @@ impl GroupValues for GroupValuesBytesView { self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + let mut hashes = vec![]; + self.intern( + &[remaining_group_values], + &mut group_indexes, + &mut hashes, + &mut vec![], + )?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index e254aebcfd7ce..5b2c591815ea3 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -116,6 +116,8 @@ pub struct GroupValuesPrimitive { values: Vec, /// The random state used to generate hashes random_state: RandomState, + /// Reused buffer to store hashes + hashes_buffer: Vec, } impl GroupValuesPrimitive { @@ -127,47 +129,66 @@ impl GroupValuesPrimitive { values: Vec::with_capacity(128), null_group: None, random_state: crate::aggregates::AGGREGATION_HASH_SEED, + hashes_buffer: Default::default(), } } -} -impl GroupValues for GroupValuesPrimitive -where - T::Native: HashValue, -{ - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern_impl( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> + where + T::Native: HashValue, + { assert_eq!(cols.len(), 1); + let array = cols[0].as_primitive::(); + hashes.clear(); + hashes.resize(array.len(), 0); + for (row, value) in array.iter().enumerate() { + hashes[row] = value + .map(|key| key.canonicalize().hash(&self.random_state)) + .unwrap_or(0); + } + groups.clear(); + new_group_rows.clear(); - for v in cols[0].as_primitive::() { - let group_id = match v { - None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); - group_id - }), + for (row, value) in array.iter().enumerate() { + let group_id = match value { + None => { + if let Some(group_id) = self.null_group { + group_id + } else { + let group_id = self.values.len(); + self.null_group = Some(group_id); + self.values.push(Default::default()); + new_group_rows.push(row); + group_id + } + } Some(key) => { - // Fold equivalence-class duplicates (e.g. `-0.0` → `+0.0`) - // so the bit-equal `is_eq` matches and the stored value is - // the canonical representative. let key = key.canonicalize(); - let state = &self.random_state; - let hash = key.hash(state); + let hash = hashes[row]; let insert = self.map.entry( hash, - |&(g, h)| unsafe { - hash == h && self.values.get_unchecked(g).is_eq(key) + |&(group_id, exist_hash)| unsafe { + hash == exist_hash + && self.values.get_unchecked(group_id).is_eq(key) }, - |&(_, h)| h, + |&(_, exist_hash)| exist_hash, ); match insert { - hashbrown::hash_table::Entry::Occupied(o) => o.get().0, - hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); - v.insert((g, hash)); + hashbrown::hash_table::Entry::Occupied(entry) => entry.get().0, + hashbrown::hash_table::Entry::Vacant(entry) => { + let group_id = self.values.len(); + entry.insert((group_id, hash)); self.values.push(key); - g + new_group_rows.push(row); + group_id } } } @@ -176,9 +197,26 @@ where } Ok(()) } +} + +impl GroupValues for GroupValuesPrimitive +where + T::Native: HashValue, +{ + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { + self.intern_impl(cols, groups, hashes, new_group_rows) + } fn size(&self) -> usize { - self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() + self.map.capacity() * size_of::<(usize, u64)>() + + self.values.allocated_size() + + self.hashes_buffer.allocated_size() } fn is_empty(&self) -> bool { @@ -244,6 +282,8 @@ where self.values.shrink_to(num_rows); self.map.clear(); self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared + self.hashes_buffer.clear(); + self.hashes_buffer.shrink_to(num_rows); } } @@ -273,7 +313,8 @@ mod tests { // Intern 20 distinct values; `new()` pre-allocates capacity 128 for `values`. let arr: ArrayRef = Arc::new(Int32Array::from_iter_values(0..20i32)); let mut groups = vec![]; - gv.intern(&[arr], &mut groups)?; + let mut hashes = vec![]; + gv.intern(&[arr], &mut groups, &mut hashes, &mut vec![])?; let capacity_before = gv.values.capacity(); // 128 // n=4, n*2=8 <= len=20 -> drain branch diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index a4d19b0f7d18a..90f73555c2153 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -334,6 +334,12 @@ pub(crate) struct GroupedHashAggregateStream { /// processed. Reused across batches here to avoid reallocations current_group_indices: Vec, + /// Hash for each row in the current input batch. + current_hashes: Vec, + + /// Input rows that created new groups in the current input batch. + new_group_rows: Vec, + /// Accumulators, one for each `AggregateFunctionExpr` in the query /// /// For example, if the query has aggregates, `SUM(x)`, @@ -601,6 +607,8 @@ impl GroupedHashAggregateStream { oom_mode, group_values, current_group_indices: Default::default(), + current_hashes: Default::default(), + new_group_rows: Default::default(), exec_state, baseline_metrics, group_by_metrics, @@ -883,8 +891,12 @@ impl GroupedHashAggregateStream { // calculate the group indices for each input row let starting_num_groups = self.group_values.len(); - self.group_values - .intern(group_values, &mut self.current_group_indices)?; + self.group_values.intern( + group_values, + &mut self.current_group_indices, + &mut self.current_hashes, + &mut self.new_group_rows, + )?; let group_indices = &self.current_group_indices; // Update ordering information if necessary @@ -991,7 +1003,9 @@ impl GroupedHashAggregateStream { let groups_and_acc_size = acc + self.group_values.size() + self.group_ordering.size() - + self.current_group_indices.allocated_size(); + + self.current_group_indices.allocated_size() + + self.current_hashes.allocated_size() + + self.new_group_rows.allocated_size(); // Reserve extra headroom for sorting during potential spill. // When OOM triggers, group_aggregate_batch has already processed the @@ -1103,8 +1117,12 @@ impl GroupedHashAggregateStream { cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); let starting_groups = self.group_values.len(); - self.group_values - .intern(&cols, &mut self.current_group_indices)?; + self.group_values.intern( + &cols, + &mut self.current_group_indices, + &mut self.current_hashes, + &mut self.new_group_rows, + )?; let total_groups = self.group_values.len(); if total_groups > starting_groups { self.group_ordering.new_groups( @@ -1238,6 +1256,10 @@ impl GroupedHashAggregateStream { self.group_values.clear_shrink(num_rows); self.current_group_indices.clear(); self.current_group_indices.shrink_to(num_rows); + self.current_hashes.clear(); + self.current_hashes.shrink_to(num_rows); + self.new_group_rows.clear(); + self.new_group_rows.shrink_to(num_rows); } /// Clear memory and shrink capacities to zero. diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 7289ac43e510c..1413f58ed1e05 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -436,6 +436,7 @@ struct DistinctDeduplicator { group_values: Box, reservation: MemoryReservation, intern_output_buffer: Vec, + hash_buffer: Vec, } impl DistinctDeduplicator { @@ -447,6 +448,7 @@ impl DistinctDeduplicator { group_values, reservation, intern_output_buffer: Vec::new(), + hash_buffer: Vec::new(), }) } @@ -466,8 +468,12 @@ impl DistinctDeduplicator { "failed to reserve {additional} recursive query group ids: {e}" ) })?; - self.group_values - .intern(batch.columns(), &mut self.intern_output_buffer)?; + self.group_values.intern( + batch.columns(), + &mut self.intern_output_buffer, + &mut self.hash_buffer, + &mut vec![], + )?; let mask = new_groups_mask(&self.intern_output_buffer, size_before); self.intern_output_buffer.clear(); // We update the reservation to reflect the new size of the hash table. From 0e09f71a4d18de90ab7b82b39be997df199e5f43 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 16:23:43 +0800 Subject: [PATCH 02/17] Optimize skip partial aggregation group values --- .../aggregates/aggregate_hash_table/common.rs | 9 --- .../aggregate_hash_table/partial_table.rs | 34 ++++++++--- .../src/aggregates/group_values/mod.rs | 12 +++- .../group_values/multi_group_by/mod.rs | 57 ++++++++++++++++++- 4 files changed, 92 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 3bda4623adf3b..a1c31f39c5e53 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -457,15 +457,6 @@ impl HashAggregateAccumulator { self.accumulator.supports_convert_to_state() } - pub(super) fn convert_to_state( - &mut self, - values: &EvaluatedAccumulatorArgs, - ) -> Result> { - let opt_filter = values.filter.as_ref().map(|filter| filter.as_boolean()); - self.accumulator - .convert_to_state(&values.arguments, opt_filter) - } - pub(super) fn null_arguments( &self, input_schema: &SchemaRef, diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index ed18b8ba8ec2a..85ca6a5e37276 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -23,6 +23,7 @@ use arrow::array::{ArrayRef, BooleanArray, new_null_array}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; +use datafusion_expr::EmitTo; use crate::aggregates::group_values::new_group_values; use crate::aggregates::order::GroupOrdering; @@ -100,11 +101,12 @@ impl AggregateHashTable { } pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool { - self.state - .building() - .accumulators - .iter() - .all(|acc| acc.supports_convert_to_state()) + let state = self.state.building(); + state.group_values.support_partial_repartition() + && state + .accumulators + .iter() + .all(|acc| acc.supports_convert_to_state()) } /// In skip-partial-aggregation optimization, when a decision has been made to skip @@ -115,7 +117,8 @@ impl AggregateHashTable { ) -> Result> { let state = self.state.building(); let group_schema = state.group_by.group_schema(&self.input_schema)?; - let group_values = new_group_values(group_schema, &GroupOrdering::None)?; + let mut group_values = new_group_values(group_schema, &GroupOrdering::None)?; + group_values.skip_hash_group_by()?; let accumulators = state .accumulators .iter() @@ -259,19 +262,32 @@ impl AggregateHashTable { 1, "group_values expected to have single element" ); - let mut output = evaluated_batch + let state = self.state.building_mut(); + let group_values = evaluated_batch .grouping_set_args .into_iter() .next() .unwrap_or_default(); + state.group_values.intern( + &group_values, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; - let state = self.state.building_mut(); + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); for (acc, values) in state .accumulators .iter_mut() .zip(evaluated_batch.accumulator_args.iter()) { - output.extend(acc.convert_to_state(values)?); + acc.update_batch(values, group_indices, total_num_groups)?; + } + + let mut output = state.group_values.emit(EmitTo::All)?; + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(EmitTo::All)?); } Ok(RecordBatch::try_new( diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index a49ad87676505..73c2c1321cf65 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -24,7 +24,7 @@ use arrow::array::types::{ }; use arrow::array::{ArrayRef, downcast_primitive}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; -use datafusion_common::Result; +use datafusion_common::{Result, internal_err}; use datafusion_expr::EmitTo; @@ -119,6 +119,16 @@ pub trait GroupValues: Send { /// Emits the group values fn emit(&mut self, emit_to: EmitTo) -> Result>; + /// Returns true if this group value storage supports partial repartition. + fn support_partial_repartition(&self) -> bool { + false + } + + /// Enable append-only grouping for partial repartition. + fn skip_hash_group_by(&mut self) -> Result<()> { + internal_err!("GroupValues does not support skip hash group by") + } + /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, num_rows: usize); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index abfe70b53a7c6..12c69dbe87890 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -41,7 +41,7 @@ use arrow::datatypes::{ }; use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{Result, internal_datafusion_err, not_impl_err}; +use datafusion_common::{Result, internal_datafusion_err, internal_err, not_impl_err}; use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; @@ -220,6 +220,9 @@ pub struct GroupValuesColumn { /// Random state for creating hashes random_state: RandomState, + + /// Whether each input row should be appended as a new group directly. + skip_hash_group_by: bool, } /// Buffers to store intermediate results in `vectorized_append` @@ -283,6 +286,7 @@ impl GroupValuesColumn { group_values, hashes_buffer: Default::default(), random_state: crate::aggregates::AGGREGATION_HASH_SEED, + skip_hash_group_by: false, }) } @@ -301,6 +305,40 @@ impl GroupValuesColumn { Ok(v) } + fn append_all_group_values( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { + let num_rows = cols.first().map_or(0, |array| array.len()); + let first_group_idx = self.len(); + + groups.clear(); + groups.extend(first_group_idx..first_group_idx + num_rows); + + hashes.clear(); + hashes.resize(num_rows, 0); + create_hashes(cols, &self.random_state, hashes)?; + + new_group_rows.clear(); + new_group_rows.extend(0..num_rows); + + self.vectorized_operation_buffers.append_row_indices.clear(); + self.vectorized_operation_buffers + .append_row_indices + .extend(0..num_rows); + for (group_value, col) in self.group_values.iter_mut().zip(cols.iter()) { + group_value.vectorized_append( + col, + &self.vectorized_operation_buffers.append_row_indices, + )?; + } + + Ok(()) + } + // ======================================================================== // Scalarized intern // ======================================================================== @@ -1084,6 +1122,10 @@ impl GroupValues for GroupValuesColumn { hashes: &mut Vec, new_group_rows: &mut Vec, ) -> Result<()> { + if self.skip_hash_group_by { + return self.append_all_group_values(cols, groups, hashes, new_group_rows); + } + let n_rows = cols[0].len(); hashes.clear(); hashes.resize(n_rows, 0); @@ -1113,6 +1155,10 @@ impl GroupValues for GroupValuesColumn { } fn emit(&mut self, emit_to: EmitTo) -> Result> { + if self.skip_hash_group_by && matches!(emit_to, EmitTo::First(_)) { + return internal_err!("skip hash group by does not support EmitTo::First"); + } + let mut output = match emit_to { EmitTo::All => { // Replace the column builders with a fresh set so the @@ -1240,6 +1286,15 @@ impl GroupValues for GroupValuesColumn { self.vectorized_operation_buffers.clear(); } } + + fn support_partial_repartition(&self) -> bool { + !self.group_values.is_empty() + } + + fn skip_hash_group_by(&mut self) -> Result<()> { + self.skip_hash_group_by = true; + Ok(()) + } } /// Returns true if [`GroupValuesColumn`] supported for the specified schema From c6f89333d0df664d495cc05a3cc3644b57c3be9d Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 17:17:07 +0800 Subject: [PATCH 03/17] Reset feedback in non-column group values --- .../src/aggregates/group_values/row.rs | 29 ++++------ .../group_values/single_group_by/boolean.rs | 15 +---- .../group_values/single_group_by/bytes.rs | 21 +------ .../single_group_by/bytes_view.rs | 21 +------ .../group_values/single_group_by/primitive.rs | 55 +++++-------------- 5 files changed, 33 insertions(+), 108 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 5344da39cf8bb..ce8fc4472be20 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -114,21 +114,11 @@ impl GroupValuesRows { }) } - fn intern_impl( - &mut self, - cols: &[ArrayRef], - groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, - ) -> Result<()> { - hashes.clear(); - hashes.resize(cols.first().map_or(0, |array| array.len()), 0); - create_hashes(cols, &self.random_state, hashes)?; - + fn intern_impl(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { let group_rows = &mut self.rows_buffer; group_rows.clear(); self.row_converter.append(group_rows, cols)?; - debug_assert_eq!(hashes.len(), group_rows.num_rows()); + let n_rows = group_rows.num_rows(); let mut group_values = match self.group_values.take() { Some(group_values) => group_values, @@ -136,9 +126,13 @@ impl GroupValuesRows { }; groups.clear(); - new_group_rows.clear(); - for (row, &target_hash) in hashes.iter().enumerate() { + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, batch_hashes)?; + + for (row, &target_hash) in batch_hashes.iter().enumerate() { let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| { target_hash == *exist_hash && group_rows.row(row) == group_values.row(*group_idx) @@ -154,7 +148,6 @@ impl GroupValuesRows { |(hash, _group_index)| *hash, &mut self.map_size, ); - new_group_rows.push(row); group_idx } }; @@ -172,13 +165,13 @@ impl GroupValues for GroupValuesRows { &mut self, cols: &[ArrayRef], groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, ) -> Result<()> { let normalized_cols: Vec = cols.iter().map(normalize_float_zero).collect(); let cols = normalized_cols.as_slice(); - self.intern_impl(cols, groups, hashes, new_group_rows) + self.intern_impl(cols, groups) } fn size(&self) -> usize { diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index fed93cc2f67de..edcb5013d240c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -16,7 +16,6 @@ // under the License. use crate::aggregates::group_values::GroupValues; -use datafusion_common::hash_utils::create_hashes; use arrow::array::{ ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, NullBufferBuilder, @@ -47,18 +46,13 @@ impl GroupValues for GroupValuesBoolean { &mut self, cols: &[ArrayRef], groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, ) -> Result<()> { - hashes.clear(); - hashes.resize(cols.first().map_or(0, |array| array.len()), 0); - create_hashes(cols, &crate::aggregates::AGGREGATION_HASH_SEED, hashes)?; - let array = cols[0].as_boolean(); groups.clear(); - new_group_rows.clear(); - for (row, value) in array.iter().enumerate() { + for value in array.iter() { let index = match value { Some(false) => { if let Some(index) = self.false_group { @@ -66,7 +60,6 @@ impl GroupValues for GroupValuesBoolean { } else { let index = self.len(); self.false_group = Some(index); - new_group_rows.push(row); index } } @@ -76,7 +69,6 @@ impl GroupValues for GroupValuesBoolean { } else { let index = self.len(); self.true_group = Some(index); - new_group_rows.push(row); index } } @@ -86,7 +78,6 @@ impl GroupValues for GroupValuesBoolean { } else { let index = self.len(); self.null_group = Some(index); - new_group_rows.push(row); index } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index 5b0d88512dfcd..73664970fdd9c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -49,36 +49,21 @@ impl GroupValues for GroupValuesBytes { &mut self, cols: &[ArrayRef], groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, ) -> Result<()> { assert_eq!(cols.len(), 1); let arr = &cols[0]; - hashes.clear(); - hashes.resize(arr.len(), 0); - datafusion_common::hash_utils::create_hashes( - cols, - &crate::aggregates::AGGREGATION_HASH_SEED, - hashes, - )?; - groups.clear(); - new_group_rows.clear(); - let mut next_new_group = self.num_groups; - self.map.insert_if_new_with_hashes( + self.map.insert_if_new( arr, - hashes, |_value| { let group_idx = self.num_groups; self.num_groups += 1; group_idx }, |group_idx| { - if group_idx == next_new_group { - new_group_rows.push(groups.len()); - next_new_group += 1; - } groups.push(group_idx); }, ); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 8d4df1d347f87..ab66dcd4a0938 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -47,36 +47,21 @@ impl GroupValues for GroupValuesBytesView { &mut self, cols: &[ArrayRef], groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); let arr = &cols[0]; - hashes.clear(); - hashes.resize(arr.len(), 0); - datafusion_common::hash_utils::create_hashes( - cols, - &crate::aggregates::AGGREGATION_HASH_SEED, - hashes, - )?; - groups.clear(); - new_group_rows.clear(); - let mut next_new_group = self.num_groups; - self.map.insert_if_new_with_hashes( + self.map.insert_if_new( arr, - hashes, |_value| { let group_idx = self.num_groups; self.num_groups += 1; group_idx }, |group_idx| { - if group_idx == next_new_group { - new_group_rows.push(groups.len()); - next_new_group += 1; - } groups.push(group_idx); }, ); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 5b2c591815ea3..e907d9f03155d 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -116,8 +116,6 @@ pub struct GroupValuesPrimitive { values: Vec, /// The random state used to generate hashes random_state: RandomState, - /// Reused buffer to store hashes - hashes_buffer: Vec, } impl GroupValuesPrimitive { @@ -129,49 +127,27 @@ impl GroupValuesPrimitive { values: Vec::with_capacity(128), null_group: None, random_state: crate::aggregates::AGGREGATION_HASH_SEED, - hashes_buffer: Default::default(), } } - fn intern_impl( - &mut self, - cols: &[ArrayRef], - groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, - ) -> Result<()> + fn intern_impl(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> where T::Native: HashValue, { assert_eq!(cols.len(), 1); - let array = cols[0].as_primitive::(); - hashes.clear(); - hashes.resize(array.len(), 0); - for (row, value) in array.iter().enumerate() { - hashes[row] = value - .map(|key| key.canonicalize().hash(&self.random_state)) - .unwrap_or(0); - } - groups.clear(); - new_group_rows.clear(); - for (row, value) in array.iter().enumerate() { + for value in cols[0].as_primitive::() { let group_id = match value { - None => { - if let Some(group_id) = self.null_group { - group_id - } else { - let group_id = self.values.len(); - self.null_group = Some(group_id); - self.values.push(Default::default()); - new_group_rows.push(row); - group_id - } - } + None => *self.null_group.get_or_insert_with(|| { + let group_id = self.values.len(); + self.values.push(Default::default()); + group_id + }), Some(key) => { let key = key.canonicalize(); - let hash = hashes[row]; + let state = &self.random_state; + let hash = key.hash(state); let insert = self.map.entry( hash, |&(group_id, exist_hash)| unsafe { @@ -187,7 +163,6 @@ impl GroupValuesPrimitive { let group_id = self.values.len(); entry.insert((group_id, hash)); self.values.push(key); - new_group_rows.push(row); group_id } } @@ -207,16 +182,14 @@ where &mut self, cols: &[ArrayRef], groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, ) -> Result<()> { - self.intern_impl(cols, groups, hashes, new_group_rows) + self.intern_impl(cols, groups) } fn size(&self) -> usize { - self.map.capacity() * size_of::<(usize, u64)>() - + self.values.allocated_size() - + self.hashes_buffer.allocated_size() + self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() } fn is_empty(&self) -> bool { @@ -282,8 +255,6 @@ where self.values.shrink_to(num_rows); self.map.clear(); self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared - self.hashes_buffer.clear(); - self.hashes_buffer.shrink_to(num_rows); } } From 25229c2ade9f2535eed43c6de813188d249430e0 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 17:26:51 +0800 Subject: [PATCH 04/17] Revert unused binary map hash feedback --- .../physical-expr-common/src/binary_map.rs | 72 ++++--------------- .../src/binary_view_map.rs | 66 ++++------------- 2 files changed, 28 insertions(+), 110 deletions(-) diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index 40a2c14269c16..ad184d6500d56 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -306,62 +306,8 @@ where values.data_type(), DataType::Binary | DataType::LargeBinary )); - self.hashes_buffer.clear(); - self.hashes_buffer.resize(values.len(), 0); - create_hashes([values], &self.random_state, &mut self.hashes_buffer) - .unwrap(); - let hashes = std::mem::take(&mut self.hashes_buffer); self.insert_if_new_inner::>( values, - &hashes, - make_payload_fn, - observe_payload_fn, - ); - self.hashes_buffer = hashes - } - OutputType::Utf8 => { - assert!(matches!( - values.data_type(), - DataType::Utf8 | DataType::LargeUtf8 - )); - self.hashes_buffer.clear(); - self.hashes_buffer.resize(values.len(), 0); - create_hashes([values], &self.random_state, &mut self.hashes_buffer) - .unwrap(); - let hashes = std::mem::take(&mut self.hashes_buffer); - self.insert_if_new_inner::>( - values, - &hashes, - make_payload_fn, - observe_payload_fn, - ); - self.hashes_buffer = hashes - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - }; - } - - /// Inserts each value from `values` into the map using precomputed hashes. - pub fn insert_if_new_with_hashes( - &mut self, - values: &ArrayRef, - hashes: &[u64], - make_payload_fn: MP, - observe_payload_fn: OP, - ) where - MP: FnMut(Option<&[u8]>) -> V, - OP: FnMut(V), - { - assert_eq!(values.len(), hashes.len()); - match self.output_type { - OutputType::Binary => { - assert!(matches!( - values.data_type(), - DataType::Binary | DataType::LargeBinary - )); - self.insert_if_new_inner::>( - values, - hashes, make_payload_fn, observe_payload_fn, ) @@ -373,7 +319,6 @@ where )); self.insert_if_new_inner::>( values, - hashes, make_payload_fn, observe_payload_fn, ) @@ -393,7 +338,6 @@ where fn insert_if_new_inner( &mut self, values: &ArrayRef, - hashes: &[u64], mut make_payload_fn: MP, mut observe_payload_fn: OP, ) where @@ -401,10 +345,22 @@ where OP: FnMut(V), B: ByteArrayType, { + // step 1: compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes([values], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + + // step 2: insert each value into the set, if not already present let values = values.as_bytes::(); - assert_eq!(values.len(), hashes.len()); - for (value, &hash) in values.iter().zip(hashes.iter()) { + // Ensure lengths are equivalent + assert_eq!(values.len(), batch_hashes.len()); + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { // handle null value let Some(value) = value else { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 6d901e8753743..9d4b556393a24 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -216,56 +216,8 @@ where match self.output_type { OutputType::BinaryView => { assert!(matches!(values.data_type(), DataType::BinaryView)); - self.hashes_buffer.clear(); - self.hashes_buffer.resize(values.len(), 0); - create_hashes([values], &self.random_state, &mut self.hashes_buffer) - .unwrap(); - let hashes = std::mem::take(&mut self.hashes_buffer); self.insert_if_new_inner::( values, - &hashes, - make_payload_fn, - observe_payload_fn, - ); - self.hashes_buffer = hashes - } - OutputType::Utf8View => { - assert!(matches!(values.data_type(), DataType::Utf8View)); - self.hashes_buffer.clear(); - self.hashes_buffer.resize(values.len(), 0); - create_hashes([values], &self.random_state, &mut self.hashes_buffer) - .unwrap(); - let hashes = std::mem::take(&mut self.hashes_buffer); - self.insert_if_new_inner::( - values, - &hashes, - make_payload_fn, - observe_payload_fn, - ); - self.hashes_buffer = hashes - } - _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), - }; - } - - /// Inserts each value from `values` into the map using precomputed hashes. - pub fn insert_if_new_with_hashes( - &mut self, - values: &ArrayRef, - hashes: &[u64], - make_payload_fn: MP, - observe_payload_fn: OP, - ) where - MP: FnMut(Option<&[u8]>) -> V, - OP: FnMut(V), - { - assert_eq!(values.len(), hashes.len()); - match self.output_type { - OutputType::BinaryView => { - assert!(matches!(values.data_type(), DataType::BinaryView)); - self.insert_if_new_inner::( - values, - hashes, make_payload_fn, observe_payload_fn, ) @@ -274,7 +226,6 @@ where assert!(matches!(values.data_type(), DataType::Utf8View)); self.insert_if_new_inner::( values, - hashes, make_payload_fn, observe_payload_fn, ) @@ -294,7 +245,6 @@ where fn insert_if_new_inner( &mut self, values: &ArrayRef, - hashes: &[u64], mut make_payload_fn: MP, mut observe_payload_fn: OP, ) where @@ -302,15 +252,27 @@ where OP: FnMut(V), B: ByteViewType, { + // step 1: compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes([values], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + + // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); - assert_eq!(values.len(), hashes.len()); // Get raw views buffer for direct comparison let input_views = values.views(); + // Ensure lengths are equivalent + assert_eq!(values.len(), self.hashes_buffer.len()); + for i in 0..values.len() { let view_u128 = input_views[i]; - let hash = hashes[i]; + let hash = self.hashes_buffer[i]; // handle null value via validity bitmap check if values.is_null(i) { From 49eabfa39258813ac9368925f427f66c359d9849 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 17:35:02 +0800 Subject: [PATCH 05/17] Minimize primitive group values intern change --- .../group_values/single_group_by/primitive.rs | 55 +++++++++---------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index e907d9f03155d..889304aad6f7e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -129,41 +129,51 @@ impl GroupValuesPrimitive { random_state: crate::aggregates::AGGREGATION_HASH_SEED, } } +} - fn intern_impl(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> - where - T::Native: HashValue, - { +impl GroupValues for GroupValuesPrimitive +where + T::Native: HashValue, +{ + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, + ) -> Result<()> { assert_eq!(cols.len(), 1); groups.clear(); - for value in cols[0].as_primitive::() { - let group_id = match value { + for v in cols[0].as_primitive::() { + let group_id = match v { None => *self.null_group.get_or_insert_with(|| { let group_id = self.values.len(); self.values.push(Default::default()); group_id }), Some(key) => { + // Fold equivalence-class duplicates (e.g. `-0.0` → `+0.0`) + // so the bit-equal `is_eq` matches and the stored value is + // the canonical representative. let key = key.canonicalize(); let state = &self.random_state; let hash = key.hash(state); let insert = self.map.entry( hash, - |&(group_id, exist_hash)| unsafe { - hash == exist_hash - && self.values.get_unchecked(group_id).is_eq(key) + |&(g, h)| unsafe { + hash == h && self.values.get_unchecked(g).is_eq(key) }, - |&(_, exist_hash)| exist_hash, + |&(_, h)| h, ); match insert { - hashbrown::hash_table::Entry::Occupied(entry) => entry.get().0, - hashbrown::hash_table::Entry::Vacant(entry) => { - let group_id = self.values.len(); - entry.insert((group_id, hash)); + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert((g, hash)); self.values.push(key); - group_id + g } } } @@ -172,21 +182,6 @@ impl GroupValuesPrimitive { } Ok(()) } -} - -impl GroupValues for GroupValuesPrimitive -where - T::Native: HashValue, -{ - fn intern( - &mut self, - cols: &[ArrayRef], - groups: &mut Vec, - _hashes: &mut Vec, - _new_group_rows: &mut Vec, - ) -> Result<()> { - self.intern_impl(cols, groups) - } fn size(&self) -> usize { self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() From e0135800ad2ab40d085b06a77678796db31a273a Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 18:09:59 +0800 Subject: [PATCH 06/17] Always support partial repartition in column groups --- .../src/aggregates/group_values/multi_group_by/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 12c69dbe87890..668c4292076d1 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -1288,7 +1288,7 @@ impl GroupValues for GroupValuesColumn { } fn support_partial_repartition(&self) -> bool { - !self.group_values.is_empty() + true } fn skip_hash_group_by(&mut self) -> Result<()> { From da15c1eced30e5d6bd6d67a555803f78c22baad3 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 18:23:48 +0800 Subject: [PATCH 07/17] Allow skip partial through either path --- .../aggregates/aggregate_hash_table/common.rs | 9 ++++++++ .../aggregate_hash_table/partial_table.rs | 23 +++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index a1c31f39c5e53..3bda4623adf3b 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -457,6 +457,15 @@ impl HashAggregateAccumulator { self.accumulator.supports_convert_to_state() } + pub(super) fn convert_to_state( + &mut self, + values: &EvaluatedAccumulatorArgs, + ) -> Result> { + let opt_filter = values.filter.as_ref().map(|filter| filter.as_boolean()); + self.accumulator + .convert_to_state(&values.arguments, opt_filter) + } + pub(super) fn null_arguments( &self, input_schema: &SchemaRef, diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 85ca6a5e37276..3ddf946c2e8bd 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -103,7 +103,7 @@ impl AggregateHashTable { pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool { let state = self.state.building(); state.group_values.support_partial_repartition() - && state + || state .accumulators .iter() .all(|acc| acc.supports_convert_to_state()) @@ -118,7 +118,9 @@ impl AggregateHashTable { let state = self.state.building(); let group_schema = state.group_by.group_schema(&self.input_schema)?; let mut group_values = new_group_values(group_schema, &GroupOrdering::None)?; - group_values.skip_hash_group_by()?; + if group_values.support_partial_repartition() { + group_values.skip_hash_group_by()?; + } let accumulators = state .accumulators .iter() @@ -268,6 +270,23 @@ impl AggregateHashTable { .into_iter() .next() .unwrap_or_default(); + + if !state.group_values.support_partial_repartition() { + let mut output = group_values; + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + output.extend(acc.convert_to_state(values)?); + } + + return Ok(RecordBatch::try_new( + Arc::clone(&self.output_schema), + output, + )?); + } + state.group_values.intern( &group_values, &mut state.batch_group_indices, From 0432f057027336042be1ed2f1fafccf5396f8a14 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 18:54:41 +0800 Subject: [PATCH 08/17] Track partial aggregation repartition groups --- .../aggregate_hash_table/partial_table.rs | 44 +++++++++ .../src/aggregates/hash_aggregate.rs | 90 +++++++++++++++++-- 2 files changed, 127 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 3ddf946c2e8bd..b3e85e41889b0 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -109,6 +109,41 @@ impl AggregateHashTable { .all(|acc| acc.supports_convert_to_state()) } + pub(in crate::aggregates) fn can_repartition_in_partial(&self) -> bool { + self.state + .building() + .group_values + .support_partial_repartition() + } + + pub(in crate::aggregates) fn append_new_groups_to_partitions( + &self, + partitions: &mut [Vec], + ) -> Result<()> { + if partitions.is_empty() { + return Ok(()); + } + + let state = self.state.building(); + for &row in &state.new_group_rows { + let Some(&group_index) = state.batch_group_indices.get(row) else { + return internal_err!( + "new group row index {row} does not have a group index" + ); + }; + let Some(&hash) = state.batch_hashes.get(row) else { + return internal_err!( + "new group row index {row} does not have a hash value" + ); + }; + + let partition = partition_for_hash(hash, partitions.len()); + partitions[partition].push(group_index); + } + + Ok(()) + } + /// In skip-partial-aggregation optimization, when a decision has been made to skip /// partial stage, build a typed hash table only for aggregation state conversion /// row-by-row. @@ -252,6 +287,15 @@ impl AggregateHashTable { } } +fn partition_for_hash(hash: u64, num_partitions: usize) -> usize { + debug_assert!(num_partitions > 0); + if num_partitions.is_power_of_two() { + (hash as usize) & (num_partitions - 1) + } else { + (hash as usize) % num_partitions + } +} + impl AggregateHashTable { pub(in crate::aggregates) fn convert_batch_to_state( &mut self, diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 4c8756c0e865c..0163c7dc9c7a1 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -33,6 +33,7 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use futures::stream::{Stream, StreamExt}; @@ -126,6 +127,9 @@ pub(crate) struct PartialHashAggregateStream { /// Tracks whether partial aggregation should switch to direct state conversion. skip_aggregation_probe: Option, + /// Target final partition for newly created groups. + repartition_state: Option, + /// Optional soft limit on the number of groups to accumulate before output. /// /// Invariant: when this is `Some(..)`, the accumulators inside `hash_table` must @@ -137,6 +141,27 @@ pub(crate) struct PartialHashAggregateStream { state: Option, } +struct PartialRepartitionState { + partitions: Vec>, +} + +impl PartialRepartitionState { + fn new(num_partitions: usize) -> Self { + Self { + partitions: vec![vec![]; num_partitions], + } + } + + fn memory_size(&self) -> usize { + self.partitions.allocated_size() + + self + .partitions + .iter() + .map(VecAllocExt::allocated_size) + .sum::() + } +} + /// States for partial hash aggregation processing. enum PartialHashAggregateState { ReadingInput { @@ -184,6 +209,18 @@ impl PartialHashAggregateState { } } +fn can_repartition_in_partial( + agg: &AggregateExec, + context: &TaskContext, + hash_table: &AggregateHashTable, +) -> bool { + agg.group_by.is_single() + && !agg.group_by.is_empty() + && context.session_config().repartition_aggregations() + && context.session_config().target_partitions() > 1 + && hash_table.can_repartition_in_partial() +} + /// Hash aggregation is implemented in two stages: partial and final. This /// stream implements the final stage. /// @@ -293,6 +330,14 @@ impl PartialHashAggregateStream { Arc::clone(&schema), batch_size, )?; + let repartition_state = + if can_repartition_in_partial(agg, context.as_ref(), &hash_table) { + Some(PartialRepartitionState::new( + context.session_config().target_partitions(), + )) + } else { + None + }; let can_skip_aggregation = agg.group_by.is_single() && hash_table.can_skip_aggregation(); let skip_aggregation_probe = if can_skip_aggregation { @@ -328,6 +373,7 @@ impl PartialHashAggregateStream { reservation, reduction_factor, skip_aggregation_probe, + repartition_state, group_values_soft_limit: agg.limit_options().map(|config| config.limit()), state: Some(PartialHashAggregateState::ReadingInput { hash_table }), }) @@ -357,6 +403,32 @@ impl PartialHashAggregateStream { .is_some_and(|probe| probe.should_skip()) } + fn append_new_groups_to_partitions( + &mut self, + hash_table: &AggregateHashTable, + ) -> Result<()> { + let Some(repartition_state) = self.repartition_state.as_mut() else { + return Ok(()); + }; + + hash_table.append_new_groups_to_partitions(&mut repartition_state.partitions) + } + + fn memory_size(&self, hash_table: &AggregateHashTable) -> usize { + hash_table.memory_size() + + self + .repartition_state + .as_ref() + .map_or(0, PartialRepartitionState::memory_size) + } + + fn resize_reservation( + &self, + hash_table: &AggregateHashTable, + ) -> Result<()> { + self.reservation.try_resize(self.memory_size(hash_table)) + } + fn start_output( &mut self, hash_table: &mut AggregateHashTable, @@ -402,6 +474,15 @@ impl PartialHashAggregateStream { )); } + if let Err(e) = + self.append_new_groups_to_partitions(original_state.hash_table()) + { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )); + } + if self.hit_soft_group_limit(original_state.hash_table()) { let timer = elapsed_compute.timer(); let result = self.start_output(original_state.hash_table_mut(), true); @@ -472,10 +553,7 @@ impl PartialHashAggregateStream { // TODO: impl memory-limited aggr, when OOM directly send // partial state to final aggregate stage - if let Err(e) = self - .reservation - .try_resize(original_state.hash_table().memory_size()) - { + if let Err(e) = self.resize_reservation(original_state.hash_table()) { return ControlFlow::Break(( Poll::Ready(Some(Err(e))), original_state, @@ -537,9 +615,7 @@ impl PartialHashAggregateStream { match result { Ok(Some(batch)) => { - let _ = self - .reservation - .try_resize(original_state.hash_table().memory_size()); + let _ = self.resize_reservation(original_state.hash_table()); self.reduction_factor.add_part(batch.num_rows()); debug_assert!(batch.num_rows() > 0); let next_state = if original_state.hash_table().is_done() { From 510d7b33edb398ee6903ffaa40095b2737307e6c Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 18:58:50 +0800 Subject: [PATCH 09/17] Relax partial repartition grouping check --- datafusion/physical-plan/src/aggregates/hash_aggregate.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 0163c7dc9c7a1..959d9cf3fd2c4 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -214,8 +214,7 @@ fn can_repartition_in_partial( context: &TaskContext, hash_table: &AggregateHashTable, ) -> bool { - agg.group_by.is_single() - && !agg.group_by.is_empty() + !agg.group_by.is_empty() && context.session_config().repartition_aggregations() && context.session_config().target_partitions() > 1 && hash_table.can_repartition_in_partial() From 98d0a6707f8916a94cbbdec40a734ddb634fa0f4 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 29 Jun 2026 20:36:54 +0800 Subject: [PATCH 10/17] Minimize row group values intern change --- .../src/aggregates/group_values/row.rs | 47 ++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index ce8fc4472be20..9350daf1de2c2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -113,8 +113,24 @@ impl GroupValuesRows { random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } +} + +impl GroupValues for GroupValuesRows { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, + ) -> Result<()> { + // Normalize -0.0 → +0.0 so RowConverter (IEEE 754 totalOrder) and + // primitive hashing both group ±0 together. No-op for non-float + // columns. + let normalized_cols: Vec = + cols.iter().map(normalize_float_zero).collect(); + let cols = normalized_cols.as_slice(); - fn intern_impl(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; group_rows.clear(); self.row_converter.append(group_rows, cols)?; @@ -125,8 +141,10 @@ impl GroupValuesRows { None => self.row_converter.empty_rows(0, 0), }; + // tracks to which group each of the input rows belongs groups.clear(); + // 1.1 Calculate the group keys for the group values let batch_hashes = &mut self.hashes_buffer; batch_hashes.clear(); batch_hashes.resize(n_rows, 0); @@ -134,15 +152,27 @@ impl GroupValuesRows { for (row, &target_hash) in batch_hashes.iter().enumerate() { let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| { + // Somewhat surprisingly, this closure can be called even if the + // hash doesn't match, so check the hash first with an integer + // comparison first avoid the more expensive comparison with + // group value. https://github.com/apache/datafusion/pull/11718 target_hash == *exist_hash + // verify that the group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) && group_rows.row(row) == group_values.row(*group_idx) }); let group_idx = match entry { + // Existing group_index for this group value Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group None => { + // Add new entry to aggr_state and save newly created index let group_idx = group_values.num_rows(); group_values.push(group_rows.row(row)); + + // for hasher function, use precomputed hash value self.map.insert_accounted( (target_hash, group_idx), |(hash, _group_index)| *hash, @@ -158,21 +188,6 @@ impl GroupValuesRows { Ok(()) } -} - -impl GroupValues for GroupValuesRows { - fn intern( - &mut self, - cols: &[ArrayRef], - groups: &mut Vec, - _hashes: &mut Vec, - _new_group_rows: &mut Vec, - ) -> Result<()> { - let normalized_cols: Vec = - cols.iter().map(normalize_float_zero).collect(); - let cols = normalized_cols.as_slice(); - self.intern_impl(cols, groups) - } fn size(&self) -> usize { let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); From 5db2863d8aac025907ef6077d598a6ca88eb8260 Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 30 Jun 2026 08:17:59 +0800 Subject: [PATCH 11/17] Use append-only group values for skip partial --- .../aggregate_hash_table/partial_table.rs | 51 +++++-------------- .../src/aggregates/hash_aggregate.rs | 23 +++++++++ 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index b3e85e41889b0..572c24c91be05 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -23,7 +23,6 @@ use arrow::array::{ArrayRef, BooleanArray, new_null_array}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; -use datafusion_expr::EmitTo; use crate::aggregates::group_values::new_group_values; use crate::aggregates::order::GroupOrdering; @@ -144,6 +143,15 @@ impl AggregateHashTable { Ok(()) } + pub(in crate::aggregates) fn skip_hash_group_by(&mut self) -> Result<()> { + self.state + .building_mut() + .group_values + .skip_hash_group_by()?; + self.batch_size = usize::MAX; + Ok(()) + } + /// In skip-partial-aggregation optimization, when a decision has been made to skip /// partial stage, build a typed hash table only for aggregation state conversion /// row-by-row. @@ -152,10 +160,7 @@ impl AggregateHashTable { ) -> Result> { let state = self.state.building(); let group_schema = state.group_by.group_schema(&self.input_schema)?; - let mut group_values = new_group_values(group_schema, &GroupOrdering::None)?; - if group_values.support_partial_repartition() { - group_values.skip_hash_group_by()?; - } + let group_values = new_group_values(group_schema, &GroupOrdering::None)?; let accumulators = state .accumulators .iter() @@ -308,49 +313,19 @@ impl AggregateHashTable { 1, "group_values expected to have single element" ); - let state = self.state.building_mut(); - let group_values = evaluated_batch + let mut output = evaluated_batch .grouping_set_args .into_iter() .next() .unwrap_or_default(); - if !state.group_values.support_partial_repartition() { - let mut output = group_values; - for (acc, values) in state - .accumulators - .iter_mut() - .zip(evaluated_batch.accumulator_args.iter()) - { - output.extend(acc.convert_to_state(values)?); - } - - return Ok(RecordBatch::try_new( - Arc::clone(&self.output_schema), - output, - )?); - } - - state.group_values.intern( - &group_values, - &mut state.batch_group_indices, - &mut state.batch_hashes, - &mut state.new_group_rows, - )?; - - let group_indices = &state.batch_group_indices; - let total_num_groups = state.group_values.len(); + let state = self.state.building_mut(); for (acc, values) in state .accumulators .iter_mut() .zip(evaluated_batch.accumulator_args.iter()) { - acc.update_batch(values, group_indices, total_num_groups)?; - } - - let mut output = state.group_values.emit(EmitTo::All)?; - for acc in state.accumulators.iter_mut() { - output.extend(acc.state(EmitTo::All)?); + output.extend(acc.convert_to_state(values)?); } Ok(RecordBatch::try_new( diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 959d9cf3fd2c4..9bf68b5546749 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -515,6 +515,29 @@ impl PartialHashAggregateStream { // True branch: a decision has been made to skip partial aggregation. if self.should_skip_aggregation() { let timer = elapsed_compute.timer(); + if original_state.hash_table().can_repartition_in_partial() { + let result = original_state.hash_table_mut().skip_hash_group_by(); + timer.done(); + + if let Err(e) = result { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )); + } + + if let Err(e) = + self.resize_reservation(original_state.hash_table()) + { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )); + } + + return ControlFlow::Continue(original_state); + } + let result = match original_state.hash_table().partial_skip_table() { Ok(skip_hash_table) => self .start_output(original_state.hash_table_mut(), false) From 1583cd413b4af372274fb639c9c39f974b154e8e Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 30 Jun 2026 08:44:03 +0800 Subject: [PATCH 12/17] Materialize append-only partial output --- .../aggregates/aggregate_hash_table/common.rs | 44 ++++++++------- .../aggregate_hash_table/final_table.rs | 22 ++------ .../aggregate_hash_table/partial_table.rs | 55 ++++++++++--------- .../group_values/multi_group_by/mod.rs | 1 + 4 files changed, 57 insertions(+), 65 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 3bda4623adf3b..513b40538efba 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -186,7 +186,7 @@ impl AggregateHashTable { + state.batch_hashes.allocated_size() + state.new_group_rows.allocated_size() } - AggregateHashTableState::OutputtingMaterializedFinal(output) => { + AggregateHashTableState::OutputtingMaterialized(output) => { output.memory_size() } AggregateHashTableState::Done => 0, @@ -216,14 +216,19 @@ impl AggregateHashTable { state.batch_group_indices = Vec::new(); self.state = AggregateHashTableState::Outputting(state); } -} -pub(super) fn emit_to_for_batch_size(batch_size: usize, group_count: usize) -> EmitTo { - debug_assert!(batch_size > 0); - if group_count <= batch_size { - EmitTo::All - } else { - EmitTo::First(batch_size) + pub(super) fn emit_next_materialized_batch( + &mut self, + mut output: MaterializedOutput, + batch_size: usize, + ) -> Option { + let batch = output.next_batch(batch_size); + if output.is_exhausted() { + self.state = AggregateHashTableState::Done; + } else { + self.state = AggregateHashTableState::OutputtingMaterialized(output); + } + batch } } @@ -314,24 +319,21 @@ pub(super) enum AggregateHashTableState { Building(AggregateHashTableBuffer), /// Emitting results directly from group keys and aggregate state. Outputting(AggregateHashTableBuffer), - /// Materialize all the output results, and then incrementally output in the `OutputtingMaterializedFinal` state. - /// - /// Note this is a temporary solution until the `GroupValues` issue is solved: - /// Issue: - OutputtingMaterializedFinal(MaterializedFinalOutput), + /// Materialized output rows sliced across output polls. + OutputtingMaterialized(MaterializedOutput), Done, } -/// Fully evaluated final aggregate output and the next row offset to emit. +/// Materialized aggregate output and the next row offset to emit. /// -/// Final aggregate evaluation consumes accumulator state, so final output is -/// materialized once and then sliced to honor `batch_size` across output polls. -pub(super) struct MaterializedFinalOutput { +/// Some output paths can only emit all rows from their backing state at once. +/// The materialized batch is sliced to honor `batch_size` across output polls. +pub(super) struct MaterializedOutput { batch: RecordBatch, offset: usize, } -impl MaterializedFinalOutput { +impl MaterializedOutput { pub(super) fn new(batch: RecordBatch) -> Self { Self { batch, offset: 0 } } @@ -505,8 +507,10 @@ mod tests { use super::*; + // Covers materialized output slicing until all rows are emitted. + // Example: a five-row batch with batch size two emits 2, 2, then 1 row. #[test] - fn materialized_final_output_slices_batches_until_exhausted() -> Result<()> { + fn test_materialized_output_slices_batches_until_exhausted() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new( "group_col", DataType::Int32, @@ -516,7 +520,7 @@ mod tests { schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], )?; - let mut output = MaterializedFinalOutput::new(batch); + let mut output = MaterializedOutput::new(batch); assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![1, 2]); assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![3, 4]); diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs index ccba78d15ba16..57682d838b68e 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs @@ -26,7 +26,7 @@ use crate::aggregates::AggregateExec; use super::common::{ AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, FinalMarker, - MaterializedFinalOutput, + MaterializedOutput, }; /// Methods specific to the aggregate hash table used in the final aggregation stage. @@ -68,7 +68,7 @@ impl AggregateHashTable { let output = self.materialize_final_output(state, output_schema)?; Ok(self.emit_next_materialized_batch(output, batch_size)) } - AggregateHashTableState::OutputtingMaterializedFinal(output) => { + AggregateHashTableState::OutputtingMaterialized(output) => { Ok(self.emit_next_materialized_batch(output, batch_size)) } AggregateHashTableState::Done => Ok(None), @@ -82,7 +82,7 @@ impl AggregateHashTable { &self, mut state: AggregateHashTableBuffer, output_schema: SchemaRef, - ) -> Result { + ) -> Result { // Final aggregate evaluation consumes accumulator state. Evaluate all // groups once, then slice the materialized batch on subsequent polls. let emit_to = EmitTo::All; @@ -96,21 +96,7 @@ impl AggregateHashTable { let batch = RecordBatch::try_new(output_schema, output)?; debug_assert!(batch.num_rows() > 0); - Ok(MaterializedFinalOutput::new(batch)) - } - - fn emit_next_materialized_batch( - &mut self, - mut output: MaterializedFinalOutput, - batch_size: usize, - ) -> Option { - let batch = output.next_batch(batch_size); - if output.is_exhausted() { - self.state = AggregateHashTableState::Done; - } else { - self.state = AggregateHashTableState::OutputtingMaterializedFinal(output); - } - batch + Ok(MaterializedOutput::new(batch)) } pub(in crate::aggregates) fn aggregate_batch( diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 572c24c91be05..56dbdf540660b 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -23,6 +23,7 @@ use arrow::array::{ArrayRef, BooleanArray, new_null_array}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; +use datafusion_expr::EmitTo; use crate::aggregates::group_values::new_group_values; use crate::aggregates::order::GroupOrdering; @@ -30,8 +31,8 @@ use crate::aggregates::{AggregateExec, group_id_array, max_duplicate_ordinal}; use super::common::{ AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, - EvaluatedAccumulatorArgs, HashAggregateAccumulator, PartialMarker, PartialSkipMarker, - emit_to_for_batch_size, + EvaluatedAccumulatorArgs, HashAggregateAccumulator, MaterializedOutput, + PartialMarker, PartialSkipMarker, }; /// Methods specific to the aggregate hash table used in the partial aggregation stage. @@ -62,43 +63,44 @@ impl AggregateHashTable { ) -> Result> { let output_schema = Arc::clone(&self.output_schema); let batch_size = self.batch_size; - match &mut self.state { + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { AggregateHashTableState::Outputting(state) => { if state.group_values.is_empty() { - self.state = AggregateHashTableState::Done; return Ok(None); } - let emit_to = - emit_to_for_batch_size(batch_size, state.group_values.len()); - let timer = self.group_by_metrics.emitting_time.timer(); - let mut output = state.group_values.emit(emit_to)?; - - for acc in state.accumulators.iter_mut() { - output.extend(acc.state(emit_to)?); - } - let done = state.group_values.is_empty(); - drop(timer); - - let batch = RecordBatch::try_new(output_schema, output)?; - debug_assert!(batch.num_rows() > 0); - if done { - self.state = AggregateHashTableState::Done; - } - Ok(Some(batch)) + let output = self.materialize_partial_output(state, output_schema)?; + Ok(self.emit_next_materialized_batch(output, batch_size)) + } + AggregateHashTableState::OutputtingMaterialized(output) => { + Ok(self.emit_next_materialized_batch(output, batch_size)) } AggregateHashTableState::Done => Ok(None), AggregateHashTableState::Building(_) => { internal_err!("next_output_batch must be called in the outputting state") } - AggregateHashTableState::OutputtingMaterializedFinal(_) => { - internal_err!( - "partial aggregate output should not materialize final output" - ) - } } } + fn materialize_partial_output( + &self, + mut state: AggregateHashTableBuffer, + output_schema: SchemaRef, + ) -> Result { + let emit_to = EmitTo::All; + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(emit_to)?); + } + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + Ok(MaterializedOutput::new(batch)) + } + pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool { let state = self.state.building(); state.group_values.support_partial_repartition() @@ -148,7 +150,6 @@ impl AggregateHashTable { .building_mut() .group_values .skip_hash_group_by()?; - self.batch_size = usize::MAX; Ok(()) } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 668c4292076d1..b10e0f359a26c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -1292,6 +1292,7 @@ impl GroupValues for GroupValuesColumn { } fn skip_hash_group_by(&mut self) -> Result<()> { + self.map.clear(); self.skip_hash_group_by = true; Ok(()) } From 1f6edfdcb6b3671dd57c9c07c4e667e5befab7f4 Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 30 Jun 2026 09:43:05 +0800 Subject: [PATCH 13/17] Emit partial repartition batches --- .../aggregate_hash_table/partial_table.rs | 87 ++++++++--- .../src/aggregates/hash_aggregate.rs | 145 ++++++++++++++++-- .../physical-plan/src/aggregates/mod.rs | 1 - .../physical-plan/src/repartition/mod.rs | 63 ++++++++ 4 files changed, 263 insertions(+), 33 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 56dbdf540660b..e32a29e8c20d2 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, new_null_array}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; +use datafusion_common::{ + Result, assert_eq_or_internal_err, internal_datafusion_err, internal_err, +}; use datafusion_expr::EmitTo; use crate::aggregates::group_values::new_group_values; @@ -61,7 +63,6 @@ impl AggregateHashTable { pub(in crate::aggregates) fn next_output_batch( &mut self, ) -> Result> { - let output_schema = Arc::clone(&self.output_schema); let batch_size = self.batch_size; match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { AggregateHashTableState::Outputting(state) => { @@ -69,8 +70,11 @@ impl AggregateHashTable { return Ok(None); } - let output = self.materialize_partial_output(state, output_schema)?; - Ok(self.emit_next_materialized_batch(output, batch_size)) + let batch = self.materialize_partial_batch(state)?; + Ok(self.emit_next_materialized_batch( + MaterializedOutput::new(batch), + batch_size, + )) } AggregateHashTableState::OutputtingMaterialized(output) => { Ok(self.emit_next_materialized_batch(output, batch_size)) @@ -82,11 +86,34 @@ impl AggregateHashTable { } } - fn materialize_partial_output( + pub(in crate::aggregates) fn materialize_output_batch( + &mut self, + ) -> Result> { + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + return Ok(None); + } + + self.materialize_partial_batch(state).map(Some) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!( + "materialize_output_batch must be called in the outputting state" + ) + } + AggregateHashTableState::OutputtingMaterialized(_) => { + internal_err!("partial aggregate output is already materialized") + } + } + } + + fn materialize_partial_batch( &self, mut state: AggregateHashTableBuffer, - output_schema: SchemaRef, - ) -> Result { + ) -> Result { + let output_schema = Arc::clone(&self.output_schema); let emit_to = EmitTo::All; let timer = self.group_by_metrics.emitting_time.timer(); let mut output = state.group_values.emit(emit_to)?; @@ -98,7 +125,7 @@ impl AggregateHashTable { let batch = RecordBatch::try_new(output_schema, output)?; debug_assert!(batch.num_rows() > 0); - Ok(MaterializedOutput::new(batch)) + Ok(batch) } pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool { @@ -117,14 +144,35 @@ impl AggregateHashTable { .support_partial_repartition() } - pub(in crate::aggregates) fn append_new_groups_to_partitions( + pub(in crate::aggregates) fn append_new_group_partitions( &self, - partitions: &mut [Vec], + partition_group_indices: &mut [Vec], ) -> Result<()> { - if partitions.is_empty() { + let num_partitions = partition_group_indices.len(); + if num_partitions == 0 { return Ok(()); } + if num_partitions.is_power_of_two() { + let mask = num_partitions - 1; + self.append_new_groups_with_partition(partition_group_indices, |hash| { + (hash as usize) & mask + }) + } else { + self.append_new_groups_with_partition(partition_group_indices, |hash| { + (hash as usize) % num_partitions + }) + } + } + + fn append_new_groups_with_partition( + &self, + partition_group_indices: &mut [Vec], + compute_partition: F, + ) -> Result<()> + where + F: Fn(u64) -> usize, + { let state = self.state.building(); for &row in &state.new_group_rows { let Some(&group_index) = state.batch_group_indices.get(row) else { @@ -138,8 +186,12 @@ impl AggregateHashTable { ); }; - let partition = partition_for_hash(hash, partitions.len()); - partitions[partition].push(group_index); + let group_index = u32::try_from(group_index).map_err(|_| { + internal_datafusion_err!( + "partitioned aggregate output index exceeds u32::MAX" + ) + })?; + partition_group_indices[compute_partition(hash)].push(group_index); } Ok(()) @@ -293,15 +345,6 @@ impl AggregateHashTable { } } -fn partition_for_hash(hash: u64, num_partitions: usize) -> usize { - debug_assert!(num_partitions > 0); - if num_partitions.is_power_of_two() { - (hash as usize) & (num_partitions - 1) - } else { - (hash as usize) % num_partitions - } -} - impl AggregateHashTable { pub(in crate::aggregates) fn convert_batch_to_state( &mut self, diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 9bf68b5546749..adbdb585cd299 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -29,8 +29,10 @@ use std::ops::ControlFlow; use std::sync::Arc; use std::task::{Context, Poll}; +use arrow::array::{PrimitiveArray, RecordBatch}; +use arrow::compute::take_arrays; use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::UInt32Type; use datafusion_common::Result; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::proxy::VecAllocExt; @@ -45,6 +47,9 @@ use super::skip_partial::SkipAggregationProbe; use crate::metrics::{ BaselineMetrics, MetricBuilder, MetricCategory, RecordOutput, SpillMetrics, }; +use crate::repartition::{ + PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY, PARTITIONED_AGGREGATION_PARTITION_KEY, +}; use crate::stream::EmptyRecordBatchStream; use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metrics}; @@ -127,7 +132,7 @@ pub(crate) struct PartialHashAggregateStream { /// Tracks whether partial aggregation should switch to direct state conversion. skip_aggregation_probe: Option, - /// Target final partition for newly created groups. + /// Group indices for each target final partition. repartition_state: Option, /// Optional soft limit on the number of groups to accumulate before output. @@ -142,26 +147,118 @@ pub(crate) struct PartialHashAggregateStream { } struct PartialRepartitionState { - partitions: Vec>, + partition_group_indices: Vec>, + partition_batches: Vec, + offsets: Vec, + remaining_rows: usize, + batch_size: usize, + num_partitions: usize, + next_partition: usize, } impl PartialRepartitionState { - fn new(num_partitions: usize) -> Self { + fn new(num_partitions: usize, batch_size: usize) -> Self { Self { - partitions: vec![vec![]; num_partitions], + partition_group_indices: vec![vec![]; num_partitions], + partition_batches: Vec::with_capacity(num_partitions), + offsets: vec![0; num_partitions], + remaining_rows: 0, + batch_size, + num_partitions, + next_partition: 0, + } + } + + fn start_output(&mut self, batch: &RecordBatch) -> Result<()> { + self.partition_batches.clear(); + self.remaining_rows = 0; + + for group_indices in &self.partition_group_indices { + let indices: PrimitiveArray = group_indices.clone().into(); + let columns = take_arrays(batch.columns(), &indices, None)?; + let partition_batch = + RecordBatch::try_new(Arc::clone(&batch.schema()), columns)?; + self.remaining_rows += partition_batch.num_rows(); + self.partition_batches.push(partition_batch); + } + + self.offsets.clear(); + self.offsets.resize(self.num_partitions, 0); + self.next_partition = 0; + Ok(()) + } + + fn next_batch(&mut self) -> Option { + if self.partition_batches.is_empty() { + return None; } + + for _ in 0..self.num_partitions { + let partition = self.next_partition; + self.next_partition = (self.next_partition + 1) % self.num_partitions; + + let batch = &self.partition_batches[partition]; + let offset = self.offsets[partition]; + let partition_rows = batch.num_rows(); + if offset >= partition_rows { + continue; + } + + let len = self.batch_size.min(partition_rows - offset); + let output = add_partitioned_aggregation_metadata( + batch.slice(offset, len), + partition, + self.num_partitions, + ); + self.offsets[partition] += len; + self.remaining_rows -= len; + if self.remaining_rows == 0 { + self.partition_batches.clear(); + } + return Some(output); + } + + self.partition_batches.clear(); + None + } + + fn is_done(&self) -> bool { + self.partition_batches.is_empty() } fn memory_size(&self) -> usize { - self.partitions.allocated_size() + self.partition_group_indices.allocated_size() + self - .partitions + .partition_group_indices .iter() .map(VecAllocExt::allocated_size) .sum::() + + self + .partition_batches + .iter() + .map(RecordBatch::get_array_memory_size) + .sum::() + + self.offsets.allocated_size() } } +fn add_partitioned_aggregation_metadata( + mut batch: RecordBatch, + partition: usize, + num_partitions: usize, +) -> RecordBatch { + let metadata = batch.schema_metadata_mut(); + metadata.insert( + PARTITIONED_AGGREGATION_PARTITION_KEY.to_string(), + partition.to_string(), + ); + metadata.insert( + PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY.to_string(), + num_partitions.to_string(), + ); + batch +} + /// States for partial hash aggregation processing. enum PartialHashAggregateState { ReadingInput { @@ -333,6 +430,7 @@ impl PartialHashAggregateStream { if can_repartition_in_partial(agg, context.as_ref(), &hash_table) { Some(PartialRepartitionState::new( context.session_config().target_partitions(), + batch_size, )) } else { None @@ -410,7 +508,8 @@ impl PartialHashAggregateStream { return Ok(()); }; - hash_table.append_new_groups_to_partitions(&mut repartition_state.partitions) + hash_table + .append_new_group_partitions(&mut repartition_state.partition_group_indices) } fn memory_size(&self, hash_table: &AggregateHashTable) -> usize { @@ -428,6 +527,32 @@ impl PartialHashAggregateStream { self.reservation.try_resize(self.memory_size(hash_table)) } + fn next_output_batch( + &mut self, + hash_table: &mut AggregateHashTable, + ) -> Result> { + let Some(repartition_state) = self.repartition_state.as_mut() else { + return hash_table.next_output_batch(); + }; + + if repartition_state.is_done() { + let Some(batch) = hash_table.materialize_output_batch()? else { + return Ok(None); + }; + repartition_state.start_output(&batch)?; + } + + Ok(repartition_state.next_batch()) + } + + fn is_output_done(&self, hash_table: &AggregateHashTable) -> bool { + hash_table.is_done() + && self + .repartition_state + .as_ref() + .is_none_or(PartialRepartitionState::is_done) + } + fn start_output( &mut self, hash_table: &mut AggregateHashTable, @@ -632,7 +757,7 @@ impl PartialHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); - let result = original_state.hash_table_mut().next_output_batch(); + let result = self.next_output_batch(original_state.hash_table_mut()); timer.done(); match result { @@ -640,7 +765,7 @@ impl PartialHashAggregateStream { let _ = self.resize_reservation(original_state.hash_table()); self.reduction_factor.add_part(batch.num_rows()); debug_assert!(batch.num_rows() > 0); - let next_state = if original_state.hash_table().is_done() { + let next_state = if self.is_output_done(original_state.hash_table()) { match original_state { PartialHashAggregateState::ProducingOutput { skip_hash_table: Some(hash_table), diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4f5b893578d74..f927a50badfb5 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -519,7 +519,6 @@ impl PartialEq for PhysicalGroupBy { /// partial state: [g, AVG(x) state columns, e.g. sum/count] /// final result: [g, AVG(x)] /// ``` -#[expect(clippy::large_enum_variant)] enum StreamType { /// Single group (no group by) aggregate stream. /// Input output scheme: initial input -> final result diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 2298183485f55..e648e4f536d07 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -81,6 +81,11 @@ use distributor_channels::{ DistributionReceiver, DistributionSender, channels, partition_aware_channels, }; +pub(crate) const PARTITIONED_AGGREGATION_PARTITION_KEY: &str = + "datafusion.partitioned_aggregation.partition"; +pub(crate) const PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY: &str = + "datafusion.partitioned_aggregation.num_partitions"; + /// A batch in the repartition queue - either in memory or spilled to disk. /// /// This enum represents the two states a batch can be in during repartitioning. @@ -152,6 +157,46 @@ type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; type InputPartitionsToCurrentPartitionReceiver = Vec>; +fn partitioned_aggregation_partition( + batch: &RecordBatch, + expected_num_partitions: usize, +) -> Result> { + let metadata = batch.schema_ref().metadata(); + let Some(partition) = metadata.get(PARTITIONED_AGGREGATION_PARTITION_KEY) else { + return Ok(None); + }; + let Some(num_partitions) = metadata.get(PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY) + else { + return Ok(None); + }; + + let partition = partition.parse::().map_err(|err| { + DataFusionError::Internal(format!( + "Invalid partitioned aggregation partition metadata: {err}" + )) + })?; + let num_partitions = num_partitions.parse::().map_err(|err| { + DataFusionError::Internal(format!( + "Invalid partitioned aggregation partition count metadata: {err}" + )) + })?; + + assert_or_internal_err!( + num_partitions == expected_num_partitions, + "Partitioned aggregation partition count {} does not match repartition count {}", + num_partitions, + expected_num_partitions + ); + assert_or_internal_err!( + partition < expected_num_partitions, + "Partitioned aggregation partition {} is out of range for {} partitions", + partition, + expected_num_partitions + ); + + Ok(Some(partition)) +} + /// Output channel with its associated memory reservation and spill writer. /// /// `coalescer` is `None` for preserve-order mode, where downstream @@ -1636,6 +1681,7 @@ impl RepartitionExec { input_partition: usize, num_input_partitions: usize, ) -> Result<()> { + let num_output_partitions = partitioning.partition_count(); let mut partitioner = match &partitioning { Partitioning::Hash(exprs, num_partitions) => { BatchPartitioner::new_hash_partitioner( @@ -1683,6 +1729,23 @@ impl RepartitionExec { continue; } + if let Partitioning::Hash(_, _) = &partitioning + && let Some(partition) = + partitioned_aggregation_partition(&batch, num_output_partitions)? + { + let timer = metrics.send_time[partition].timer(); + if let Some(output_channel) = output_channels.get_mut(&partition) { + for batch in output_channel.coalesce(batch)? { + if output_channel.send(batch).await.is_err() { + output_channels.remove(&partition); + break; + } + } + } + timer.done(); + continue; + } + for res in partitioner.partition_iter(batch)? { let (partition, batch) = res?; From 7d3846cd57294df0ce9ba1deb62dbbead647c8fe Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 30 Jun 2026 18:17:11 +0800 Subject: [PATCH 14/17] Bypass coalesce for partitioned aggregation --- datafusion/physical-plan/src/repartition/mod.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index e648e4f536d07..7022bc101b7c5 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1734,13 +1734,10 @@ impl RepartitionExec { partitioned_aggregation_partition(&batch, num_output_partitions)? { let timer = metrics.send_time[partition].timer(); - if let Some(output_channel) = output_channels.get_mut(&partition) { - for batch in output_channel.coalesce(batch)? { - if output_channel.send(batch).await.is_err() { - output_channels.remove(&partition); - break; - } - } + if let Some(output_channel) = output_channels.get_mut(&partition) + && output_channel.send(batch).await.is_err() + { + output_channels.remove(&partition); } timer.done(); continue; From 8cfa8750e352e90ab3d0033e8afb0fc7d7e31b4e Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 30 Jun 2026 20:08:41 +0800 Subject: [PATCH 15/17] Move partial aggregation repartition into emit --- .../aggregates/aggregate_hash_table/common.rs | 5 + .../aggregate_hash_table/partial_table.rs | 154 +++---- .../group_values/multi_group_by/mod.rs | 22 +- .../src/aggregates/hash_aggregate.rs | 419 ++++++++++++------ .../physical-plan/src/aggregates/mod.rs | 10 +- 5 files changed, 365 insertions(+), 245 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 513b40538efba..a1f7fb9f0d102 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -135,6 +135,7 @@ impl AggregateHashTable { group_values, batch_group_indices: Default::default(), batch_hashes: Default::default(), + group_hashes: Default::default(), new_group_rows: Default::default(), accumulators, }), @@ -184,6 +185,7 @@ impl AggregateHashTable { acc + state.group_values.size() + state.batch_group_indices.allocated_size() + state.batch_hashes.allocated_size() + + state.group_hashes.allocated_size() + state.new_group_rows.allocated_size() } AggregateHashTableState::OutputtingMaterialized(output) => { @@ -304,6 +306,9 @@ pub(super) struct AggregateHashTableBuffer { /// Hash for each row in the current input batch. pub(super) batch_hashes: Vec, + /// Hash for each accumulated group. + pub(super) group_hashes: Vec, + /// Input rows that created new groups in the current input batch. pub(super) new_group_rows: Vec, diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index e32a29e8c20d2..921fe429c7823 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -22,9 +22,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, new_null_array}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{ - Result, assert_eq_or_internal_err, internal_datafusion_err, internal_err, -}; +use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_expr::EmitTo; use crate::aggregates::group_values::new_group_values; @@ -33,8 +31,8 @@ use crate::aggregates::{AggregateExec, group_id_array, max_duplicate_ordinal}; use super::common::{ AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, - EvaluatedAccumulatorArgs, HashAggregateAccumulator, MaterializedOutput, - PartialMarker, PartialSkipMarker, + EvaluatedAccumulatorArgs, EvaluatedAggregateBatch, HashAggregateAccumulator, + MaterializedOutput, PartialMarker, PartialSkipMarker, }; /// Methods specific to the aggregate hash table used in the partial aggregation stage. @@ -86,21 +84,23 @@ impl AggregateHashTable { } } - pub(in crate::aggregates) fn materialize_output_batch( + pub(in crate::aggregates) fn materialize_output_batch_with_hashes( &mut self, - ) -> Result> { + ) -> Result)>> { match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { AggregateHashTableState::Outputting(state) => { if state.group_values.is_empty() { return Ok(None); } - self.materialize_partial_batch(state).map(Some) + let hashes = state.group_hashes.clone(); + self.materialize_partial_batch(state) + .map(|batch| Some((batch, hashes))) } AggregateHashTableState::Done => Ok(None), AggregateHashTableState::Building(_) => { internal_err!( - "materialize_output_batch must be called in the outputting state" + "materialize_output_batch_with_hashes must be called in the outputting state" ) } AggregateHashTableState::OutputtingMaterialized(_) => { @@ -144,76 +144,19 @@ impl AggregateHashTable { .support_partial_repartition() } - pub(in crate::aggregates) fn append_new_group_partitions( - &self, - partition_group_indices: &mut [Vec], - ) -> Result<()> { - let num_partitions = partition_group_indices.len(); - if num_partitions == 0 { - return Ok(()); - } - - if num_partitions.is_power_of_two() { - let mask = num_partitions - 1; - self.append_new_groups_with_partition(partition_group_indices, |hash| { - (hash as usize) & mask - }) - } else { - self.append_new_groups_with_partition(partition_group_indices, |hash| { - (hash as usize) % num_partitions - }) - } - } - - fn append_new_groups_with_partition( - &self, - partition_group_indices: &mut [Vec], - compute_partition: F, - ) -> Result<()> - where - F: Fn(u64) -> usize, - { - let state = self.state.building(); - for &row in &state.new_group_rows { - let Some(&group_index) = state.batch_group_indices.get(row) else { - return internal_err!( - "new group row index {row} does not have a group index" - ); - }; - let Some(&hash) = state.batch_hashes.get(row) else { - return internal_err!( - "new group row index {row} does not have a hash value" - ); - }; - - let group_index = u32::try_from(group_index).map_err(|_| { - internal_datafusion_err!( - "partitioned aggregate output index exceeds u32::MAX" - ) - })?; - partition_group_indices[compute_partition(hash)].push(group_index); - } - - Ok(()) - } - - pub(in crate::aggregates) fn skip_hash_group_by(&mut self) -> Result<()> { - self.state - .building_mut() - .group_values - .skip_hash_group_by()?; - Ok(()) - } - /// In skip-partial-aggregation optimization, when a decision has been made to skip /// partial stage, build a typed hash table only for aggregation state conversion /// row-by-row. pub(in crate::aggregates) fn partial_skip_table( &self, + skip_hash_group_by: bool, ) -> Result> { let state = self.state.building(); let group_schema = state.group_by.group_schema(&self.input_schema)?; - let group_values = new_group_values(group_schema, &GroupOrdering::None)?; + let mut group_values = new_group_values(group_schema, &GroupOrdering::None)?; + if skip_hash_group_by && group_values.support_partial_repartition() { + group_values.skip_hash_group_by()?; + } let accumulators = state .accumulators .iter() @@ -230,6 +173,7 @@ impl AggregateHashTable { group_values, batch_group_indices: Default::default(), batch_hashes: Default::default(), + group_hashes: Default::default(), new_group_rows: Default::default(), accumulators, }), @@ -254,6 +198,11 @@ impl AggregateHashTable { )?; let group_indices = &state.batch_group_indices; let total_num_groups = state.group_values.len(); + state.group_hashes.resize(total_num_groups, 0); + for &row in &state.new_group_rows { + let group_index = group_indices[row]; + state.group_hashes[group_index] = state.batch_hashes[row]; + } for (acc, values) in state .accumulators @@ -325,6 +274,11 @@ impl AggregateHashTable { &mut state.batch_hashes, &mut state.new_group_rows, )?; + state.group_hashes.resize(state.group_values.len(), 0); + for &row in &state.new_group_rows { + let group_index = state.batch_group_indices[row]; + state.group_hashes[group_index] = state.batch_hashes[row]; + } any_interned = true; } @@ -351,17 +305,7 @@ impl AggregateHashTable { batch: &RecordBatch, ) -> Result { let evaluated_batch = self.evaluate_batch(batch)?; - - assert_eq_or_internal_err!( - evaluated_batch.grouping_set_args.len(), - 1, - "group_values expected to have single element" - ); - let mut output = evaluated_batch - .grouping_set_args - .into_iter() - .next() - .unwrap_or_default(); + let mut output = Self::output_group_values(&evaluated_batch)?; let state = self.state.building_mut(); for (acc, values) in state @@ -377,4 +321,50 @@ impl AggregateHashTable { output, )?) } + + pub(in crate::aggregates) fn convert_batch_to_state_with_hashes( + &mut self, + batch: &RecordBatch, + ) -> Result<(RecordBatch, Vec)> { + let evaluated_batch = self.evaluate_batch(batch)?; + let output_group_values = Self::output_group_values(&evaluated_batch)?; + + let state = self.state.building_mut(); + state.group_values.intern( + &output_group_values, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; + let hashes = state.batch_hashes.clone(); + + let mut output = output_group_values; + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + output.extend(acc.convert_to_state(values)?); + } + + Ok(( + RecordBatch::try_new(Arc::clone(&self.output_schema), output)?, + hashes, + )) + } + + fn output_group_values( + evaluated_batch: &EvaluatedAggregateBatch, + ) -> Result> { + assert_eq_or_internal_err!( + evaluated_batch.grouping_set_args.len(), + 1, + "group_values expected to have single element" + ); + Ok(evaluated_batch + .grouping_set_args + .first() + .cloned() + .unwrap_or_default()) + } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index b10e0f359a26c..c3e4968acf6e8 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -305,36 +305,20 @@ impl GroupValuesColumn { Ok(v) } - fn append_all_group_values( - &mut self, + fn compute_hashes( + &self, cols: &[ArrayRef], groups: &mut Vec, hashes: &mut Vec, new_group_rows: &mut Vec, ) -> Result<()> { let num_rows = cols.first().map_or(0, |array| array.len()); - let first_group_idx = self.len(); groups.clear(); - groups.extend(first_group_idx..first_group_idx + num_rows); - hashes.clear(); hashes.resize(num_rows, 0); create_hashes(cols, &self.random_state, hashes)?; - new_group_rows.clear(); - new_group_rows.extend(0..num_rows); - - self.vectorized_operation_buffers.append_row_indices.clear(); - self.vectorized_operation_buffers - .append_row_indices - .extend(0..num_rows); - for (group_value, col) in self.group_values.iter_mut().zip(cols.iter()) { - group_value.vectorized_append( - col, - &self.vectorized_operation_buffers.append_row_indices, - )?; - } Ok(()) } @@ -1123,7 +1107,7 @@ impl GroupValues for GroupValuesColumn { new_group_rows: &mut Vec, ) -> Result<()> { if self.skip_hash_group_by { - return self.append_all_group_values(cols, groups, hashes, new_group_rows); + return self.compute_hashes(cols, groups, hashes, new_group_rows); } let n_rows = cols[0].len(); diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index adbdb585cd299..a729c0a169077 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -25,15 +25,16 @@ //! //! See issue for details: +use std::collections::VecDeque; use std::ops::ControlFlow; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::array::{PrimitiveArray, RecordBatch}; +use arrow::array::{PrimitiveArray, RecordBatch, UInt32Builder}; use arrow::compute::take_arrays; use arrow::datatypes::SchemaRef; use arrow::datatypes::UInt32Type; -use datafusion_common::Result; +use datafusion_common::{Result, internal_err}; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -44,6 +45,7 @@ use super::aggregate_hash_table::{ AggregateHashTable, FinalMarker, PartialMarker, PartialSkipMarker, }; use super::skip_partial::SkipAggregationProbe; +use crate::coalesce::LimitedBatchCoalescer; use crate::metrics::{ BaselineMetrics, MetricBuilder, MetricCategory, RecordOutput, SpillMetrics, }; @@ -132,8 +134,8 @@ pub(crate) struct PartialHashAggregateStream { /// Tracks whether partial aggregation should switch to direct state conversion. skip_aggregation_probe: Option, - /// Group indices for each target final partition. - repartition_state: Option, + /// Local repartition and coalesce state for partial output. + repartition_state: Option>, /// Optional soft limit on the number of groups to accumulate before output. /// @@ -143,102 +145,206 @@ pub(crate) struct PartialHashAggregateStream { /// Tracks the high-level stream lifecycle. The hash table owns the lower-level /// state for emitting output batches. - state: Option, + state: Option>, } struct PartialRepartitionState { - partition_group_indices: Vec>, - partition_batches: Vec, - offsets: Vec, - remaining_rows: usize, + coalescers: Vec, + pending: VecDeque, + partition_indices: Vec>, + source_batch: Option, + source_hashes: Vec, + source_offset: usize, batch_size: usize, num_partitions: usize, - next_partition: usize, + finished: bool, } impl PartialRepartitionState { - fn new(num_partitions: usize, batch_size: usize) -> Self { + fn new(schema: &SchemaRef, num_partitions: usize, batch_size: usize) -> Self { + let coalescers = (0..num_partitions) + .map(|_| LimitedBatchCoalescer::new(Arc::clone(schema), batch_size, None)) + .collect(); Self { - partition_group_indices: vec![vec![]; num_partitions], - partition_batches: Vec::with_capacity(num_partitions), - offsets: vec![0; num_partitions], - remaining_rows: 0, + coalescers, + pending: VecDeque::new(), + partition_indices: vec![vec![]; num_partitions], + source_batch: None, + source_hashes: vec![], + source_offset: 0, batch_size, num_partitions, - next_partition: 0, + finished: false, } } - fn start_output(&mut self, batch: &RecordBatch) -> Result<()> { - self.partition_batches.clear(); - self.remaining_rows = 0; + fn start_output(&mut self, batch: RecordBatch, hashes: Vec) -> Result<()> { + if batch.num_rows() != hashes.len() { + return internal_err!( + "partial aggregate output has {} rows, but {} hashes", + batch.num_rows(), + hashes.len() + ); + } + + self.source_batch = Some(batch); + self.source_hashes = hashes; + self.source_offset = 0; + Ok(()) + } + + fn push_batch(&mut self, batch: &RecordBatch, hashes: &[u64]) -> Result<()> { + if batch.num_rows() != hashes.len() { + return internal_err!( + "partial aggregate output has {} rows, but {} hashes", + batch.num_rows(), + hashes.len() + ); + } + + for indices in &mut self.partition_indices { + indices.clear(); + } + + if self.num_partitions.is_power_of_two() { + let mask = self.num_partitions - 1; + self.push_partition_indices(hashes, |hash| (hash as usize) & mask)?; + } else { + let num_partitions = self.num_partitions; + self.push_partition_indices(hashes, |hash| (hash as usize) % num_partitions)?; + } + + for partition in 0..self.num_partitions { + let indices = &self.partition_indices[partition]; + if indices.is_empty() { + continue; + } - for group_indices in &self.partition_group_indices { - let indices: PrimitiveArray = group_indices.clone().into(); + let mut indices_builder = UInt32Builder::with_capacity(indices.len()); + indices_builder.append_slice(indices); + let indices: PrimitiveArray = indices_builder.finish(); let columns = take_arrays(batch.columns(), &indices, None)?; let partition_batch = RecordBatch::try_new(Arc::clone(&batch.schema()), columns)?; - self.remaining_rows += partition_batch.num_rows(); - self.partition_batches.push(partition_batch); + self.coalescers[partition].push_batch(partition_batch)?; + while let Some(batch) = self.coalescers[partition].next_completed_batch() { + self.pending.push_back(add_partitioned_aggregation_metadata( + batch, + partition, + self.num_partitions, + )); + } } - self.offsets.clear(); - self.offsets.resize(self.num_partitions, 0); - self.next_partition = 0; Ok(()) } - fn next_batch(&mut self) -> Option { - if self.partition_batches.is_empty() { - return None; + fn push_partition_indices( + &mut self, + hashes: &[u64], + compute_partition: F, + ) -> Result<()> + where + F: Fn(u64) -> usize, + { + for (row, hash) in hashes.iter().enumerate() { + let row = u32::try_from(row).map_err(|_| { + datafusion_common::internal_datafusion_err!( + "partitioned aggregate row index exceeds u32::MAX" + ) + })?; + self.partition_indices[compute_partition(*hash)].push(row); } + Ok(()) + } - for _ in 0..self.num_partitions { - let partition = self.next_partition; - self.next_partition = (self.next_partition + 1) % self.num_partitions; + fn push_next_source_batch(&mut self) -> Result { + let Some(source_batch) = self.source_batch.as_ref() else { + return Ok(false); + }; + if self.source_offset >= source_batch.num_rows() { + self.source_batch = None; + self.source_hashes.clear(); + self.source_offset = 0; + return Ok(false); + } - let batch = &self.partition_batches[partition]; - let offset = self.offsets[partition]; - let partition_rows = batch.num_rows(); - if offset >= partition_rows { - continue; + let len = self + .batch_size + .min(source_batch.num_rows() - self.source_offset); + let batch = source_batch.slice(self.source_offset, len); + let hashes = + self.source_hashes[self.source_offset..self.source_offset + len].to_vec(); + self.source_offset += len; + self.push_batch(&batch, &hashes)?; + Ok(true) + } + + fn next_batch( + &mut self, + finish_when_source_done: bool, + ) -> Result> { + if let Some(batch) = self.pending.pop_front() { + return Ok(Some(batch)); + } + + while self.push_next_source_batch()? { + if let Some(batch) = self.pending.pop_front() { + return Ok(Some(batch)); } + } - let len = self.batch_size.min(partition_rows - offset); - let output = add_partitioned_aggregation_metadata( - batch.slice(offset, len), - partition, - self.num_partitions, - ); - self.offsets[partition] += len; - self.remaining_rows -= len; - if self.remaining_rows == 0 { - self.partition_batches.clear(); + if finish_when_source_done { + self.finish()?; + } + + Ok(self.pending.pop_front()) + } + + fn finish(&mut self) -> Result<()> { + if self.finished { + return Ok(()); + } + + for partition in 0..self.num_partitions { + self.coalescers[partition].finish()?; + while let Some(batch) = self.coalescers[partition].next_completed_batch() { + self.pending.push_back(add_partitioned_aggregation_metadata( + batch, + partition, + self.num_partitions, + )); } - return Some(output); } + self.finished = true; + Ok(()) + } - self.partition_batches.clear(); - None + fn has_source(&self) -> bool { + self.source_batch.is_some() } fn is_done(&self) -> bool { - self.partition_batches.is_empty() + self.pending.is_empty() && self.source_batch.is_none() && self.finished } fn memory_size(&self) -> usize { - self.partition_group_indices.allocated_size() + self.partition_indices.allocated_size() + self - .partition_group_indices + .partition_indices .iter() .map(VecAllocExt::allocated_size) .sum::() + self - .partition_batches + .source_batch + .as_ref() + .map_or(0, RecordBatch::get_array_memory_size) + + self + .pending .iter() .map(RecordBatch::get_array_memory_size) .sum::() - + self.offsets.allocated_size() + + self.source_hashes.allocated_size() } } @@ -270,7 +376,7 @@ enum PartialHashAggregateState { /// finish in `Done`. If `Some`, partial skip has triggered and the /// stream will move to `SkippingAggregation` after these accumulated /// groups are emitted. - skip_hash_table: Option>, + skip_hash_table: Option>>, }, SkippingAggregation { hash_table: AggregateHashTable, @@ -428,15 +534,15 @@ impl PartialHashAggregateStream { )?; let repartition_state = if can_repartition_in_partial(agg, context.as_ref(), &hash_table) { - Some(PartialRepartitionState::new( + Some(Box::new(PartialRepartitionState::new( + &schema, context.session_config().target_partitions(), batch_size, - )) + ))) } else { None }; - let can_skip_aggregation = - agg.group_by.is_single() && hash_table.can_skip_aggregation(); + let can_skip_aggregation = hash_table.can_skip_aggregation(); let skip_aggregation_probe = if can_skip_aggregation { let options = &context.session_config().options().execution; let probe_ratio_threshold = @@ -472,7 +578,9 @@ impl PartialHashAggregateStream { skip_aggregation_probe, repartition_state, group_values_soft_limit: agg.limit_options().map(|config| config.limit()), - state: Some(PartialHashAggregateState::ReadingInput { hash_table }), + state: Some(Box::new(PartialHashAggregateState::ReadingInput { + hash_table, + })), }) } @@ -500,24 +608,12 @@ impl PartialHashAggregateStream { .is_some_and(|probe| probe.should_skip()) } - fn append_new_groups_to_partitions( - &mut self, - hash_table: &AggregateHashTable, - ) -> Result<()> { - let Some(repartition_state) = self.repartition_state.as_mut() else { - return Ok(()); - }; - - hash_table - .append_new_group_partitions(&mut repartition_state.partition_group_indices) - } - fn memory_size(&self, hash_table: &AggregateHashTable) -> usize { hash_table.memory_size() + self .repartition_state .as_ref() - .map_or(0, PartialRepartitionState::memory_size) + .map_or(0, |state| state.memory_size()) } fn resize_reservation( @@ -530,19 +626,20 @@ impl PartialHashAggregateStream { fn next_output_batch( &mut self, hash_table: &mut AggregateHashTable, + finish_when_source_done: bool, ) -> Result> { let Some(repartition_state) = self.repartition_state.as_mut() else { return hash_table.next_output_batch(); }; - if repartition_state.is_done() { - let Some(batch) = hash_table.materialize_output_batch()? else { - return Ok(None); - }; - repartition_state.start_output(&batch)?; + if !repartition_state.has_source() + && let Some((batch, hashes)) = + hash_table.materialize_output_batch_with_hashes()? + { + repartition_state.start_output(batch, hashes)?; } - Ok(repartition_state.next_batch()) + repartition_state.next_batch(finish_when_source_done) } fn is_output_done(&self, hash_table: &AggregateHashTable) -> bool { @@ -550,7 +647,7 @@ impl PartialHashAggregateStream { && self .repartition_state .as_ref() - .is_none_or(PartialRepartitionState::is_done) + .is_none_or(|state| state.is_done()) } fn start_output( @@ -598,15 +695,6 @@ impl PartialHashAggregateStream { )); } - if let Err(e) = - self.append_new_groups_to_partitions(original_state.hash_table()) - { - return ControlFlow::Break(( - Poll::Ready(Some(Err(e))), - original_state, - )); - } - if self.hit_soft_group_limit(original_state.hash_table()) { let timer = elapsed_compute.timer(); let result = self.start_output(original_state.hash_table_mut(), true); @@ -640,30 +728,10 @@ impl PartialHashAggregateStream { // True branch: a decision has been made to skip partial aggregation. if self.should_skip_aggregation() { let timer = elapsed_compute.timer(); - if original_state.hash_table().can_repartition_in_partial() { - let result = original_state.hash_table_mut().skip_hash_group_by(); - timer.done(); - - if let Err(e) = result { - return ControlFlow::Break(( - Poll::Ready(Some(Err(e))), - original_state, - )); - } - - if let Err(e) = - self.resize_reservation(original_state.hash_table()) - { - return ControlFlow::Break(( - Poll::Ready(Some(Err(e))), - original_state, - )); - } - - return ControlFlow::Continue(original_state); - } - - let result = match original_state.hash_table().partial_skip_table() { + let result = match original_state + .hash_table() + .partial_skip_table(self.repartition_state.is_some()) + { Ok(skip_hash_table) => self .start_output(original_state.hash_table_mut(), false) .map(|()| skip_hash_table), @@ -685,7 +753,7 @@ impl PartialHashAggregateStream { return ControlFlow::Continue( PartialHashAggregateState::ProducingOutput { hash_table, - skip_hash_table: Some(skip_hash_table), + skip_hash_table: Some(Box::new(skip_hash_table)), }, ); } @@ -757,7 +825,15 @@ impl PartialHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); - let result = self.next_output_batch(original_state.hash_table_mut()); + let finish_when_source_done = !matches!( + &original_state, + PartialHashAggregateState::ProducingOutput { + skip_hash_table: Some(_), + .. + } + ); + let result = self + .next_output_batch(original_state.hash_table_mut(), finish_when_source_done); timer.done(); match result { @@ -770,9 +846,9 @@ impl PartialHashAggregateStream { PartialHashAggregateState::ProducingOutput { skip_hash_table: Some(hash_table), .. - } => { - PartialHashAggregateState::SkippingAggregation { hash_table } - } + } => PartialHashAggregateState::SkippingAggregation { + hash_table: *hash_table, + }, PartialHashAggregateState::ProducingOutput { skip_hash_table: None, .. @@ -796,7 +872,9 @@ impl PartialHashAggregateStream { PartialHashAggregateState::ProducingOutput { skip_hash_table: Some(hash_table), .. - } => PartialHashAggregateState::SkippingAggregation { hash_table }, + } => PartialHashAggregateState::SkippingAggregation { + hash_table: *hash_table, + }, PartialHashAggregateState::ProducingOutput { skip_hash_table: None, .. @@ -809,6 +887,26 @@ impl PartialHashAggregateStream { } } + fn push_skip_batch( + &mut self, + batch: RecordBatch, + hashes: &[u64], + ) -> Result> { + let Some(repartition_state) = self.repartition_state.as_mut() else { + return Ok(Some(batch)); + }; + repartition_state.push_batch(&batch, hashes)?; + repartition_state.next_batch(false) + } + + fn finish_repartition(&mut self) -> Result> { + let Some(repartition_state) = self.repartition_state.as_mut() else { + return Ok(None); + }; + repartition_state.finish()?; + repartition_state.next_batch(true) + } + /// Handle SkippingAggregation state - convert raw input directly to partial states. /// /// See comments at `poll_next()` for details. @@ -833,23 +931,55 @@ impl PartialHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); - let result = match &mut original_state { - PartialHashAggregateState::SkippingAggregation { hash_table } => { - hash_table.convert_batch_to_state(&batch) + if self.repartition_state.is_some() { + let result = match &mut original_state { + PartialHashAggregateState::SkippingAggregation { hash_table } => { + hash_table.convert_batch_to_state_with_hashes(&batch) + } + _ => unreachable!("expected skipping aggregation state"), + }; + timer.done(); + + match result { + Ok((batch, hashes)) => match self.push_skip_batch(batch, &hashes) + { + Ok(Some(batch)) => ControlFlow::Break(( + Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))), + original_state, + )), + Ok(None) => ControlFlow::Continue(original_state), + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )), + }, + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )), } - _ => unreachable!("expected skipping aggregation state"), - }; - timer.done(); + } else { + let result = match &mut original_state { + PartialHashAggregateState::SkippingAggregation { hash_table } => { + hash_table.convert_batch_to_state(&batch) + } + _ => unreachable!("expected skipping aggregation state"), + }; + timer.done(); - match result { - Ok(batch) => ControlFlow::Break(( - Poll::Ready(Some( - Ok(batch.record_output(&self.baseline_metrics)), + match result { + Ok(batch) => ControlFlow::Break(( + Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))), + original_state, + )), + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, )), - original_state, - )), - Err(e) => { - ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) } } } @@ -859,7 +989,18 @@ impl PartialHashAggregateStream { Poll::Ready(None) => { let input_schema = self.input.schema(); self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); - ControlFlow::Continue(PartialHashAggregateState::Done) + match self.finish_repartition() { + Ok(Some(batch)) => ControlFlow::Break(( + Poll::Ready(Some( + Ok(batch.record_output(&self.baseline_metrics)), + )), + original_state, + )), + Ok(None) => ControlFlow::Continue(PartialHashAggregateState::Done), + Err(e) => { + ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) + } + } } } } @@ -923,7 +1064,7 @@ impl Stream for PartialHashAggregateStream { cx: &mut Context<'_>, ) -> Poll> { loop { - let cur_state = self + let cur_state = *self .state .take() .expect("PartialHashAggregateStream state should not be None"); @@ -940,18 +1081,18 @@ impl Stream for PartialHashAggregateStream { } state @ PartialHashAggregateState::Done => { let _ = self.reservation.try_resize(0); - self.state = Some(state); + self.state = Some(Box::new(state)); return Poll::Ready(None); } }; match next_state { ControlFlow::Continue(next_state) => { - self.state = Some(next_state); + self.state = Some(Box::new(next_state)); continue; } ControlFlow::Break((poll, next_state)) => { - self.state = Some(next_state); + self.state = Some(Box::new(next_state)); return poll; } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f927a50badfb5..2297b63f6f7c3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -535,7 +535,7 @@ enum StreamType { /// [`StreamType::PartialHash`] and [`StreamType::FinalHash`] /// /// See issue for details: - GroupedHash(GroupedHashAggregateStream), + GroupedHash(Box), /// Grouped TopK aggregate stream. /// Input output scheme: initial input -> final result /// @@ -550,7 +550,7 @@ impl From for SendableRecordBatchStream { StreamType::AggregateStream(stream) => Box::pin(stream), StreamType::PartialHash(stream) => Box::pin(stream), StreamType::FinalHash(stream) => Box::pin(stream), - StreamType::GroupedHash(stream) => Box::pin(stream), + StreamType::GroupedHash(stream) => Box::pin(*stream), StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } } @@ -1035,9 +1035,9 @@ impl AggregateExec { } // Execution paths that have not been migrated use the fallback implementation - Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( - self, context, partition, - )?)) + Ok(StreamType::GroupedHash(Box::new( + GroupedHashAggregateStream::new(self, context, partition)?, + ))) } fn should_use_partial_hash_stream(&self, context: &TaskContext) -> bool { From d82f05bdbf6939cee9bde932eb427ff1d08f4899 Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 30 Jun 2026 20:26:39 +0800 Subject: [PATCH 16/17] Optimize local partial repartition --- .../aggregate_hash_table/partial_table.rs | 23 +++-- .../src/aggregates/hash_aggregate.rs | 89 +++++++++++-------- 2 files changed, 60 insertions(+), 52 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 921fe429c7823..1156bec41f651 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, new_null_array}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_expr::EmitTo; @@ -149,14 +150,10 @@ impl AggregateHashTable { /// row-by-row. pub(in crate::aggregates) fn partial_skip_table( &self, - skip_hash_group_by: bool, ) -> Result> { let state = self.state.building(); let group_schema = state.group_by.group_schema(&self.input_schema)?; - let mut group_values = new_group_values(group_schema, &GroupOrdering::None)?; - if skip_hash_group_by && group_values.support_partial_repartition() { - group_values.skip_hash_group_by()?; - } + let group_values = new_group_values(group_schema, &GroupOrdering::None)?; let accumulators = state .accumulators .iter() @@ -322,21 +319,21 @@ impl AggregateHashTable { )?) } - pub(in crate::aggregates) fn convert_batch_to_state_with_hashes( - &mut self, + pub(in crate::aggregates) fn convert_batch_to_state_with_hashes<'a>( + &'a mut self, batch: &RecordBatch, - ) -> Result<(RecordBatch, Vec)> { + ) -> Result<(RecordBatch, &'a [u64])> { let evaluated_batch = self.evaluate_batch(batch)?; let output_group_values = Self::output_group_values(&evaluated_batch)?; let state = self.state.building_mut(); - state.group_values.intern( + state.batch_hashes.clear(); + state.batch_hashes.resize(batch.num_rows(), 0); + create_hashes( &output_group_values, - &mut state.batch_group_indices, + &crate::aggregates::AGGREGATION_HASH_SEED, &mut state.batch_hashes, - &mut state.new_group_rows, )?; - let hashes = state.batch_hashes.clone(); let mut output = output_group_values; for (acc, values) in state @@ -349,7 +346,7 @@ impl AggregateHashTable { Ok(( RecordBatch::try_new(Arc::clone(&self.output_schema), output)?, - hashes, + &state.batch_hashes, )) } diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index a729c0a169077..9b61eb7d3b540 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -30,10 +30,9 @@ use std::ops::ControlFlow; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::array::{PrimitiveArray, RecordBatch, UInt32Builder}; +use arrow::array::{RecordBatch, UInt32Builder}; use arrow::compute::take_arrays; use arrow::datatypes::SchemaRef; -use arrow::datatypes::UInt32Type; use datafusion_common::{Result, internal_err}; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::proxy::VecAllocExt; @@ -202,30 +201,43 @@ impl PartialRepartitionState { ); } + if hashes.is_empty() { + return Ok(()); + } + for indices in &mut self.partition_indices { indices.clear(); } if self.num_partitions.is_power_of_two() { let mask = self.num_partitions - 1; - self.push_partition_indices(hashes, |hash| (hash as usize) & mask)?; + Self::push_partition_indices(hashes, &mut self.partition_indices, |hash| { + (hash as usize) & mask + })?; } else { let num_partitions = self.num_partitions; - self.push_partition_indices(hashes, |hash| (hash as usize) % num_partitions)?; + Self::push_partition_indices(hashes, &mut self.partition_indices, |hash| { + (hash as usize) % num_partitions + })?; + } + + let mut indices_builder = UInt32Builder::with_capacity(hashes.len()); + for indices in &self.partition_indices { + indices_builder.append_slice(indices); } + let indices = indices_builder.finish(); + let columns = take_arrays(batch.columns(), &indices, None)?; + let reordered_batch = RecordBatch::try_new(Arc::clone(&batch.schema()), columns)?; + let mut offset = 0; for partition in 0..self.num_partitions { - let indices = &self.partition_indices[partition]; - if indices.is_empty() { + let len = self.partition_indices[partition].len(); + if len == 0 { continue; } - let mut indices_builder = UInt32Builder::with_capacity(indices.len()); - indices_builder.append_slice(indices); - let indices: PrimitiveArray = indices_builder.finish(); - let columns = take_arrays(batch.columns(), &indices, None)?; - let partition_batch = - RecordBatch::try_new(Arc::clone(&batch.schema()), columns)?; + let partition_batch = reordered_batch.slice(offset, len); + offset += len; self.coalescers[partition].push_batch(partition_batch)?; while let Some(batch) = self.coalescers[partition].next_completed_batch() { self.pending.push_back(add_partitioned_aggregation_metadata( @@ -240,8 +252,8 @@ impl PartialRepartitionState { } fn push_partition_indices( - &mut self, hashes: &[u64], + partition_indices: &mut [Vec], compute_partition: F, ) -> Result<()> where @@ -253,7 +265,7 @@ impl PartialRepartitionState { "partitioned aggregate row index exceeds u32::MAX" ) })?; - self.partition_indices[compute_partition(*hash)].push(row); + partition_indices[compute_partition(*hash)].push(row); } Ok(()) } @@ -269,14 +281,15 @@ impl PartialRepartitionState { return Ok(false); } - let len = self - .batch_size - .min(source_batch.num_rows() - self.source_offset); - let batch = source_batch.slice(self.source_offset, len); - let hashes = - self.source_hashes[self.source_offset..self.source_offset + len].to_vec(); + let offset = self.source_offset; + let len = self.batch_size.min(source_batch.num_rows() - offset); + let batch = source_batch.slice(offset, len); self.source_offset += len; - self.push_batch(&batch, &hashes)?; + + let source_hashes = std::mem::take(&mut self.source_hashes); + let result = self.push_batch(&batch, &source_hashes[offset..offset + len]); + self.source_hashes = source_hashes; + result?; Ok(true) } @@ -728,10 +741,7 @@ impl PartialHashAggregateStream { // True branch: a decision has been made to skip partial aggregation. if self.should_skip_aggregation() { let timer = elapsed_compute.timer(); - let result = match original_state - .hash_table() - .partial_skip_table(self.repartition_state.is_some()) - { + let result = match original_state.hash_table().partial_skip_table() { Ok(skip_hash_table) => self .start_output(original_state.hash_table_mut(), false) .map(|()| skip_hash_table), @@ -941,20 +951,21 @@ impl PartialHashAggregateStream { timer.done(); match result { - Ok((batch, hashes)) => match self.push_skip_batch(batch, &hashes) - { - Ok(Some(batch)) => ControlFlow::Break(( - Poll::Ready(Some(Ok( - batch.record_output(&self.baseline_metrics) - ))), - original_state, - )), - Ok(None) => ControlFlow::Continue(original_state), - Err(e) => ControlFlow::Break(( - Poll::Ready(Some(Err(e))), - original_state, - )), - }, + Ok((batch, hashes)) => { + match self.push_skip_batch(batch, hashes) { + Ok(Some(batch)) => ControlFlow::Break(( + Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))), + original_state, + )), + Ok(None) => ControlFlow::Continue(original_state), + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )), + } + } Err(e) => ControlFlow::Break(( Poll::Ready(Some(Err(e))), original_state, From bc70b9de3d9ff5695b50fc6fff3b02e883b6fe0b Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 30 Jun 2026 20:53:02 +0800 Subject: [PATCH 17/17] Simplify partial local repartition --- datafusion/common/src/config.rs | 4 + .../aggregate_hash_table/partial_table.rs | 16 +- .../src/aggregates/group_values/mod.rs | 12 +- .../group_values/multi_group_by/mod.rs | 42 +--- .../src/aggregates/hash_aggregate.rs | 209 +++++++++--------- 5 files changed, 113 insertions(+), 170 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index cc263dfe3e619..ffa7ea6392af6 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -912,6 +912,10 @@ config_namespace! { /// aggregation ratio check and trying to switch to skipping aggregation mode pub skip_partial_aggregation_probe_rows_threshold: usize, default = 100_000 + /// Should partial hash aggregation repartition and coalesce output locally + /// before sending it to the upstream repartition operator. + pub enable_partial_aggregation_local_repartition: bool, default = true + /// Should DataFusion use row number estimates at the input to decide /// whether increasing parallelism is beneficial or not. By default, /// only exact row numbers (not estimates) are used for this decision. diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 1156bec41f651..805c5c709de98 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -131,18 +131,10 @@ impl AggregateHashTable { pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool { let state = self.state.building(); - state.group_values.support_partial_repartition() - || state - .accumulators - .iter() - .all(|acc| acc.supports_convert_to_state()) - } - - pub(in crate::aggregates) fn can_repartition_in_partial(&self) -> bool { - self.state - .building() - .group_values - .support_partial_repartition() + state + .accumulators + .iter() + .all(|acc| acc.supports_convert_to_state()) } /// In skip-partial-aggregation optimization, when a decision has been made to skip diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 73c2c1321cf65..a49ad87676505 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -24,7 +24,7 @@ use arrow::array::types::{ }; use arrow::array::{ArrayRef, downcast_primitive}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; -use datafusion_common::{Result, internal_err}; +use datafusion_common::Result; use datafusion_expr::EmitTo; @@ -119,16 +119,6 @@ pub trait GroupValues: Send { /// Emits the group values fn emit(&mut self, emit_to: EmitTo) -> Result>; - /// Returns true if this group value storage supports partial repartition. - fn support_partial_repartition(&self) -> bool { - false - } - - /// Enable append-only grouping for partial repartition. - fn skip_hash_group_by(&mut self) -> Result<()> { - internal_err!("GroupValues does not support skip hash group by") - } - /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, num_rows: usize); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index c3e4968acf6e8..abfe70b53a7c6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -41,7 +41,7 @@ use arrow::datatypes::{ }; use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{Result, internal_datafusion_err, internal_err, not_impl_err}; +use datafusion_common::{Result, internal_datafusion_err, not_impl_err}; use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; @@ -220,9 +220,6 @@ pub struct GroupValuesColumn { /// Random state for creating hashes random_state: RandomState, - - /// Whether each input row should be appended as a new group directly. - skip_hash_group_by: bool, } /// Buffers to store intermediate results in `vectorized_append` @@ -286,7 +283,6 @@ impl GroupValuesColumn { group_values, hashes_buffer: Default::default(), random_state: crate::aggregates::AGGREGATION_HASH_SEED, - skip_hash_group_by: false, }) } @@ -305,24 +301,6 @@ impl GroupValuesColumn { Ok(v) } - fn compute_hashes( - &self, - cols: &[ArrayRef], - groups: &mut Vec, - hashes: &mut Vec, - new_group_rows: &mut Vec, - ) -> Result<()> { - let num_rows = cols.first().map_or(0, |array| array.len()); - - groups.clear(); - hashes.clear(); - hashes.resize(num_rows, 0); - create_hashes(cols, &self.random_state, hashes)?; - new_group_rows.clear(); - - Ok(()) - } - // ======================================================================== // Scalarized intern // ======================================================================== @@ -1106,10 +1084,6 @@ impl GroupValues for GroupValuesColumn { hashes: &mut Vec, new_group_rows: &mut Vec, ) -> Result<()> { - if self.skip_hash_group_by { - return self.compute_hashes(cols, groups, hashes, new_group_rows); - } - let n_rows = cols[0].len(); hashes.clear(); hashes.resize(n_rows, 0); @@ -1139,10 +1113,6 @@ impl GroupValues for GroupValuesColumn { } fn emit(&mut self, emit_to: EmitTo) -> Result> { - if self.skip_hash_group_by && matches!(emit_to, EmitTo::First(_)) { - return internal_err!("skip hash group by does not support EmitTo::First"); - } - let mut output = match emit_to { EmitTo::All => { // Replace the column builders with a fresh set so the @@ -1270,16 +1240,6 @@ impl GroupValues for GroupValuesColumn { self.vectorized_operation_buffers.clear(); } } - - fn support_partial_repartition(&self) -> bool { - true - } - - fn skip_hash_group_by(&mut self) -> Result<()> { - self.map.clear(); - self.skip_hash_group_by = true; - Ok(()) - } } /// Returns true if [`GroupValuesColumn`] supported for the specified schema diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 9b61eb7d3b540..5c0c2d0f8e7d3 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -136,6 +136,18 @@ pub(crate) struct PartialHashAggregateStream { /// Local repartition and coalesce state for partial output. repartition_state: Option>, + /// Materialized partial output being sliced into local repartition. + repartition_source_batch: Option, + + /// Hashes for rows in [`Self::repartition_source_batch`]. + repartition_source_hashes: Vec, + + /// Next row offset to push from [`Self::repartition_source_batch`]. + repartition_source_offset: usize, + + /// Target output batch size. + batch_size: usize, + /// Optional soft limit on the number of groups to accumulate before output. /// /// Invariant: when this is `Some(..)`, the accumulators inside `hash_table` must @@ -151,10 +163,6 @@ struct PartialRepartitionState { coalescers: Vec, pending: VecDeque, partition_indices: Vec>, - source_batch: Option, - source_hashes: Vec, - source_offset: usize, - batch_size: usize, num_partitions: usize, finished: bool, } @@ -168,30 +176,11 @@ impl PartialRepartitionState { coalescers, pending: VecDeque::new(), partition_indices: vec![vec![]; num_partitions], - source_batch: None, - source_hashes: vec![], - source_offset: 0, - batch_size, num_partitions, finished: false, } } - fn start_output(&mut self, batch: RecordBatch, hashes: Vec) -> Result<()> { - if batch.num_rows() != hashes.len() { - return internal_err!( - "partial aggregate output has {} rows, but {} hashes", - batch.num_rows(), - hashes.len() - ); - } - - self.source_batch = Some(batch); - self.source_hashes = hashes; - self.source_offset = 0; - Ok(()) - } - fn push_batch(&mut self, batch: &RecordBatch, hashes: &[u64]) -> Result<()> { if batch.num_rows() != hashes.len() { return internal_err!( @@ -270,50 +259,6 @@ impl PartialRepartitionState { Ok(()) } - fn push_next_source_batch(&mut self) -> Result { - let Some(source_batch) = self.source_batch.as_ref() else { - return Ok(false); - }; - if self.source_offset >= source_batch.num_rows() { - self.source_batch = None; - self.source_hashes.clear(); - self.source_offset = 0; - return Ok(false); - } - - let offset = self.source_offset; - let len = self.batch_size.min(source_batch.num_rows() - offset); - let batch = source_batch.slice(offset, len); - self.source_offset += len; - - let source_hashes = std::mem::take(&mut self.source_hashes); - let result = self.push_batch(&batch, &source_hashes[offset..offset + len]); - self.source_hashes = source_hashes; - result?; - Ok(true) - } - - fn next_batch( - &mut self, - finish_when_source_done: bool, - ) -> Result> { - if let Some(batch) = self.pending.pop_front() { - return Ok(Some(batch)); - } - - while self.push_next_source_batch()? { - if let Some(batch) = self.pending.pop_front() { - return Ok(Some(batch)); - } - } - - if finish_when_source_done { - self.finish()?; - } - - Ok(self.pending.pop_front()) - } - fn finish(&mut self) -> Result<()> { if self.finished { return Ok(()); @@ -333,12 +278,8 @@ impl PartialRepartitionState { Ok(()) } - fn has_source(&self) -> bool { - self.source_batch.is_some() - } - fn is_done(&self) -> bool { - self.pending.is_empty() && self.source_batch.is_none() && self.finished + self.pending.is_empty() && self.finished } fn memory_size(&self) -> usize { @@ -348,16 +289,11 @@ impl PartialRepartitionState { .iter() .map(VecAllocExt::allocated_size) .sum::() - + self - .source_batch - .as_ref() - .map_or(0, RecordBatch::get_array_memory_size) + self .pending .iter() .map(RecordBatch::get_array_memory_size) .sum::() - + self.source_hashes.allocated_size() } } @@ -425,17 +361,6 @@ impl PartialHashAggregateState { } } -fn can_repartition_in_partial( - agg: &AggregateExec, - context: &TaskContext, - hash_table: &AggregateHashTable, -) -> bool { - !agg.group_by.is_empty() - && context.session_config().repartition_aggregations() - && context.session_config().target_partitions() > 1 - && hash_table.can_repartition_in_partial() -} - /// Hash aggregation is implemented in two stages: partial and final. This /// stream implements the final stage. /// @@ -545,19 +470,22 @@ impl PartialHashAggregateStream { Arc::clone(&schema), batch_size, )?; - let repartition_state = - if can_repartition_in_partial(agg, context.as_ref(), &hash_table) { - Some(Box::new(PartialRepartitionState::new( - &schema, - context.session_config().target_partitions(), - batch_size, - ))) - } else { - None - }; + let options = &context.session_config().options().execution; + let repartition_state = if options.enable_partial_aggregation_local_repartition + && !agg.group_by.is_empty() + && context.session_config().repartition_aggregations() + && context.session_config().target_partitions() > 1 + { + Some(Box::new(PartialRepartitionState::new( + &schema, + context.session_config().target_partitions(), + batch_size, + ))) + } else { + None + }; let can_skip_aggregation = hash_table.can_skip_aggregation(); let skip_aggregation_probe = if can_skip_aggregation { - let options = &context.session_config().options().execution; let probe_ratio_threshold = options.skip_partial_aggregation_probe_ratio_threshold; // A threshold >= 1.0 means the ratio (num_groups / input_rows) can @@ -590,6 +518,10 @@ impl PartialHashAggregateStream { reduction_factor, skip_aggregation_probe, repartition_state, + repartition_source_batch: None, + repartition_source_hashes: vec![], + repartition_source_offset: 0, + batch_size, group_values_soft_limit: agg.limit_options().map(|config| config.limit()), state: Some(Box::new(PartialHashAggregateState::ReadingInput { hash_table, @@ -627,6 +559,11 @@ impl PartialHashAggregateStream { .repartition_state .as_ref() .map_or(0, |state| state.memory_size()) + + self + .repartition_source_batch + .as_ref() + .map_or(0, RecordBatch::get_array_memory_size) + + self.repartition_source_hashes.allocated_size() } fn resize_reservation( @@ -641,22 +578,82 @@ impl PartialHashAggregateStream { hash_table: &mut AggregateHashTable, finish_when_source_done: bool, ) -> Result> { - let Some(repartition_state) = self.repartition_state.as_mut() else { + if self.repartition_state.is_none() { return hash_table.next_output_batch(); - }; + } - if !repartition_state.has_source() + if self.repartition_source_batch.is_none() && let Some((batch, hashes)) = hash_table.materialize_output_batch_with_hashes()? { - repartition_state.start_output(batch, hashes)?; + self.repartition_source_batch = Some(batch); + self.repartition_source_hashes = hashes; + self.repartition_source_offset = 0; + } + + self.next_repartition_batch(finish_when_source_done) + } + + fn next_repartition_batch( + &mut self, + finish_when_source_done: bool, + ) -> Result> { + if let Some(batch) = self.next_repartition_pending_batch() { + return Ok(Some(batch)); + } + + while self.push_next_repartition_source_batch()? { + if let Some(batch) = self.next_repartition_pending_batch() { + return Ok(Some(batch)); + } + } + + if finish_when_source_done + && let Some(repartition_state) = self.repartition_state.as_mut() + { + repartition_state.finish()?; + } + + Ok(self.next_repartition_pending_batch()) + } + + fn push_next_repartition_source_batch(&mut self) -> Result { + let Some(source_batch) = self.repartition_source_batch.as_ref() else { + return Ok(false); + }; + if self.repartition_source_offset >= source_batch.num_rows() { + self.repartition_source_batch = None; + self.repartition_source_hashes.clear(); + self.repartition_source_offset = 0; + return Ok(false); } - repartition_state.next_batch(finish_when_source_done) + let offset = self.repartition_source_offset; + let len = self.batch_size.min(source_batch.num_rows() - offset); + let batch = source_batch.slice(offset, len); + self.repartition_source_offset += len; + + let source_hashes = std::mem::take(&mut self.repartition_source_hashes); + let result = match self.repartition_state.as_mut() { + Some(repartition_state) => { + repartition_state.push_batch(&batch, &source_hashes[offset..offset + len]) + } + None => Ok(()), + }; + self.repartition_source_hashes = source_hashes; + result?; + Ok(true) + } + + fn next_repartition_pending_batch(&mut self) -> Option { + self.repartition_state + .as_mut() + .and_then(|state| state.pending.pop_front()) } fn is_output_done(&self, hash_table: &AggregateHashTable) -> bool { hash_table.is_done() + && self.repartition_source_batch.is_none() && self .repartition_state .as_ref() @@ -906,7 +903,7 @@ impl PartialHashAggregateStream { return Ok(Some(batch)); }; repartition_state.push_batch(&batch, hashes)?; - repartition_state.next_batch(false) + Ok(self.next_repartition_pending_batch()) } fn finish_repartition(&mut self) -> Result> { @@ -914,7 +911,7 @@ impl PartialHashAggregateStream { return Ok(None); }; repartition_state.finish()?; - repartition_state.next_batch(true) + Ok(self.next_repartition_pending_batch()) } /// Handle SkippingAggregation state - convert raw input directly to partial states.