Skip to content

Commit 9fdae45

Browse files
committed
fix
1 parent a5f208a commit 9fdae45

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

flox/core.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -365,16 +365,6 @@ def find_group_cohorts(
365365
if not is_duck_array(labels):
366366
labels = np.asarray(labels)
367367

368-
if is_duck_dask_array(labels):
369-
import dask
370-
371-
((bitmask, nlabels, ilabels),) = dask.compute(
372-
dask.delayed(_compute_label_chunk_bitmask)(labels, chunks)
373-
)
374-
else:
375-
bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks)
376-
377-
shape = tuple(sum(c) for c in chunks)
378368
nchunks = math.prod(len(c) for c in chunks)
379369

380370
# assumes that `labels` are factorized
@@ -387,8 +377,14 @@ def find_group_cohorts(
387377
if nchunks == 1:
388378
return "blockwise", {(0,): list(range(nlabels))}
389379

390-
labels = np.broadcast_to(labels, shape[-labels.ndim :])
391-
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)
380+
if is_duck_dask_array(labels):
381+
import dask
382+
383+
((bitmask, nlabels, ilabels),) = dask.compute(
384+
dask.delayed(_compute_label_chunk_bitmask)(labels, chunks, nlabels)
385+
)
386+
else:
387+
bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks, nlabels)
392388

393389
CHUNK_AXIS, LABEL_AXIS = 0, 1
394390
chunks_per_label = bitmask.sum(axis=CHUNK_AXIS)

0 commit comments

Comments
 (0)