@@ -639,7 +639,9 @@ def rechunk_for_cohorts(
639
639
return array .rechunk ({axis : newchunks })
640
640
641
641
642
- def rechunk_for_blockwise (array : DaskArray , axis : T_Axis , labels : np .ndarray ) -> DaskArray :
642
+ def rechunk_for_blockwise (
643
+ array : DaskArray , axis : T_Axis , labels : np .ndarray , * , force : bool = True
644
+ ) -> DaskArray :
643
645
"""
644
646
Rechunks array so that group boundaries line up with chunk boundaries, allowing
645
647
embarrassingly parallel group reductions.
@@ -672,11 +674,17 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
672
674
return array
673
675
674
676
Δn = abs (len (newchunks ) - len (chunks ))
675
- if (Δn / len (chunks ) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD ) and (
676
- abs (max (newchunks ) - max (chunks )) / max (chunks ) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
677
+ if force or (
678
+ (Δn / len (chunks ) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD )
679
+ and (
680
+ abs (max (newchunks ) - max (chunks )) / max (chunks ) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
681
+ )
677
682
):
683
+ logger .debug ("Rechunking to enable blockwise." )
678
684
# Less than 25% change in number of chunks, let's do it
679
685
return array .rechunk ({axis : newchunks })
686
+ else :
687
+ return array
680
688
681
689
682
690
def reindex_ (
@@ -2496,7 +2504,7 @@ def groupby_reduce(
2496
2504
):
2497
2505
# Let's try rechunking for sorted 1D by.
2498
2506
(single_axis ,) = axis_
2499
- array = rechunk_for_blockwise (array , single_axis , by_ )
2507
+ array = rechunk_for_blockwise (array , single_axis , by_ , force = False )
2500
2508
2501
2509
if _is_first_last_reduction (func ):
2502
2510
if has_dask and nax != 1 :
0 commit comments