|
113 | 113 | # _simple_combine.
|
114 | 114 | DUMMY_AXIS = -2
|
115 | 115 |
|
| 116 | +# Thresholds below which we will automatically rechunk to blockwise if it makes sense |
| 117 | +# 1. Fractional change in number of chunks after rechunking |
| 118 | +BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD = 0.25 |
| 119 | +# 2. Fractional change in max chunk size after rechunking |
| 120 | +BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD = 0.15 |
| 121 | + |
116 | 122 | logger = logging.getLogger("flox")
|
117 | 123 |
|
118 | 124 |
|
@@ -230,6 +236,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
|
230 | 236 | Δl = abs(c - l)
|
231 | 237 | if c == 0 or newchunkidx[-1] > l:
|
232 | 238 | continue
|
| 239 | + f = f.item() # noqa |
| 240 | + l = l.item() # noqa |
233 | 241 | if Δf < Δl and f > newchunkidx[-1]:
|
234 | 242 | newchunkidx.append(f)
|
235 | 243 | else:
|
@@ -651,12 +659,20 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
|
651 | 659 | DaskArray
|
652 | 660 | Rechunked array
|
653 | 661 | """
|
654 |
| - labels = factorize_((labels,), axes=())[0] |
655 | 662 | chunks = array.chunks[axis]
|
| 663 | + if len(chunks) == 1: |
| 664 | + return array |
| 665 | + |
| 666 | + labels = factorize_((labels,), axes=())[0] |
656 | 667 | newchunks = _get_optimal_chunks_for_groups(chunks, labels)
|
657 | 668 | if newchunks == chunks:
|
658 | 669 | return array
|
659 |
| - else: |
| 670 | + |
| 671 | + Δn = abs(len(newchunks) - len(chunks)) |
| 672 | + if (Δn / len(chunks) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD) and ( |
| 673 | + abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD |
| 674 | + ): |
| 675 | + # Less than 25% change in number of chunks, let's do it |
660 | 676 | return array.rechunk({axis: newchunks})
|
661 | 677 |
|
662 | 678 |
|
@@ -2468,6 +2484,11 @@ def groupby_reduce(
|
2468 | 2484 | has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
|
2469 | 2485 | has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
|
2470 | 2486 |
|
| 2487 | + if method is None and not any_by_dask and by_.ndim == 1 and _issorted(by_): |
| 2488 | + # Let's try rechunking for sorted 1D by. |
| 2489 | + (single_axis,) = axis_ |
| 2490 | + array = rechunk_for_blockwise(array, single_axis, by_) |
| 2491 | + |
2471 | 2492 | if _is_first_last_reduction(func):
|
2472 | 2493 | if has_dask and nax != 1:
|
2473 | 2494 | raise ValueError(
|
|
0 commit comments