Skip to content

Commit 13a67ae

Browse files
committed
Emit partial repartition batches
1 parent 1583cd4 commit 13a67ae

4 files changed

Lines changed: 308 additions & 35 deletions

File tree

datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,18 @@ impl AggregateHashTable<PartialMarker> {
6161
pub(in crate::aggregates) fn next_output_batch(
6262
&mut self,
6363
) -> Result<Option<RecordBatch>> {
64-
let output_schema = Arc::clone(&self.output_schema);
6564
let batch_size = self.batch_size;
6665
match std::mem::replace(&mut self.state, AggregateHashTableState::Done) {
6766
AggregateHashTableState::Outputting(state) => {
6867
if state.group_values.is_empty() {
6968
return Ok(None);
7069
}
7170

72-
let output = self.materialize_partial_output(state, output_schema)?;
73-
Ok(self.emit_next_materialized_batch(output, batch_size))
71+
let batch = self.materialize_partial_batch(state)?;
72+
Ok(self.emit_next_materialized_batch(
73+
MaterializedOutput::new(batch),
74+
batch_size,
75+
))
7476
}
7577
AggregateHashTableState::OutputtingMaterialized(output) => {
7678
Ok(self.emit_next_materialized_batch(output, batch_size))
@@ -82,11 +84,34 @@ impl AggregateHashTable<PartialMarker> {
8284
}
8385
}
8486

85-
fn materialize_partial_output(
87+
pub(in crate::aggregates) fn materialize_output_batch(
88+
&mut self,
89+
) -> Result<Option<RecordBatch>> {
90+
match std::mem::replace(&mut self.state, AggregateHashTableState::Done) {
91+
AggregateHashTableState::Outputting(state) => {
92+
if state.group_values.is_empty() {
93+
return Ok(None);
94+
}
95+
96+
self.materialize_partial_batch(state).map(Some)
97+
}
98+
AggregateHashTableState::Done => Ok(None),
99+
AggregateHashTableState::Building(_) => {
100+
internal_err!(
101+
"materialize_output_batch must be called in the outputting state"
102+
)
103+
}
104+
AggregateHashTableState::OutputtingMaterialized(_) => {
105+
internal_err!("partial aggregate output is already materialized")
106+
}
107+
}
108+
}
109+
110+
fn materialize_partial_batch(
86111
&self,
87112
mut state: AggregateHashTableBuffer,
88-
output_schema: SchemaRef,
89-
) -> Result<MaterializedOutput> {
113+
) -> Result<RecordBatch> {
114+
let output_schema = Arc::clone(&self.output_schema);
90115
let emit_to = EmitTo::All;
91116
let timer = self.group_by_metrics.emitting_time.timer();
92117
let mut output = state.group_values.emit(emit_to)?;
@@ -98,7 +123,7 @@ impl AggregateHashTable<PartialMarker> {
98123

99124
let batch = RecordBatch::try_new(output_schema, output)?;
100125
debug_assert!(batch.num_rows() > 0);
101-
Ok(MaterializedOutput::new(batch))
126+
Ok(batch)
102127
}
103128

104129
pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool {
@@ -117,14 +142,35 @@ impl AggregateHashTable<PartialMarker> {
117142
.support_partial_repartition()
118143
}
119144

120-
pub(in crate::aggregates) fn append_new_groups_to_partitions(
145+
pub(in crate::aggregates) fn append_new_group_partitions(
121146
&self,
122-
partitions: &mut [Vec<usize>],
147+
group_partitions: &mut Vec<usize>,
148+
num_partitions: usize,
123149
) -> Result<()> {
124-
if partitions.is_empty() {
150+
if num_partitions == 0 {
125151
return Ok(());
126152
}
127153

154+
if num_partitions.is_power_of_two() {
155+
let mask = num_partitions - 1;
156+
self.append_new_groups_with_partition(group_partitions, |hash| {
157+
(hash as usize) & mask
158+
})
159+
} else {
160+
self.append_new_groups_with_partition(group_partitions, |hash| {
161+
(hash as usize) % num_partitions
162+
})
163+
}
164+
}
165+
166+
fn append_new_groups_with_partition<F>(
167+
&self,
168+
group_partitions: &mut Vec<usize>,
169+
compute_partition: F,
170+
) -> Result<()>
171+
where
172+
F: Fn(u64) -> usize,
173+
{
128174
let state = self.state.building();
129175
for &row in &state.new_group_rows {
130176
let Some(&group_index) = state.batch_group_indices.get(row) else {
@@ -138,8 +184,10 @@ impl AggregateHashTable<PartialMarker> {
138184
);
139185
};
140186

141-
let partition = partition_for_hash(hash, partitions.len());
142-
partitions[partition].push(group_index);
187+
if group_index >= group_partitions.len() {
188+
group_partitions.resize(group_index + 1, 0);
189+
}
190+
group_partitions[group_index] = compute_partition(hash);
143191
}
144192

145193
Ok(())
@@ -293,15 +341,6 @@ impl AggregateHashTable<PartialMarker> {
293341
}
294342
}
295343

296-
fn partition_for_hash(hash: u64, num_partitions: usize) -> usize {
297-
debug_assert!(num_partitions > 0);
298-
if num_partitions.is_power_of_two() {
299-
(hash as usize) & (num_partitions - 1)
300-
} else {
301-
(hash as usize) % num_partitions
302-
}
303-
}
304-
305344
impl AggregateHashTable<PartialSkipMarker> {
306345
pub(in crate::aggregates) fn convert_batch_to_state(
307346
&mut self,

0 commit comments

Comments
 (0)