Skip to content

Commit e542f7c

Browse files
hyanwongmergify[bot]
authored andcommitted
Properly treat blank ancestral allele, and set "N" as the default "unknown" state
Also document the class
1 parent 1d04fb8 commit e542f7c

File tree

6 files changed

+246
-81
lines changed

6 files changed

+246
-81
lines changed

CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## [0.4.0a3] - ****-**-**
4+
5+
**Fixes**
6+
7+
- Properly account for "N" as an unknown ancestral state, and ban "" from being
8+
set as an ancestral state ({pr}`963`, {user}`hyanwong`))
9+
310
## [0.4.0a2] - 2024-09-06
411

512
2nd Alpha release of tsinfer 0.4.0

docs/usage.md

+14-14
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ for sample in range(ds['call_genotype'].shape[1]):
6060

6161
We wish to infer a genealogy that could have given rise to this data set. To run _tsinfer_
6262
we wrap the .vcz file in a `tsinfer.VariantData` object. This requires an
63-
*ancestral allele* to be specified for each site; there are
63+
*ancestral state* to be specified for each site; there are
6464
many methods for calculating these: details are outside the scope of this manual, but we
6565
have started a [discussion topic](https://github.com/tskit-dev/tsinfer/discussions/523)
6666
on this issue to provide some recommendations.
6767

6868
Sometimes VCF files will contain the
69-
ancestral allele in the "AA" info field, in which case it will be encoded in the
70-
`variant_AA` field of the .vcz file. It's also possible to provide a numpy array
69+
ancestral state in the "AA" ("ancestral allele") info field, in which case it will be encoded
70+
in the `variant_AA` field of the .vcz file. It's also possible to provide a numpy array
7171
of ancestral alleles, of the same length as the number of variants. Ancestral
7272
alleles that are not in the list of alleles for their respective site are treated as unknown
7373
and not used for inference (with a warning given).
@@ -76,11 +76,11 @@ and not used for inference (with a warning given).
7676
import tsinfer
7777
7878
# For this example take the REF allele (index 0) as ancestral
79-
ancestral_allele = ds['variant_allele'][:,0].astype(str)
79+
ancestral_state = ds['variant_allele'][:,0].astype(str)
8080
# This is just a numpy array, set the last site to an unknown value, for demo purposes
81-
ancestral_allele[-1] = "."
81+
ancestral_state[-1] = "."
8282
83-
vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_allele)
83+
vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_state)
8484
```
8585

8686
The `VariantData` object is a lightweight wrapper around the .vcz file.
@@ -127,7 +127,7 @@ site_mask[ds.variant_position[:] >= 6] = True
127127
128128
smaller_vdata = tsinfer.VariantData(
129129
"_static/example_data.vcz",
130-
ancestral_allele=ancestral_allele[site_mask == False],
130+
ancestral_state=ancestral_state[site_mask == False],
131131
site_mask=site_mask,
132132
)
133133
print(f"The `smaller_vdata` object returns data for only {smaller_vdata.num_sites} sites")
@@ -351,8 +351,8 @@ Once we have our `.vcz` file created, running the inference is straightforward.
351351

352352
```{code-cell} ipython3
353353
# Infer & save a ts from the notebook simulation.
354-
ancestral_alleles = np.load(f"{name}-AA.npy")
355-
vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_alleles)
354+
ancestral_states = np.load(f"{name}-AA.npy")
355+
vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_states)
356356
tsinfer.infer(vdata, progress_monitor=True, num_threads=4).dump(name + ".trees")
357357
```
358358

@@ -477,12 +477,12 @@ vcf_location = "_static/P_dom_chr24_phased.vcf.gz"
477477
```
478478

479479
This creates the `sparrows.vcz` datastore, which we open using `tsinfer.VariantData`.
480-
The original VCF had ancestral alleles specified in the `AA` INFO field, so we can
481-
simply provide the string `"variant_AA"` as the ancestral_allele parameter.
480+
The original VCF had the ancestral allelic state specified in the `AA` INFO field,
481+
so we can simply provide the string `"variant_AA"` as the ancestral_state parameter.
482482

483483
```{code-cell} ipython3
484-
# Do the inference: this VCF has ancestral alleles in the AA field
485-
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA")
484+
# Do the inference: this VCF has ancestral states in the AA field
485+
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA")
486486
ts = tsinfer.infer(vdata)
487487
print(
488488
"Inferred tree sequence: {} trees over {} Mb ({} edges)".format(
@@ -534,7 +534,7 @@ Now when we carry out the inference, we get a tree sequence in which the nodes a
534534
correctly assigned to named populations
535535

536536
```{code-cell} ipython3
537-
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA")
537+
vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA")
538538
sparrow_ts = tsinfer.infer(vdata)
539539
540540
for sample_node_id in sparrow_ts.samples():

tests/test_inference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15321532
mat_wd = tsinfer.match_samples_batch_init(
15331533
work_dir=tmpdir / "working_mat",
15341534
sample_data_path=mat_sd.path,
1535-
ancestral_allele="variant_ancestral_allele",
1535+
ancestral_state="variant_ancestral_allele",
15361536
ancestor_ts_path=tmpdir / "mat_anc.trees",
15371537
min_work_per_job=1,
15381538
max_num_partitions=10,
@@ -1547,7 +1547,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
15471547
mask_wd = tsinfer.match_samples_batch_init(
15481548
work_dir=tmpdir / "working_mask",
15491549
sample_data_path=mask_sd.path,
1550-
ancestral_allele="variant_ancestral_allele",
1550+
ancestral_state="variant_ancestral_allele",
15511551
ancestor_ts_path=tmpdir / "mask_anc.trees",
15521552
min_work_per_job=1,
15531553
max_num_partitions=10,

tests/test_variantdata.py

+101-17
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
Tests for the data files.
2121
"""
2222
import json
23+
import logging
2324
import sys
2425
import tempfile
26+
import warnings
2527

2628
import msprime
2729
import numcodecs
@@ -627,14 +629,12 @@ def test_missing_ancestral_allele(tmp_path):
627629

628630

629631
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
630-
def test_ancestral_missingness(tmp_path):
632+
def test_deliberate_ancestral_missingness(tmp_path):
631633
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
632634
ds = sgkit.load_dataset(zarr_path)
633635
ancestral_allele = ds.variant_ancestral_allele.values
634636
ancestral_allele[0] = "N"
635-
ancestral_allele[11] = "-"
636-
ancestral_allele[12] = "💩"
637-
ancestral_allele[15] = "💩"
637+
ancestral_allele[1] = "n"
638638
ds = ds.drop_vars(["variant_ancestral_allele"])
639639
sgkit.save_dataset(ds, str(zarr_path) + ".tmp")
640640
tsutil.add_array_to_dataset(
@@ -644,15 +644,57 @@ def test_ancestral_missingness(tmp_path):
644644
["variants"],
645645
)
646646
ds = sgkit.load_dataset(str(zarr_path) + ".tmp")
647+
with warnings.catch_warnings():
648+
warnings.simplefilter("error") # No warning raised if AA deliberately missing
649+
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
650+
inf_ts = tsinfer.infer(sd)
651+
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
652+
if i in [0, 1]:
653+
assert inf_var.site.metadata == {"inference_type": "parsimony"}
654+
else:
655+
assert inf_var.site.ancestral_state == var.site.ancestral_state
656+
657+
658+
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
659+
def test_ancestral_missing_warning(tmp_path):
660+
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
661+
ds = sgkit.load_dataset(zarr_path)
662+
anc_state = ds.variant_ancestral_allele.values
663+
anc_state[0] = "N"
664+
anc_state[11] = "-"
665+
anc_state[12] = "💩"
666+
anc_state[15] = "💩"
647667
with pytest.warns(
648668
UserWarning,
649669
match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2",
650670
):
651-
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
652-
inf_ts = tsinfer.infer(sd)
671+
vdata = tsinfer.VariantData(zarr_path, anc_state)
672+
inf_ts = tsinfer.infer(vdata)
673+
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
674+
if i in [0, 11, 12, 15]:
675+
assert inf_var.site.metadata == {"inference_type": "parsimony"}
676+
assert inf_var.site.ancestral_state in var.site.alleles
677+
else:
678+
assert inf_var.site.ancestral_state == var.site.ancestral_state
679+
680+
681+
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
682+
def test_ancestral_missing_info(tmp_path, caplog):
683+
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
684+
ds = sgkit.load_dataset(zarr_path)
685+
anc_state = ds.variant_ancestral_allele.values
686+
anc_state[0] = "N"
687+
anc_state[11] = "N"
688+
anc_state[12] = "n"
689+
anc_state[15] = "n"
690+
with caplog.at_level(logging.INFO):
691+
vdata = tsinfer.VariantData(zarr_path, anc_state)
692+
assert f"4 sites ({4/ts.num_sites * 100 :.2f}%) were deliberately " in caplog.text
693+
inf_ts = tsinfer.infer(vdata)
653694
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
654695
if i in [0, 11, 12, 15]:
655696
assert inf_var.site.metadata == {"inference_type": "parsimony"}
697+
assert inf_var.site.ancestral_state in var.site.alleles
656698
else:
657699
assert inf_var.site.ancestral_state == var.site.ancestral_state
658700

@@ -670,6 +712,25 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
670712

671713

672714
class TestVariantDataErrors:
715+
@staticmethod
716+
def simulate_genotype_call_dataset(*args, **kwargs):
717+
# roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
718+
# duplicate alleles are created. Doesn't need to be efficient: just for testing
719+
if "seed" not in kwargs:
720+
kwargs["seed"] = 123
721+
ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs)
722+
variant_alleles = ds["variant_allele"].values
723+
allowed_alleles = np.array(
724+
["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype
725+
)
726+
for row in range(len(variant_alleles)):
727+
alleles = variant_alleles[row]
728+
if len(set(alleles)) != len(alleles):
729+
# Just use a set that we know is unique
730+
variant_alleles[row] = allowed_alleles[0 : len(alleles)]
731+
ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles
732+
return ds
733+
673734
def test_bad_zarr_spec(self):
674735
ds = zarr.group()
675736
ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
@@ -680,7 +741,7 @@ def test_bad_zarr_spec(self):
680741

681742
def test_missing_phase(self, tmp_path):
682743
path = tmp_path / "data.zarr"
683-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
744+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
684745
sgkit.save_dataset(ds, path)
685746
with pytest.raises(
686747
ValueError, match="The call_genotype_phased array is missing"
@@ -689,7 +750,7 @@ def test_missing_phase(self, tmp_path):
689750

690751
def test_phased(self, tmp_path):
691752
path = tmp_path / "data.zarr"
692-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
753+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
693754
ds["call_genotype_phased"] = (
694755
ds["call_genotype"].dims,
695756
np.ones(ds["call_genotype"].shape, dtype=bool),
@@ -700,13 +761,13 @@ def test_phased(self, tmp_path):
700761
def test_ploidy1_missing_phase(self, tmp_path):
701762
path = tmp_path / "data.zarr"
702763
# Ploidy==1 is always ok
703-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
764+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
704765
sgkit.save_dataset(ds, path)
705766
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
706767

707768
def test_ploidy1_unphased(self, tmp_path):
708769
path = tmp_path / "data.zarr"
709-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
770+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
710771
ds["call_genotype_phased"] = (
711772
ds["call_genotype"].dims,
712773
np.zeros(ds["call_genotype"].shape, dtype=bool),
@@ -716,31 +777,54 @@ def test_ploidy1_unphased(self, tmp_path):
716777

717778
def test_duplicate_positions(self, tmp_path):
718779
path = tmp_path / "data.zarr"
719-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
780+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
720781
ds["variant_position"][2] = ds["variant_position"][1]
721782
sgkit.save_dataset(ds, path)
722783
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
723784
tsinfer.VariantData(path, "variant_ancestral_allele")
724785

725786
def test_bad_order_positions(self, tmp_path):
726787
path = tmp_path / "data.zarr"
727-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
788+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
728789
ds["variant_position"][0] = ds["variant_position"][2] - 0.5
729790
sgkit.save_dataset(ds, path)
730791
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
731792
tsinfer.VariantData(path, "variant_ancestral_allele")
732793

794+
def test_bad_ancestral_state(self, tmp_path):
795+
path = tmp_path / "data.zarr"
796+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
797+
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
798+
ancestral_state[1] = ""
799+
sgkit.save_dataset(ds, path)
800+
with pytest.raises(ValueError, match="cannot contain empty strings"):
801+
tsinfer.VariantData(path, ancestral_state)
802+
733803
def test_empty_alleles_not_at_end(self, tmp_path):
734804
path = tmp_path / "data.zarr"
735-
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
805+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
736806
ds["variant_allele"] = (
737807
ds["variant_allele"].dims,
738-
np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
808+
np.array([["A", "", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
739809
)
740810
sgkit.save_dataset(ds, path)
741-
vdata = tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
742-
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
743-
tsinfer.infer(vdata)
811+
with pytest.raises(
812+
ValueError, match='Bad alleles: fill value "" in middle of list'
813+
):
814+
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
815+
816+
def test_unique_alleles(self, tmp_path):
817+
path = tmp_path / "data.zarr"
818+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
819+
ds["variant_allele"] = (
820+
ds["variant_allele"].dims,
821+
np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"),
822+
)
823+
sgkit.save_dataset(ds, path)
824+
with pytest.raises(
825+
ValueError, match="Duplicate allele values provided at site 2"
826+
):
827+
tsinfer.VariantData(path, np.array(["A", "A", "A"], dtype="S1"))
744828

745829
def test_unimplemented_from_tree_sequence(self):
746830
# NB we should reimplement something like this functionality.

0 commit comments

Comments
 (0)