From 1b106d5e25f96170ee8c1ec8c3a09b0057e18a21 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 8 Dec 2021 13:47:57 -0800 Subject: [PATCH 1/8] move date logic to new function expand_date_col() --- augur/filter.py | 84 +++++++++++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/augur/filter.py b/augur/filter.py index ef3005a2f..1797bfb0a 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -17,7 +17,7 @@ import sys from tempfile import NamedTemporaryFile import treetime.utils -from typing import Collection +from typing import Collection, List, Tuple from .index import index_sequences, index_vcf from .io import open_file, read_metadata, read_sequences, write_sequences @@ -912,36 +912,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): df_dates = pd.DataFrame({'year': 'unknown', 'month': 'unknown'}, index=metadata.index) metadata = pd.concat([metadata, df_dates], axis=1) else: - # replace date with year/month/day as nullable ints - date_cols = ['year', 'month', 'day'] - df_dates = (metadata['date'].str.split('-', n=2, expand=True) - .set_axis(date_cols, axis=1)) - for col in date_cols: - df_dates[col] = pd.to_numeric(df_dates[col], errors='coerce').astype(pd.Int64Dtype()) - metadata = pd.concat([metadata.drop('date', axis=1), df_dates], axis=1) - if 'year' in group_by_set: - # skip ambiguous years - df_skip = metadata[metadata['year'].isnull()] - metadata.dropna(subset=['year'], inplace=True) - for strain in df_skip.index: - skipped_strains.append({ - "strain": strain, - "filter": "skip_group_by_with_ambiguous_year", - "kwargs": "", - }) - if 'month' in group_by_set: - # skip ambiguous months - df_skip = metadata[metadata['month'].isnull()] - metadata.dropna(subset=['month'], inplace=True) - for strain in df_skip.index: - skipped_strains.append({ - "strain": strain, - "filter": "skip_group_by_with_ambiguous_month", - "kwargs": "", - }) - # month = (year, month) - metadata['month'] = list(zip(metadata['year'], metadata['month'])) - # TODO: support group by day + metadata, skipped_strains = expand_date_col(metadata, group_by_set) unknown_groups = group_by_set - set(metadata.columns) if unknown_groups: @@ -954,6 +925,57 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): return group_by_strain, skipped_strains +def expand_date_col(metadata: pd.DataFrame, group_by_set: set) -> Tuple[pd.DataFrame, List[dict]]: + """Expand the date column of a DataFrame to year=year, month=(year, month). + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata containing date column. + group_by_set : set + A set of metadata columns to group records by. + + Returns + ------- + pandas.DataFrame : + The input metadata with expanded date columns. + list : + A list of dictionaries with strains that were skipped from grouping and the reason why (see also: `apply_filters` output). + """ + metadata_new = metadata.copy() + skipped_strains = [] + # replace date with year/month/day as nullable ints + date_cols = ['year', 'month', 'day'] + df_dates = (metadata_new['date'].str.split('-', n=2, expand=True) + .set_axis(date_cols, axis=1)) + for col in date_cols: + df_dates[col] = pd.to_numeric(df_dates[col], errors='coerce').astype(pd.Int64Dtype()) + metadata_new = pd.concat([metadata_new.drop('date', axis=1), df_dates], axis=1) + if 'year' in group_by_set: + # skip ambiguous years + df_skip = metadata_new[metadata_new['year'].isnull()] + metadata_new.dropna(subset=['year'], inplace=True) + for strain in df_skip.index: + skipped_strains.append({ + "strain": strain, + "filter": "skip_group_by_with_ambiguous_year", + "kwargs": "", + }) + if 'month' in group_by_set: + # skip ambiguous months + df_skip = metadata_new[metadata_new['month'].isnull()] + metadata_new.dropna(subset=['month'], inplace=True) + for strain in df_skip.index: + skipped_strains.append({ + "strain": strain, + "filter": "skip_group_by_with_ambiguous_month", + "kwargs": "", + }) + # month = (year, month) + metadata_new['month'] = list(zip(metadata_new['year'], metadata_new['month'])) + # TODO: support group by day + return metadata_new, skipped_strains + class PriorityQueue: """A priority queue implementation that automatically replaces lower priority items in the heap with incoming higher priority items. From dc0eda25933f13da87fc0b5dead78f827bef15ba Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 10 Dec 2021 00:02:35 -0800 Subject: [PATCH 2/8] rewrite PriorityQueue logic with pandas functions - remove `class PriorityQueue` - use `prioritized_metadata` DataFrame in place of `queues_per_group` - repurpose `create_queues_per_group` to `create_sizes_per_group` - other logical refactoring: - use global dummy group key and value - key is `list`: pd.DataFrame.groupby does not take a tuple as grouping key, also our `--group-by` is stored as list already. - value is `tuple: `get_groups_for_subsampling` currently returns group values in this form. - use records_per_group for _dummy - replace conditional logic of `records_per_group is not None` with `group_by` - add functional tests --- augur/filter.py | 306 +++++++----------- .../filter/metadata_ambiguous_months.tsv | 11 + tests/functional/filter_groupby.t | 123 +++++++ 3 files changed, 251 insertions(+), 189 deletions(-) create mode 100644 tests/functional/filter/metadata_ambiguous_months.tsv create mode 100644 tests/functional/filter_groupby.t diff --git a/augur/filter.py b/augur/filter.py index 1797bfb0a..e27916d33 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -25,6 +25,9 @@ comment_char = '#' +dummy_group = ['_dummy',] +dummy_group_value = ('_dummy',) + SEQUENCE_ONLY_FILTERS = ( "min_length", "non_nucleotide", @@ -891,8 +894,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): if metadata.empty: return group_by_strain, skipped_strains - if not group_by or group_by == ('_dummy',): - group_by_strain = {strain: ('_dummy',) for strain in strains} + if not group_by or group_by == dummy_group: + group_by_strain = {strain: dummy_group_value for strain in strains} return group_by_strain, skipped_strains group_by_set = set(group_by) @@ -976,115 +979,39 @@ def expand_date_col(metadata: pd.DataFrame, group_by_set: set) -> Tuple[pd.DataF # TODO: support group by day return metadata_new, skipped_strains -class PriorityQueue: - """A priority queue implementation that automatically replaces lower priority - items in the heap with incoming higher priority items. - - Add a single record to a heap with a maximum of 2 records. - - >>> queue = PriorityQueue(max_size=2) - >>> queue.add({"strain": "strain1"}, 0.5) - 1 - - Add another record with a higher priority. The queue should be at its maximum - size. - - >>> queue.add({"strain": "strain2"}, 1.0) - 2 - >>> queue.heap - [(0.5, 0, {'strain': 'strain1'}), (1.0, 1, {'strain': 'strain2'})] - >>> list(queue.get_items()) - [{'strain': 'strain1'}, {'strain': 'strain2'}] - - Add a higher priority record that causes the queue to exceed its maximum - size. The resulting queue should contain the two highest priority records - after the lowest priority record is removed. - - >>> queue.add({"strain": "strain3"}, 2.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain2'}, {'strain': 'strain3'}] - - Add a record with the same priority as another record, forcing the duplicate - to be resolved by removing the oldest entry. - - >>> queue.add({"strain": "strain4"}, 1.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain4'}, {'strain': 'strain3'}] - - """ - def __init__(self, max_size): - """Create a fixed size heap (priority queue) - - """ - self.max_size = max_size - self.heap = [] - self.counter = itertools.count() - - def add(self, item, priority): - """Add an item to the queue with a given priority. - - If adding the item causes the queue to exceed its maximum size, replace - the lowest priority item with the given item. The queue stores items - with an additional heap id value (a count) to resolve ties between items - with equal priority (favoring the most recently added item). - - """ - heap_id = next(self.counter) - - if len(self.heap) >= self.max_size: - heapq.heappushpop(self.heap, (priority, heap_id, item)) - else: - heapq.heappush(self.heap, (priority, heap_id, item)) - - return len(self.heap) - - def get_items(self): - """Return each item in the queue in order. - - Yields - ------ - Any - Item stored in the queue. - """ - for priority, heap_id, item in self.heap: - yield item - - -def create_queues_by_group(groups, max_size, max_attempts=100, random_seed=None): - """Create a dictionary of priority queues per group for the given maximum size. +def create_sizes_per_group(groups, max_size, max_attempts=100, random_seed=None): + """Create a dictionary of sizes per group for the given maximum size. When the maximum size is fractional, probabilistically sample the maximum size from a Poisson distribution. Make at least the given number of maximum - attempts to create queues for which the sum of their maximum sizes is + attempts to create groups for which the sum of their maximum sizes is greater than zero. - Create queues for two groups with a fixed maximum size. + Create sizes for two groups with a fixed maximum size. >>> groups = ("2015", "2016") - >>> queues = create_queues_by_group(groups, 2) - >>> sum(queue.max_size for queue in queues.values()) + >>> sizes = create_sizes_per_group(groups, 2) + >>> sum(sizes.values()) 4 - Create queues for two groups with a fractional maximum size. Their total max + Create sizes for two groups with a fractional maximum size. Their total max size should still be an integer value greater than zero. >>> seed = 314159 - >>> queues = create_queues_by_group(groups, 0.1, random_seed=seed) - >>> int(sum(queue.max_size for queue in queues.values())) > 0 + >>> sizes = create_sizes_per_group(groups, 0.1, random_seed=seed) + >>> int(sum(sizes.values())) > 0 True A subsequent run of this function with the same groups and random seed - should produce the same queues and queue sizes. + should produce the same sizes. - >>> more_queues = create_queues_by_group(groups, 0.1, random_seed=seed) - >>> [queue.max_size for queue in queues.values()] == [queue.max_size for queue in more_queues.values()] + >>> more_sizes = create_sizes_per_group(groups, 0.1, random_seed=seed) + >>> list(sizes.values()) == list(more_sizes.values()) True """ - queues_by_group = {} + size_per_group = {} total_max_size = 0 attempts = 0 @@ -1092,22 +1019,22 @@ def create_queues_by_group(groups, max_size, max_attempts=100, random_seed=None) random_generator = np.random.default_rng(random_seed) # For small fractional maximum sizes, it is possible to randomly select - # maximum queue sizes that all equal zero. When this happens, filtering - # fails unexpectedly. We make multiple attempts to create queues with - # maximum sizes greater than zero for at least one queue. + # maximum sizes that all equal zero. When this happens, filtering + # fails unexpectedly. We make multiple attempts to create sizes with + # maximum sizes greater than zero for at least one group. while total_max_size == 0 and attempts < max_attempts: - for group in sorted(groups): + for group in groups: if max_size < 1.0: - queue_max_size = random_generator.poisson(max_size) + group_max_size = random_generator.poisson(max_size) else: - queue_max_size = max_size + group_max_size = max_size - queues_by_group[group] = PriorityQueue(queue_max_size) + size_per_group[group] = group_max_size - total_max_size = sum(queue.max_size for queue in queues_by_group.values()) + total_max_size = sum(size_per_group.values()) attempts += 1 - return queues_by_group + return size_per_group def register_arguments(parser): @@ -1306,27 +1233,25 @@ def run(args): # group such that we select at most the requested maximum number of # sequences in a single pass through the metadata. # - # Each case relies on a priority queue to track the highest priority records + # Each case relies on a priority DataFrame to track the highest priority records # per group. In the best case, we can track these records in a single pass # through the metadata. In the worst case, we don't know how many sequences # per group to use, so we need to calculate this number after the first pass - # and use a second pass to add records to the queue. + # and use a second pass to add records to the priority DataFrame. group_by = args.group_by sequences_per_group = args.sequences_per_group records_per_group = None - if group_by and args.subsample_max_sequences: - # In this case, we need two passes through the metadata with the first - # pass used to count the number of records per group. + if args.subsample_max_sequences: records_per_group = defaultdict(int) - elif not group_by and args.subsample_max_sequences: - group_by = ("_dummy",) - sequences_per_group = args.subsample_max_sequences + if not group_by: + group_by = dummy_group + sequences_per_group = args.subsample_max_sequences - # If we are grouping data, use queues to store the highest priority strains + # If we are grouping data, use DataFrame to store the highest priority strains # for each group. When no priorities are provided, they will be randomly # generated. - queues_by_group = None + prioritized_metadata = pd.DataFrame() if group_by: # Use user-defined priorities, if possible. Otherwise, setup a # corresponding dictionary that returns a random float for each strain. @@ -1428,32 +1353,13 @@ def run(args): if args.output_log: output_log_writer.writerow(skipped_strain) - if args.subsample_max_sequences and records_per_group is not None: + if args.subsample_max_sequences and group_by: # Count the number of records per group. We will use this # information to calculate the number of sequences per group # for the given maximum number of requested sequences. + # TODO: use pandas logic for group in group_by_strain.values(): records_per_group[group] += 1 - else: - # Track the highest priority records, when we already - # know the number of sequences allowed per group. - if queues_by_group is None: - queues_by_group = {} - - for strain in sorted(group_by_strain.keys()): - # During this first pass, we do not know all possible - # groups will be, so we need to build each group's queue - # as we first encounter the group. - group = group_by_strain[strain] - if group not in queues_by_group: - queues_by_group[group] = PriorityQueue( - max_size=sequences_per_group, - ) - - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], - ) except FilterException as error: # When we cannot group by the requested columns, we print a # warning to the user and continue without subsampling or @@ -1493,7 +1399,7 @@ def run(args): # requested maximum number of sequences and the number of sequences per # group. Then, we need to make a second pass through the metadata to find # the requested number of records. - if args.subsample_max_sequences and records_per_group is not None: + if args.subsample_max_sequences and group_by: # Calculate sequences per group. If there are more groups than maximum # sequences requested, sequences per group will be a floating point # value and subsampling will be probabilistic. @@ -1509,76 +1415,98 @@ def run(args): if (probabilistic_used): print(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") - else: + elif group_by != dummy_group: print(f"Sampling at {sequences_per_group} per group.") - if queues_by_group is None: - # We know all of the possible groups now from the first pass through - # the metadata, so we can create queues for all groups at once. - queues_by_group = create_queues_by_group( - records_per_group.keys(), - sequences_per_group, - random_seed=args.subsample_seed, - ) - - # Make a second pass through the metadata, only considering records that - # have passed filters. + if group_by and valid_strains: + # Make a second pass through the metadata. metadata_reader = read_metadata( args.metadata, id_columns=args.metadata_id_columns, chunk_size=args.metadata_chunk_size, ) for metadata in metadata_reader: - # Recalculate groups for subsampling as we loop through the - # metadata a second time. TODO: We could store these in memory - # during the first pass, but we want to minimize overall memory - # usage at the moment. seq_keep = set(metadata.index.values) & valid_strains - group_by_strain, skipped_strains = get_groups_for_subsampling( - seq_keep, - metadata, - group_by, + if not seq_keep: + continue + # create a copy of metadata, only considering records that have passed filters + seq_keep_ordered = [seq for seq in metadata.index if seq in seq_keep] + metadata_copy = metadata.loc[seq_keep_ordered].copy() + # add columns for priority and expanded date + metadata_copy['priority'] = metadata_copy.index.to_series().apply(lambda x: priorities[x]) + metadata_with_dates, _ = expand_date_col(metadata_copy, set(group_by)) # TODO: don't drop date col in this function so we don't need this intermediate var + metadata_copy = (pd.concat([metadata_copy, metadata_with_dates[['year','month','day']]], axis=1) + .reindex(metadata_copy.index)) # not needed after pandas>=1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html#index-column-name-preservation-when-aggregating + # add column for dummy group + if group_by == dummy_group: + metadata_copy[dummy_group[0]] = dummy_group_value[0] + # add columns for unknown groups + unknown_groups = set(group_by) - set(metadata_copy.columns) + if unknown_groups: + for group in unknown_groups: + metadata_copy[group] = 'unknown' + # apply priorities + poisson_max = 5 # TODO: get this from first pass along with sequences_per_group + int_group_size = sequences_per_group if sequences_per_group >= 1 else poisson_max + # get index of dataframe with top n priority per group, breaking ties by last occurrence in metadata + chunk_top_priorities_index = ( + metadata_copy.groupby(group_by, sort=False)['priority'] + .nlargest(int_group_size, keep='last') + .index.get_level_values('strain') ) + # get prioritized strains + chunk_prioritized_metadata = metadata_copy.loc[chunk_top_priorities_index] - for strain in sorted(group_by_strain.keys()): - group = group_by_strain[strain] - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], + # update global priority + if prioritized_metadata.empty: + prioritized_metadata = chunk_prioritized_metadata + else: + prioritized_metadata = prioritized_metadata.append(chunk_prioritized_metadata) + # get index of dataframe with top n priority per group, breaking ties by last occurrence in metadata + global_top_priorities_index = ( + prioritized_metadata.groupby(group_by, sort=False)['priority'] + .nlargest(int_group_size, keep='last') + .index.get_level_values('strain') ) - - # If we have any records in queues, we have grouped results and need to - # stream the highest priority records to the requested outputs. + # get prioritized strains + prioritized_metadata = prioritized_metadata.loc[global_top_priorities_index] + + # probabilistic subsampling + if sequences_per_group < 1: + groups = prioritized_metadata.groupby(group_by).groups.keys() + if len(group_by) == 1: + groups = [(x,) for x in groups] + sequences_per_group_map = create_sizes_per_group(groups, sequences_per_group, random_seed=args.subsample_seed) + prioritized_metadata['group'] = list(zip(*[prioritized_metadata[col] for col in group_by])) + prioritized_metadata['group_size'] = prioritized_metadata['group'].map(sizes_per_group) + prioritized_metadata['group_cumcount'] = prioritized_metadata.sort_values('priority', ascending=False).groupby(group_by + ['group_size']).cumcount() + prioritized_metadata = prioritized_metadata[prioritized_metadata['group_cumcount'] < prioritized_metadata['group_size']] + + # drop intermediate columns before writing to output + prioritized_metadata = prioritized_metadata[metadata.columns] + + # If we have any sequences in the prioritized metadata, we have grouped results and need to + # stream the highest priority sequences to the requested outputs. num_excluded_subsamp = 0 - if queues_by_group: - # Populate the set of strains to keep from the records in queues. - subsampled_strains = set() - for group, queue in queues_by_group.items(): - records = [] - for record in queue.get_items(): - # Each record is a pandas.Series instance. Track the name of the - # record, so we can output its sequences later. - subsampled_strains.add(record.name) - - # Construct a data frame of records to simplify metadata output. - records.append(record) - - if args.output_strains: - # TODO: Output strains will no longer be ordered. This is a - # small breaking change. - output_strains.write(f"{record.name}\n") - - # Write records to metadata output, if requested. - if args.output_metadata and len(records) > 0: - records = pd.DataFrame(records) - records.to_csv( - args.output_metadata, - sep="\t", - header=metadata_header, - mode=metadata_mode, - ) - metadata_header = False - metadata_mode = "a" + if not prioritized_metadata.empty: + subsampled_strains = set(prioritized_metadata.index) + + if args.output_strains: + # TODO: Output strains will no longer be ordered. This is a + # small breaking change. + for strain in prioritized_metadata.index: + output_strains.write(f"{strain}\n") + + # Write records to metadata output, if requested. + if args.output_metadata: + prioritized_metadata.to_csv( + args.output_metadata, + sep="\t", + header=metadata_header, + mode=metadata_mode, + ) + metadata_header = False + metadata_mode = "a" # Count and optionally log strains that were not included due to # subsampling. diff --git a/tests/functional/filter/metadata_ambiguous_months.tsv b/tests/functional/filter/metadata_ambiguous_months.tsv new file mode 100644 index 000000000..821f0c529 --- /dev/null +++ b/tests/functional/filter/metadata_ambiguous_months.tsv @@ -0,0 +1,11 @@ +strain accession date region country location db authors cluster paper_url title +G22670 10155 1991-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22671 10223 1992-XX-XX north_america canada village_d genbank Lee et al Mj-V.a http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22672 11011 1991-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22673 11234 1992-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22674 14069 1993-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22675 14508 1993-XX-XX north_america canada village_c genbank Lee et al Mj-V.a http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22676 15613 1994-XX-XX north_america canada village_e genbank Lee et al Mj-V.d http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22677 16490 1995-XX-XX north_america canada village_e genbank Lee et al Mj-IV.b http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22678 16493 1995-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22679 18421 1996-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit diff --git a/tests/functional/filter_groupby.t b/tests/functional/filter_groupby.t new file mode 100644 index 000000000..0aec6637b --- /dev/null +++ b/tests/functional/filter_groupby.t @@ -0,0 +1,123 @@ +Integration tests for grouping features in augur filter. + + $ pushd "$TESTDIR" > /dev/null + $ export AUGUR="../../bin/augur" + +Try simple grouping. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by country year month \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" \ + > --output-metadata "$TMP/filtered_metadata.tsv" + Sampling at 10 per group. + 2 strains were dropped during filtering + \t1 were dropped during grouping due to ambiguous year information (esc) + \t1 were dropped during grouping due to ambiguous month information (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + $ wc -l "$TMP/filtered_strains.txt" + \s*10 .* (re) + $ head -n 2 "$TMP/filtered_metadata.tsv" + strain\tvirus\taccession\tdate\tregion\tcountry\tdivision\tcity\tdb\tsegment\tauthors\turl\ttitle\tjournal\tpaper_url (esc) + PRVABC59\tzika\tKU501215\t2015-12-XX\tNorth America\tPuerto Rico\tPuerto Rico\tPuerto Rico\tgenbank\tgenome\tLanciotti et al\thttps://www.ncbi.nlm.nih.gov/nuccore/KU501215\tPhylogeny of Zika Virus in Western Hemisphere, 2015\tEmerging Infect. Dis. 22 (5), 933-935 (2016)\thttps://www.ncbi.nlm.nih.gov/pubmed/27088323 (esc) + +Try subsample without any groups. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" \ + > --output-metadata "$TMP/filtered_metadata.tsv" + 2 strains were dropped during filtering + \t2 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + $ wc -l "$TMP/filtered_strains.txt" + \s*10 .* (re) + $ head -n 2 "$TMP/filtered_metadata.tsv" + strain\tvirus\taccession\tdate\tregion\tcountry\tdivision\tcity\tdb\tsegment\tauthors\turl\ttitle\tjournal\tpaper_url (esc) + COL/FLR_00024/2015\tzika\tMF574569\t\tSouth America\tColombia\tColombia\tColombia\tgenbank\tgenome\tPickett et al\thttps://www.ncbi.nlm.nih.gov/nuccore/MF574569\tDirect Submission\tSubmitted (28-JUL-2017) J. Craig Venter Institute, 9704 Medical Center Drive, Rockville, MD 20850, USA\thttps://www.ncbi.nlm.nih.gov/pubmed/ (esc) + +Try grouping by an unknown column. +This should warn then continue without grouping. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --exclude-where "region=South America" "region=North America" "region=Southeast Asia" \ + > --include-where "country=Ecuador" \ + > --group-by invalid \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-metadata "$TMP/filtered_metadata.tsv" + WARNING: The specified group-by categories (['invalid']) were not found. No sequences-per-group sampling will be done. + 10 strains were dropped during filtering + \t6 of these were dropped because of 'region=South America' (esc) + \t4 of these were dropped because of 'region=North America' (esc) + \t1 of these were dropped because of 'region=Southeast Asia' (esc) + \t1 sequences were added back because of 'country=Ecuador' (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 2 strains passed all filters + $ wc -l "$TMP/filtered_metadata.tsv" + \s*3 .* (re) + +Try grouping by an unknown column and a valid column. +This should warn then continue with grouping by valid column only. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by country invalid \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-metadata "$TMP/filtered_metadata.tsv" + WARNING: Some of the specified group-by categories couldn't be found: invalid + Filtering by group may behave differently than expected! + Sampling at 1 per group. + 3 strains were dropped during filtering + \t3 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 9 strains passed all filters + $ wc -l "$TMP/filtered_metadata.tsv" + \s*10 .* (re) + +Try grouping with no probabilistic sampling + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by country year month \ + > --sequences-per-group 1 \ + > --subsample-seed 314159 \ + > --no-probabilistic-sampling \ + > --output-strains "$TMP/filtered_strains.txt" \ + > --output-metadata "$TMP/filtered_metadata.tsv" > /dev/null + $ wc -l "$TMP/filtered_strains.txt" + \s*9 .* (re) + $ wc -l "$TMP/filtered_metadata.tsv" + \s*10 .* (re) + +Try grouping with year only + + $ ${AUGUR} filter \ + > --metadata filter/metadata_ambiguous_months.tsv \ + > --group-by year \ + > --sequences-per-group 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" + 0 strains were dropped during filtering + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + +Try grouping with year, month + + $ ${AUGUR} filter \ + > --metadata filter/metadata_ambiguous_months.tsv \ + > --group-by year month \ + > --sequences-per-group 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" + ERROR: All samples have been dropped! Check filter rules and metadata file format. + 10 strains were dropped during filtering + \t10 were dropped during grouping due to ambiguous month information (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + [1] From 4e3e155ab1dfd8d686917787e0e52d9e2112efa2 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 10 Dec 2021 01:15:16 -0800 Subject: [PATCH 3/8] store unique group values globally to improve probablistic subsampling logic --- augur/filter.py | 48 ++++++++++++++++++++++-------------- tests/test_filter_groupby.py | 18 +++++++------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/augur/filter.py b/augur/filter.py index e27916d33..6d6c659d4 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -821,6 +821,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): Returns ------- + set : + A set of all distinct group values. dict : A mapping of strain names to tuples corresponding to the values of the strain's group. list : @@ -829,7 +831,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> strains = ["strain1", "strain2"] >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020-01-01", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") >>> group_by = ["region"] - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain {'strain1': ('Africa',), 'strain2': ('Europe',)} >>> skipped_strains @@ -839,13 +841,13 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): string. >>> group_by = ["year", "month"] - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain {'strain1': (2020, (2020, 1)), 'strain2': (2020, (2020, 2))} If we omit the grouping columns, the result will group by a dummy column. - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) >>> group_by_strain {'strain1': ('_dummy',), 'strain2': ('_dummy',)} @@ -861,7 +863,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): grouping to continue and print a warning message to stderr. >>> group_by = ["year", "month", "missing_column"] - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain {'strain1': (2020, (2020, 1), 'unknown'), 'strain2': (2020, (2020, 2), 'unknown')} @@ -870,7 +872,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): track which records were skipped for which reasons. >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"]) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"]) >>> group_by_strain {'strain2': (2020,)} >>> skipped_strains @@ -880,7 +882,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): month information in their date fields. >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"]) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"]) >>> group_by_strain {'strain2': ((2020, 2),)} >>> skipped_strains @@ -888,15 +890,17 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): """ metadata = metadata.loc[strains] + group_values = set() group_by_strain = {} skipped_strains = [] if metadata.empty: - return group_by_strain, skipped_strains + return group_values, group_by_strain, skipped_strains if not group_by or group_by == dummy_group: + group_values = set(dummy_group_value) group_by_strain = {strain: dummy_group_value for strain in strains} - return group_by_strain, skipped_strains + return group_values, group_by_strain, skipped_strains group_by_set = set(group_by) @@ -924,8 +928,12 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): for group in unknown_groups: metadata[group] = 'unknown' + group_values = set(metadata.groupby(group_by).groups.keys()) + if len(group_by) == 1: + # force tuple for single column group values + group_values = {(x,) for x in group_values} group_by_strain = dict(zip(metadata.index, metadata[group_by].apply(tuple, axis=1))) - return group_by_strain, skipped_strains + return group_values, group_by_strain, skipped_strains def expand_date_col(metadata: pd.DataFrame, group_by_set: set) -> Tuple[pd.DataFrame, List[dict]]: @@ -1239,6 +1247,7 @@ def run(args): # per group to use, so we need to calculate this number after the first pass # and use a second pass to add records to the priority DataFrame. group_by = args.group_by + groups = set() sequences_per_group = args.sequences_per_group records_per_group = None @@ -1338,12 +1347,14 @@ def run(args): # count the number of records per group. First, we need to get # the groups for the given records. try: - group_by_strain, skipped_strains = get_groups_for_subsampling( + chunk_groups, group_by_strain, skipped_strains = get_groups_for_subsampling( seq_keep, metadata, group_by, ) + groups.update(chunk_groups) + # Track strains skipped during grouping, so users know why those # strains were excluded from the analysis. for skipped_strain in skipped_strains: @@ -1395,6 +1406,7 @@ def run(args): for strain in strains_to_write: output_strains.write(f"{strain}\n") + probabilistic_used = False # In the worst case, we need to calculate sequences per group from the # requested maximum number of sequences and the number of sequences per # group. Then, we need to make a second pass through the metadata to find @@ -1419,6 +1431,10 @@ def run(args): print(f"Sampling at {sequences_per_group} per group.") if group_by and valid_strains: + if probabilistic_used: + # sort groups to eliminate set order randomness + sizes_per_group = create_sizes_per_group(sorted(groups), sequences_per_group, random_seed=args.subsample_seed) + # Make a second pass through the metadata. metadata_reader = read_metadata( args.metadata, @@ -1446,8 +1462,9 @@ def run(args): for group in unknown_groups: metadata_copy[group] = 'unknown' # apply priorities - poisson_max = 5 # TODO: get this from first pass along with sequences_per_group - int_group_size = sequences_per_group if sequences_per_group >= 1 else poisson_max + int_group_size = sequences_per_group + if probabilistic_used: + int_group_size = max(sizes_per_group.values()) # get index of dataframe with top n priority per group, breaking ties by last occurrence in metadata chunk_top_priorities_index = ( metadata_copy.groupby(group_by, sort=False)['priority'] @@ -1471,12 +1488,7 @@ def run(args): # get prioritized strains prioritized_metadata = prioritized_metadata.loc[global_top_priorities_index] - # probabilistic subsampling - if sequences_per_group < 1: - groups = prioritized_metadata.groupby(group_by).groups.keys() - if len(group_by) == 1: - groups = [(x,) for x in groups] - sequences_per_group_map = create_sizes_per_group(groups, sequences_per_group, random_seed=args.subsample_seed) + if probabilistic_used: prioritized_metadata['group'] = list(zip(*[prioritized_metadata[col] for col in group_by])) prioritized_metadata['group_size'] = prioritized_metadata['group'].map(sizes_per_group) prioritized_metadata['group_cumcount'] = prioritized_metadata.sort_values('priority', ascending=False).groupby(group_by + ['group_size']).cumcount() diff --git a/tests/test_filter_groupby.py b/tests/test_filter_groupby.py index df98fc358..60ddc1961 100644 --- a/tests/test_filter_groupby.py +++ b/tests/test_filter_groupby.py @@ -18,7 +18,7 @@ class TestFilterGroupBy: def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = ['SEQ_1', 'SEQ_3', 'SEQ_5'] - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_3': ('_dummy',), @@ -29,7 +29,7 @@ def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): def test_filter_groupby_dummy(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_2': ('_dummy',), @@ -51,7 +51,7 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys) groups = ['country', 'year', 'month', 'invalid'] metadata = valid_metadata.copy() strains = metadata.index.tolist() - group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups) + groups, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1), 'unknown'), 'SEQ_2': ('A', 2020, (2020, 2), 'unknown'), @@ -67,7 +67,7 @@ def test_filter_groupby_skip_ambiguous_year(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "XXXX-02-01" strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -81,7 +81,7 @@ def test_filter_groupby_skip_missing_date(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = None strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -95,7 +95,7 @@ def test_filter_groupby_skip_ambiguous_month(self, valid_metadata: pd.DataFrame) metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "2020-XX-01" strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -109,7 +109,7 @@ def test_filter_groupby_skip_missing_month(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "2020" strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -150,7 +150,7 @@ def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, ca metadata = valid_metadata.copy() metadata = metadata.drop('date', axis='columns') strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 'unknown', 'unknown'), 'SEQ_2': ('A', 'unknown', 'unknown'), @@ -166,6 +166,6 @@ def test_filter_groupby_no_strains(self, valid_metadata: pd.DataFrame): groups = ['country', 'year', 'month'] metadata = valid_metadata.copy() strains = [] - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == {} assert skipped_strains == [] From 9cf22641cff214299a817049cd0d341460fc300a Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 10 Dec 2021 16:19:37 -0800 Subject: [PATCH 4/8] update tests for new value returned from get_groups_for_subsampling --- augur/filter.py | 6 ++++++ tests/test_filter_groupby.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/augur/filter.py b/augur/filter.py index 1ee94d1c3..03b1c37df 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -888,6 +888,12 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> skipped_strains [{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}] + Distinct groups are returned as a set. + + >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["region"]) + >>> list(sorted(groups)) + [('Africa',), ('Europe',)] """ metadata = metadata.loc[strains] group_values = set() diff --git a/tests/test_filter_groupby.py b/tests/test_filter_groupby.py index cc91518ee..5f355ebce 100644 --- a/tests/test_filter_groupby.py +++ b/tests/test_filter_groupby.py @@ -18,7 +18,7 @@ class TestFilterGroupBy: def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = ['SEQ_1', 'SEQ_3', 'SEQ_5'] - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_3': ('_dummy',), @@ -29,7 +29,7 @@ def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): def test_filter_groupby_dummy(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = metadata.index.tolist() - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_2': ('_dummy',), @@ -51,7 +51,7 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys) groups = ['country', 'year', 'month', 'invalid'] metadata = valid_metadata.copy() strains = metadata.index.tolist() - groups, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1), 'unknown'), 'SEQ_2': ('A', 2020, (2020, 2), 'unknown'), @@ -67,7 +67,7 @@ def test_filter_groupby_skip_ambiguous_year(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "XXXX-02-01" strains = metadata.index.tolist() - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -81,7 +81,7 @@ def test_filter_groupby_skip_missing_date(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = None strains = metadata.index.tolist() - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -95,7 +95,7 @@ def test_filter_groupby_skip_ambiguous_month(self, valid_metadata: pd.DataFrame) metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "2020-XX-01" strains = metadata.index.tolist() - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -109,7 +109,7 @@ def test_filter_groupby_skip_missing_month(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "2020" strains = metadata.index.tolist() - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020, (2020, 1)), 'SEQ_3': ('B', 2020, (2020, 3)), @@ -150,7 +150,7 @@ def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, ca metadata = valid_metadata.copy() metadata = metadata.drop('date', axis='columns') strains = metadata.index.tolist() - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 'unknown', 'unknown'), 'SEQ_2': ('A', 'unknown', 'unknown'), @@ -166,7 +166,7 @@ def test_filter_groupby_no_strains(self, valid_metadata: pd.DataFrame): groups = ['country', 'year', 'month'] metadata = valid_metadata.copy() strains = [] - groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == {} assert skipped_strains == [] From e01d302c21ea2f761a57f6506f6efebba771d04a Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 10 Dec 2021 17:09:32 -0800 Subject: [PATCH 5/8] remove accidentally committed files with merge commit --- filtered_strains.txt | 0 tmp/filtered_strains-1.txt | 4 ---- tmp/filtered_strains.txt | 4 ---- 3 files changed, 8 deletions(-) delete mode 100644 filtered_strains.txt delete mode 100644 tmp/filtered_strains-1.txt delete mode 100644 tmp/filtered_strains.txt diff --git a/filtered_strains.txt b/filtered_strains.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/tmp/filtered_strains-1.txt b/tmp/filtered_strains-1.txt deleted file mode 100644 index 6d8ae11eb..000000000 --- a/tmp/filtered_strains-1.txt +++ /dev/null @@ -1,4 +0,0 @@ -PRVABC59 -COL/FLR_00008/2015 -ZKC2/2016 -VEN/UF_1/2016 diff --git a/tmp/filtered_strains.txt b/tmp/filtered_strains.txt deleted file mode 100644 index f321ba4fd..000000000 --- a/tmp/filtered_strains.txt +++ /dev/null @@ -1,4 +0,0 @@ -PRVABC59 -ZKC2/2016 -VEN/UF_1/2016 -BRA/2016/FC_6706 From 897d00e46d72dce874818f44303d3c5aed85bcc4 Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Fri, 17 Dec 2021 11:59:26 -0800 Subject: [PATCH 6/8] Add test for grouping by month alone This test currently fails with a pandas-specific index error. --- tests/functional/filter_groupby.t | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/functional/filter_groupby.t b/tests/functional/filter_groupby.t index 0aec6637b..a9e5087b2 100644 --- a/tests/functional/filter_groupby.t +++ b/tests/functional/filter_groupby.t @@ -108,6 +108,20 @@ Try grouping with year only \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) 10 strains passed all filters +Try grouping with month only + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by month \ + > --sequences-per-group 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" + 2 strains were dropped during filtering + \t1 were dropped during grouping due to ambiguous year information (esc) + \t1 were dropped during grouping due to ambiguous month information (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + Try grouping with year, month $ ${AUGUR} filter \ From 966da1d07e2e7809fe34e8e709680fcdc609176a Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Fri, 17 Dec 2021 12:24:12 -0800 Subject: [PATCH 7/8] Implicitly group by year and month for month group Instead of calculating a new (year, month) tuple when users group by month, add a "year" key to the list of group fields. This fixes a pandas indexing bug where calling `nlargest` on a SeriesGroupBy object that has a year and month tuple key for month causes pandas to think the single month key is a MultiIndex that should be a list. Although this commit is motivated to fix this pandas issue, this implementation of the year/month disambiguation is simpler and a more idiomatic pandas implementation that wouldn't have been possible in the original augur filter implementation (before we switched to pandas for metadata parsing). --- augur/filter.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/augur/filter.py b/augur/filter.py index 03b1c37df..bdbc8b6c6 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -991,8 +991,7 @@ def expand_date_col(metadata: pd.DataFrame, group_by_set: set) -> Tuple[pd.DataF "filter": "skip_group_by_with_ambiguous_month", "kwargs": "", }) - # month = (year, month) - metadata_new['month'] = list(zip(metadata_new['year'], metadata_new['month'])) + # TODO: support group by day return metadata_new, skipped_strains @@ -1279,6 +1278,11 @@ def run(args): random_generator = np.random.default_rng(args.subsample_seed) priorities = defaultdict(random_generator.random) + # When grouping by month, we implicitly group by year and month to avoid + # grouping across years meaninglessly. + if len(group_by) == 1 and group_by[0] == "month": + group_by = ["year", "month"] + # Setup metadata output. We track whether any records have been written to # disk yet through the following variables, to control whether we write the # metadata's header and open a new file for writing. From eea96fb4d85e1dd9554ee788591c848577e97690 Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Fri, 17 Dec 2021 12:48:49 -0800 Subject: [PATCH 8/8] Update unit and doc tests to match new month group Simplifies unit tests and doctests by expecting a single value for each month instead of a tuple. --- augur/filter.py | 7 ++--- tests/test_filter_groupby.py | 52 ++++++++++++++++++------------------ 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/augur/filter.py b/augur/filter.py index bdbc8b6c6..fbf109fd2 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -843,7 +843,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> group_by = ["year", "month"] >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain - {'strain1': (2020, (2020, 1)), 'strain2': (2020, (2020, 2))} + {'strain1': (2020, 1), 'strain2': (2020, 2)} If we omit the grouping columns, the result will group by a dummy column. @@ -865,7 +865,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> group_by = ["year", "month", "missing_column"] >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain - {'strain1': (2020, (2020, 1), 'unknown'), 'strain2': (2020, (2020, 2), 'unknown')} + {'strain1': (2020, 1, 'unknown'), 'strain2': (2020, 2, 'unknown')} If we group by year month and some records don't have that information in their date fields, we should skip those records from the group output and @@ -884,7 +884,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"]) >>> group_by_strain - {'strain2': ((2020, 2),)} + {'strain2': (2,)} >>> skipped_strains [{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}] @@ -894,6 +894,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["region"]) >>> list(sorted(groups)) [('Africa',), ('Europe',)] + """ metadata = metadata.loc[strains] group_values = set() diff --git a/tests/test_filter_groupby.py b/tests/test_filter_groupby.py index 5f355ebce..f04b37f98 100644 --- a/tests/test_filter_groupby.py +++ b/tests/test_filter_groupby.py @@ -53,11 +53,11 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys) strains = metadata.index.tolist() _, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1), 'unknown'), - 'SEQ_2': ('A', 2020, (2020, 2), 'unknown'), - 'SEQ_3': ('B', 2020, (2020, 3), 'unknown'), - 'SEQ_4': ('B', 2020, (2020, 4), 'unknown'), - 'SEQ_5': ('B', 2020, (2020, 5), 'unknown') + 'SEQ_1': ('A', 2020, 1, 'unknown'), + 'SEQ_2': ('A', 2020, 2, 'unknown'), + 'SEQ_3': ('B', 2020, 3, 'unknown'), + 'SEQ_4': ('B', 2020, 4, 'unknown'), + 'SEQ_5': ('B', 2020, 5, 'unknown') } captured = capsys.readouterr() assert captured.err == "WARNING: Some of the specified group-by categories couldn't be found: invalid\nFiltering by group may behave differently than expected!\n" @@ -69,10 +69,10 @@ def test_filter_groupby_skip_ambiguous_year(self, valid_metadata: pd.DataFrame): strains = metadata.index.tolist() _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_year', 'kwargs': ''}] @@ -83,10 +83,10 @@ def test_filter_groupby_skip_missing_date(self, valid_metadata: pd.DataFrame): strains = metadata.index.tolist() _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_year', 'kwargs': ''}] @@ -97,10 +97,10 @@ def test_filter_groupby_skip_ambiguous_month(self, valid_metadata: pd.DataFrame) strains = metadata.index.tolist() _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}] @@ -111,10 +111,10 @@ def test_filter_groupby_skip_missing_month(self, valid_metadata: pd.DataFrame): strains = metadata.index.tolist() _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}] @@ -207,10 +207,10 @@ def test_filter_groupby_only_year_month_provided(self, valid_metadata: pd.DataFr strains = metadata.index.tolist() _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_2': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 1)), - 'SEQ_4': ('B', 2020, (2020, 1)), - 'SEQ_5': ('B', 2020, (2020, 1)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_2': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 1), + 'SEQ_4': ('B', 2020, 1), + 'SEQ_5': ('B', 2020, 1) } assert skipped_strains == []