@@ -639,7 +639,9 @@ def rechunk_for_cohorts(
639639 return array .rechunk ({axis : newchunks })
640640
641641
642- def rechunk_for_blockwise (array : DaskArray , axis : T_Axis , labels : np .ndarray ) -> DaskArray :
642+ def rechunk_for_blockwise (
643+ array : DaskArray , axis : T_Axis , labels : np .ndarray , * , force : bool = True
644+ ) -> DaskArray :
643645 """
644646 Rechunks array so that group boundaries line up with chunk boundaries, allowing
645647 embarrassingly parallel group reductions.
@@ -672,11 +674,17 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
672674 return array
673675
674676 Δn = abs (len (newchunks ) - len (chunks ))
675- if (Δn / len (chunks ) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD ) and (
676- abs (max (newchunks ) - max (chunks )) / max (chunks ) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
677+ if force or (
678+ (Δn / len (chunks ) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD )
679+ and (
680+ abs (max (newchunks ) - max (chunks )) / max (chunks ) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
681+ )
677682 ):
683+ logger .debug ("Rechunking to enable blockwise." )
678684 # Less than 25% change in number of chunks, let's do it
679685 return array .rechunk ({axis : newchunks })
686+ else :
687+ return array
680688
681689
682690def reindex_ (
@@ -2496,7 +2504,7 @@ def groupby_reduce(
24962504 ):
24972505 # Let's try rechunking for sorted 1D by.
24982506 (single_axis ,) = axis_
2499- array = rechunk_for_blockwise (array , single_axis , by_ )
2507+ array = rechunk_for_blockwise (array , single_axis , by_ , force = False )
25002508
25012509 if _is_first_last_reduction (func ):
25022510 if has_dask and nax != 1 :
0 commit comments