@@ -403,7 +403,7 @@ def find_group_cohorts(
403
403
# Invert the label_chunks mapping so we know which labels occur together.
404
404
def invert (x ) -> tuple [np .ndarray , ...]:
405
405
arr = label_chunks [x ]
406
- return tuple (arr )
406
+ return tuple (arr . tolist () )
407
407
408
408
chunks_cohorts = tlz .groupby (invert , label_chunks .keys ())
409
409
@@ -477,22 +477,37 @@ def invert(x) -> tuple[np.ndarray, ...]:
477
477
containment .nnz / math .prod (containment .shape )
478
478
)
479
479
)
480
- # Use a threshold to force some merging. We do not use the filtered
481
- # containment matrix for estimating "sparsity" because it is a bit
482
- # hard to reason about.
480
+
481
+ # Next we for-loop over groups and merge those that are quite similar.
482
+ # Use a threshold on containment to always force some merging.
483
+ # Note that we do not use the filtered containment matrix for estimating "sparsity"
484
+ # because it is a bit hard to reason about.
483
485
MIN_CONTAINMENT = 0.75 # arbitrary
484
486
mask = containment .data < MIN_CONTAINMENT
487
+
488
+ # Now we also know "exact cohorts" -- cohorts whose constituent groups
489
+ # occur in exactly the same chunks. We only need examine one member of each group.
490
+ # Skip the others by first looping over the exact cohorts, and zero out those rows.
491
+ repeated = np .concatenate ([v [1 :] for v in chunks_cohorts .values ()]).astype (int )
492
+ repeated_idx = np .searchsorted (present_labels , repeated )
493
+ for i in repeated_idx :
494
+ mask [containment .indptr [i ] : containment .indptr [i + 1 ]] = True
485
495
containment .data [mask ] = 0
486
496
containment .eliminate_zeros ()
487
497
488
- # Iterate over labels, beginning with those with most chunks
498
+ # Figure out all the labels we need to loop over later
499
+ n_overlapping_labels = containment .astype (bool ).sum (axis = 1 )
500
+ order = np .argsort (n_overlapping_labels , kind = "stable" )[::- 1 ]
501
+ # Order is such that we iterate over labels, beginning with those with most overlaps
502
+ # Also filter out any "exact" cohorts
503
+ order = order [n_overlapping_labels [order ] > 0 ]
504
+
489
505
logger .debug ("find_group_cohorts: merging cohorts" )
490
- order = np .argsort (containment .sum (axis = LABEL_AXIS ), kind = "stable" )[::- 1 ]
491
506
merged_cohorts = {}
492
507
merged_keys = set ()
493
- # TODO: we can optimize this to loop over chunk_cohorts instead
494
- # by zeroing out rows that are already in a cohort
495
508
for rowidx in order :
509
+ if present_labels [rowidx ] in merged_keys :
510
+ continue
496
511
cohidx = containment .indices [
497
512
slice (containment .indptr [rowidx ], containment .indptr [rowidx + 1 ])
498
513
]
@@ -507,6 +522,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
507
522
508
523
actual_ngroups = np .concatenate (tuple (merged_cohorts .values ())).size
509
524
expected_ngroups = present_labels .size
525
+ assert len (merged_keys ) == actual_ngroups
510
526
assert expected_ngroups == actual_ngroups , (expected_ngroups , actual_ngroups )
511
527
512
528
# sort by first label in cohort
0 commit comments