Skip to content

Commit 8742ee8

Browse files
committed
Deal with multiple contigs and sequence lengths
Introduces a `contig_id` parameter to variant_data, as described in #949. Fixes #249
1 parent 7f23758 commit 8742ee8

File tree

4 files changed

+234
-15
lines changed

4 files changed

+234
-15
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
- Document that the `zarr-vcf` dataset can be either a path or an in-memory zarr group.
88
(feature introduced in {pr}`966`, documented in {pr}`974`, {user}`hyanwong`)
99

10+
- Allow a contig to be selected by name (`contig_id`), and get the `sequence_length`
11+
of the contig associated with the unmasked sites, if contig lengths are provided
12+
({pr}`964`, {user}`hyanwong`)
13+
1014
**Fixes**
1115

1216
- Properly account for "N" as an unknown ancestral state, and ban "" from being

docs/usage.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ onto branches by {meth}`parsimony<tskit.Tree.map_mutations>`.
107107
It is also possible to *completely* exclude sites and samples, by specifing a boolean
108108
`site_mask` and/or a `sample_mask` when creating the `VariantData` object. Sites or samples with
109109
a mask value of `True` will be completely omitted both from inference and the final tree sequence.
110-
This can be useful, for example, if your VCF file contains multiple chromosomes (in which case
111-
`tsinfer` will need to be run separately on each chromosome) or if you wish to select only a subset
112-
of the chromosome for inference (e.g. to reduce computational load). If a `site_mask` is provided,
113-
note that the ancestral alleles array only specifies alleles for the unmasked sites.
110+
This can be useful, for example, if you wish to select only a subset of the chromosome for
111+
inference, e.g. to reduce computational load. You can also use it to subset inference to a
112+
particular contig, if your dataset contains multiple contigs (although this can be more easily
113+
done using the `contig_id` parameter). Note that if a `site_mask` is provided,
114+
the ancestral states array should only specify alleles for the unmasked sites.
114115

115116
Below, for instance, is an example of including only sites up to position six in the contig
116117
labelled "chr1" in the `example_data.vcz` file:

tests/test_variantdata.py

Lines changed: 165 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from tsinfer import formats
3939

4040

41-
def ts_to_dataset(ts, chunks=None, samples=None):
41+
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
4242
"""
4343
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
4444
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
6363
genotypes = np.expand_dims(genotypes, axis=2)
6464

6565
ds = sgkit.create_genotype_call_dataset(
66-
variant_contig_names=["1"],
66+
variant_contig_names=["1"] if contigs is None else contigs,
6767
variant_contig=np.zeros(len(tables.sites), dtype=int),
6868
variant_position=tables.sites.position.astype(int),
6969
variant_allele=alleles,
@@ -280,18 +280,102 @@ def test_simulate_genotype_call_dataset(tmp_path):
280280
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
281281
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
282282
ds = ts_to_dataset(ts)
283-
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
284283
ds.to_zarr(tmp_path, mode="w")
285-
sd = tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
286-
ts = tsinfer.infer(sd)
287-
for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes):
284+
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
285+
ts = tsinfer.infer(vdata)
286+
for v, ds_v, vd_v in zip(ts.variants(), ds.call_genotype, vdata.sites_genotypes):
288287
assert np.all(v.genotypes == ds_v.values.flatten())
289-
assert np.all(v.genotypes == sd_v)
288+
assert np.all(v.genotypes == vd_v)
289+
290+
291+
def test_simulate_genotype_call_dataset_length(tmp_path):
292+
# create_genotype_call_dataset does not save contig lengths
293+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
294+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
295+
ds = ts_to_dataset(ts)
296+
assert "contig_length" not in ds
297+
ds.to_zarr(tmp_path, mode="w")
298+
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
299+
assert vdata.sequence_length == ts.sites_position[-1] + 1
300+
301+
302+
class TestMultiContig:
303+
def make_two_ts_dataset(self, path):
304+
# split ts into 2; put them as different contigs in the same dataset
305+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
306+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
307+
split_at_site = 7
308+
assert ts.num_sites > 10
309+
site_break = ts.site(split_at_site).position
310+
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
311+
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
312+
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
313+
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
314+
variant_contig = ds["variant_contig"][:]
315+
variant_contig[split_at_site:] = 1
316+
ds.update({"variant_contig": variant_contig})
317+
variant_position = ds["variant_position"].values
318+
variant_position[split_at_site:] -= int(site_break)
319+
ds.update({"variant_position": ds["variant_position"]})
320+
ds.update(
321+
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
322+
)
323+
ds.to_zarr(path, mode="w")
324+
return ts1, ts2
325+
326+
def test_unmasked(self, tmp_path):
327+
self.make_two_ts_dataset(tmp_path)
328+
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
329+
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
330+
331+
def test_mask(self, tmp_path):
332+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
333+
vdata = tsinfer.VariantData(
334+
tmp_path,
335+
"variant_ancestral_allele",
336+
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
337+
)
338+
assert np.all(ts2.sites_position == vdata.sites_position)
339+
assert vdata.contig_id == "chr2"
340+
assert vdata.sequence_length == ts2.sequence_length
341+
342+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
343+
def test_contig_id_param(self, contig_id, tmp_path):
344+
tree_seqs = {}
345+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
346+
vdata = tsinfer.VariantData(
347+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
348+
)
349+
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
350+
assert vdata.contig_id == contig_id
351+
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length
352+
353+
def test_contig_id_param_and_mask(self, tmp_path):
354+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
355+
vdata = tsinfer.VariantData(
356+
tmp_path,
357+
"variant_ancestral_allele",
358+
site_mask=np.array(
359+
(ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False]
360+
),
361+
contig_id="chr2",
362+
)
363+
assert np.all(ts2.sites_position[1:] == vdata.sites_position)
364+
assert vdata.contig_id == "chr2"
365+
366+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
367+
def test_contig_length(self, contig_id, tmp_path):
368+
tree_seqs = {}
369+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
370+
vdata = tsinfer.VariantData(
371+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
372+
)
373+
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length
290374

291375

292376
@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
293377
class TestSgkitMask:
294-
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
378+
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0]])
295379
def test_sgkit_variant_mask(self, tmp_path, sites):
296380
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
297381
ds = sgkit.load_dataset(zarr_path)
@@ -800,6 +884,20 @@ def test_bad_ancestral_state(self, tmp_path):
800884
with pytest.raises(ValueError, match="cannot contain empty strings"):
801885
tsinfer.VariantData(path, ancestral_state)
802886

887+
def test_ancestral_state_len_not_same_as_mask(self, tmp_path):
888+
path = tmp_path / "data.zarr"
889+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
890+
sgkit.save_dataset(ds, path)
891+
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
892+
site_mask = np.zeros(ds.sizes["variants"], dtype=bool)
893+
site_mask[0] = True
894+
with pytest.raises(
895+
ValueError,
896+
match="Ancestral state array must be the same length as the number of"
897+
" selected sites",
898+
):
899+
tsinfer.VariantData(path, ancestral_state, site_mask=site_mask)
900+
803901
def test_empty_alleles_not_at_end(self, tmp_path):
804902
path = tmp_path / "data.zarr"
805903
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
@@ -831,3 +929,62 @@ def test_unimplemented_from_tree_sequence(self):
831929
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
832930
with pytest.raises(NotImplementedError):
833931
tsinfer.VariantData.from_tree_sequence(None)
932+
933+
def test_multiple_contigs(self, tmp_path):
934+
path = tmp_path / "data.zarr"
935+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
936+
ds["contig_id"] = (
937+
ds["contig_id"].dims,
938+
np.array(["c10", "c11"], dtype="<U3"),
939+
)
940+
ds["variant_contig"] = (
941+
ds["variant_contig"].dims,
942+
np.array([0, 0, 1], dtype=ds["variant_contig"].dtype),
943+
)
944+
sgkit.save_dataset(ds, path)
945+
with pytest.raises(
946+
ValueError, match=r'Sites belong to multiple contigs \("c10", "c11"\)'
947+
):
948+
tsinfer.VariantData(path, ds["variant_allele"][:, 0].astype(str))
949+
950+
def test_all_masked(self, tmp_path):
951+
path = tmp_path / "data.zarr"
952+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
953+
sgkit.save_dataset(ds, path)
954+
with pytest.raises(ValueError, match="All sites have been masked out"):
955+
tsinfer.VariantData(
956+
path, ds["variant_allele"][:, 0].astype(str), site_mask=np.ones(3, bool)
957+
)
958+
959+
def test_bad_contig_param(self, tmp_path):
960+
path = tmp_path / "data.zarr"
961+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
962+
sgkit.save_dataset(ds, path)
963+
with pytest.raises(ValueError, match='"XX" not found'):
964+
tsinfer.VariantData(
965+
path, ds["variant_allele"][:, 0].astype(str), contig_id="XX"
966+
)
967+
968+
def test_multiple_contig_param(self, tmp_path):
969+
path = tmp_path / "data.zarr"
970+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
971+
ds["contig_id"] = (
972+
ds["contig_id"].dims,
973+
np.array(["chr1", "chr1"], dtype="<U4"),
974+
)
975+
sgkit.save_dataset(ds, path)
976+
with pytest.raises(ValueError, match='Multiple contigs named "chr1"'):
977+
tsinfer.VariantData(
978+
path, ds["variant_allele"][:, 0].astype(str), contig_id="chr1"
979+
)
980+
981+
def test_missing_sites_time(self, tmp_path):
982+
path = tmp_path / "data.zarr"
983+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
984+
sgkit.save_dataset(ds, path)
985+
with pytest.raises(
986+
ValueError, match="The sites time array XX was not found in the dataset"
987+
):
988+
tsinfer.VariantData(
989+
path, ds["variant_allele"][:, 0].astype(str), sites_time="XX"
990+
)

tsinfer/formats.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,7 +2326,8 @@ class VariantData(SampleData):
23262326
:param Union(array, str) site_mask: A numpy array of booleans specifying which
23272327
sites to mask out (exclude) from the dataset. Alternatively, a string
23282328
can be provided, giving the name of an array in the input dataset which contains
2329-
the site mask. If ``None`` (default), all sites are included.
2329+
the site mask. If ``None`` (default), all sites are included (unless restricted
2330+
to a ``contig_id``, see below).
23302331
:param Union(array, str) sites_time: A numpy array of floats specifying the relative
23312332
time of occurrence of the mutation to the derived state at each site. This must
23322333
be of the same length as the number of unmasked sites. Alternatively, a
@@ -2336,6 +2337,11 @@ class VariantData(SampleData):
23362337
reasonable approximation to the relative order of ancestors used for inference.
23372338
Time values are ignored for sites not used in inference, such as singletons,
23382339
sites with more than two alleles, or sites with an unknown ancestral state.
2340+
:param str contig_id: The name of the contig to use (e.g. "chr1"), if the .vcz file
2341+
contains multiple contigs; contig names can be found in the `.contig_id array
2342+
of the input dataset. If provided, sites associated with any other contigs will
2343+
be added to the sites that are masked out. If ``None`` (default), do not mark
2344+
out sites on the basis of their contig ID.
23392345
"""
23402346

23412347
FORMAT_NAME = "tsinfer-variant-data"
@@ -2349,6 +2355,7 @@ def __init__(
23492355
sample_mask=None,
23502356
site_mask=None,
23512357
sites_time=None,
2358+
contig_id=None,
23522359
):
23532360
try:
23542361
if len(path_or_zarr.call_genotype.shape) == 3:
@@ -2382,8 +2389,24 @@ def __init__(
23822389
raise ValueError(
23832390
"Site mask array must be the same length as the number of unmasked sites"
23842391
)
2392+
if contig_id is not None:
2393+
contig_index = np.where(self.data.contig_id[:] == contig_id)[0]
2394+
if len(contig_index) == 0:
2395+
raise ValueError(
2396+
f'"{contig_id}" not found among the available contig IDs: '
2397+
+ ",".join(f"{n}" for n in self.data.contig_id[:])
2398+
)
2399+
elif len(contig_index) > 1:
2400+
raise ValueError(f'Multiple contigs named "{contig_id}"')
2401+
contig_index = contig_index[0]
2402+
site_mask = np.logical_or(
2403+
site_mask, self.data["variant_contig"][:] != contig_index
2404+
)
2405+
23852406
# We negate the mask as it is much easier in numpy to have True=keep
23862407
self.sites_select = ~site_mask.astype(bool)
2408+
if np.sum(self.sites_select) == 0:
2409+
raise ValueError("All sites have been masked out. Please unmask some")
23872410

23882411
if sample_mask is None:
23892412
sample_mask = np.full(self._num_individuals_before_mask, False, dtype=bool)
@@ -2413,6 +2436,20 @@ def __init__(
24132436
" zarr dataset, indicating that all the genotypes are"
24142437
" unphased"
24152438
)
2439+
2440+
used_contigs = self.data.variant_contig[:][self.sites_select]
2441+
self._contig_index = used_contigs[0]
2442+
self._contig_id = self.data.contig_id[self._contig_index]
2443+
2444+
if np.any(used_contigs != self._contig_index):
2445+
contig_names = ", ".join(
2446+
f'"{self.data.contig_id[c]}"' for c in np.unique(used_contigs)
2447+
)
2448+
raise ValueError(
2449+
f"Sites belong to multiple contigs ({contig_names}). Please restrict "
2450+
"sites to one contig e.g. via the `contig_id` argument."
2451+
)
2452+
24162453
if np.any(np.diff(self.sites_position) <= 0):
24172454
raise ValueError(
24182455
"Values taken from the variant_position array are not strictly "
@@ -2434,7 +2471,7 @@ def __init__(
24342471
self._sites_time = self.data[sites_time][:][self.sites_select]
24352472
except KeyError:
24362473
raise ValueError(
2437-
f"The sites time {sites_time} was not found" f" in the dataset."
2474+
f"The sites time array {sites_time} was not found in the dataset"
24382475
)
24392476

24402477
if isinstance(ancestral_state, np.ndarray):
@@ -2517,10 +2554,30 @@ def finalised(self):
25172554

25182555
@functools.cached_property
25192556
def sequence_length(self):
2557+
"""
2558+
The sequence length of the contig associated with sites used in the dataset.
2559+
If the dataset has a "sequence_length" attribute, this is always used, otherwise
2560+
if the dataset has recorded contig lengths, the appropriate length is taken,
2561+
otherwise the length is calculated from the maximum variant position plus one.
2562+
"""
25202563
try:
25212564
return self.data.attrs["sequence_length"]
25222565
except KeyError:
2523-
return int(np.max(self.data["variant_position"])) + 1
2566+
if self._contig_index is not None:
2567+
try:
2568+
if self._contig_index < len(self.data.contig_length):
2569+
return self.data.contig_length[self._contig_index]
2570+
except AttributeError:
2571+
pass # contig_length is optional, fall back to calculating length
2572+
return int(np.max(self.data["variant_position"])) + 1
2573+
2574+
@property
2575+
def contig_id(self):
2576+
"""
2577+
The contig ID (name) for all used sites, or None if no
2578+
contig IDs were provided
2579+
"""
2580+
return self._contig_id
25242581

25252582
@property
25262583
def num_sites(self):

0 commit comments

Comments
 (0)