@@ -61,16 +61,18 @@ impl AggregateHashTable<PartialMarker> {
6161 pub ( in crate :: aggregates) fn next_output_batch (
6262 & mut self ,
6363 ) -> Result < Option < RecordBatch > > {
64- let output_schema = Arc :: clone ( & self . output_schema ) ;
6564 let batch_size = self . batch_size ;
6665 match std:: mem:: replace ( & mut self . state , AggregateHashTableState :: Done ) {
6766 AggregateHashTableState :: Outputting ( state) => {
6867 if state. group_values . is_empty ( ) {
6968 return Ok ( None ) ;
7069 }
7170
72- let output = self . materialize_partial_output ( state, output_schema) ?;
73- Ok ( self . emit_next_materialized_batch ( output, batch_size) )
71+ let batch = self . materialize_partial_batch ( state) ?;
72+ Ok ( self . emit_next_materialized_batch (
73+ MaterializedOutput :: new ( batch) ,
74+ batch_size,
75+ ) )
7476 }
7577 AggregateHashTableState :: OutputtingMaterialized ( output) => {
7678 Ok ( self . emit_next_materialized_batch ( output, batch_size) )
@@ -82,11 +84,34 @@ impl AggregateHashTable<PartialMarker> {
8284 }
8385 }
8486
85- fn materialize_partial_output (
87+ pub ( in crate :: aggregates) fn materialize_output_batch (
88+ & mut self ,
89+ ) -> Result < Option < RecordBatch > > {
90+ match std:: mem:: replace ( & mut self . state , AggregateHashTableState :: Done ) {
91+ AggregateHashTableState :: Outputting ( state) => {
92+ if state. group_values . is_empty ( ) {
93+ return Ok ( None ) ;
94+ }
95+
96+ self . materialize_partial_batch ( state) . map ( Some )
97+ }
98+ AggregateHashTableState :: Done => Ok ( None ) ,
99+ AggregateHashTableState :: Building ( _) => {
100+ internal_err ! (
101+ "materialize_output_batch must be called in the outputting state"
102+ )
103+ }
104+ AggregateHashTableState :: OutputtingMaterialized ( _) => {
105+ internal_err ! ( "partial aggregate output is already materialized" )
106+ }
107+ }
108+ }
109+
110+ fn materialize_partial_batch (
86111 & self ,
87112 mut state : AggregateHashTableBuffer ,
88- output_schema : SchemaRef ,
89- ) -> Result < MaterializedOutput > {
113+ ) -> Result < RecordBatch > {
114+ let output_schema = Arc :: clone ( & self . output_schema ) ;
90115 let emit_to = EmitTo :: All ;
91116 let timer = self . group_by_metrics . emitting_time . timer ( ) ;
92117 let mut output = state. group_values . emit ( emit_to) ?;
@@ -98,7 +123,7 @@ impl AggregateHashTable<PartialMarker> {
98123
99124 let batch = RecordBatch :: try_new ( output_schema, output) ?;
100125 debug_assert ! ( batch. num_rows( ) > 0 ) ;
101- Ok ( MaterializedOutput :: new ( batch) )
126+ Ok ( batch)
102127 }
103128
104129 pub ( in crate :: aggregates) fn can_skip_aggregation ( & self ) -> bool {
@@ -117,14 +142,35 @@ impl AggregateHashTable<PartialMarker> {
117142 . support_partial_repartition ( )
118143 }
119144
120- pub ( in crate :: aggregates) fn append_new_groups_to_partitions (
145+ pub ( in crate :: aggregates) fn append_new_group_partitions (
121146 & self ,
122- partitions : & mut [ Vec < usize > ] ,
147+ group_partitions : & mut Vec < usize > ,
148+ num_partitions : usize ,
123149 ) -> Result < ( ) > {
124- if partitions . is_empty ( ) {
150+ if num_partitions == 0 {
125151 return Ok ( ( ) ) ;
126152 }
127153
154+ if num_partitions. is_power_of_two ( ) {
155+ let mask = num_partitions - 1 ;
156+ self . append_new_groups_with_partition ( group_partitions, |hash| {
157+ ( hash as usize ) & mask
158+ } )
159+ } else {
160+ self . append_new_groups_with_partition ( group_partitions, |hash| {
161+ ( hash as usize ) % num_partitions
162+ } )
163+ }
164+ }
165+
166+ fn append_new_groups_with_partition < F > (
167+ & self ,
168+ group_partitions : & mut Vec < usize > ,
169+ compute_partition : F ,
170+ ) -> Result < ( ) >
171+ where
172+ F : Fn ( u64 ) -> usize ,
173+ {
128174 let state = self . state . building ( ) ;
129175 for & row in & state. new_group_rows {
130176 let Some ( & group_index) = state. batch_group_indices . get ( row) else {
@@ -138,8 +184,10 @@ impl AggregateHashTable<PartialMarker> {
138184 ) ;
139185 } ;
140186
141- let partition = partition_for_hash ( hash, partitions. len ( ) ) ;
142- partitions[ partition] . push ( group_index) ;
187+ if group_index >= group_partitions. len ( ) {
188+ group_partitions. resize ( group_index + 1 , 0 ) ;
189+ }
190+ group_partitions[ group_index] = compute_partition ( hash) ;
143191 }
144192
145193 Ok ( ( ) )
@@ -293,15 +341,6 @@ impl AggregateHashTable<PartialMarker> {
293341 }
294342}
295343
296- fn partition_for_hash ( hash : u64 , num_partitions : usize ) -> usize {
297- debug_assert ! ( num_partitions > 0 ) ;
298- if num_partitions. is_power_of_two ( ) {
299- ( hash as usize ) & ( num_partitions - 1 )
300- } else {
301- ( hash as usize ) % num_partitions
302- }
303- }
304-
305344impl AggregateHashTable < PartialSkipMarker > {
306345 pub ( in crate :: aggregates) fn convert_batch_to_state (
307346 & mut self ,
0 commit comments