Skip to content

Commit f8f34b9

Browse files
authored
Optimize for-loop merging of cohorts. (#378)
* Optimize for-loop merging of cohorts. Do this by skipping perfect cohorts that we already know about. * Add new benchmark * Fix * Cleanup print statements * minimize diff * cleanup * Update snapshot
1 parent cb3fc1f commit f8f34b9

File tree

3 files changed

+7945
-15876
lines changed

3 files changed

+7945
-15876
lines changed

asv_bench/benchmarks/cohorts.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def time_find_group_cohorts(self):
4747
except AttributeError:
4848
pass
4949

50+
def track_num_cohorts(self):
51+
return len(self.chunks_cohorts())
52+
5053
def time_graph_construct(self):
5154
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)
5255

@@ -60,10 +63,11 @@ def track_num_tasks_optimized(self):
6063
def track_num_layers(self):
6164
return len(self.result.dask.layers)
6265

66+
track_num_cohorts.unit = "cohorts" # type: ignore[attr-defined] # Lazy
6367
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
6468
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
6569
track_num_layers.unit = "layers" # type: ignore[attr-defined] # Lazy
66-
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers]:
70+
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers, track_num_cohorts]:
6771
f.repeat = 1 # type: ignore[attr-defined] # Lazy
6872
f.rounds = 1 # type: ignore[attr-defined] # Lazy
6973
f.number = 1 # type: ignore[attr-defined] # Lazy

flox/core.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def find_group_cohorts(
403403
# Invert the label_chunks mapping so we know which labels occur together.
404404
def invert(x) -> tuple[np.ndarray, ...]:
405405
arr = label_chunks[x]
406-
return tuple(arr)
406+
return tuple(arr.tolist())
407407

408408
chunks_cohorts = tlz.groupby(invert, label_chunks.keys())
409409

@@ -477,22 +477,37 @@ def invert(x) -> tuple[np.ndarray, ...]:
477477
containment.nnz / math.prod(containment.shape)
478478
)
479479
)
480-
# Use a threshold to force some merging. We do not use the filtered
481-
# containment matrix for estimating "sparsity" because it is a bit
482-
# hard to reason about.
480+
481+
# Next we for-loop over groups and merge those that are quite similar.
482+
# Use a threshold on containment to always force some merging.
483+
# Note that we do not use the filtered containment matrix for estimating "sparsity"
484+
# because it is a bit hard to reason about.
483485
MIN_CONTAINMENT = 0.75 # arbitrary
484486
mask = containment.data < MIN_CONTAINMENT
487+
488+
# Now we also know "exact cohorts" -- cohorts whose constituent groups
489+
# occur in exactly the same chunks. We only need examine one member of each group.
490+
# Skip the others by first looping over the exact cohorts, and zero out those rows.
491+
repeated = np.concatenate([v[1:] for v in chunks_cohorts.values()]).astype(int)
492+
repeated_idx = np.searchsorted(present_labels, repeated)
493+
for i in repeated_idx:
494+
mask[containment.indptr[i] : containment.indptr[i + 1]] = True
485495
containment.data[mask] = 0
486496
containment.eliminate_zeros()
487497

488-
# Iterate over labels, beginning with those with most chunks
498+
# Figure out all the labels we need to loop over later
499+
n_overlapping_labels = containment.astype(bool).sum(axis=1)
500+
order = np.argsort(n_overlapping_labels, kind="stable")[::-1]
501+
# Order is such that we iterate over labels, beginning with those with most overlaps
502+
# Also filter out any "exact" cohorts
503+
order = order[n_overlapping_labels[order] > 0]
504+
489505
logger.debug("find_group_cohorts: merging cohorts")
490-
order = np.argsort(containment.sum(axis=LABEL_AXIS), kind="stable")[::-1]
491506
merged_cohorts = {}
492507
merged_keys = set()
493-
# TODO: we can optimize this to loop over chunk_cohorts instead
494-
# by zeroing out rows that are already in a cohort
495508
for rowidx in order:
509+
if present_labels[rowidx] in merged_keys:
510+
continue
496511
cohidx = containment.indices[
497512
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
498513
]
@@ -507,6 +522,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
507522

508523
actual_ngroups = np.concatenate(tuple(merged_cohorts.values())).size
509524
expected_ngroups = present_labels.size
525+
assert len(merged_keys) == actual_ngroups
510526
assert expected_ngroups == actual_ngroups, (expected_ngroups, actual_ngroups)
511527

512528
# sort by first label in cohort

0 commit comments

Comments
 (0)