@@ -365,16 +365,6 @@ def find_group_cohorts(
365
365
if not is_duck_array (labels ):
366
366
labels = np .asarray (labels )
367
367
368
- if is_duck_dask_array (labels ):
369
- import dask
370
-
371
- ((bitmask , nlabels , ilabels ),) = dask .compute (
372
- dask .delayed (_compute_label_chunk_bitmask )(labels , chunks )
373
- )
374
- else :
375
- bitmask , nlabels , ilabels = _compute_label_chunk_bitmask (labels , chunks )
376
-
377
- shape = tuple (sum (c ) for c in chunks )
378
368
nchunks = math .prod (len (c ) for c in chunks )
379
369
380
370
# assumes that `labels` are factorized
@@ -387,8 +377,14 @@ def find_group_cohorts(
387
377
if nchunks == 1 :
388
378
return "blockwise" , {(0 ,): list (range (nlabels ))}
389
379
390
- labels = np .broadcast_to (labels , shape [- labels .ndim :])
391
- bitmask = _compute_label_chunk_bitmask (labels , chunks , nlabels )
380
+ if is_duck_dask_array (labels ):
381
+ import dask
382
+
383
+ ((bitmask , nlabels , ilabels ),) = dask .compute (
384
+ dask .delayed (_compute_label_chunk_bitmask )(labels , chunks , nlabels )
385
+ )
386
+ else :
387
+ bitmask , nlabels , ilabels = _compute_label_chunk_bitmask (labels , chunks , nlabels )
392
388
393
389
CHUNK_AXIS , LABEL_AXIS = 0 , 1
394
390
chunks_per_label = bitmask .sum (axis = CHUNK_AXIS )
0 commit comments