From 0173f803ba4997788e1c969755414ad71fc806a7 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 24 Jun 2026 17:53:11 +0800 Subject: [PATCH 1/3] refactor(hash-aggr): Migrate ordered partial/final aggregation --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 12 +- .../aggregates/aggregate_hash_table/common.rs | 13 +- .../aggregate_hash_table/common_ordered.rs | 214 ++++++++++ .../aggregate_hash_table/final_table.rs | 8 +- .../aggregates/aggregate_hash_table/mod.rs | 6 + .../ordered_final_table.rs | 127 ++++++ .../ordered_partial_table.rs | 145 +++++++ .../aggregate_hash_table/partial_table.rs | 8 +- .../physical-plan/src/aggregates/mod.rs | 296 +++++++++++++- .../src/aggregates/ordered_final_stream.rs | 348 ++++++++++++++++ .../src/aggregates/ordered_partial_stream.rs | 380 ++++++++++++++++++ .../physical-plan/src/aggregates/row_hash.rs | 5 +- 12 files changed, 1548 insertions(+), 14 deletions(-) create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/common_ordered.rs create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_final_table.rs create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_partial_table.rs create mode 100644 datafusion/physical-plan/src/aggregates/ordered_final_stream.rs create mode 100644 datafusion/physical-plan/src/aggregates/ordered_partial_stream.rs diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 4726e7c4aca5c..f597c9c59a099 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -350,7 +350,12 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str schema.clone(), ) .unwrap(), - ) as Arc; + ); + assert_ne!( + aggregate_exec_running.input_order_mode(), + &InputOrderMode::Linear, + "running aggregate should observe ordered input for group_by: {group_by:?}" + ); let aggregate_exec_usual = Arc::new( AggregateExec::try_new( @@ -362,7 +367,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str schema.clone(), ) .unwrap(), - ) as Arc; + ); let task_ctx = ctx.task_ctx(); let collected_usual = collect(aggregate_exec_usual.clone(), task_ctx.clone()) @@ -373,9 +378,6 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .await .unwrap(); assert!(collected_running.len() > 2); - // Running should produce more chunk than the usual AggregateExec. - // Otherwise it means that we cannot generate result in running mode. - assert!(collected_running.len() > collected_usual.len()); // compare let usual_formatted = pretty_format_batches(&collected_usual).unwrap().to_string(); let running_formatted = pretty_format_batches(&collected_running) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 90039e70a654e..3d36bd9fdfaff 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -243,6 +243,8 @@ pub(super) struct HashAggregateAccumulator { accumulator: Box, } +pub(super) type AggregateAccumulator = HashAggregateAccumulator; + /// Evaluated aggregate arguments and filter for one input batch. /// /// For example, `AVG(x + 1) FILTER (WHERE x > 0)` evaluates both `x + 1` @@ -303,7 +305,7 @@ pub(super) enum AggregateHashTableState { } impl HashAggregateAccumulator { - fn new( + pub(super) fn new( aggregate_expr: Arc, arguments: Vec>, filter: Option>, @@ -335,7 +337,10 @@ impl HashAggregateAccumulator { /// and `x > 0`. /// /// These arrays can be passed directly to [`GroupsAccumulator`] next. - fn evaluate_acc_args(&self, batch: &RecordBatch) -> Result { + pub(super) fn evaluate_acc_args( + &self, + batch: &RecordBatch, + ) -> Result { let arguments = self .arguments .iter() @@ -358,6 +363,10 @@ impl HashAggregateAccumulator { Ok(EvaluatedAccumulatorArgs { arguments, filter }) } + pub(super) fn size(&self) -> usize { + self.accumulator.size() + } + pub(super) fn update_batch( &mut self, values: &EvaluatedAccumulatorArgs, diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common_ordered.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common_ordered.rs new file mode 100644 index 0000000000000..9f8e116ed6b42 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common_ordered.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for aggregate tables used in aggregations that inputs are ordered +//! by the groups. + +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::EmitTo; + +use crate::InputOrderMode; +use crate::PhysicalExpr; +use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; +use crate::aggregates::order::GroupOrdering; +use crate::aggregates::row_hash::create_group_accumulator; +use crate::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, aggregate_expressions, + evaluate_group_by, +}; + +use super::common::{AggregateAccumulator, EvaluatedAggregateBatch}; + +/// Marker for raw rows -> partial state aggregation on ordered input. +pub(in crate::aggregates) struct OrderedPartialMarker; +/// Marker for partial state -> final value aggregation on ordered input. +pub(in crate::aggregates) struct OrderedFinalMarker; + +/// Aggregate table shared by the ordered partial and final paths. +/// +/// The table consumes input batches while `GroupOrdering` tracks which groups +/// are proven complete. Completed groups can be emitted before the input stream +/// ends, which keeps memory bounded by the active ordered key range. +/// +/// # Marker Type +/// +/// `OrderedAggrMode` selects the aggregate semantics. For example, +/// `OrderedAggregateTable::::new(...)` consumes raw rows +/// and emits partial states, while +/// `OrderedAggregateTable::::new_with_input_order(...)` +/// consumes partial states and emits final values. +/// +/// Shared methods live on `impl`; partial/final behavior lives on +/// marker-specific impls. +pub(in crate::aggregates) struct OrderedAggregateTable { + /// Output schema: group columns followed by aggregate state or final values. + pub(super) output_schema: SchemaRef, + + /// Grouping and accumulator-specific timing metrics. + pub(super) group_by_metrics: GroupByMetrics, + + /// Group keys, ordering state, and accumulator states. + pub(super) buffer: OrderedAggregateTableBuffer, + + _mode: PhantomData, +} + +/// Buffer for the ordered aggregate table's group keys and accumulator states. +/// +/// It accumulates input during aggregation and emits output rows as soon as the +/// input ordering proves those groups are complete. +/// +/// [`GroupOrdering`] tracks when and how to do early emit. +/// [`GroupValues`] stores the physical group-key layout, while +/// [`datafusion_expr::GroupsAccumulator`] stores per-group aggregate state. +pub(super) struct OrderedAggregateTableBuffer { + /// GROUP BY expressions evaluated against input batches. + pub(super) group_by: Arc, + + /// Tracks how far ordered input allows this table to drain safely. + pub(super) group_ordering: GroupOrdering, + + /// Interned group keys, in the same group-id order used by accumulators. + pub(super) group_values: Box, + + /// Scratch group id vector for the current input batch. + pub(super) group_indices: Vec, + + /// One item per aggregate expression. + /// + /// Example: `COUNT(x), SUM(y)` creates two items. Each item owns the input + /// expressions, optional filter, and accumulator state for all groups. + pub(super) accumulators: Vec, +} + +/// Methods shared by all aggregate modes +impl OrderedAggregateTable { + pub(super) fn new_for_mode( + agg: &AggregateExec, + partition: usize, + input_schema: &SchemaRef, + output_schema: SchemaRef, + input_order_mode: &InputOrderMode, + aggregate_mode: &AggregateMode, + filters: Vec>>, + ) -> Result { + let group_ordering = GroupOrdering::try_new(input_order_mode)?; + let group_schema = agg.group_by.group_schema(input_schema)?; + let group_values = new_group_values(group_schema, &group_ordering)?; + let aggregate_arguments = aggregate_expressions( + &agg.aggr_expr, + aggregate_mode, + agg.group_by.num_group_exprs(), + )?; + let accumulators = agg + .aggr_expr + .iter() + .zip(aggregate_arguments) + .zip(filters) + .map(|((agg_expr, arguments), filter)| { + let accumulator = create_group_accumulator(agg_expr)?; + Ok(AggregateAccumulator::new( + Arc::clone(agg_expr), + arguments, + filter, + accumulator, + )) + }) + .collect::>()?; + + Ok(Self { + output_schema, + group_by_metrics: GroupByMetrics::new(&agg.metrics, partition), + buffer: OrderedAggregateTableBuffer { + group_by: Arc::clone(&agg.group_by), + group_ordering, + group_values, + group_indices: vec![], + accumulators, + }, + _mode: PhantomData, + }) + } + + /// Evaluates all group by keys and accumulator args. + /// + /// e.g., `select k+1, sum(v*v) from t group by (k+1)`, this function + /// evaluates `k+1`, `v*v`. + pub(super) fn evaluate_batch( + &self, + batch: &RecordBatch, + ) -> Result { + let timer = self.group_by_metrics.time_calculating_group_ids.timer(); + let grouping_set_args = evaluate_group_by(&self.buffer.group_by, batch)?; + drop(timer); + + let timer = self.group_by_metrics.aggregate_arguments_time.timer(); + let accumulator_args = self + .buffer + .accumulators + .iter() + .map(|acc| acc.evaluate_acc_args(batch)) + .collect::>>()?; + drop(timer); + + Ok(EvaluatedAggregateBatch { + grouping_set_args, + accumulator_args, + }) + } + + /// Called after the input stream is exhausted and the last batch has been + /// aggregated. + /// + /// Updates the internal `GroupOrdering` so it can continue emitting until + /// the buffer is empty. + pub(in crate::aggregates) fn input_done(&mut self) { + self.buffer.group_ordering.input_done(); + } + + /// Check if there is zero groups accumulated so far. + pub(in crate::aggregates) fn is_empty(&self) -> bool { + self.buffer.group_values.is_empty() + } + + /// All internal buffer's memory size. + pub(in crate::aggregates) fn memory_size(&self) -> usize { + self.buffer + .accumulators + .iter() + .map(|acc| acc.size()) + .sum::() + + self.buffer.group_values.size() + + self.buffer.group_ordering.size() + + self.buffer.group_indices.allocated_size() + } +} + +pub(super) fn remove_emitted_groups(group_ordering: &mut GroupOrdering, emit_to: EmitTo) { + match emit_to { + EmitTo::First(n) => group_ordering.remove_groups(n), + // `EmitTo::All` is only used after `input_done`, when all buffered groups + // are known complete and the ordering state is no longer needed. + EmitTo::All => {} + } +} diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs index 415694d8c2f59..7368d476a4d8f 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs @@ -27,7 +27,13 @@ use super::common::{ AggregateHashTable, AggregateHashTableState, FinalMarker, emit_to_for_batch_size, }; -/// Methods specific to the aggregate hash table used in the final aggregation stage. +/// Implementation specific to final aggregation, where the table stores partial +/// aggregate states and the input rows are also partial states. +/// +/// Example: `AVG(x) GROUP BY k` +/// +/// - Aggregate table stores: `k, sum(x), count(x)` +/// - Input rows: `k, sum(x), count(x)` impl AggregateHashTable { pub(in crate::aggregates) fn new( agg: &AggregateExec, diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs index eb152f4128896..b0f8d4e9b062f 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs @@ -16,9 +16,15 @@ // under the License. mod common; +mod common_ordered; mod final_table; +mod ordered_final_table; +mod ordered_partial_table; mod partial_table; pub(super) use common::{ AggregateHashTable, FinalMarker, PartialMarker, PartialSkipMarker, }; +pub(super) use common_ordered::{ + OrderedAggregateTable, OrderedFinalMarker, OrderedPartialMarker, +}; diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_final_table.rs new file mode 100644 index 0000000000000..91c54888c44f7 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_final_table.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate table for final aggregation when partial-state input is ordered. +//! +//! See comments in [`super::ordered_partial_table`] for details. + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; + +use crate::InputOrderMode; +use crate::aggregates::{AggregateExec, AggregateMode}; + +use super::common_ordered::{ + OrderedAggregateTable, OrderedFinalMarker, remove_emitted_groups, +}; + +/// Implementation specific to final aggregation, where the table stores partial +/// aggregate states and the input rows are also partial states. +/// +/// Example: `AVG(x) GROUP BY k` +/// +/// - Aggregate table stores: `k, sum(x), count(x)` +/// - Input rows: `k, sum(x), count(x)` +impl OrderedAggregateTable { + pub(in crate::aggregates) fn new_with_input_order( + agg: &AggregateExec, + partition: usize, + input_schema: &SchemaRef, + output_schema: SchemaRef, + input_order_mode: &InputOrderMode, + ) -> Result { + Self::new_for_mode( + agg, + partition, + input_schema, + output_schema, + input_order_mode, + &AggregateMode::Final, + vec![None; agg.aggr_expr.len()], + ) + } + + /// Merges one partial-state input batch and updates ordering information for + /// any newly observed groups. + pub(in crate::aggregates) fn aggregate_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + // `PhysicalGroupBy::as_final()` ensures it removes grouping set when + // planning final aggregate, so it's safe to reuse here. + debug_assert_eq!(evaluated_batch.grouping_set_args.len(), 1); + + for group_values in &evaluated_batch.grouping_set_args { + let starting_num_groups = self.buffer.group_values.len(); + self.buffer + .group_values + .intern(group_values, &mut self.buffer.group_indices)?; + let total_num_groups = self.buffer.group_values.len(); + if total_num_groups > starting_num_groups { + self.buffer.group_ordering.new_groups( + group_values, + &self.buffer.group_indices, + total_num_groups, + )?; + } + + let timer = self.group_by_metrics.aggregation_time.timer(); + for (acc, values) in self + .buffer + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.merge_batch(values, &self.buffer.group_indices, total_num_groups)?; + } + drop(timer); + } + + Ok(()) + } + + /// See comments in `ordered_partial_stream::next_output_batch` + pub(in crate::aggregates) fn next_output_batch( + &mut self, + ) -> Result> { + if self.buffer.group_values.is_empty() { + return Ok(None); + } + + let Some(emit_to) = self.buffer.group_ordering.emit_to() else { + return Ok(None); + }; + + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = self.buffer.group_values.emit(emit_to)?; + remove_emitted_groups(&mut self.buffer.group_ordering, emit_to); + + for acc in &mut self.buffer.accumulators { + output.push(acc.evaluate(emit_to)?); + } + drop(timer); + + let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; + debug_assert!(batch.num_rows() > 0); + + Ok(Some(batch)) + } +} diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_partial_table.rs new file mode 100644 index 0000000000000..c4fa380853be7 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/ordered_partial_table.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate table for partial aggregation when input is ordered by group keys. +//! +//! See the [`super::common_ordered`] comments for the high-level ideas. +//! +//! This operator handles input that is ordered by group keys: +//! - Fully ordered: `GROUP BY a, b`, input is `ORDER BY a, b` +//! - Partially ordered: `GROUP BY a, b`, input is `ORDER BY a` +//! +//! When a group key combination is exhausted, this table eagerly flushes the +//! completed groups to improve memory efficiency. +//! +//! The implementation is separated from other aggregate tables because this +//! execution path is likely to be optimized further in the future. + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; + +use crate::aggregates::{AggregateExec, AggregateMode}; + +use super::common_ordered::{ + OrderedAggregateTable, OrderedPartialMarker, remove_emitted_groups, +}; + +/// Implementation specific to partial aggregation, where the table stores +/// partial aggregate states and the input rows are raw rows. +/// +/// Example: `AVG(x) GROUP BY k` +/// +/// - Aggregate table stores: `k, sum(x), count(x)` +/// - Input rows: `k, x` +impl OrderedAggregateTable { + pub(in crate::aggregates) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + ) -> Result { + let input_schema = agg.input().schema(); + Self::new_for_mode( + agg, + partition, + &input_schema, + output_schema, + &agg.input_order_mode, + &AggregateMode::Partial, + agg.filter_expr.iter().cloned().collect(), + ) + } + + /// Aggregates one raw input batch and updates ordering information for any + /// newly observed groups. + pub(in crate::aggregates) fn aggregate_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + + for group_values in &evaluated_batch.grouping_set_args { + let starting_num_groups = self.buffer.group_values.len(); + self.buffer + .group_values + .intern(group_values, &mut self.buffer.group_indices)?; + let total_num_groups = self.buffer.group_values.len(); + if total_num_groups > starting_num_groups { + self.buffer.group_ordering.new_groups( + group_values, + &self.buffer.group_indices, + total_num_groups, + )?; + } + + let timer = self.group_by_metrics.aggregation_time.timer(); + for (acc, values) in self + .buffer + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.update_batch(values, &self.buffer.group_indices, total_num_groups)?; + } + drop(timer); + } + + Ok(()) + } + + /// Emits the next batch of partial state rows for groups proven complete by + /// the input ordering. + /// + /// For example, when the query is `GROUP BY a` and the input is ordered by + /// `a`, seeing a latest input row with `a = 3` means all groups with `a < 3` + /// are complete and safe to emit. + /// + /// Key steps: + /// 1. Ask `group_ordering` to decide how many groups can be emitted eagerly. + /// 2. Remove the emitted groups from `group_ordering`, `GroupValues`, and + /// all `GroupsAccumulator`s. + /// + /// This may output small batches. Avoiding tiny batches is left to future + /// ordered-aggregation optimizations. + pub(in crate::aggregates) fn next_output_batch( + &mut self, + ) -> Result> { + if self.buffer.group_values.is_empty() { + return Ok(None); + } + + let Some(emit_to) = self.buffer.group_ordering.emit_to() else { + return Ok(None); + }; + + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = self.buffer.group_values.emit(emit_to)?; + remove_emitted_groups(&mut self.buffer.group_ordering, emit_to); + + for acc in &mut self.buffer.accumulators { + output.extend(acc.state(emit_to)?); + } + drop(timer); + + let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; + debug_assert!(batch.num_rows() > 0); + + Ok(Some(batch)) + } +} diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index fd3cf801cfe57..81e959a3d9bbc 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -34,7 +34,13 @@ use super::common::{ emit_to_for_batch_size, }; -/// Methods specific to the aggregate hash table used in the partial aggregation stage. +/// Implementation specific to partial aggregation, where the table stores +/// partial aggregate states and the input rows are raw rows. +/// +/// Example: `AVG(x) GROUP BY k` +/// +/// - Aggregate table stores: `k, sum(x), count(x)` +/// - Input rows: `k, x` impl AggregateHashTable { pub(in crate::aggregates) fn new( agg: &AggregateExec, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 08468bffc0dd9..8fba0e6bd9fd5 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -24,6 +24,8 @@ use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::aggregates::{ hash_aggregate::{FinalHashAggregateStream, PartialHashAggregateStream}, no_grouping::AggregateStream, + ordered_final_stream::OrderedFinalAggregateStream, + ordered_partial_stream::OrderedPartialAggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; @@ -77,6 +79,8 @@ pub mod group_values; mod hash_aggregate; mod no_grouping; pub mod order; +mod ordered_final_stream; +mod ordered_partial_stream; mod row_hash; mod skip_partial; mod topk; @@ -530,10 +534,16 @@ enum StreamType { /// Final stage of the hash aggregation /// Input output scheme: partial state -> final result FinalHash(FinalHashAggregateStream), + /// Partial stage of aggregation for ordered input. + OrderedPartialAggregate(OrderedPartialAggregateStream), + /// Final stage of aggregation for ordered input. + OrderedFinalAggregate(OrderedFinalAggregateStream), /// Hash aggregation reused for multiple stages /// /// Note this is being incrementally migrated to dedicated streams like - /// [`StreamType::PartialHash`] and [`StreamType::FinalHash`] + /// [`StreamType::PartialHash`], [`StreamType::FinalHash`], + /// [`StreamType::OrderedPartialAggregate`], and + /// [`StreamType::OrderedFinalAggregate`] /// /// See issue for details: GroupedHash(GroupedHashAggregateStream), @@ -551,6 +561,8 @@ impl From for SendableRecordBatchStream { StreamType::AggregateStream(stream) => Box::pin(stream), StreamType::PartialHash(stream) => Box::pin(stream), StreamType::FinalHash(stream) => Box::pin(stream), + StreamType::OrderedPartialAggregate(stream) => Box::pin(stream), + StreamType::OrderedFinalAggregate(stream) => Box::pin(stream), StreamType::GroupedHash(stream) => Box::pin(stream), StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } @@ -1015,12 +1027,24 @@ impl AggregateExec { .execution .enable_migration_aggregate { + if self.should_use_ordered_partial_aggregate_stream(context) { + return Ok(StreamType::OrderedPartialAggregate( + OrderedPartialAggregateStream::new(self, context, partition)?, + )); + } + if self.should_use_partial_hash_stream(context) { return Ok(StreamType::PartialHash(PartialHashAggregateStream::new( self, context, partition, )?)); } + if self.should_use_ordered_final_aggregate_stream(context) { + return Ok(StreamType::OrderedFinalAggregate( + OrderedFinalAggregateStream::new(self, context, partition)?, + )); + } + if self.should_use_final_hash_stream(context) { return Ok(StreamType::FinalHash(FinalHashAggregateStream::new( self, context, partition, @@ -1047,6 +1071,19 @@ impl AggregateExec { && self.limit_options_supported_by_hash_stream() } + fn should_use_ordered_partial_aggregate_stream(&self, context: &TaskContext) -> bool { + // TODO: implement memory-limited path and remove this limitation + if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { + return false; + } + + self.mode == AggregateMode::Partial + && self.input_order_mode != InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + && self.limit_options_supported_by_hash_stream() + } + fn should_use_final_hash_stream(&self, context: &TaskContext) -> bool { // TODO: implement memory-limited path and remove this limitation if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { @@ -1062,6 +1099,21 @@ impl AggregateExec { && self.group_by.is_single() } + fn should_use_ordered_final_aggregate_stream(&self, context: &TaskContext) -> bool { + // TODO: implement memory-limited path and remove this limitation + if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { + return false; + } + + matches!( + self.mode, + AggregateMode::Final | AggregateMode::FinalPartitioned + ) && self.limit_options_supported_by_hash_stream() + && self.input_order_mode != InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + } + /// See comments in `PartialHashAggregateStream` limit optimization section fn limit_options_supported_by_hash_stream(&self) -> bool { self.limit_options.is_none() || self.is_unordered_unfiltered_group_by_distinct() @@ -2454,7 +2506,7 @@ mod tests { use crate::projection::ProjectionExec; use datafusion_physical_expr::projection::ProjectionExpr; - use futures::{FutureExt, Stream}; + use futures::{FutureExt, Stream, StreamExt}; use insta::{allow_duplicates, assert_snapshot}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -2564,6 +2616,16 @@ mod tests { Arc::new(task_ctx) } + fn new_migrated_hash_ctx(batch_size: usize) -> Arc { + Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(batch_size) + .set_bool("datafusion.execution.enable_migration_aggregate", true), + ), + ) + } + async fn check_grouping_sets( input: Arc, spill: bool, @@ -3319,6 +3381,236 @@ mod tests { Ok(()) } + /// Ensures for ordered input, `OrderedPartilAggregateStream` is used. + #[tokio::test] + async fn ordered_partial_aggregate_planning() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("sort_col", DataType::Int32, false), + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + let input_batches = vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![10, 11, 10])), + Arc::new(Int64Array::from(vec![1, 1, 1])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![2, 2])), + Arc::new(Int32Array::from(vec![20, 21])), + Arc::new(Int64Array::from(vec![1, 1])), + ], + )?, + ]; + let ordering = LexOrdering::new([PhysicalSortExpr::new_default(Arc::new( + Column::new("sort_col", 0), + ))]) + .unwrap(); + let input = TestMemoryExec::try_new(&[input_batches], Arc::clone(&schema), None)? + .try_with_sort_information(vec![ordering])?; + let input = Arc::new(TestMemoryExec::update_cache(&Arc::new(input))); + + let group_by = PhysicalGroupBy::new_single(vec![ + (col("sort_col", &schema)?, "sort_col".to_string()), + (col("group_col", &schema)?, "group_col".to_string()), + ]); + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(value_col)") + .build()?, + )]; + let aggregate = AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggr_expr, + vec![None], + input, + Arc::clone(&schema), + )?; + assert!(matches!( + aggregate.input_order_mode(), + InputOrderMode::PartiallySorted(_) + )); + + let task_ctx = new_migrated_hash_ctx(2); + let stream = aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(stream, StreamType::OrderedPartialAggregate(_))); + + let stream: SendableRecordBatchStream = stream.into(); + let output = collect(stream).await?; + assert_snapshot!(batches_to_sort_string(&output), @r" ++----------+-----------+-------------------------+ +| sort_col | group_col | COUNT(value_col)[count] | ++----------+-----------+-------------------------+ +| 1 | 10 | 2 | +| 1 | 11 | 1 | +| 2 | 20 | 1 | +| 2 | 21 | 1 | ++----------+-----------+-------------------------+ +"); + + Ok(()) + } + + /// Ensures for ordered input, `OrderedFinalAggregateStream` is used. + #[tokio::test] + async fn ordered_final_aggregate_planning() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, false), + Field::new("value", DataType::Int64, false), + ])); + let group_by = + PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(value)") + .build()?, + )]; + + let empty_input = + TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(&schema), None)?; + let partial_aggregate = AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + empty_input, + Arc::clone(&schema), + )?; + let partial_schema = partial_aggregate.schema(); + let partial_state_batch = RecordBatch::try_new( + Arc::clone(&partial_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 2, 3])), + Arc::new(Int64Array::from(vec![2, 3, 5, 7])), + ], + )?; + let ordering = LexOrdering::new([PhysicalSortExpr::new_default(Arc::new( + Column::new("key", 0), + ))]) + .unwrap(); + let final_input = + TestMemoryExec::try_new(&[vec![partial_state_batch]], partial_schema, None)? + .try_with_sort_information(vec![ordering])?; + let final_input = Arc::new(TestMemoryExec::update_cache(&Arc::new(final_input))); + + let final_aggregate = AggregateExec::try_new( + AggregateMode::Final, + group_by.as_final(), + aggr_expr, + vec![None], + final_input, + Arc::clone(&schema), + )?; + assert_eq!(final_aggregate.input_order_mode(), &InputOrderMode::Sorted); + + let task_ctx = new_migrated_hash_ctx(2); + let stream = final_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(stream, StreamType::OrderedFinalAggregate(_))); + + let stream: SendableRecordBatchStream = stream.into(); + let output = collect(stream).await?; + assert_snapshot!(batches_to_sort_string(&output), @r" ++-----+--------------+ +| key | COUNT(value) | ++-----+--------------+ +| 1 | 5 | +| 2 | 5 | +| 3 | 7 | ++-----+--------------+ +"); + + Ok(()) + } + + #[tokio::test] + async fn ordered_partial_aggregate_partially_sorted_no_emit_panic() -> Result<()> { + // Reproducer for #20445: emitting from PartiallySorted input must not + // drain more groups than the completed sort boundary allows. + let schema = Arc::new(Schema::new(vec![ + Field::new("sort_col", DataType::Int32, false), + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + // All rows share sort_col=1, so there is no completed sort boundary + // inside this batch even though there are many distinct groups. + let n = 256; + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1; n])), + Arc::new(Int32Array::from((0..n as i32).collect::>())), + Arc::new(Int64Array::from(vec![1; n])), + ], + )?; + + let ordering = LexOrdering::new([PhysicalSortExpr::new_default(Arc::new( + Column::new("sort_col", 0), + ))]) + .unwrap(); + let input = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)? + .try_with_sort_information(vec![ordering])?; + let input = Arc::new(TestMemoryExec::update_cache(&Arc::new(input))); + + let aggregate = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![ + (col("sort_col", &schema)?, "sort_col".to_string()), + (col("group_col", &schema)?, "group_col".to_string()), + ]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + assert!(matches!( + aggregate.input_order_mode(), + InputOrderMode::PartiallySorted(_) + )); + + let runtime = RuntimeEnvBuilder::default() + .with_memory_limit(4096, 1.0) + .build_arc()?; + let session_config = SessionConfig::new().with_batch_size(128).set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::UInt64(Some(u64::MAX)), + ); + let task_ctx = Arc::new( + TaskContext::default() + .with_runtime(runtime) + .with_session_config(session_config), + ); + + let mut stream: SendableRecordBatchStream = Box::pin( + OrderedPartialAggregateStream::new(&aggregate, &task_ctx, 0)?, + ); + + while let Some(result) = stream.next().await { + if let Err(e) = result { + if e.to_string().contains("Resources exhausted") { + break; + } + return Err(e); + } + } + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/aggregates/ordered_final_stream.rs b/datafusion/physical-plan/src/aggregates/ordered_final_stream.rs new file mode 100644 index 0000000000000..d76b020a8e2ac --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/ordered_final_stream.rs @@ -0,0 +1,348 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Final aggregate stream for ordered partial-state input. + +use std::ops::ControlFlow; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use futures::stream::{Stream, StreamExt}; + +use super::AggregateExec; +use super::aggregate_hash_table::{OrderedAggregateTable, OrderedFinalMarker}; +use crate::aggregates::AggregateMode; +use crate::metrics::{BaselineMetrics, RecordOutput, SpillMetrics}; +use crate::stream::EmptyRecordBatchStream; +use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream}; + +/// Final aggregate stream for `InputOrderMode::Sorted` and +/// `InputOrderMode::PartiallySorted`. +/// +/// See comments at [`super::ordered_partial_stream`] for details. +pub(crate) struct OrderedFinalAggregateStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + reservation: MemoryReservation, + baseline_metrics: BaselineMetrics, + state: Option, +} + +/// See comments at `poll_next()` for details. +enum OrderedFinalAggregateState { + ReadingInput { + table: OrderedAggregateTable, + }, + DrainingFinal { + table: OrderedAggregateTable, + }, + Done, +} + +type OrderedFinalAggregatePoll = Poll>>; +type OrderedFinalAggregateStateTransition = ControlFlow< + (OrderedFinalAggregatePoll, OrderedFinalAggregateState), + OrderedFinalAggregateState, +>; + +impl OrderedFinalAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert!(matches!( + agg.mode, + AggregateMode::Final | AggregateMode::FinalPartitioned + )); + debug_assert_ne!(agg.input_order_mode, InputOrderMode::Linear); + + let input = agg.input.execute(partition, Arc::clone(context))?; + Self::new_with_input(agg, context, partition, input, &agg.input_order_mode) + } + + pub(in crate::aggregates) fn new_with_input( + agg: &AggregateExec, + context: &Arc, + partition: usize, + input: SendableRecordBatchStream, + input_order_mode: &InputOrderMode, + ) -> Result { + debug_assert!(matches!( + agg.mode, + AggregateMode::Final | AggregateMode::FinalPartitioned + )); + debug_assert_ne!(*input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input_schema = input.schema(); + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + + let table = OrderedAggregateTable::::new_with_input_order( + agg, + partition, + &input_schema, + Arc::clone(&schema), + input_order_mode, + )?; + let reservation = + MemoryConsumer::new(format!("OrderedFinalAggregateStream[{partition}]")) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + reservation, + baseline_metrics, + state: Some(OrderedFinalAggregateState::ReadingInput { table }), + }) + } + + fn close_input(&mut self) { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + } + + /// Consumes one ordered partial-state input batch, then immediately emits + /// finalized groups if the ordering proves any group is ready. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_reading_input( + &mut self, + cx: &mut Context<'_>, + original_state: OrderedFinalAggregateState, + ) -> OrderedFinalAggregateStateTransition { + let OrderedFinalAggregateState::ReadingInput { mut table } = original_state + else { + unreachable!("expected reading input state") + }; + + match self.input.poll_next_unpin(cx) { + Poll::Pending => ControlFlow::Break(( + Poll::Pending, + OrderedFinalAggregateState::ReadingInput { table }, + )), + Poll::Ready(Some(Ok(batch))) => { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = table.aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedFinalAggregateState::ReadingInput { table }, + )); + } + + let timer = elapsed_compute.timer(); + let result = table.next_output_batch(); + timer.done(); + + match result { + // Some finalized groups can be emitted. Yield them, then + // continue aggregating input in the current state. + Ok(Some(batch)) => { + let next_state = + OrderedFinalAggregateState::ReadingInput { table }; + self.resize_reservation_for_state(&next_state); + + ControlFlow::Break(( + Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))), + next_state, + )) + } + Ok(None) => { + // Ordered variant doesn't support memory-limited + // execution, so it errors when memory reservation fails. + if let Err(e) = self.reservation.try_resize(table.memory_size()) { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedFinalAggregateState::ReadingInput { table }, + )); + } + + // Can't do early emit, continue aggregating. + ControlFlow::Continue(OrderedFinalAggregateState::ReadingInput { + table, + }) + } + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedFinalAggregateState::ReadingInput { table }, + )), + } + } + Poll::Ready(Some(Err(e))) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedFinalAggregateState::ReadingInput { table }, + )), + Poll::Ready(None) => { + self.close_input(); + table.input_done(); + ControlFlow::Continue(OrderedFinalAggregateState::DrainingFinal { table }) + } + } + } + + /// Emits one batch after input is exhausted. + /// + /// `table.input_done()` has already made every remaining group safe to emit, + /// so this state keeps draining until the table is empty. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_draining_final( + &mut self, + original_state: OrderedFinalAggregateState, + ) -> OrderedFinalAggregateStateTransition { + let OrderedFinalAggregateState::DrainingFinal { table } = original_state else { + unreachable!("expected draining final state") + }; + + let mut table = table; + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = table.next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + let next_state = if table.is_empty() { + OrderedFinalAggregateState::Done + } else { + OrderedFinalAggregateState::DrainingFinal { table } + }; + self.resize_reservation_for_state(&next_state); + + ControlFlow::Break(( + Poll::Ready(Some(Ok(batch.record_output(&self.baseline_metrics)))), + next_state, + )) + } + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedFinalAggregateState::DrainingFinal { table }, + )), + Ok(None) => { + let next_state = OrderedFinalAggregateState::Done; + self.resize_reservation_for_state(&next_state); + ControlFlow::Continue(next_state) + } + } + } + + fn resize_reservation_for_state(&mut self, state: &OrderedFinalAggregateState) { + let new_size = match state { + OrderedFinalAggregateState::ReadingInput { table } + | OrderedFinalAggregateState::DrainingFinal { table } => table.memory_size(), + OrderedFinalAggregateState::Done => 0, + }; + let _ = self.reservation.try_resize(new_size); + } +} + +impl Stream for OrderedFinalAggregateStream { + type Item = Result; + + /// Entry point for the ordered final aggregate state machine. + /// + /// See comments in [`OrderedFinalAggregateStream`] for high-level ideas. + /// + /// State transition graph: + /// + /// ```text + /// (start) + /// -> ReadingInput + /// The stream starts by polling ordered partial-state input and merging + /// those states into the ordered final aggregate table. + /// + /// ReadingInput + /// -> ReadingInput + /// Merge one input batch. If the ordering proves some groups are + /// complete, yield one final aggregate batch immediately, then continue + /// reading input. Otherwise continue directly with the next input batch. + /// -> DrainingFinal + /// Input was exhausted. Mark the table input as done so every remaining + /// group is safe to emit. + /// + /// DrainingFinal + /// -> DrainingFinal + /// One remaining final aggregate batch was yielded; repeat to continue + /// draining the table. + /// -> Done + /// All remaining groups were emitted. + /// + /// Done + /// -> (end) + /// ``` + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + let cur_state = self + .state + .take() + .expect("OrderedFinalAggregateStream state should not be None"); + + let next_state = match cur_state { + state @ OrderedFinalAggregateState::ReadingInput { .. } => { + self.handle_reading_input(cx, state) + } + state @ OrderedFinalAggregateState::DrainingFinal { .. } => { + self.handle_draining_final(state) + } + state @ OrderedFinalAggregateState::Done => { + let _ = self.reservation.try_resize(0); + self.state = Some(state); + return Poll::Ready(None); + } + }; + + match next_state { + ControlFlow::Continue(next_state) => { + self.state = Some(next_state); + continue; + } + ControlFlow::Break((poll, next_state)) => { + self.state = Some(next_state); + return poll; + } + } + } + } +} + +impl RecordBatchStream for OrderedFinalAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/datafusion/physical-plan/src/aggregates/ordered_partial_stream.rs b/datafusion/physical-plan/src/aggregates/ordered_partial_stream.rs new file mode 100644 index 0000000000000..dc8b2a3668661 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/ordered_partial_stream.rs @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partial aggregate stream for ordered group input. + +use std::ops::ControlFlow; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use futures::stream::{Stream, StreamExt}; + +use super::AggregateExec; +use super::aggregate_hash_table::{OrderedAggregateTable, OrderedPartialMarker}; +use crate::aggregates::AggregateMode; +use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput, SpillMetrics}; +use crate::stream::EmptyRecordBatchStream; +use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metrics}; + +/// Partial aggregate stream for `InputOrderMode::Sorted` and +/// `InputOrderMode::PartiallySorted`. +/// +/// # Example +/// +/// SELECT k, AVG(v) FROM t GROUP BY k; +/// +/// If the input is ordered by `k`, the aggregate can use ordered partial and +/// final stages: +/// +/// ## Plan +/// AggregateExec(stage=final, ordered) +/// -- RepartitionExec(hash(k), preserves_order=true) +/// ---- AggregateExec(stage=partial, ordered) +/// +/// ## Partial Stage Behavior +/// Input: raw rows +/// Output: partial states for all groups (for example, `AVG(x)` emits `SUM(x)` +/// and `COUNT(x)`) +/// +/// ## Final Stage Behavior +/// Input: partial states +/// Output: results for all groups (for example, `AVG(x)` calculated from the +/// state) +/// +/// # Order-based Optimization +/// +/// For the aggregation work, the hash aggregation implementation is reused. +/// +/// After each input batch, check whether any groups can be emitted eagerly to +/// improve memory efficiency. For example, if the last group key seen is +/// `k = 100`, it is safe to emit all groups with keys less than 100 because the +/// input is ordered. +/// +/// ## Implementation Note +/// +/// This is intentionally kept simple and closely maps to +/// `GroupedHashAggregateStream` to finish the refactor sooner. +/// +/// See issue for details: +/// +/// More applicable optimizations are left to future work. +pub(crate) struct OrderedPartialAggregateStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + reservation: MemoryReservation, + baseline_metrics: BaselineMetrics, + reduction_factor: metrics::RatioMetrics, + state: Option, +} + +/// See comments at `poll_next()` for details. +enum OrderedPartialAggregateState { + ReadingInput { + table: OrderedAggregateTable, + }, + DrainingFinal { + table: OrderedAggregateTable, + }, + Done, +} + +type OrderedPartialAggregatePoll = Poll>>; +type OrderedPartialAggregateStateTransition = ControlFlow< + (OrderedPartialAggregatePoll, OrderedPartialAggregateState), + OrderedPartialAggregateState, +>; + +impl OrderedPartialAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert_eq!(agg.mode, AggregateMode::Partial); + debug_assert_ne!(agg.input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input = agg.input.execute(partition, Arc::clone(context))?; + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + let reduction_factor = MetricBuilder::new(&agg.metrics) + .with_type(metrics::MetricType::Summary) + .ratio_metrics("reduction_factor", partition); + + let table = OrderedAggregateTable::::new( + agg, + partition, + Arc::clone(&schema), + )?; + let reservation = + MemoryConsumer::new(format!("OrderedPartialAggregateStream[{partition}]")) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + reservation, + baseline_metrics, + reduction_factor, + state: Some(OrderedPartialAggregateState::ReadingInput { table }), + }) + } + + fn close_input(&mut self) { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + } + + /// Consumes one ordered input batch, then immediately emits completed groups + /// if the ordering proves any group is ready. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_reading_input( + &mut self, + cx: &mut Context<'_>, + original_state: OrderedPartialAggregateState, + ) -> OrderedPartialAggregateStateTransition { + let OrderedPartialAggregateState::ReadingInput { mut table } = original_state + else { + unreachable!("expected reading input state") + }; + + match self.input.poll_next_unpin(cx) { + Poll::Pending => ControlFlow::Break(( + Poll::Pending, + OrderedPartialAggregateState::ReadingInput { table }, + )), + Poll::Ready(Some(Ok(batch))) => { + let input_rows = batch.num_rows(); + self.reduction_factor.add_total(input_rows); + + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = table.aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedPartialAggregateState::ReadingInput { table }, + )); + } + + let timer = elapsed_compute.timer(); + let result = table.next_output_batch(); + timer.done(); + + match result { + // There is some previous group results can be emitted: emit + // them, and next continuing aggreagting input (loop in the + // current state) + Ok(Some(batch)) => { + self.reduction_factor.add_part(batch.num_rows()); + let next_state = + OrderedPartialAggregateState::ReadingInput { table }; + self.resize_reservation_for_state(&next_state); + + ControlFlow::Break(( + Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))), + next_state, + )) + } + Ok(None) => { + // Ordered variant don't support memory-limited execution, + // it have to error when OOM + if let Err(e) = self.reservation.try_resize(table.memory_size()) { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedPartialAggregateState::ReadingInput { table }, + )); + } + + // Can't do early emit, continue aggregating. + ControlFlow::Continue( + OrderedPartialAggregateState::ReadingInput { table }, + ) + } + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedPartialAggregateState::ReadingInput { table }, + )), + } + } + Poll::Ready(Some(Err(e))) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedPartialAggregateState::ReadingInput { table }, + )), + // Input has exhausted, move to the final draining stage. + Poll::Ready(None) => { + self.close_input(); + table.input_done(); + ControlFlow::Continue(OrderedPartialAggregateState::DrainingFinal { + table, + }) + } + } + } + + /// Emits one batch after input is exhausted. + /// + /// `table.input_done()` has already made every remaining group safe to emit, + /// so this state keeps draining until the table is empty. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_draining_final( + &mut self, + original_state: OrderedPartialAggregateState, + ) -> OrderedPartialAggregateStateTransition { + let OrderedPartialAggregateState::DrainingFinal { table } = original_state else { + unreachable!("expected draining final state") + }; + + let mut table = table; + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = table.next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + self.reduction_factor.add_part(batch.num_rows()); + let next_state = if table.is_empty() { + OrderedPartialAggregateState::Done + } else { + OrderedPartialAggregateState::DrainingFinal { table } + }; + self.resize_reservation_for_state(&next_state); + + ControlFlow::Break(( + Poll::Ready(Some(Ok(batch.record_output(&self.baseline_metrics)))), + next_state, + )) + } + Err(e) => ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + OrderedPartialAggregateState::DrainingFinal { table }, + )), + Ok(None) => { + let next_state = OrderedPartialAggregateState::Done; + self.resize_reservation_for_state(&next_state); + ControlFlow::Continue(next_state) + } + } + } + + fn resize_reservation_for_state(&mut self, state: &OrderedPartialAggregateState) { + let new_size = match state { + OrderedPartialAggregateState::ReadingInput { table } + | OrderedPartialAggregateState::DrainingFinal { table } => { + table.memory_size() + } + OrderedPartialAggregateState::Done => 0, + }; + let _ = self.reservation.try_resize(new_size); + } +} + +impl Stream for OrderedPartialAggregateStream { + type Item = Result; + + /// Entry point for the ordered partial aggregate state machine. + /// + /// See comments in [`OrderedPartialAggregateStream`] for high-level ideas. + /// + /// State transition graph: + /// + /// ```text + /// (start) + /// -> ReadingInput + /// The stream starts by polling ordered input and aggregating batches + /// into the ordered partial aggregate table. + /// + /// ReadingInput + /// -> ReadingInput + /// Aggregate one input batch. If the ordering proves some groups are + /// complete, yield one partial-state batch immediately, then continue + /// reading input. Otherwise continue directly with the next input batch. + /// -> DrainingFinal + /// Input was exhausted. Mark the table input as done so every remaining + /// group is safe to emit. + /// + /// DrainingFinal + /// -> DrainingFinal + /// One remaining partial-state batch was yielded; repeat to continue + /// draining the table. + /// -> Done + /// All remaining groups were emitted. + /// + /// Done + /// -> (end) + /// ``` + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + let cur_state = self + .state + .take() + .expect("OrderedPartialAggregateStream state should not be None"); + + let next_state = match cur_state { + state @ OrderedPartialAggregateState::ReadingInput { .. } => { + self.handle_reading_input(cx, state) + } + state @ OrderedPartialAggregateState::DrainingFinal { .. } => { + self.handle_draining_final(state) + } + state @ OrderedPartialAggregateState::Done => { + let _ = self.reservation.try_resize(0); + self.state = Some(state); + return Poll::Ready(None); + } + }; + + match next_state { + ControlFlow::Continue(next_state) => { + self.state = Some(next_state); + continue; + } + ControlFlow::Break((poll, next_state)) => { + self.state = Some(next_state); + return poll; + } + } + } + } +} + +impl RecordBatchStream for OrderedPartialAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index d46faf9acc14a..acd5a96d88ab9 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1512,9 +1512,8 @@ mod tests { Ok(()) } - // TODO: migrate to PartialHashAggregateStream when it supports - // InputOrderMode::PartiallySorted; kept here for the legacy - // GroupedHashAggregateStream implementation. + // Migrated to OrderedPartialAggregateStream coverage in aggregates/mod.rs; + // kept here for the legacy GroupedHashAggregateStream implementation. #[tokio::test] async fn test_emit_early_with_partially_sorted() -> Result<()> { // Reproducer for #20445: EmitEarly with PartiallySorted panics in From ce800d5725f435a7f66366dfd472945daeff7917 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Mon, 29 Jun 2026 16:36:20 +0800 Subject: [PATCH 2/3] remove memory limit check, ordered case uses bounded memory --- datafusion/physical-plan/src/aggregates/mod.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8fba0e6bd9fd5..2bc1457bdc23e 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1072,11 +1072,6 @@ impl AggregateExec { } fn should_use_ordered_partial_aggregate_stream(&self, context: &TaskContext) -> bool { - // TODO: implement memory-limited path and remove this limitation - if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { - return false; - } - self.mode == AggregateMode::Partial && self.input_order_mode != InputOrderMode::Linear && !self.group_by.is_true_no_grouping() @@ -1100,11 +1095,6 @@ impl AggregateExec { } fn should_use_ordered_final_aggregate_stream(&self, context: &TaskContext) -> bool { - // TODO: implement memory-limited path and remove this limitation - if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { - return false; - } - matches!( self.mode, AggregateMode::Final | AggregateMode::FinalPartitioned From 59864f2b751ce94ec786ca7104411ca12a95b044 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Mon, 29 Jun 2026 17:07:35 +0800 Subject: [PATCH 3/3] fix clippy --- datafusion/physical-plan/src/aggregates/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 19858b9db1161..7c50a7af56bb0 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1034,7 +1034,7 @@ impl AggregateExec { .execution .enable_migration_aggregate { - if self.should_use_ordered_partial_aggregate_stream(context) { + if self.should_use_ordered_partial_aggregate_stream() { return Ok(StreamType::OrderedPartialAggregate( OrderedPartialAggregateStream::new(self, context, partition)?, )); @@ -1046,7 +1046,7 @@ impl AggregateExec { )?)); } - if self.should_use_ordered_final_aggregate_stream(context) { + if self.should_use_ordered_final_aggregate_stream() { return Ok(StreamType::OrderedFinalAggregate( OrderedFinalAggregateStream::new(self, context, partition)?, )); @@ -1078,7 +1078,7 @@ impl AggregateExec { && self.limit_options_supported_by_hash_stream() } - fn should_use_ordered_partial_aggregate_stream(&self, context: &TaskContext) -> bool { + fn should_use_ordered_partial_aggregate_stream(&self) -> bool { self.mode == AggregateMode::Partial && self.input_order_mode != InputOrderMode::Linear && !self.group_by.is_true_no_grouping() @@ -1101,7 +1101,7 @@ impl AggregateExec { && self.group_by.is_single() } - fn should_use_ordered_final_aggregate_stream(&self, context: &TaskContext) -> bool { + fn should_use_ordered_final_aggregate_stream(&self) -> bool { matches!( self.mode, AggregateMode::Final | AggregateMode::FinalPartitioned