Skip to content

Commit

Permalink
store unique group values globally to improve probablistic subsamplin…
Browse files Browse the repository at this point in the history
…g logic
  • Loading branch information
victorlin committed Dec 10, 2021
1 parent dc0eda2 commit 4e3e155
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
48 changes: 30 additions & 18 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
Returns
-------
set :
A set of all distinct group values.
dict :
A mapping of strain names to tuples corresponding to the values of the strain's group.
list :
Expand All @@ -829,7 +831,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
>>> strains = ["strain1", "strain2"]
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020-01-01", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by = ["region"]
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': ('Africa',), 'strain2': ('Europe',)}
>>> skipped_strains
Expand All @@ -839,13 +841,13 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
string.
>>> group_by = ["year", "month"]
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, (2020, 1)), 'strain2': (2020, (2020, 2))}
If we omit the grouping columns, the result will group by a dummy column.
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
>>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
>>> group_by_strain
{'strain1': ('_dummy',), 'strain2': ('_dummy',)}
Expand All @@ -861,7 +863,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
grouping to continue and print a warning message to stderr.
>>> group_by = ["year", "month", "missing_column"]
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, (2020, 1), 'unknown'), 'strain2': (2020, (2020, 2), 'unknown')}
Expand All @@ -870,7 +872,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
track which records were skipped for which reasons.
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"])
>>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"])
>>> group_by_strain
{'strain2': (2020,)}
>>> skipped_strains
Expand All @@ -880,23 +882,25 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
month information in their date fields.
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"])
>>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"])
>>> group_by_strain
{'strain2': ((2020, 2),)}
>>> skipped_strains
[{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}]
"""
metadata = metadata.loc[strains]
group_values = set()
group_by_strain = {}
skipped_strains = []

if metadata.empty:
return group_by_strain, skipped_strains
return group_values, group_by_strain, skipped_strains

if not group_by or group_by == dummy_group:
group_values = set(dummy_group_value)
group_by_strain = {strain: dummy_group_value for strain in strains}
return group_by_strain, skipped_strains
return group_values, group_by_strain, skipped_strains

group_by_set = set(group_by)

Expand Down Expand Up @@ -924,8 +928,12 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
for group in unknown_groups:
metadata[group] = 'unknown'

group_values = set(metadata.groupby(group_by).groups.keys())
if len(group_by) == 1:
# force tuple for single column group values
group_values = {(x,) for x in group_values}
group_by_strain = dict(zip(metadata.index, metadata[group_by].apply(tuple, axis=1)))
return group_by_strain, skipped_strains
return group_values, group_by_strain, skipped_strains


def expand_date_col(metadata: pd.DataFrame, group_by_set: set) -> Tuple[pd.DataFrame, List[dict]]:
Expand Down Expand Up @@ -1239,6 +1247,7 @@ def run(args):
# per group to use, so we need to calculate this number after the first pass
# and use a second pass to add records to the priority DataFrame.
group_by = args.group_by
groups = set()
sequences_per_group = args.sequences_per_group
records_per_group = None

Expand Down Expand Up @@ -1338,12 +1347,14 @@ def run(args):
# count the number of records per group. First, we need to get
# the groups for the given records.
try:
group_by_strain, skipped_strains = get_groups_for_subsampling(
chunk_groups, group_by_strain, skipped_strains = get_groups_for_subsampling(
seq_keep,
metadata,
group_by,
)

groups.update(chunk_groups)

# Track strains skipped during grouping, so users know why those
# strains were excluded from the analysis.
for skipped_strain in skipped_strains:
Expand Down Expand Up @@ -1395,6 +1406,7 @@ def run(args):
for strain in strains_to_write:
output_strains.write(f"{strain}\n")

probabilistic_used = False
# In the worst case, we need to calculate sequences per group from the
# requested maximum number of sequences and the number of sequences per
# group. Then, we need to make a second pass through the metadata to find
Expand All @@ -1419,6 +1431,10 @@ def run(args):
print(f"Sampling at {sequences_per_group} per group.")

if group_by and valid_strains:
if probabilistic_used:
# sort groups to eliminate set order randomness
sizes_per_group = create_sizes_per_group(sorted(groups), sequences_per_group, random_seed=args.subsample_seed)

# Make a second pass through the metadata.
metadata_reader = read_metadata(
args.metadata,
Expand Down Expand Up @@ -1446,8 +1462,9 @@ def run(args):
for group in unknown_groups:
metadata_copy[group] = 'unknown'
# apply priorities
poisson_max = 5 # TODO: get this from first pass along with sequences_per_group
int_group_size = sequences_per_group if sequences_per_group >= 1 else poisson_max
int_group_size = sequences_per_group
if probabilistic_used:
int_group_size = max(sizes_per_group.values())
# get index of dataframe with top n priority per group, breaking ties by last occurrence in metadata
chunk_top_priorities_index = (
metadata_copy.groupby(group_by, sort=False)['priority']
Expand All @@ -1471,12 +1488,7 @@ def run(args):
# get prioritized strains
prioritized_metadata = prioritized_metadata.loc[global_top_priorities_index]

# probabilistic subsampling
if sequences_per_group < 1:
groups = prioritized_metadata.groupby(group_by).groups.keys()
if len(group_by) == 1:
groups = [(x,) for x in groups]
sequences_per_group_map = create_sizes_per_group(groups, sequences_per_group, random_seed=args.subsample_seed)
if probabilistic_used:
prioritized_metadata['group'] = list(zip(*[prioritized_metadata[col] for col in group_by]))
prioritized_metadata['group_size'] = prioritized_metadata['group'].map(sizes_per_group)
prioritized_metadata['group_cumcount'] = prioritized_metadata.sort_values('priority', ascending=False).groupby(group_by + ['group_size']).cumcount()
Expand Down
18 changes: 9 additions & 9 deletions tests/test_filter_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestFilterGroupBy:
def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
strains = ['SEQ_1', 'SEQ_3', 'SEQ_5']
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
assert group_by_strain == {
'SEQ_1': ('_dummy',),
'SEQ_3': ('_dummy',),
Expand All @@ -29,7 +29,7 @@ def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame):
def test_filter_groupby_dummy(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
strains = metadata.index.tolist()
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
assert group_by_strain == {
'SEQ_1': ('_dummy',),
'SEQ_2': ('_dummy',),
Expand All @@ -51,7 +51,7 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys)
groups = ['country', 'year', 'month', 'invalid']
metadata = valid_metadata.copy()
strains = metadata.index.tolist()
group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups)
groups, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1), 'unknown'),
'SEQ_2': ('A', 2020, (2020, 2), 'unknown'),
Expand All @@ -67,7 +67,7 @@ def test_filter_groupby_skip_ambiguous_year(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = "XXXX-02-01"
strains = metadata.index.tolist()
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand All @@ -81,7 +81,7 @@ def test_filter_groupby_skip_missing_date(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = None
strains = metadata.index.tolist()
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand All @@ -95,7 +95,7 @@ def test_filter_groupby_skip_ambiguous_month(self, valid_metadata: pd.DataFrame)
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = "2020-XX-01"
strains = metadata.index.tolist()
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand All @@ -109,7 +109,7 @@ def test_filter_groupby_skip_missing_month(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = "2020"
strains = metadata.index.tolist()
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, ca
metadata = valid_metadata.copy()
metadata = metadata.drop('date', axis='columns')
strains = metadata.index.tolist()
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 'unknown', 'unknown'),
'SEQ_2': ('A', 'unknown', 'unknown'),
Expand All @@ -166,6 +166,6 @@ def test_filter_groupby_no_strains(self, valid_metadata: pd.DataFrame):
groups = ['country', 'year', 'month']
metadata = valid_metadata.copy()
strains = []
group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {}
assert skipped_strains == []

0 comments on commit 4e3e155

Please sign in to comment.