diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index cc263dfe3e619..ffa7ea6392af6 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -912,6 +912,10 @@ config_namespace! { /// aggregation ratio check and trying to switch to skipping aggregation mode pub skip_partial_aggregation_probe_rows_threshold: usize, default = 100_000 + /// Should partial hash aggregation repartition and coalesce output locally + /// before sending it to the upstream repartition operator. + pub enable_partial_aggregation_local_repartition: bool, default = true + /// Should DataFusion use row number estimates at the input to decide /// whether increasing parallelism is beneficial or not. By default, /// only exact row numbers (not estimates) are used for this decision. diff --git a/datafusion/physical-plan/benches/dictionary_group_values.rs b/datafusion/physical-plan/benches/dictionary_group_values.rs index ded52aebd1100..f853105e2aa71 100644 --- a/datafusion/physical-plan/benches/dictionary_group_values.rs +++ b/datafusion/physical-plan/benches/dictionary_group_values.rs @@ -109,10 +109,17 @@ fn bench_intern_emit(c: &mut Criterion) { new_group_values(schema.clone(), &GroupOrdering::None) .unwrap(), Vec::::with_capacity(size), + Vec::::with_capacity(size), ) }, - |(gv, groups)| { - gv.intern(std::slice::from_ref(&array), groups).unwrap(); + |(gv, groups, hashes)| { + gv.intern( + std::slice::from_ref(&array), + groups, + hashes, + &mut vec![], + ) + .unwrap(); black_box(&*groups); black_box(gv.emit(EmitTo::All).unwrap()); }, @@ -154,11 +161,18 @@ fn bench_repeated_intern_emit(c: &mut Criterion) { new_group_values(schema.clone(), &GroupOrdering::None) .unwrap(), Vec::::with_capacity(size), + Vec::::with_capacity(size), ) }, - |(gv, groups)| { + |(gv, groups, hashes)| { for arr in &batches { - gv.intern(std::slice::from_ref(arr), groups).unwrap(); + gv.intern( + std::slice::from_ref(arr), + groups, + hashes, + &mut vec![], + ) + .unwrap(); black_box(&*groups); } black_box(gv.emit(EmitTo::All).unwrap()); diff --git a/datafusion/physical-plan/benches/multi_group_by.rs b/datafusion/physical-plan/benches/multi_group_by.rs index 92d0448775599..71769d68b83ab 100644 --- a/datafusion/physical-plan/benches/multi_group_by.rs +++ b/datafusion/physical-plan/benches/multi_group_by.rs @@ -95,9 +95,10 @@ fn bench_intern( batches: &[Vec], groups: &mut Vec, ) { + let mut hashes = vec![]; for batch in batches { groups.clear(); - gv.intern(batch, groups).unwrap(); + gv.intern(batch, groups, &mut hashes, &mut vec![]).unwrap(); } black_box(&*groups); } diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 719fbe93e5416..a1f7fb9f0d102 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -134,6 +134,9 @@ impl AggregateHashTable { group_by: Arc::clone(&agg.group_by), group_values, batch_group_indices: Default::default(), + batch_hashes: Default::default(), + group_hashes: Default::default(), + new_group_rows: Default::default(), accumulators, }), _mode: PhantomData, @@ -181,8 +184,11 @@ impl AggregateHashTable { acc + state.group_values.size() + state.batch_group_indices.allocated_size() + + state.batch_hashes.allocated_size() + + state.group_hashes.allocated_size() + + state.new_group_rows.allocated_size() } - AggregateHashTableState::OutputtingMaterializedFinal(output) => { + AggregateHashTableState::OutputtingMaterialized(output) => { output.memory_size() } AggregateHashTableState::Done => 0, @@ -212,14 +218,19 @@ impl AggregateHashTable { state.batch_group_indices = Vec::new(); self.state = AggregateHashTableState::Outputting(state); } -} -pub(super) fn emit_to_for_batch_size(batch_size: usize, group_count: usize) -> EmitTo { - debug_assert!(batch_size > 0); - if group_count <= batch_size { - EmitTo::All - } else { - EmitTo::First(batch_size) + pub(super) fn emit_next_materialized_batch( + &mut self, + mut output: MaterializedOutput, + batch_size: usize, + ) -> Option { + let batch = output.next_batch(batch_size); + if output.is_exhausted() { + self.state = AggregateHashTableState::Done; + } else { + self.state = AggregateHashTableState::OutputtingMaterialized(output); + } + batch } } @@ -292,6 +303,15 @@ pub(super) struct AggregateHashTableBuffer { /// accumulator to update that group's aggregate state. pub(super) batch_group_indices: Vec, + /// Hash for each row in the current input batch. + pub(super) batch_hashes: Vec, + + /// Hash for each accumulated group. + pub(super) group_hashes: Vec, + + /// Input rows that created new groups in the current input batch. + pub(super) new_group_rows: Vec, + /// One item per aggregate expression. /// /// Example: `COUNT(x), SUM(y)` creates two items. Each item owns the input @@ -304,24 +324,21 @@ pub(super) enum AggregateHashTableState { Building(AggregateHashTableBuffer), /// Emitting results directly from group keys and aggregate state. Outputting(AggregateHashTableBuffer), - /// Materialize all the output results, and then incrementally output in the `OutputtingMaterializedFinal` state. - /// - /// Note this is a temporary solution until the `GroupValues` issue is solved: - /// Issue: - OutputtingMaterializedFinal(MaterializedFinalOutput), + /// Materialized output rows sliced across output polls. + OutputtingMaterialized(MaterializedOutput), Done, } -/// Fully evaluated final aggregate output and the next row offset to emit. +/// Materialized aggregate output and the next row offset to emit. /// -/// Final aggregate evaluation consumes accumulator state, so final output is -/// materialized once and then sliced to honor `batch_size` across output polls. -pub(super) struct MaterializedFinalOutput { +/// Some output paths can only emit all rows from their backing state at once. +/// The materialized batch is sliced to honor `batch_size` across output polls. +pub(super) struct MaterializedOutput { batch: RecordBatch, offset: usize, } -impl MaterializedFinalOutput { +impl MaterializedOutput { pub(super) fn new(batch: RecordBatch) -> Self { Self { batch, offset: 0 } } @@ -495,8 +512,10 @@ mod tests { use super::*; + // Covers materialized output slicing until all rows are emitted. + // Example: a five-row batch with batch size two emits 2, 2, then 1 row. #[test] - fn materialized_final_output_slices_batches_until_exhausted() -> Result<()> { + fn test_materialized_output_slices_batches_until_exhausted() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new( "group_col", DataType::Int32, @@ -506,7 +525,7 @@ mod tests { schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], )?; - let mut output = MaterializedFinalOutput::new(batch); + let mut output = MaterializedOutput::new(batch); assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![1, 2]); assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![3, 4]); diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs index c3e4f831c4bbf..57682d838b68e 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs @@ -26,7 +26,7 @@ use crate::aggregates::AggregateExec; use super::common::{ AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, FinalMarker, - MaterializedFinalOutput, + MaterializedOutput, }; /// Methods specific to the aggregate hash table used in the final aggregation stage. @@ -68,7 +68,7 @@ impl AggregateHashTable { let output = self.materialize_final_output(state, output_schema)?; Ok(self.emit_next_materialized_batch(output, batch_size)) } - AggregateHashTableState::OutputtingMaterializedFinal(output) => { + AggregateHashTableState::OutputtingMaterialized(output) => { Ok(self.emit_next_materialized_batch(output, batch_size)) } AggregateHashTableState::Done => Ok(None), @@ -82,7 +82,7 @@ impl AggregateHashTable { &self, mut state: AggregateHashTableBuffer, output_schema: SchemaRef, - ) -> Result { + ) -> Result { // Final aggregate evaluation consumes accumulator state. Evaluate all // groups once, then slice the materialized batch on subsequent polls. let emit_to = EmitTo::All; @@ -96,21 +96,7 @@ impl AggregateHashTable { let batch = RecordBatch::try_new(output_schema, output)?; debug_assert!(batch.num_rows() > 0); - Ok(MaterializedFinalOutput::new(batch)) - } - - fn emit_next_materialized_batch( - &mut self, - mut output: MaterializedFinalOutput, - batch_size: usize, - ) -> Option { - let batch = output.next_batch(batch_size); - if output.is_exhausted() { - self.state = AggregateHashTableState::Done; - } else { - self.state = AggregateHashTableState::OutputtingMaterializedFinal(output); - } - batch + Ok(MaterializedOutput::new(batch)) } pub(in crate::aggregates) fn aggregate_batch( @@ -122,9 +108,12 @@ impl AggregateHashTable { let timer = self.group_by_metrics.aggregation_time.timer(); for group_values in &evaluated_batch.grouping_set_args { - state - .group_values - .intern(group_values, &mut state.batch_group_indices)?; + state.group_values.intern( + group_values, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; let group_indices = &state.batch_group_indices; let total_num_groups = state.group_values.len(); diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 9d226aa28b35f..805c5c709de98 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, new_null_array}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; +use datafusion_expr::EmitTo; use crate::aggregates::group_values::new_group_values; use crate::aggregates::order::GroupOrdering; @@ -30,8 +32,8 @@ use crate::aggregates::{AggregateExec, group_id_array, max_duplicate_ordinal}; use super::common::{ AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, - EvaluatedAccumulatorArgs, HashAggregateAccumulator, PartialMarker, PartialSkipMarker, - emit_to_for_batch_size, + EvaluatedAccumulatorArgs, EvaluatedAggregateBatch, HashAggregateAccumulator, + MaterializedOutput, PartialMarker, PartialSkipMarker, }; /// Methods specific to the aggregate hash table used in the partial aggregation stage. @@ -60,48 +62,76 @@ impl AggregateHashTable { pub(in crate::aggregates) fn next_output_batch( &mut self, ) -> Result> { - let output_schema = Arc::clone(&self.output_schema); let batch_size = self.batch_size; - match &mut self.state { + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { AggregateHashTableState::Outputting(state) => { if state.group_values.is_empty() { - self.state = AggregateHashTableState::Done; return Ok(None); } - let emit_to = - emit_to_for_batch_size(batch_size, state.group_values.len()); - let timer = self.group_by_metrics.emitting_time.timer(); - let mut output = state.group_values.emit(emit_to)?; + let batch = self.materialize_partial_batch(state)?; + Ok(self.emit_next_materialized_batch( + MaterializedOutput::new(batch), + batch_size, + )) + } + AggregateHashTableState::OutputtingMaterialized(output) => { + Ok(self.emit_next_materialized_batch(output, batch_size)) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!("next_output_batch must be called in the outputting state") + } + } + } - for acc in state.accumulators.iter_mut() { - output.extend(acc.state(emit_to)?); + pub(in crate::aggregates) fn materialize_output_batch_with_hashes( + &mut self, + ) -> Result)>> { + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + return Ok(None); } - let done = state.group_values.is_empty(); - drop(timer); - let batch = RecordBatch::try_new(output_schema, output)?; - debug_assert!(batch.num_rows() > 0); - if done { - self.state = AggregateHashTableState::Done; - } - Ok(Some(batch)) + let hashes = state.group_hashes.clone(); + self.materialize_partial_batch(state) + .map(|batch| Some((batch, hashes))) } AggregateHashTableState::Done => Ok(None), AggregateHashTableState::Building(_) => { - internal_err!("next_output_batch must be called in the outputting state") - } - AggregateHashTableState::OutputtingMaterializedFinal(_) => { internal_err!( - "partial aggregate output should not materialize final output" + "materialize_output_batch_with_hashes must be called in the outputting state" ) } + AggregateHashTableState::OutputtingMaterialized(_) => { + internal_err!("partial aggregate output is already materialized") + } } } + fn materialize_partial_batch( + &self, + mut state: AggregateHashTableBuffer, + ) -> Result { + let output_schema = Arc::clone(&self.output_schema); + let emit_to = EmitTo::All; + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(emit_to)?); + } + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + Ok(batch) + } + pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool { - self.state - .building() + let state = self.state.building(); + state .accumulators .iter() .all(|acc| acc.supports_convert_to_state()) @@ -131,6 +161,9 @@ impl AggregateHashTable { group_by: Arc::clone(&state.group_by), group_values, batch_group_indices: Default::default(), + batch_hashes: Default::default(), + group_hashes: Default::default(), + new_group_rows: Default::default(), accumulators, }), _mode: PhantomData, @@ -146,11 +179,19 @@ impl AggregateHashTable { let _timer = self.group_by_metrics.aggregation_time.timer(); for group_values in &evaluated_batch.grouping_set_args { - state - .group_values - .intern(group_values, &mut state.batch_group_indices)?; + state.group_values.intern( + group_values, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; let group_indices = &state.batch_group_indices; let total_num_groups = state.group_values.len(); + state.group_hashes.resize(total_num_groups, 0); + for &row in &state.new_group_rows { + let group_index = group_indices[row]; + state.group_hashes[group_index] = state.batch_hashes[row]; + } for (acc, values) in state .accumulators @@ -216,9 +257,17 @@ impl AggregateHashTable { .collect(); cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); - state - .group_values - .intern(&cols, &mut state.batch_group_indices)?; + state.group_values.intern( + &cols, + &mut state.batch_group_indices, + &mut state.batch_hashes, + &mut state.new_group_rows, + )?; + state.group_hashes.resize(state.group_values.len(), 0); + for &row in &state.new_group_rows { + let group_index = state.batch_group_indices[row]; + state.group_hashes[group_index] = state.batch_hashes[row]; + } any_interned = true; } @@ -245,17 +294,7 @@ impl AggregateHashTable { batch: &RecordBatch, ) -> Result { let evaluated_batch = self.evaluate_batch(batch)?; - - assert_eq_or_internal_err!( - evaluated_batch.grouping_set_args.len(), - 1, - "group_values expected to have single element" - ); - let mut output = evaluated_batch - .grouping_set_args - .into_iter() - .next() - .unwrap_or_default(); + let mut output = Self::output_group_values(&evaluated_batch)?; let state = self.state.building_mut(); for (acc, values) in state @@ -271,4 +310,50 @@ impl AggregateHashTable { output, )?) } + + pub(in crate::aggregates) fn convert_batch_to_state_with_hashes<'a>( + &'a mut self, + batch: &RecordBatch, + ) -> Result<(RecordBatch, &'a [u64])> { + let evaluated_batch = self.evaluate_batch(batch)?; + let output_group_values = Self::output_group_values(&evaluated_batch)?; + + let state = self.state.building_mut(); + state.batch_hashes.clear(); + state.batch_hashes.resize(batch.num_rows(), 0); + create_hashes( + &output_group_values, + &crate::aggregates::AGGREGATION_HASH_SEED, + &mut state.batch_hashes, + )?; + + let mut output = output_group_values; + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + output.extend(acc.convert_to_state(values)?); + } + + Ok(( + RecordBatch::try_new(Arc::clone(&self.output_schema), output)?, + &state.batch_hashes, + )) + } + + fn output_group_values( + evaluated_batch: &EvaluatedAggregateBatch, + ) -> Result> { + assert_eq_or_internal_err!( + evaluated_batch.grouping_set_args.len(), + 1, + "group_values expected to have single element" + ); + Ok(evaluated_batch + .grouping_set_args + .first() + .cloned() + .unwrap_or_default()) + } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ee253e5d7afdd..a49ad87676505 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -91,13 +91,21 @@ pub trait GroupValues: Send { /// Calculates the group id for each input row of `cols`, assigning new /// group ids as necessary. /// - /// When the function returns, `groups` must contain the group id for each - /// row in `cols`. + /// When the function returns, `groups` must contain the group id for each + /// row in `cols`, and `hashes` must contain the hash for each row in `cols`. + /// `new_group_rows` is filled with the input row index that first created + /// each new group in this call. /// /// If a row has the same value as a previous row, the same group id is /// assigned. If a row has a new value, the next available group id is /// assigned. - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()>; /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index f275d777c3279..abfe70b53a7c6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -344,23 +344,20 @@ impl GroupValuesColumn { /// /// `Group indices` order are against with their input order, and this will lead to error /// in `streaming aggregation`. - fn scalarized_intern( + fn scalarized_intern_impl( &mut self, cols: &[ArrayRef], + hashes: &[u64], groups: &mut Vec, + new_group_rows: &mut Vec, ) -> Result<()> { - let n_rows = cols[0].len(); + debug_assert_eq!(hashes.len(), cols.first().map_or(0, |array| array.len())); // tracks to which group each of the input rows belongs groups.clear(); + new_group_rows.clear(); - // 1.1 Calculate the group keys for the group values - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, batch_hashes)?; - - for (row, &target_hash) in batch_hashes.iter().enumerate() { + for (row, &target_hash) in hashes.iter().enumerate() { let entry = self .map .find_mut(target_hash, |(exist_hash, group_idx_view)| { @@ -425,6 +422,7 @@ impl GroupValuesColumn { |(hash, _group_index)| *hash, &mut self.map_size, ); + new_group_rows.push(row); group_idx } }; @@ -445,21 +443,20 @@ impl GroupValuesColumn { /// /// The vectorized approach can offer higher performance for avoiding row by row /// downcast for `cols` and being able to implement even more optimizations(like simd). - fn vectorized_intern( + fn vectorized_intern_impl( &mut self, cols: &[ArrayRef], + hashes: &[u64], groups: &mut Vec, + new_group_rows: &mut Vec, ) -> Result<()> { let n_rows = cols[0].len(); + debug_assert_eq!(hashes.len(), n_rows); // tracks to which group each of the input rows belongs groups.clear(); groups.resize(n_rows, usize::MAX); - - let mut batch_hashes = mem::take(&mut self.hashes_buffer); - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, &mut batch_hashes)?; + new_group_rows.clear(); // General steps for one round `vectorized equal_to & append`: // 1. Collect vectorized context by checking hash values of `cols` in `map`, @@ -482,7 +479,7 @@ impl GroupValuesColumn { // // 1. Collect vectorized context by checking hash values of `cols` in `map` - self.collect_vectorized_process_context(&batch_hashes, groups); + self.collect_vectorized_process_context(hashes, groups, new_group_rows); // 2. Perform `vectorized_append` self.vectorized_append(cols)?; @@ -492,9 +489,7 @@ impl GroupValuesColumn { // 4. Perform scalarized inter for remaining rows // (about remaining rows, can see comments for `remaining_row_indices`) - self.scalarized_intern_remaining(cols, &batch_hashes, groups)?; - - self.hashes_buffer = batch_hashes; + self.scalarized_intern_remaining(cols, hashes, groups, new_group_rows)?; Ok(()) } @@ -514,8 +509,9 @@ impl GroupValuesColumn { /// Otherwise get all group indices from `group_index_lists`, and add them. fn collect_vectorized_process_context( &mut self, - batch_hashes: &[u64], + hashes: &[u64], groups: &mut [usize], + new_group_rows: &mut Vec, ) { self.vectorized_operation_buffers.append_row_indices.clear(); self.vectorized_operation_buffers @@ -525,7 +521,7 @@ impl GroupValuesColumn { .equal_to_group_indices .clear(); - for (row, &target_hash) in batch_hashes.iter().enumerate() { + for (row, &target_hash) in hashes.iter().enumerate() { let entry = self .map .find(target_hash, |(exist_hash, _)| target_hash == *exist_hash); @@ -553,6 +549,7 @@ impl GroupValuesColumn { // Set group index to row in `groups` groups[row] = current_group_idx; + new_group_rows.push(row); continue; }; @@ -741,8 +738,9 @@ impl GroupValuesColumn { fn scalarized_intern_remaining( &mut self, cols: &[ArrayRef], - batch_hashes: &[u64], + hashes: &[u64], groups: &mut [usize], + new_group_rows: &mut Vec, ) -> Result<()> { if self .vectorized_operation_buffers @@ -755,7 +753,7 @@ impl GroupValuesColumn { let mut map = mem::take(&mut self.map); for &row in &self.vectorized_operation_buffers.remaining_row_indices { - let target_hash = batch_hashes[row]; + let target_hash = hashes[row]; let entry = map.find_mut(target_hash, |(exist_hash, _)| { // Somewhat surprisingly, this closure can be called even if the // hash doesn't match, so check the hash first with an integer @@ -816,6 +814,7 @@ impl GroupValuesColumn { } groups[row] = group_idx; + new_group_rows.push(row); } self.map = map; @@ -1078,14 +1077,21 @@ fn make_group_column(field: &Field) -> Result> { } impl GroupValues for GroupValuesColumn { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - // `try_new` and the reset points in `emit` / `clear_shrink` keep - // `self.group_values` populated with one builder per schema field, - // so no lazy initialization is needed here. + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + hashes: &mut Vec, + new_group_rows: &mut Vec, + ) -> Result<()> { + let n_rows = cols[0].len(); + hashes.clear(); + hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, hashes)?; if !STREAMING { - self.vectorized_intern(cols, groups) + self.vectorized_intern_impl(cols, hashes, groups, new_group_rows) } else { - self.scalarized_intern(cols, groups) + self.scalarized_intern_impl(cols, hashes, groups, new_group_rows) } } @@ -1878,7 +1884,11 @@ mod tests { fn load_to_group_values(&self, group_values: &mut impl GroupValues) { for batch in self.test_batches.iter() { - group_values.intern(batch, &mut vec![]).unwrap(); + let mut groups = vec![]; + let mut hashes = vec![]; + group_values + .intern(batch, &mut groups, &mut hashes, &mut vec![]) + .unwrap(); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 4976a098ecee5..9350daf1de2c2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -116,7 +116,13 @@ impl GroupValuesRows { } impl GroupValues for GroupValuesRows { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, + ) -> Result<()> { // Normalize -0.0 → +0.0 so RowConverter (IEEE 754 totalOrder) and // primitive hashing both group ±0 together. No-op for non-float // columns. diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index e993c0c53d199..edcb5013d240c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -42,7 +42,13 @@ impl GroupValuesBoolean { } impl GroupValues for GroupValuesBoolean { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, + ) -> Result<()> { let array = cols[0].as_boolean(); groups.clear(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index b881a51b25474..73664970fdd9c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -45,29 +45,29 @@ impl GroupValuesBytes { } impl GroupValues for GroupValuesBytes { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, + ) -> Result<()> { assert_eq!(cols.len(), 1); - - // look up / add entries in the table let arr = &cols[0]; groups.clear(); self.map.insert_if_new( arr, - // called for each new group |_value| { - // assign new group index on each insert let group_idx = self.num_groups; self.num_groups += 1; group_idx }, - // called for each group |group_idx| { groups.push(group_idx); }, ); - // ensure we assigned a group to for each row assert_eq!(groups.len(), arr.len()); Ok(()) } @@ -108,7 +108,13 @@ impl GroupValues for GroupValuesBytes { self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + let mut hashes = vec![]; + self.intern( + &[remaining_group_values], + &mut group_indexes, + &mut hashes, + &mut vec![], + )?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 7a56f7c52c11a..ab66dcd4a0938 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -47,29 +47,25 @@ impl GroupValues for GroupValuesBytesView { &mut self, cols: &[ArrayRef], groups: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); - - // look up / add entries in the table let arr = &cols[0]; groups.clear(); self.map.insert_if_new( arr, - // called for each new group |_value| { - // assign new group index on each insert let group_idx = self.num_groups; self.num_groups += 1; group_idx }, - // called for each group |group_idx| { groups.push(group_idx); }, ); - // ensure we assigned a group to for each row assert_eq!(groups.len(), arr.len()); Ok(()) } @@ -110,7 +106,13 @@ impl GroupValues for GroupValuesBytesView { self.num_groups = 0; let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; + let mut hashes = vec![]; + self.intern( + &[remaining_group_values], + &mut group_indexes, + &mut hashes, + &mut vec![], + )?; // Verify that the group indexes were assigned in the correct order assert_eq!(0, group_indexes[0]); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index e254aebcfd7ce..889304aad6f7e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -135,7 +135,13 @@ impl GroupValues for GroupValuesPrimitive where T::Native: HashValue, { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + _hashes: &mut Vec, + _new_group_rows: &mut Vec, + ) -> Result<()> { assert_eq!(cols.len(), 1); groups.clear(); @@ -273,7 +279,8 @@ mod tests { // Intern 20 distinct values; `new()` pre-allocates capacity 128 for `values`. let arr: ArrayRef = Arc::new(Int32Array::from_iter_values(0..20i32)); let mut groups = vec![]; - gv.intern(&[arr], &mut groups)?; + let mut hashes = vec![]; + gv.intern(&[arr], &mut groups, &mut hashes, &mut vec![])?; let capacity_before = gv.values.capacity(); // 128 // n=4, n*2=8 <= len=20 -> drain branch diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 4c8756c0e865c..5c0c2d0f8e7d3 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -25,14 +25,17 @@ //! //! See issue for details: +use std::collections::VecDeque; use std::ops::ControlFlow; use std::sync::Arc; use std::task::{Context, Poll}; +use arrow::array::{RecordBatch, UInt32Builder}; +use arrow::compute::take_arrays; use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{Result, internal_err}; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use futures::stream::{Stream, StreamExt}; @@ -41,9 +44,13 @@ use super::aggregate_hash_table::{ AggregateHashTable, FinalMarker, PartialMarker, PartialSkipMarker, }; use super::skip_partial::SkipAggregationProbe; +use crate::coalesce::LimitedBatchCoalescer; use crate::metrics::{ BaselineMetrics, MetricBuilder, MetricCategory, RecordOutput, SpillMetrics, }; +use crate::repartition::{ + PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY, PARTITIONED_AGGREGATION_PARTITION_KEY, +}; use crate::stream::EmptyRecordBatchStream; use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metrics}; @@ -126,6 +133,21 @@ pub(crate) struct PartialHashAggregateStream { /// Tracks whether partial aggregation should switch to direct state conversion. skip_aggregation_probe: Option, + /// Local repartition and coalesce state for partial output. + repartition_state: Option>, + + /// Materialized partial output being sliced into local repartition. + repartition_source_batch: Option, + + /// Hashes for rows in [`Self::repartition_source_batch`]. + repartition_source_hashes: Vec, + + /// Next row offset to push from [`Self::repartition_source_batch`]. + repartition_source_offset: usize, + + /// Target output batch size. + batch_size: usize, + /// Optional soft limit on the number of groups to accumulate before output. /// /// Invariant: when this is `Some(..)`, the accumulators inside `hash_table` must @@ -134,7 +156,162 @@ pub(crate) struct PartialHashAggregateStream { /// Tracks the high-level stream lifecycle. The hash table owns the lower-level /// state for emitting output batches. - state: Option, + state: Option>, +} + +struct PartialRepartitionState { + coalescers: Vec, + pending: VecDeque, + partition_indices: Vec>, + num_partitions: usize, + finished: bool, +} + +impl PartialRepartitionState { + fn new(schema: &SchemaRef, num_partitions: usize, batch_size: usize) -> Self { + let coalescers = (0..num_partitions) + .map(|_| LimitedBatchCoalescer::new(Arc::clone(schema), batch_size, None)) + .collect(); + Self { + coalescers, + pending: VecDeque::new(), + partition_indices: vec![vec![]; num_partitions], + num_partitions, + finished: false, + } + } + + fn push_batch(&mut self, batch: &RecordBatch, hashes: &[u64]) -> Result<()> { + if batch.num_rows() != hashes.len() { + return internal_err!( + "partial aggregate output has {} rows, but {} hashes", + batch.num_rows(), + hashes.len() + ); + } + + if hashes.is_empty() { + return Ok(()); + } + + for indices in &mut self.partition_indices { + indices.clear(); + } + + if self.num_partitions.is_power_of_two() { + let mask = self.num_partitions - 1; + Self::push_partition_indices(hashes, &mut self.partition_indices, |hash| { + (hash as usize) & mask + })?; + } else { + let num_partitions = self.num_partitions; + Self::push_partition_indices(hashes, &mut self.partition_indices, |hash| { + (hash as usize) % num_partitions + })?; + } + + let mut indices_builder = UInt32Builder::with_capacity(hashes.len()); + for indices in &self.partition_indices { + indices_builder.append_slice(indices); + } + let indices = indices_builder.finish(); + let columns = take_arrays(batch.columns(), &indices, None)?; + let reordered_batch = RecordBatch::try_new(Arc::clone(&batch.schema()), columns)?; + + let mut offset = 0; + for partition in 0..self.num_partitions { + let len = self.partition_indices[partition].len(); + if len == 0 { + continue; + } + + let partition_batch = reordered_batch.slice(offset, len); + offset += len; + self.coalescers[partition].push_batch(partition_batch)?; + while let Some(batch) = self.coalescers[partition].next_completed_batch() { + self.pending.push_back(add_partitioned_aggregation_metadata( + batch, + partition, + self.num_partitions, + )); + } + } + + Ok(()) + } + + fn push_partition_indices( + hashes: &[u64], + partition_indices: &mut [Vec], + compute_partition: F, + ) -> Result<()> + where + F: Fn(u64) -> usize, + { + for (row, hash) in hashes.iter().enumerate() { + let row = u32::try_from(row).map_err(|_| { + datafusion_common::internal_datafusion_err!( + "partitioned aggregate row index exceeds u32::MAX" + ) + })?; + partition_indices[compute_partition(*hash)].push(row); + } + Ok(()) + } + + fn finish(&mut self) -> Result<()> { + if self.finished { + return Ok(()); + } + + for partition in 0..self.num_partitions { + self.coalescers[partition].finish()?; + while let Some(batch) = self.coalescers[partition].next_completed_batch() { + self.pending.push_back(add_partitioned_aggregation_metadata( + batch, + partition, + self.num_partitions, + )); + } + } + self.finished = true; + Ok(()) + } + + fn is_done(&self) -> bool { + self.pending.is_empty() && self.finished + } + + fn memory_size(&self) -> usize { + self.partition_indices.allocated_size() + + self + .partition_indices + .iter() + .map(VecAllocExt::allocated_size) + .sum::() + + self + .pending + .iter() + .map(RecordBatch::get_array_memory_size) + .sum::() + } +} + +fn add_partitioned_aggregation_metadata( + mut batch: RecordBatch, + partition: usize, + num_partitions: usize, +) -> RecordBatch { + let metadata = batch.schema_metadata_mut(); + metadata.insert( + PARTITIONED_AGGREGATION_PARTITION_KEY.to_string(), + partition.to_string(), + ); + metadata.insert( + PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY.to_string(), + num_partitions.to_string(), + ); + batch } /// States for partial hash aggregation processing. @@ -148,7 +325,7 @@ enum PartialHashAggregateState { /// finish in `Done`. If `Some`, partial skip has triggered and the /// stream will move to `SkippingAggregation` after these accumulated /// groups are emitted. - skip_hash_table: Option>, + skip_hash_table: Option>>, }, SkippingAggregation { hash_table: AggregateHashTable, @@ -293,10 +470,22 @@ impl PartialHashAggregateStream { Arc::clone(&schema), batch_size, )?; - let can_skip_aggregation = - agg.group_by.is_single() && hash_table.can_skip_aggregation(); + let options = &context.session_config().options().execution; + let repartition_state = if options.enable_partial_aggregation_local_repartition + && !agg.group_by.is_empty() + && context.session_config().repartition_aggregations() + && context.session_config().target_partitions() > 1 + { + Some(Box::new(PartialRepartitionState::new( + &schema, + context.session_config().target_partitions(), + batch_size, + ))) + } else { + None + }; + let can_skip_aggregation = hash_table.can_skip_aggregation(); let skip_aggregation_probe = if can_skip_aggregation { - let options = &context.session_config().options().execution; let probe_ratio_threshold = options.skip_partial_aggregation_probe_ratio_threshold; // A threshold >= 1.0 means the ratio (num_groups / input_rows) can @@ -328,8 +517,15 @@ impl PartialHashAggregateStream { reservation, reduction_factor, skip_aggregation_probe, + repartition_state, + repartition_source_batch: None, + repartition_source_hashes: vec![], + repartition_source_offset: 0, + batch_size, group_values_soft_limit: agg.limit_options().map(|config| config.limit()), - state: Some(PartialHashAggregateState::ReadingInput { hash_table }), + state: Some(Box::new(PartialHashAggregateState::ReadingInput { + hash_table, + })), }) } @@ -357,6 +553,113 @@ impl PartialHashAggregateStream { .is_some_and(|probe| probe.should_skip()) } + fn memory_size(&self, hash_table: &AggregateHashTable) -> usize { + hash_table.memory_size() + + self + .repartition_state + .as_ref() + .map_or(0, |state| state.memory_size()) + + self + .repartition_source_batch + .as_ref() + .map_or(0, RecordBatch::get_array_memory_size) + + self.repartition_source_hashes.allocated_size() + } + + fn resize_reservation( + &self, + hash_table: &AggregateHashTable, + ) -> Result<()> { + self.reservation.try_resize(self.memory_size(hash_table)) + } + + fn next_output_batch( + &mut self, + hash_table: &mut AggregateHashTable, + finish_when_source_done: bool, + ) -> Result> { + if self.repartition_state.is_none() { + return hash_table.next_output_batch(); + } + + if self.repartition_source_batch.is_none() + && let Some((batch, hashes)) = + hash_table.materialize_output_batch_with_hashes()? + { + self.repartition_source_batch = Some(batch); + self.repartition_source_hashes = hashes; + self.repartition_source_offset = 0; + } + + self.next_repartition_batch(finish_when_source_done) + } + + fn next_repartition_batch( + &mut self, + finish_when_source_done: bool, + ) -> Result> { + if let Some(batch) = self.next_repartition_pending_batch() { + return Ok(Some(batch)); + } + + while self.push_next_repartition_source_batch()? { + if let Some(batch) = self.next_repartition_pending_batch() { + return Ok(Some(batch)); + } + } + + if finish_when_source_done + && let Some(repartition_state) = self.repartition_state.as_mut() + { + repartition_state.finish()?; + } + + Ok(self.next_repartition_pending_batch()) + } + + fn push_next_repartition_source_batch(&mut self) -> Result { + let Some(source_batch) = self.repartition_source_batch.as_ref() else { + return Ok(false); + }; + if self.repartition_source_offset >= source_batch.num_rows() { + self.repartition_source_batch = None; + self.repartition_source_hashes.clear(); + self.repartition_source_offset = 0; + return Ok(false); + } + + let offset = self.repartition_source_offset; + let len = self.batch_size.min(source_batch.num_rows() - offset); + let batch = source_batch.slice(offset, len); + self.repartition_source_offset += len; + + let source_hashes = std::mem::take(&mut self.repartition_source_hashes); + let result = match self.repartition_state.as_mut() { + Some(repartition_state) => { + repartition_state.push_batch(&batch, &source_hashes[offset..offset + len]) + } + None => Ok(()), + }; + self.repartition_source_hashes = source_hashes; + result?; + Ok(true) + } + + fn next_repartition_pending_batch(&mut self) -> Option { + self.repartition_state + .as_mut() + .and_then(|state| state.pending.pop_front()) + } + + fn is_output_done(&self, hash_table: &AggregateHashTable) -> bool { + hash_table.is_done() + && self.repartition_source_batch.is_none() + && self + .repartition_state + .as_ref() + .is_none_or(|state| state.is_done()) + } + fn start_output( &mut self, hash_table: &mut AggregateHashTable, @@ -457,7 +760,7 @@ impl PartialHashAggregateStream { return ControlFlow::Continue( PartialHashAggregateState::ProducingOutput { hash_table, - skip_hash_table: Some(skip_hash_table), + skip_hash_table: Some(Box::new(skip_hash_table)), }, ); } @@ -472,10 +775,7 @@ impl PartialHashAggregateStream { // TODO: impl memory-limited aggr, when OOM directly send // partial state to final aggregate stage - if let Err(e) = self - .reservation - .try_resize(original_state.hash_table().memory_size()) - { + if let Err(e) = self.resize_reservation(original_state.hash_table()) { return ControlFlow::Break(( Poll::Ready(Some(Err(e))), original_state, @@ -532,24 +832,30 @@ impl PartialHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); - let result = original_state.hash_table_mut().next_output_batch(); + let finish_when_source_done = !matches!( + &original_state, + PartialHashAggregateState::ProducingOutput { + skip_hash_table: Some(_), + .. + } + ); + let result = self + .next_output_batch(original_state.hash_table_mut(), finish_when_source_done); timer.done(); match result { Ok(Some(batch)) => { - let _ = self - .reservation - .try_resize(original_state.hash_table().memory_size()); + let _ = self.resize_reservation(original_state.hash_table()); self.reduction_factor.add_part(batch.num_rows()); debug_assert!(batch.num_rows() > 0); - let next_state = if original_state.hash_table().is_done() { + let next_state = if self.is_output_done(original_state.hash_table()) { match original_state { PartialHashAggregateState::ProducingOutput { skip_hash_table: Some(hash_table), .. - } => { - PartialHashAggregateState::SkippingAggregation { hash_table } - } + } => PartialHashAggregateState::SkippingAggregation { + hash_table: *hash_table, + }, PartialHashAggregateState::ProducingOutput { skip_hash_table: None, .. @@ -573,7 +879,9 @@ impl PartialHashAggregateStream { PartialHashAggregateState::ProducingOutput { skip_hash_table: Some(hash_table), .. - } => PartialHashAggregateState::SkippingAggregation { hash_table }, + } => PartialHashAggregateState::SkippingAggregation { + hash_table: *hash_table, + }, PartialHashAggregateState::ProducingOutput { skip_hash_table: None, .. @@ -586,6 +894,26 @@ impl PartialHashAggregateStream { } } + fn push_skip_batch( + &mut self, + batch: RecordBatch, + hashes: &[u64], + ) -> Result> { + let Some(repartition_state) = self.repartition_state.as_mut() else { + return Ok(Some(batch)); + }; + repartition_state.push_batch(&batch, hashes)?; + Ok(self.next_repartition_pending_batch()) + } + + fn finish_repartition(&mut self) -> Result> { + let Some(repartition_state) = self.repartition_state.as_mut() else { + return Ok(None); + }; + repartition_state.finish()?; + Ok(self.next_repartition_pending_batch()) + } + /// Handle SkippingAggregation state - convert raw input directly to partial states. /// /// See comments at `poll_next()` for details. @@ -610,23 +938,56 @@ impl PartialHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); - let result = match &mut original_state { - PartialHashAggregateState::SkippingAggregation { hash_table } => { - hash_table.convert_batch_to_state(&batch) + if self.repartition_state.is_some() { + let result = match &mut original_state { + PartialHashAggregateState::SkippingAggregation { hash_table } => { + hash_table.convert_batch_to_state_with_hashes(&batch) + } + _ => unreachable!("expected skipping aggregation state"), + }; + timer.done(); + + match result { + Ok((batch, hashes)) => { + match self.push_skip_batch(batch, hashes) { + Ok(Some(batch)) => ControlFlow::Break(( + Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))), + original_state, + )), + Ok(None) => ControlFlow::Continue(original_state), + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )), + } + } + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )), } - _ => unreachable!("expected skipping aggregation state"), - }; - timer.done(); + } else { + let result = match &mut original_state { + PartialHashAggregateState::SkippingAggregation { hash_table } => { + hash_table.convert_batch_to_state(&batch) + } + _ => unreachable!("expected skipping aggregation state"), + }; + timer.done(); - match result { - Ok(batch) => ControlFlow::Break(( - Poll::Ready(Some( - Ok(batch.record_output(&self.baseline_metrics)), + match result { + Ok(batch) => ControlFlow::Break(( + Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))), + original_state, + )), + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, )), - original_state, - )), - Err(e) => { - ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) } } } @@ -636,7 +997,18 @@ impl PartialHashAggregateStream { Poll::Ready(None) => { let input_schema = self.input.schema(); self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); - ControlFlow::Continue(PartialHashAggregateState::Done) + match self.finish_repartition() { + Ok(Some(batch)) => ControlFlow::Break(( + Poll::Ready(Some( + Ok(batch.record_output(&self.baseline_metrics)), + )), + original_state, + )), + Ok(None) => ControlFlow::Continue(PartialHashAggregateState::Done), + Err(e) => { + ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) + } + } } } } @@ -700,7 +1072,7 @@ impl Stream for PartialHashAggregateStream { cx: &mut Context<'_>, ) -> Poll> { loop { - let cur_state = self + let cur_state = *self .state .take() .expect("PartialHashAggregateStream state should not be None"); @@ -717,18 +1089,18 @@ impl Stream for PartialHashAggregateStream { } state @ PartialHashAggregateState::Done => { let _ = self.reservation.try_resize(0); - self.state = Some(state); + self.state = Some(Box::new(state)); return Poll::Ready(None); } }; match next_state { ControlFlow::Continue(next_state) => { - self.state = Some(next_state); + self.state = Some(Box::new(next_state)); continue; } ControlFlow::Break((poll, next_state)) => { - self.state = Some(next_state); + self.state = Some(Box::new(next_state)); return poll; } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4f5b893578d74..2297b63f6f7c3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -519,7 +519,6 @@ impl PartialEq for PhysicalGroupBy { /// partial state: [g, AVG(x) state columns, e.g. sum/count] /// final result: [g, AVG(x)] /// ``` -#[expect(clippy::large_enum_variant)] enum StreamType { /// Single group (no group by) aggregate stream. /// Input output scheme: initial input -> final result @@ -536,7 +535,7 @@ enum StreamType { /// [`StreamType::PartialHash`] and [`StreamType::FinalHash`] /// /// See issue for details: - GroupedHash(GroupedHashAggregateStream), + GroupedHash(Box), /// Grouped TopK aggregate stream. /// Input output scheme: initial input -> final result /// @@ -551,7 +550,7 @@ impl From for SendableRecordBatchStream { StreamType::AggregateStream(stream) => Box::pin(stream), StreamType::PartialHash(stream) => Box::pin(stream), StreamType::FinalHash(stream) => Box::pin(stream), - StreamType::GroupedHash(stream) => Box::pin(stream), + StreamType::GroupedHash(stream) => Box::pin(*stream), StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } } @@ -1036,9 +1035,9 @@ impl AggregateExec { } // Execution paths that have not been migrated use the fallback implementation - Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( - self, context, partition, - )?)) + Ok(StreamType::GroupedHash(Box::new( + GroupedHashAggregateStream::new(self, context, partition)?, + ))) } fn should_use_partial_hash_stream(&self, context: &TaskContext) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index a4d19b0f7d18a..90f73555c2153 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -334,6 +334,12 @@ pub(crate) struct GroupedHashAggregateStream { /// processed. Reused across batches here to avoid reallocations current_group_indices: Vec, + /// Hash for each row in the current input batch. + current_hashes: Vec, + + /// Input rows that created new groups in the current input batch. + new_group_rows: Vec, + /// Accumulators, one for each `AggregateFunctionExpr` in the query /// /// For example, if the query has aggregates, `SUM(x)`, @@ -601,6 +607,8 @@ impl GroupedHashAggregateStream { oom_mode, group_values, current_group_indices: Default::default(), + current_hashes: Default::default(), + new_group_rows: Default::default(), exec_state, baseline_metrics, group_by_metrics, @@ -883,8 +891,12 @@ impl GroupedHashAggregateStream { // calculate the group indices for each input row let starting_num_groups = self.group_values.len(); - self.group_values - .intern(group_values, &mut self.current_group_indices)?; + self.group_values.intern( + group_values, + &mut self.current_group_indices, + &mut self.current_hashes, + &mut self.new_group_rows, + )?; let group_indices = &self.current_group_indices; // Update ordering information if necessary @@ -991,7 +1003,9 @@ impl GroupedHashAggregateStream { let groups_and_acc_size = acc + self.group_values.size() + self.group_ordering.size() - + self.current_group_indices.allocated_size(); + + self.current_group_indices.allocated_size() + + self.current_hashes.allocated_size() + + self.new_group_rows.allocated_size(); // Reserve extra headroom for sorting during potential spill. // When OOM triggers, group_aggregate_batch has already processed the @@ -1103,8 +1117,12 @@ impl GroupedHashAggregateStream { cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); let starting_groups = self.group_values.len(); - self.group_values - .intern(&cols, &mut self.current_group_indices)?; + self.group_values.intern( + &cols, + &mut self.current_group_indices, + &mut self.current_hashes, + &mut self.new_group_rows, + )?; let total_groups = self.group_values.len(); if total_groups > starting_groups { self.group_ordering.new_groups( @@ -1238,6 +1256,10 @@ impl GroupedHashAggregateStream { self.group_values.clear_shrink(num_rows); self.current_group_indices.clear(); self.current_group_indices.shrink_to(num_rows); + self.current_hashes.clear(); + self.current_hashes.shrink_to(num_rows); + self.new_group_rows.clear(); + self.new_group_rows.shrink_to(num_rows); } /// Clear memory and shrink capacities to zero. diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 7289ac43e510c..1413f58ed1e05 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -436,6 +436,7 @@ struct DistinctDeduplicator { group_values: Box, reservation: MemoryReservation, intern_output_buffer: Vec, + hash_buffer: Vec, } impl DistinctDeduplicator { @@ -447,6 +448,7 @@ impl DistinctDeduplicator { group_values, reservation, intern_output_buffer: Vec::new(), + hash_buffer: Vec::new(), }) } @@ -466,8 +468,12 @@ impl DistinctDeduplicator { "failed to reserve {additional} recursive query group ids: {e}" ) })?; - self.group_values - .intern(batch.columns(), &mut self.intern_output_buffer)?; + self.group_values.intern( + batch.columns(), + &mut self.intern_output_buffer, + &mut self.hash_buffer, + &mut vec![], + )?; let mask = new_groups_mask(&self.intern_output_buffer, size_before); self.intern_output_buffer.clear(); // We update the reservation to reflect the new size of the hash table. diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 2298183485f55..7022bc101b7c5 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -81,6 +81,11 @@ use distributor_channels::{ DistributionReceiver, DistributionSender, channels, partition_aware_channels, }; +pub(crate) const PARTITIONED_AGGREGATION_PARTITION_KEY: &str = + "datafusion.partitioned_aggregation.partition"; +pub(crate) const PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY: &str = + "datafusion.partitioned_aggregation.num_partitions"; + /// A batch in the repartition queue - either in memory or spilled to disk. /// /// This enum represents the two states a batch can be in during repartitioning. @@ -152,6 +157,46 @@ type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; type InputPartitionsToCurrentPartitionReceiver = Vec>; +fn partitioned_aggregation_partition( + batch: &RecordBatch, + expected_num_partitions: usize, +) -> Result> { + let metadata = batch.schema_ref().metadata(); + let Some(partition) = metadata.get(PARTITIONED_AGGREGATION_PARTITION_KEY) else { + return Ok(None); + }; + let Some(num_partitions) = metadata.get(PARTITIONED_AGGREGATION_NUM_PARTITIONS_KEY) + else { + return Ok(None); + }; + + let partition = partition.parse::().map_err(|err| { + DataFusionError::Internal(format!( + "Invalid partitioned aggregation partition metadata: {err}" + )) + })?; + let num_partitions = num_partitions.parse::().map_err(|err| { + DataFusionError::Internal(format!( + "Invalid partitioned aggregation partition count metadata: {err}" + )) + })?; + + assert_or_internal_err!( + num_partitions == expected_num_partitions, + "Partitioned aggregation partition count {} does not match repartition count {}", + num_partitions, + expected_num_partitions + ); + assert_or_internal_err!( + partition < expected_num_partitions, + "Partitioned aggregation partition {} is out of range for {} partitions", + partition, + expected_num_partitions + ); + + Ok(Some(partition)) +} + /// Output channel with its associated memory reservation and spill writer. /// /// `coalescer` is `None` for preserve-order mode, where downstream @@ -1636,6 +1681,7 @@ impl RepartitionExec { input_partition: usize, num_input_partitions: usize, ) -> Result<()> { + let num_output_partitions = partitioning.partition_count(); let mut partitioner = match &partitioning { Partitioning::Hash(exprs, num_partitions) => { BatchPartitioner::new_hash_partitioner( @@ -1683,6 +1729,20 @@ impl RepartitionExec { continue; } + if let Partitioning::Hash(_, _) = &partitioning + && let Some(partition) = + partitioned_aggregation_partition(&batch, num_output_partitions)? + { + let timer = metrics.send_time[partition].timer(); + if let Some(output_channel) = output_channels.get_mut(&partition) + && output_channel.send(batch).await.is_err() + { + output_channels.remove(&partition); + } + timer.done(); + continue; + } + for res in partitioner.partition_iter(batch)? { let (partition, batch) = res?;