diff --git a/augur/filter.py b/augur/filter.py index f4300505a..fbf109fd2 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -17,7 +17,7 @@ import sys from tempfile import NamedTemporaryFile import treetime.utils -from typing import Collection +from typing import Collection, List, Tuple from .index import index_sequences, index_vcf from .io import open_file, read_metadata, read_sequences, write_sequences @@ -25,6 +25,9 @@ comment_char = '#' +dummy_group = ['_dummy',] +dummy_group_value = ('_dummy',) + SEQUENCE_ONLY_FILTERS = ( "min_length", "non_nucleotide", @@ -818,6 +821,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): Returns ------- + set : + A set of all distinct group values. dict : A mapping of strain names to tuples corresponding to the values of the strain's group. list : @@ -826,7 +831,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): >>> strains = ["strain1", "strain2"] >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020-01-01", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") >>> group_by = ["region"] - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain {'strain1': ('Africa',), 'strain2': ('Europe',)} >>> skipped_strains @@ -836,13 +841,13 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): string. >>> group_by = ["year", "month"] - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain - {'strain1': (2020, (2020, 1)), 'strain2': (2020, (2020, 2))} + {'strain1': (2020, 1), 'strain2': (2020, 2)} If we omit the grouping columns, the result will group by a dummy column. - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) >>> group_by_strain {'strain1': ('_dummy',), 'strain2': ('_dummy',)} @@ -858,16 +863,16 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): grouping to continue and print a warning message to stderr. >>> group_by = ["year", "month", "missing_column"] - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by) >>> group_by_strain - {'strain1': (2020, (2020, 1), 'unknown'), 'strain2': (2020, (2020, 2), 'unknown')} + {'strain1': (2020, 1, 'unknown'), 'strain2': (2020, 2, 'unknown')} If we group by year month and some records don't have that information in their date fields, we should skip those records from the group output and track which records were skipped for which reasons. >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"]) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"]) >>> group_by_strain {'strain2': (2020,)} >>> skipped_strains @@ -877,23 +882,32 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): month information in their date fields. >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") - >>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"]) + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"]) >>> group_by_strain - {'strain2': ((2020, 2),)} + {'strain2': (2,)} >>> skipped_strains [{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}] + Distinct groups are returned as a set. + + >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") + >>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["region"]) + >>> list(sorted(groups)) + [('Africa',), ('Europe',)] + """ metadata = metadata.loc[strains] + group_values = set() group_by_strain = {} skipped_strains = [] if metadata.empty: - return group_by_strain, skipped_strains + return group_values, group_by_strain, skipped_strains - if not group_by or group_by == ('_dummy',): - group_by_strain = {strain: ('_dummy',) for strain in strains} - return group_by_strain, skipped_strains + if not group_by or group_by == dummy_group: + group_values = set(dummy_group_value) + group_by_strain = {strain: dummy_group_value for strain in strains} + return group_values, group_by_strain, skipped_strains group_by_set = set(group_by) @@ -912,39 +926,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): df_dates = pd.DataFrame({'year': 'unknown', 'month': 'unknown'}, index=metadata.index) metadata = pd.concat([metadata, df_dates], axis=1) else: - # replace date with year/month/day as nullable ints - date_cols = ['year', 'month', 'day'] - df_dates = metadata['date'].str.split('-', n=2, expand=True) - df_dates = df_dates.set_axis(date_cols[:len(df_dates.columns)], axis=1) - missing_date_cols = set(date_cols) - set(df_dates.columns) - for col in missing_date_cols: - df_dates[col] = pd.NA - for col in date_cols: - df_dates[col] = pd.to_numeric(df_dates[col], errors='coerce').astype(pd.Int64Dtype()) - metadata = pd.concat([metadata.drop('date', axis=1), df_dates], axis=1) - if 'year' in group_by_set: - # skip ambiguous years - df_skip = metadata[metadata['year'].isnull()] - metadata.dropna(subset=['year'], inplace=True) - for strain in df_skip.index: - skipped_strains.append({ - "strain": strain, - "filter": "skip_group_by_with_ambiguous_year", - "kwargs": "", - }) - if 'month' in group_by_set: - # skip ambiguous months - df_skip = metadata[metadata['month'].isnull()] - metadata.dropna(subset=['month'], inplace=True) - for strain in df_skip.index: - skipped_strains.append({ - "strain": strain, - "filter": "skip_group_by_with_ambiguous_month", - "kwargs": "", - }) - # month = (year, month) - metadata['month'] = list(zip(metadata['year'], metadata['month'])) - # TODO: support group by day + metadata, skipped_strains = expand_date_col(metadata, group_by_set) unknown_groups = group_by_set - set(metadata.columns) if unknown_groups: @@ -953,119 +935,100 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): for group in unknown_groups: metadata[group] = 'unknown' + group_values = set(metadata.groupby(group_by).groups.keys()) + if len(group_by) == 1: + # force tuple for single column group values + group_values = {(x,) for x in group_values} group_by_strain = dict(zip(metadata.index, metadata[group_by].apply(tuple, axis=1))) - return group_by_strain, skipped_strains - - -class PriorityQueue: - """A priority queue implementation that automatically replaces lower priority - items in the heap with incoming higher priority items. - - Add a single record to a heap with a maximum of 2 records. + return group_values, group_by_strain, skipped_strains - >>> queue = PriorityQueue(max_size=2) - >>> queue.add({"strain": "strain1"}, 0.5) - 1 - - Add another record with a higher priority. The queue should be at its maximum - size. - - >>> queue.add({"strain": "strain2"}, 1.0) - 2 - >>> queue.heap - [(0.5, 0, {'strain': 'strain1'}), (1.0, 1, {'strain': 'strain2'})] - >>> list(queue.get_items()) - [{'strain': 'strain1'}, {'strain': 'strain2'}] - - Add a higher priority record that causes the queue to exceed its maximum - size. The resulting queue should contain the two highest priority records - after the lowest priority record is removed. - - >>> queue.add({"strain": "strain3"}, 2.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain2'}, {'strain': 'strain3'}] - Add a record with the same priority as another record, forcing the duplicate - to be resolved by removing the oldest entry. +def expand_date_col(metadata: pd.DataFrame, group_by_set: set) -> Tuple[pd.DataFrame, List[dict]]: + """Expand the date column of a DataFrame to year=year, month=(year, month). - >>> queue.add({"strain": "strain4"}, 1.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain4'}, {'strain': 'strain3'}] + Parameters + ---------- + metadata : pandas.DataFrame + Metadata containing date column. + group_by_set : set + A set of metadata columns to group records by. + Returns + ------- + pandas.DataFrame : + The input metadata with expanded date columns. + list : + A list of dictionaries with strains that were skipped from grouping and the reason why (see also: `apply_filters` output). """ - def __init__(self, max_size): - """Create a fixed size heap (priority queue) - - """ - self.max_size = max_size - self.heap = [] - self.counter = itertools.count() - - def add(self, item, priority): - """Add an item to the queue with a given priority. - - If adding the item causes the queue to exceed its maximum size, replace - the lowest priority item with the given item. The queue stores items - with an additional heap id value (a count) to resolve ties between items - with equal priority (favoring the most recently added item). - - """ - heap_id = next(self.counter) - - if len(self.heap) >= self.max_size: - heapq.heappushpop(self.heap, (priority, heap_id, item)) - else: - heapq.heappush(self.heap, (priority, heap_id, item)) - - return len(self.heap) - - def get_items(self): - """Return each item in the queue in order. - - Yields - ------ - Any - Item stored in the queue. + metadata_new = metadata.copy() + skipped_strains = [] + # replace date with year/month/day as nullable ints + date_cols = ['year', 'month', 'day'] + df_dates = metadata['date'].str.split('-', n=2, expand=True) + df_dates = df_dates.set_axis(date_cols[:len(df_dates.columns)], axis=1) + missing_date_cols = set(date_cols) - set(df_dates.columns) + for col in missing_date_cols: + df_dates[col] = pd.NA + for col in date_cols: + df_dates[col] = pd.to_numeric(df_dates[col], errors='coerce').astype(pd.Int64Dtype()) + metadata_new = pd.concat([metadata_new.drop('date', axis=1), df_dates], axis=1) + if 'year' in group_by_set: + # skip ambiguous years + df_skip = metadata_new[metadata_new['year'].isnull()] + metadata_new.dropna(subset=['year'], inplace=True) + for strain in df_skip.index: + skipped_strains.append({ + "strain": strain, + "filter": "skip_group_by_with_ambiguous_year", + "kwargs": "", + }) + if 'month' in group_by_set: + # skip ambiguous months + df_skip = metadata_new[metadata_new['month'].isnull()] + metadata_new.dropna(subset=['month'], inplace=True) + for strain in df_skip.index: + skipped_strains.append({ + "strain": strain, + "filter": "skip_group_by_with_ambiguous_month", + "kwargs": "", + }) - """ - for priority, heap_id, item in self.heap: - yield item + # TODO: support group by day + return metadata_new, skipped_strains -def create_queues_by_group(groups, max_size, max_attempts=100, random_seed=None): - """Create a dictionary of priority queues per group for the given maximum size. +def create_sizes_per_group(groups, max_size, max_attempts=100, random_seed=None): + """Create a dictionary of sizes per group for the given maximum size. When the maximum size is fractional, probabilistically sample the maximum size from a Poisson distribution. Make at least the given number of maximum - attempts to create queues for which the sum of their maximum sizes is + attempts to create groups for which the sum of their maximum sizes is greater than zero. - Create queues for two groups with a fixed maximum size. + Create sizes for two groups with a fixed maximum size. >>> groups = ("2015", "2016") - >>> queues = create_queues_by_group(groups, 2) - >>> sum(queue.max_size for queue in queues.values()) + >>> sizes = create_sizes_per_group(groups, 2) + >>> sum(sizes.values()) 4 - Create queues for two groups with a fractional maximum size. Their total max + Create sizes for two groups with a fractional maximum size. Their total max size should still be an integer value greater than zero. >>> seed = 314159 - >>> queues = create_queues_by_group(groups, 0.1, random_seed=seed) - >>> int(sum(queue.max_size for queue in queues.values())) > 0 + >>> sizes = create_sizes_per_group(groups, 0.1, random_seed=seed) + >>> int(sum(sizes.values())) > 0 True A subsequent run of this function with the same groups and random seed - should produce the same queues and queue sizes. + should produce the same sizes. - >>> more_queues = create_queues_by_group(groups, 0.1, random_seed=seed) - >>> [queue.max_size for queue in queues.values()] == [queue.max_size for queue in more_queues.values()] + >>> more_sizes = create_sizes_per_group(groups, 0.1, random_seed=seed) + >>> list(sizes.values()) == list(more_sizes.values()) True """ - queues_by_group = {} + size_per_group = {} total_max_size = 0 attempts = 0 @@ -1073,22 +1036,22 @@ def create_queues_by_group(groups, max_size, max_attempts=100, random_seed=None) random_generator = np.random.default_rng(random_seed) # For small fractional maximum sizes, it is possible to randomly select - # maximum queue sizes that all equal zero. When this happens, filtering - # fails unexpectedly. We make multiple attempts to create queues with - # maximum sizes greater than zero for at least one queue. + # maximum sizes that all equal zero. When this happens, filtering + # fails unexpectedly. We make multiple attempts to create sizes with + # maximum sizes greater than zero for at least one group. while total_max_size == 0 and attempts < max_attempts: - for group in sorted(groups): + for group in groups: if max_size < 1.0: - queue_max_size = random_generator.poisson(max_size) + group_max_size = random_generator.poisson(max_size) else: - queue_max_size = max_size + group_max_size = max_size - queues_by_group[group] = PriorityQueue(queue_max_size) + size_per_group[group] = group_max_size - total_max_size = sum(queue.max_size for queue in queues_by_group.values()) + total_max_size = sum(size_per_group.values()) attempts += 1 - return queues_by_group + return size_per_group def register_arguments(parser): @@ -1287,27 +1250,26 @@ def run(args): # group such that we select at most the requested maximum number of # sequences in a single pass through the metadata. # - # Each case relies on a priority queue to track the highest priority records + # Each case relies on a priority DataFrame to track the highest priority records # per group. In the best case, we can track these records in a single pass # through the metadata. In the worst case, we don't know how many sequences # per group to use, so we need to calculate this number after the first pass - # and use a second pass to add records to the queue. + # and use a second pass to add records to the priority DataFrame. group_by = args.group_by + groups = set() sequences_per_group = args.sequences_per_group records_per_group = None - if group_by and args.subsample_max_sequences: - # In this case, we need two passes through the metadata with the first - # pass used to count the number of records per group. + if args.subsample_max_sequences: records_per_group = defaultdict(int) - elif not group_by and args.subsample_max_sequences: - group_by = ("_dummy",) - sequences_per_group = args.subsample_max_sequences + if not group_by: + group_by = dummy_group + sequences_per_group = args.subsample_max_sequences - # If we are grouping data, use queues to store the highest priority strains + # If we are grouping data, use DataFrame to store the highest priority strains # for each group. When no priorities are provided, they will be randomly # generated. - queues_by_group = None + prioritized_metadata = pd.DataFrame() if group_by: # Use user-defined priorities, if possible. Otherwise, setup a # corresponding dictionary that returns a random float for each strain. @@ -1317,6 +1279,11 @@ def run(args): random_generator = np.random.default_rng(args.subsample_seed) priorities = defaultdict(random_generator.random) + # When grouping by month, we implicitly group by year and month to avoid + # grouping across years meaninglessly. + if len(group_by) == 1 and group_by[0] == "month": + group_by = ["year", "month"] + # Setup metadata output. We track whether any records have been written to # disk yet through the following variables, to control whether we write the # metadata's header and open a new file for writing. @@ -1394,12 +1361,14 @@ def run(args): # count the number of records per group. First, we need to get # the groups for the given records. try: - group_by_strain, skipped_strains = get_groups_for_subsampling( + chunk_groups, group_by_strain, skipped_strains = get_groups_for_subsampling( seq_keep, metadata, group_by, ) + groups.update(chunk_groups) + # Track strains skipped during grouping, so users know why those # strains were excluded from the analysis. for skipped_strain in skipped_strains: @@ -1409,32 +1378,13 @@ def run(args): if args.output_log: output_log_writer.writerow(skipped_strain) - if args.subsample_max_sequences and records_per_group is not None: + if args.subsample_max_sequences and group_by: # Count the number of records per group. We will use this # information to calculate the number of sequences per group # for the given maximum number of requested sequences. + # TODO: use pandas logic for group in group_by_strain.values(): records_per_group[group] += 1 - else: - # Track the highest priority records, when we already - # know the number of sequences allowed per group. - if queues_by_group is None: - queues_by_group = {} - - for strain in sorted(group_by_strain.keys()): - # During this first pass, we do not know all possible - # groups will be, so we need to build each group's queue - # as we first encounter the group. - group = group_by_strain[strain] - if group not in queues_by_group: - queues_by_group[group] = PriorityQueue( - max_size=sequences_per_group, - ) - - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], - ) except FilterException as error: # When we cannot group by the requested columns, we print a # warning to the user and continue without subsampling or @@ -1470,11 +1420,12 @@ def run(args): for strain in strains_to_write: output_strains.write(f"{strain}\n") + probabilistic_used = False # In the worst case, we need to calculate sequences per group from the # requested maximum number of sequences and the number of sequences per # group. Then, we need to make a second pass through the metadata to find # the requested number of records. - if args.subsample_max_sequences and records_per_group is not None: + if args.subsample_max_sequences and group_by: # Calculate sequences per group. If there are more groups than maximum # sequences requested, sequences per group will be a floating point # value and subsampling will be probabilistic. @@ -1490,76 +1441,98 @@ def run(args): if (probabilistic_used): print(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") - else: + elif group_by != dummy_group: print(f"Sampling at {sequences_per_group} per group.") - if queues_by_group is None: - # We know all of the possible groups now from the first pass through - # the metadata, so we can create queues for all groups at once. - queues_by_group = create_queues_by_group( - records_per_group.keys(), - sequences_per_group, - random_seed=args.subsample_seed, - ) + if group_by and valid_strains: + if probabilistic_used: + # sort groups to eliminate set order randomness + sizes_per_group = create_sizes_per_group(sorted(groups), sequences_per_group, random_seed=args.subsample_seed) - # Make a second pass through the metadata, only considering records that - # have passed filters. + # Make a second pass through the metadata. metadata_reader = read_metadata( args.metadata, id_columns=args.metadata_id_columns, chunk_size=args.metadata_chunk_size, ) for metadata in metadata_reader: - # Recalculate groups for subsampling as we loop through the - # metadata a second time. TODO: We could store these in memory - # during the first pass, but we want to minimize overall memory - # usage at the moment. seq_keep = set(metadata.index.values) & valid_strains - group_by_strain, skipped_strains = get_groups_for_subsampling( - seq_keep, - metadata, - group_by, + if not seq_keep: + continue + # create a copy of metadata, only considering records that have passed filters + seq_keep_ordered = [seq for seq in metadata.index if seq in seq_keep] + metadata_copy = metadata.loc[seq_keep_ordered].copy() + # add columns for priority and expanded date + metadata_copy['priority'] = metadata_copy.index.to_series().apply(lambda x: priorities[x]) + metadata_with_dates, _ = expand_date_col(metadata_copy, set(group_by)) # TODO: don't drop date col in this function so we don't need this intermediate var + metadata_copy = (pd.concat([metadata_copy, metadata_with_dates[['year','month','day']]], axis=1) + .reindex(metadata_copy.index)) # not needed after pandas>=1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html#index-column-name-preservation-when-aggregating + # add column for dummy group + if group_by == dummy_group: + metadata_copy[dummy_group[0]] = dummy_group_value[0] + # add columns for unknown groups + unknown_groups = set(group_by) - set(metadata_copy.columns) + if unknown_groups: + for group in unknown_groups: + metadata_copy[group] = 'unknown' + # apply priorities + int_group_size = sequences_per_group + if probabilistic_used: + int_group_size = max(sizes_per_group.values()) + # get index of dataframe with top n priority per group, breaking ties by last occurrence in metadata + chunk_top_priorities_index = ( + metadata_copy.groupby(group_by, sort=False)['priority'] + .nlargest(int_group_size, keep='last') + .index.get_level_values('strain') ) + # get prioritized strains + chunk_prioritized_metadata = metadata_copy.loc[chunk_top_priorities_index] - for strain in sorted(group_by_strain.keys()): - group = group_by_strain[strain] - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], + # update global priority + if prioritized_metadata.empty: + prioritized_metadata = chunk_prioritized_metadata + else: + prioritized_metadata = prioritized_metadata.append(chunk_prioritized_metadata) + # get index of dataframe with top n priority per group, breaking ties by last occurrence in metadata + global_top_priorities_index = ( + prioritized_metadata.groupby(group_by, sort=False)['priority'] + .nlargest(int_group_size, keep='last') + .index.get_level_values('strain') ) + # get prioritized strains + prioritized_metadata = prioritized_metadata.loc[global_top_priorities_index] - # If we have any records in queues, we have grouped results and need to - # stream the highest priority records to the requested outputs. + if probabilistic_used: + prioritized_metadata['group'] = list(zip(*[prioritized_metadata[col] for col in group_by])) + prioritized_metadata['group_size'] = prioritized_metadata['group'].map(sizes_per_group) + prioritized_metadata['group_cumcount'] = prioritized_metadata.sort_values('priority', ascending=False).groupby(group_by + ['group_size']).cumcount() + prioritized_metadata = prioritized_metadata[prioritized_metadata['group_cumcount'] < prioritized_metadata['group_size']] + + # drop intermediate columns before writing to output + prioritized_metadata = prioritized_metadata[metadata.columns] + + # If we have any sequences in the prioritized metadata, we have grouped results and need to + # stream the highest priority sequences to the requested outputs. num_excluded_subsamp = 0 - if queues_by_group: - # Populate the set of strains to keep from the records in queues. - subsampled_strains = set() - for group, queue in queues_by_group.items(): - records = [] - for record in queue.get_items(): - # Each record is a pandas.Series instance. Track the name of the - # record, so we can output its sequences later. - subsampled_strains.add(record.name) - - # Construct a data frame of records to simplify metadata output. - records.append(record) - - if args.output_strains: - # TODO: Output strains will no longer be ordered. This is a - # small breaking change. - output_strains.write(f"{record.name}\n") - - # Write records to metadata output, if requested. - if args.output_metadata and len(records) > 0: - records = pd.DataFrame(records) - records.to_csv( - args.output_metadata, - sep="\t", - header=metadata_header, - mode=metadata_mode, - ) - metadata_header = False - metadata_mode = "a" + if not prioritized_metadata.empty: + subsampled_strains = set(prioritized_metadata.index) + + if args.output_strains: + # TODO: Output strains will no longer be ordered. This is a + # small breaking change. + for strain in prioritized_metadata.index: + output_strains.write(f"{strain}\n") + + # Write records to metadata output, if requested. + if args.output_metadata: + prioritized_metadata.to_csv( + args.output_metadata, + sep="\t", + header=metadata_header, + mode=metadata_mode, + ) + metadata_header = False + metadata_mode = "a" # Count and optionally log strains that were not included due to # subsampling. diff --git a/tests/functional/filter/metadata_ambiguous_months.tsv b/tests/functional/filter/metadata_ambiguous_months.tsv new file mode 100644 index 000000000..821f0c529 --- /dev/null +++ b/tests/functional/filter/metadata_ambiguous_months.tsv @@ -0,0 +1,11 @@ +strain accession date region country location db authors cluster paper_url title +G22670 10155 1991-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22671 10223 1992-XX-XX north_america canada village_d genbank Lee et al Mj-V.a http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22672 11011 1991-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22673 11234 1992-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22674 14069 1993-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22675 14508 1993-XX-XX north_america canada village_c genbank Lee et al Mj-V.a http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22676 15613 1994-XX-XX north_america canada village_e genbank Lee et al Mj-V.d http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22677 16490 1995-XX-XX north_america canada village_e genbank Lee et al Mj-IV.b http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22678 16493 1995-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit +G22679 18421 1996-XX-XX north_america canada village_k genbank Lee et al Mj-I http://www.pnas.org/content/112/44/13609 Population Genomics of Mycobacterium tuberculosis in the Inuit diff --git a/tests/functional/filter_groupby.t b/tests/functional/filter_groupby.t new file mode 100644 index 000000000..a9e5087b2 --- /dev/null +++ b/tests/functional/filter_groupby.t @@ -0,0 +1,137 @@ +Integration tests for grouping features in augur filter. + + $ pushd "$TESTDIR" > /dev/null + $ export AUGUR="../../bin/augur" + +Try simple grouping. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by country year month \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" \ + > --output-metadata "$TMP/filtered_metadata.tsv" + Sampling at 10 per group. + 2 strains were dropped during filtering + \t1 were dropped during grouping due to ambiguous year information (esc) + \t1 were dropped during grouping due to ambiguous month information (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + $ wc -l "$TMP/filtered_strains.txt" + \s*10 .* (re) + $ head -n 2 "$TMP/filtered_metadata.tsv" + strain\tvirus\taccession\tdate\tregion\tcountry\tdivision\tcity\tdb\tsegment\tauthors\turl\ttitle\tjournal\tpaper_url (esc) + PRVABC59\tzika\tKU501215\t2015-12-XX\tNorth America\tPuerto Rico\tPuerto Rico\tPuerto Rico\tgenbank\tgenome\tLanciotti et al\thttps://www.ncbi.nlm.nih.gov/nuccore/KU501215\tPhylogeny of Zika Virus in Western Hemisphere, 2015\tEmerging Infect. Dis. 22 (5), 933-935 (2016)\thttps://www.ncbi.nlm.nih.gov/pubmed/27088323 (esc) + +Try subsample without any groups. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" \ + > --output-metadata "$TMP/filtered_metadata.tsv" + 2 strains were dropped during filtering + \t2 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + $ wc -l "$TMP/filtered_strains.txt" + \s*10 .* (re) + $ head -n 2 "$TMP/filtered_metadata.tsv" + strain\tvirus\taccession\tdate\tregion\tcountry\tdivision\tcity\tdb\tsegment\tauthors\turl\ttitle\tjournal\tpaper_url (esc) + COL/FLR_00024/2015\tzika\tMF574569\t\tSouth America\tColombia\tColombia\tColombia\tgenbank\tgenome\tPickett et al\thttps://www.ncbi.nlm.nih.gov/nuccore/MF574569\tDirect Submission\tSubmitted (28-JUL-2017) J. Craig Venter Institute, 9704 Medical Center Drive, Rockville, MD 20850, USA\thttps://www.ncbi.nlm.nih.gov/pubmed/ (esc) + +Try grouping by an unknown column. +This should warn then continue without grouping. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --exclude-where "region=South America" "region=North America" "region=Southeast Asia" \ + > --include-where "country=Ecuador" \ + > --group-by invalid \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-metadata "$TMP/filtered_metadata.tsv" + WARNING: The specified group-by categories (['invalid']) were not found. No sequences-per-group sampling will be done. + 10 strains were dropped during filtering + \t6 of these were dropped because of 'region=South America' (esc) + \t4 of these were dropped because of 'region=North America' (esc) + \t1 of these were dropped because of 'region=Southeast Asia' (esc) + \t1 sequences were added back because of 'country=Ecuador' (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 2 strains passed all filters + $ wc -l "$TMP/filtered_metadata.tsv" + \s*3 .* (re) + +Try grouping by an unknown column and a valid column. +This should warn then continue with grouping by valid column only. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by country invalid \ + > --subsample-max-sequences 10 \ + > --subsample-seed 314159 \ + > --output-metadata "$TMP/filtered_metadata.tsv" + WARNING: Some of the specified group-by categories couldn't be found: invalid + Filtering by group may behave differently than expected! + Sampling at 1 per group. + 3 strains were dropped during filtering + \t3 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 9 strains passed all filters + $ wc -l "$TMP/filtered_metadata.tsv" + \s*10 .* (re) + +Try grouping with no probabilistic sampling + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by country year month \ + > --sequences-per-group 1 \ + > --subsample-seed 314159 \ + > --no-probabilistic-sampling \ + > --output-strains "$TMP/filtered_strains.txt" \ + > --output-metadata "$TMP/filtered_metadata.tsv" > /dev/null + $ wc -l "$TMP/filtered_strains.txt" + \s*9 .* (re) + $ wc -l "$TMP/filtered_metadata.tsv" + \s*10 .* (re) + +Try grouping with year only + + $ ${AUGUR} filter \ + > --metadata filter/metadata_ambiguous_months.tsv \ + > --group-by year \ + > --sequences-per-group 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" + 0 strains were dropped during filtering + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + +Try grouping with month only + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --group-by month \ + > --sequences-per-group 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" + 2 strains were dropped during filtering + \t1 were dropped during grouping due to ambiguous year information (esc) + \t1 were dropped during grouping due to ambiguous month information (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + 10 strains passed all filters + +Try grouping with year, month + + $ ${AUGUR} filter \ + > --metadata filter/metadata_ambiguous_months.tsv \ + > --group-by year month \ + > --sequences-per-group 10 \ + > --subsample-seed 314159 \ + > --output-strains "$TMP/filtered_strains.txt" + ERROR: All samples have been dropped! Check filter rules and metadata file format. + 10 strains were dropped during filtering + \t10 were dropped during grouping due to ambiguous month information (esc) + \t0 of these were dropped because of subsampling criteria, using seed 314159 (esc) + [1] diff --git a/tests/test_filter_groupby.py b/tests/test_filter_groupby.py index 2793722c8..f04b37f98 100644 --- a/tests/test_filter_groupby.py +++ b/tests/test_filter_groupby.py @@ -18,7 +18,7 @@ class TestFilterGroupBy: def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = ['SEQ_1', 'SEQ_3', 'SEQ_5'] - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_3': ('_dummy',), @@ -29,7 +29,7 @@ def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): def test_filter_groupby_dummy(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata) assert group_by_strain == { 'SEQ_1': ('_dummy',), 'SEQ_2': ('_dummy',), @@ -51,13 +51,13 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys) groups = ['country', 'year', 'month', 'invalid'] metadata = valid_metadata.copy() strains = metadata.index.tolist() - group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1), 'unknown'), - 'SEQ_2': ('A', 2020, (2020, 2), 'unknown'), - 'SEQ_3': ('B', 2020, (2020, 3), 'unknown'), - 'SEQ_4': ('B', 2020, (2020, 4), 'unknown'), - 'SEQ_5': ('B', 2020, (2020, 5), 'unknown') + 'SEQ_1': ('A', 2020, 1, 'unknown'), + 'SEQ_2': ('A', 2020, 2, 'unknown'), + 'SEQ_3': ('B', 2020, 3, 'unknown'), + 'SEQ_4': ('B', 2020, 4, 'unknown'), + 'SEQ_5': ('B', 2020, 5, 'unknown') } captured = capsys.readouterr() assert captured.err == "WARNING: Some of the specified group-by categories couldn't be found: invalid\nFiltering by group may behave differently than expected!\n" @@ -67,12 +67,12 @@ def test_filter_groupby_skip_ambiguous_year(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "XXXX-02-01" strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_year', 'kwargs': ''}] @@ -81,12 +81,12 @@ def test_filter_groupby_skip_missing_date(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = None strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_year', 'kwargs': ''}] @@ -95,12 +95,12 @@ def test_filter_groupby_skip_ambiguous_month(self, valid_metadata: pd.DataFrame) metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "2020-XX-01" strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}] @@ -109,12 +109,12 @@ def test_filter_groupby_skip_missing_month(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata.at["SEQ_2", "date"] = "2020" strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 3)), - 'SEQ_4': ('B', 2020, (2020, 4)), - 'SEQ_5': ('B', 2020, (2020, 5)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 3), + 'SEQ_4': ('B', 2020, 4), + 'SEQ_5': ('B', 2020, 5) } assert skipped_strains == [{'strain': 'SEQ_2', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}] @@ -150,7 +150,7 @@ def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, ca metadata = valid_metadata.copy() metadata = metadata.drop('date', axis='columns') strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 'unknown', 'unknown'), 'SEQ_2': ('A', 'unknown', 'unknown'), @@ -166,7 +166,7 @@ def test_filter_groupby_no_strains(self, valid_metadata: pd.DataFrame): groups = ['country', 'year', 'month'] metadata = valid_metadata.copy() strains = [] - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == {} assert skipped_strains == [] @@ -175,7 +175,7 @@ def test_filter_groupby_only_year_provided(self, valid_metadata: pd.DataFrame): metadata = valid_metadata.copy() metadata['date'] = '2020' strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { 'SEQ_1': ('A', 2020), 'SEQ_2': ('A', 2020), @@ -190,7 +190,7 @@ def test_filter_groupby_month_with_only_year_provided(self, valid_metadata: pd.D metadata = valid_metadata.copy() metadata['date'] = '2020' strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == {} assert skipped_strains == [ {'strain': 'SEQ_1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}, @@ -205,12 +205,12 @@ def test_filter_groupby_only_year_month_provided(self, valid_metadata: pd.DataFr metadata = valid_metadata.copy() metadata['date'] = '2020-01' strains = metadata.index.tolist() - group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) + _, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups) assert group_by_strain == { - 'SEQ_1': ('A', 2020, (2020, 1)), - 'SEQ_2': ('A', 2020, (2020, 1)), - 'SEQ_3': ('B', 2020, (2020, 1)), - 'SEQ_4': ('B', 2020, (2020, 1)), - 'SEQ_5': ('B', 2020, (2020, 1)) + 'SEQ_1': ('A', 2020, 1), + 'SEQ_2': ('A', 2020, 1), + 'SEQ_3': ('B', 2020, 1), + 'SEQ_4': ('B', 2020, 1), + 'SEQ_5': ('B', 2020, 1) } assert skipped_strains == []