|
113 | 113 | # _simple_combine.
|
114 | 114 | DUMMY_AXIS = -2
|
115 | 115 |
|
| 116 | +# Thresholds below which we will automatically rechunk to blockwise if it makes sense |
| 117 | +# 1. Fractional change in number of chunks after rechunking |
| 118 | +BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD = 0.25 |
| 119 | +# 2. Fractional change in max chunk size after rechunking |
| 120 | +BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD = 0.15 |
| 121 | + |
116 | 122 | logger = logging.getLogger("flox")
|
117 | 123 |
|
118 | 124 |
|
@@ -230,6 +236,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
|
230 | 236 | Δl = abs(c - l)
|
231 | 237 | if c == 0 or newchunkidx[-1] > l:
|
232 | 238 | continue
|
| 239 | + f = f.item() # noqa |
| 240 | + l = l.item() # noqa |
233 | 241 | if Δf < Δl and f > newchunkidx[-1]:
|
234 | 242 | newchunkidx.append(f)
|
235 | 243 | else:
|
@@ -654,9 +662,15 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
|
654 | 662 | labels = factorize_((labels,), axes=())[0]
|
655 | 663 | chunks = array.chunks[axis]
|
656 | 664 | newchunks = _get_optimal_chunks_for_groups(chunks, labels)
|
| 665 | + |
657 | 666 | if newchunks == chunks:
|
658 | 667 | return array
|
659 |
| - else: |
| 668 | + |
| 669 | + Δn = abs(len(newchunks) - len(chunks)) |
| 670 | + if (Δn / len(chunks) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD) and ( |
| 671 | + abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD |
| 672 | + ): |
| 673 | + # Less than 25% change in number of chunks, let's do it |
660 | 674 | return array.rechunk({axis: newchunks})
|
661 | 675 |
|
662 | 676 |
|
@@ -2468,6 +2482,11 @@ def groupby_reduce(
|
2468 | 2482 | has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
|
2469 | 2483 | has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
|
2470 | 2484 |
|
| 2485 | + if method is None and nax == 1 and not any_by_dask and by_.ndim == 1 and _issorted(by_): |
| 2486 | + # Let's try rechunking for sorted 1D by. |
| 2487 | + (single_axis,) = axis_ |
| 2488 | + array = rechunk_for_blockwise(array, single_axis, by_) |
| 2489 | + |
2471 | 2490 | if _is_first_last_reduction(func):
|
2472 | 2491 | if has_dask and nax != 1:
|
2473 | 2492 | raise ValueError(
|
|
0 commit comments