From d4c8f49f864dd1570392e1cf8acb8b827feb7a2e Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Wed, 4 Sep 2024 18:04:56 +0100 Subject: [PATCH] Properly treat blank ancestral allele, and set "N" as the default "unknown" state Also document the class --- tests/test_variantdata.py | 34 ++++++++++++++++++++-- tsinfer/formats.py | 61 ++++++++++++++++++++++++++++++++++----- 2 files changed, 85 insertions(+), 10 deletions(-) diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index a353c4c0..18aa94ff 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -22,6 +22,7 @@ import json import sys import tempfile +import warnings import msprime import numcodecs @@ -626,12 +627,40 @@ def test_missing_ancestral_allele(tmp_path): tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele") +@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") +def test_deliberate_ancestral_missingness(tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + ds = sgkit.load_dataset(zarr_path) + ancestral_allele = ds.variant_ancestral_allele.values + ancestral_allele[0] = "N" + ancestral_allele[1] = "n" + ds = ds.drop_vars(["variant_ancestral_allele"]) + sgkit.save_dataset(ds, str(zarr_path) + ".tmp") + tsutil.add_array_to_dataset( + "variant_ancestral_allele", + ancestral_allele, + str(zarr_path) + ".tmp", + ["variants"], + ) + ds = sgkit.load_dataset(str(zarr_path) + ".tmp") + with warnings.catch_warnings(): + warnings.simplefilter("error") # No warning raised if AA deliberately missing + sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele") + inf_ts = tsinfer.infer(sd) + for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())): + if i in [0, 1]: + assert inf_var.site.metadata == {"inference_type": "parsimony"} + else: + assert inf_var.site.ancestral_state == var.site.ancestral_state + + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") def test_ancestral_missingness(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) ds = sgkit.load_dataset(zarr_path) ancestral_allele = ds.variant_ancestral_allele.values ancestral_allele[0] = "N" + ancestral_allele[2] = "" ancestral_allele[11] = "-" ancestral_allele[12] = "💩" ancestral_allele[15] = "💩" @@ -646,13 +675,14 @@ def test_ancestral_missingness(tmp_path): ds = sgkit.load_dataset(str(zarr_path) + ".tmp") with pytest.warns( UserWarning, - match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2", + match=r"not found in the variant_allele array for the 5 [\s\S]*'💩': 2", ): sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele") inf_ts = tsinfer.infer(sd) for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())): - if i in [0, 11, 12, 15]: + if i in [0, 2, 11, 12, 15]: assert inf_var.site.metadata == {"inference_type": "parsimony"} + assert inf_var.site.ancestral_state in var.site.alleles else: assert inf_var.site.ancestral_state == var.site.ancestral_state diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 39f59e13..7095e04b 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -31,6 +31,7 @@ import sys import threading import warnings +from typing import Union # noqa: F401 import attr import humanize @@ -2293,6 +2294,48 @@ def populations(self): class VariantData(SampleData): + """ + Class representing input variant data used for inference. This is + mostly a thin wrapper for a Zarr dataset storing information in + the VCF Zarr (.vcz) format, plus information specifing the ancestral allele + and (optional) data masks. It then provides various derived properties and + methods for accessing the data in a form suitable for inference. + + .. note:: + In the VariantData object, "samples" refer to the individuals in the dataset, + each of which can be of arbitrary ploidy. This is in contrast to ``tskit``, + in which each *haploid genome* is treated as a separate "sample". For example + in a diploid dataset, the inferred tree sequence returned at the end of + the inference process will have ``inferred_ts.num_samples`` equal to double + the number returned by ``VariantData.num_samples``. + + :param str path: The path to the file containing the input dataset in VCF-Zarr + format. + :param Union(array, str) ancestral_allele: A numpy array of strings specifying + the ancestral alleles used in inference. This must be the same length as + the number of unmasked sites in the dataset. Alternatively, a single string + can be provided, giving the name of an array in the input dataset which contains + the ancestral alleles. Unknown ancestral alleles can be specified using "N". + Any ancestral alleles which do not match any of the known alleles at that site, + will be tallied, and a warning issued summarizing the unknown ancestral states. + :param Union(array, str) sample_mask: A numpy array of booleans specifying which + samples to mask out (exclude) from the dataset. Alternatively, a string + can be provided, giving the name of an array in the input dataset which contains + the sample mask. If ``None`` (default), all samples are included. + :param Union(array, str) site_mask: A numpy array of booleans specifying which + sites to mask out (exclude) from the dataset. Alternatively, a string + can be provided, giving the name of an array in the input dataset which contains + the site mask. If ``None`` (default), all sites are included. + :param Union(array, str) sites_time: A numpy array of floats specifying the relative + time of occurrence of the mutation to the derived state at each site. This must + be of the same length as the number of unmasked sites. Alternatively, a + string can be provided, giving the name of an array in the input dataset + which contains the site times. If ``None`` (default), the frequency of the + derived allele is used as a proxy for the time of occurrence: this is usually a + reasonable approximation to the relative order of ancestors used for inference. + Time values are ignored for sites not used in inference, such as singletons, + sites with more than two alleles, or sites with an unknown ancestral allele. + """ FORMAT_NAME = "tsinfer-variant-data" FORMAT_VERSION = (0, 1) @@ -2412,16 +2455,18 @@ def __init__( ) self._sites_ancestral_allele = self._sites_ancestral_allele.astype(str) unknown_alleles = collections.Counter() - converted = np.zeros(self.num_sites, dtype=np.int8) + converted = np.full(self.num_sites, -1, dtype=np.int8) for i, allele in enumerate(self._sites_ancestral_allele): - allele_index = -1 - try: - allele_index = np.where(allele == self.sites_alleles[i])[0][0] - except IndexError: - unknown_alleles[allele] += 1 - converted[i] = allele_index + if not (allele in {"", "N", "n"}): # All these must represent unknown + try: + converted[i] = np.where(allele == self.sites_alleles[i])[0][0] + continue + except IndexError: + pass + unknown_alleles[allele] += 1 + deliberately_unknown = sum([unknown_alleles.get(c, 0) for c in ("N", "n")]) tot = sum(unknown_alleles.values()) - if tot > 0: + if tot != deliberately_unknown: frac_bad = tot / self.num_sites frac_bad_per_type = [v / self.num_sites for v in unknown_alleles.values()] summarise_unknown = [