Skip to content

Commit b3ac2c2

Browse files
committed
Auto rechunk to enable blockwise reduction
Done when 1. `method` is None 2. Grouping and reducing by a 1D array We gate this on fractional change in number of chunks and change in size of largest chunk. Closes #359
1 parent f8f34b9 commit b3ac2c2

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

flox/core.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@
113113
# _simple_combine.
114114
DUMMY_AXIS = -2
115115

116+
# Thresholds below which we will automatically rechunk to blockwise if it makes sense
117+
# 1. Fractional change in number of chunks after rechunking
118+
BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD = 0.25
119+
# 2. Fractional change in max chunk size after rechunking
120+
BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD = 0.15
121+
116122
logger = logging.getLogger("flox")
117123

118124

@@ -230,6 +236,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
230236
Δl = abs(c - l)
231237
if c == 0 or newchunkidx[-1] > l:
232238
continue
239+
f = f.item() # noqa
240+
l = l.item() # noqa
233241
if Δf < Δl and f > newchunkidx[-1]:
234242
newchunkidx.append(f)
235243
else:
@@ -654,9 +662,15 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
654662
labels = factorize_((labels,), axes=())[0]
655663
chunks = array.chunks[axis]
656664
newchunks = _get_optimal_chunks_for_groups(chunks, labels)
665+
657666
if newchunks == chunks:
658667
return array
659-
else:
668+
669+
Δn = abs(len(newchunks) - len(chunks))
670+
if (Δn / len(chunks) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD) and (
671+
abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
672+
):
673+
# Less than 25% change in number of chunks, let's do it
660674
return array.rechunk({axis: newchunks})
661675

662676

@@ -2468,6 +2482,11 @@ def groupby_reduce(
24682482
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
24692483
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
24702484

2485+
if method is None and nax == 1 and not any_by_dask and by_.ndim == 1 and _issorted(by_):
2486+
# Let's try rechunking for sorted 1D by.
2487+
(single_axis,) = axis_
2488+
array = rechunk_for_blockwise(array, single_axis, by_)
2489+
24712490
if _is_first_last_reduction(func):
24722491
if has_dask and nax != 1:
24732492
raise ValueError(

0 commit comments

Comments
 (0)