Skip to content

Commit 41cc370

Browse files
hyanwongbenjeffery
authored andcommitted
Deal with multiple contigs and sequence lengths
Introduces a `contig_id` parameter to variant_data, as described in #949. Fixes #249
1 parent 1aa0233 commit 41cc370

File tree

4 files changed

+234
-21
lines changed

4 files changed

+234
-21
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
- If a mismatch ratio is provided to the `infer` command, it only applies during the
1313
`match_samples` phase ({issue}`980`, {pr}`981`, {user}`hyanwong`)
1414

15+
- Allow a contig to be selected by name (`contig_id`), and get the `sequence_length`
16+
of the contig associated with the unmasked sites, if contig lengths are provided
17+
({pr}`964`, {user}`hyanwong`)
18+
1519
**Fixes**
1620

1721
- Properly account for "N" as an unknown ancestral state, and ban "" from being

docs/usage.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ onto branches by {meth}`parsimony<tskit.Tree.map_mutations>`.
107107
It is also possible to *completely* exclude sites and samples, by specifing a boolean
108108
`site_mask` and/or a `sample_mask` when creating the `VariantData` object. Sites or samples with
109109
a mask value of `True` will be completely omitted both from inference and the final tree sequence.
110-
This can be useful, for example, if your VCF file contains multiple chromosomes (in which case
111-
`tsinfer` will need to be run separately on each chromosome) or if you wish to select only a subset
112-
of the chromosome for inference (e.g. to reduce computational load). If a `site_mask` is provided,
113-
note that the ancestral alleles array only specifies alleles for the unmasked sites.
110+
This can be useful, for example, if you wish to select only a subset of the chromosome for
111+
inference, e.g. to reduce computational load. You can also use it to subset inference to a
112+
particular contig, if your dataset contains multiple contigs (although this can be more easily
113+
done using the `contig_id` parameter). Note that if a `site_mask` is provided,
114+
the ancestral states array should only specify alleles for the unmasked sites.
114115

115116
Below, for instance, is an example of including only sites up to position six in the contig
116117
labelled "chr1" in the `example_data.vcz` file:

tests/test_variantdata.py

Lines changed: 165 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from tsinfer import formats
3939

4040

41-
def ts_to_dataset(ts, chunks=None, samples=None):
41+
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
4242
"""
4343
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
4444
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
6363
genotypes = np.expand_dims(genotypes, axis=2)
6464

6565
ds = sgkit.create_genotype_call_dataset(
66-
variant_contig_names=["1"],
66+
variant_contig_names=["1"] if contigs is None else contigs,
6767
variant_contig=np.zeros(len(tables.sites), dtype=int),
6868
variant_position=tables.sites.position.astype(int),
6969
variant_allele=alleles,
@@ -299,18 +299,102 @@ def test_simulate_genotype_call_dataset(tmp_path):
299299
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
300300
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
301301
ds = ts_to_dataset(ts)
302-
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
303302
ds.to_zarr(tmp_path, mode="w")
304-
sd = tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
305-
ts = tsinfer.infer(sd)
306-
for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes):
303+
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
304+
ts = tsinfer.infer(vdata)
305+
for v, ds_v, vd_v in zip(ts.variants(), ds.call_genotype, vdata.sites_genotypes):
307306
assert np.all(v.genotypes == ds_v.values.flatten())
308-
assert np.all(v.genotypes == sd_v)
307+
assert np.all(v.genotypes == vd_v)
308+
309+
310+
def test_simulate_genotype_call_dataset_length(tmp_path):
311+
# create_genotype_call_dataset does not save contig lengths
312+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
313+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
314+
ds = ts_to_dataset(ts)
315+
assert "contig_length" not in ds
316+
ds.to_zarr(tmp_path, mode="w")
317+
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
318+
assert vdata.sequence_length == ts.sites_position[-1] + 1
319+
320+
321+
class TestMultiContig:
322+
def make_two_ts_dataset(self, path):
323+
# split ts into 2; put them as different contigs in the same dataset
324+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
325+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
326+
split_at_site = 7
327+
assert ts.num_sites > 10
328+
site_break = ts.site(split_at_site).position
329+
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
330+
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
331+
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
332+
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
333+
variant_contig = ds["variant_contig"][:]
334+
variant_contig[split_at_site:] = 1
335+
ds.update({"variant_contig": variant_contig})
336+
variant_position = ds["variant_position"].values
337+
variant_position[split_at_site:] -= int(site_break)
338+
ds.update({"variant_position": ds["variant_position"]})
339+
ds.update(
340+
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
341+
)
342+
ds.to_zarr(path, mode="w")
343+
return ts1, ts2
344+
345+
def test_unmasked(self, tmp_path):
346+
self.make_two_ts_dataset(tmp_path)
347+
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
348+
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
349+
350+
def test_mask(self, tmp_path):
351+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
352+
vdata = tsinfer.VariantData(
353+
tmp_path,
354+
"variant_ancestral_allele",
355+
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
356+
)
357+
assert np.all(ts2.sites_position == vdata.sites_position)
358+
assert vdata.contig_id == "chr2"
359+
assert vdata.sequence_length == ts2.sequence_length
360+
361+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
362+
def test_contig_id_param(self, contig_id, tmp_path):
363+
tree_seqs = {}
364+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
365+
vdata = tsinfer.VariantData(
366+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
367+
)
368+
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
369+
assert vdata.contig_id == contig_id
370+
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length
371+
372+
def test_contig_id_param_and_mask(self, tmp_path):
373+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
374+
vdata = tsinfer.VariantData(
375+
tmp_path,
376+
"variant_ancestral_allele",
377+
site_mask=np.array(
378+
(ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False]
379+
),
380+
contig_id="chr2",
381+
)
382+
assert np.all(ts2.sites_position[1:] == vdata.sites_position)
383+
assert vdata.contig_id == "chr2"
384+
385+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
386+
def test_contig_length(self, contig_id, tmp_path):
387+
tree_seqs = {}
388+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
389+
vdata = tsinfer.VariantData(
390+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
391+
)
392+
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length
309393

310394

311395
@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
312396
class TestSgkitMask:
313-
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
397+
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0]])
314398
def test_sgkit_variant_mask(self, tmp_path, sites):
315399
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
316400
ds = sgkit.load_dataset(zarr_path)
@@ -823,6 +907,20 @@ def test_bad_ancestral_state(self, tmp_path):
823907
with pytest.raises(ValueError, match="cannot contain empty strings"):
824908
tsinfer.VariantData(path, ancestral_state)
825909

910+
def test_ancestral_state_len_not_same_as_mask(self, tmp_path):
911+
path = tmp_path / "data.zarr"
912+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
913+
sgkit.save_dataset(ds, path)
914+
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
915+
site_mask = np.zeros(ds.sizes["variants"], dtype=bool)
916+
site_mask[0] = True
917+
with pytest.raises(
918+
ValueError,
919+
match="Ancestral state array must be the same length as the number of"
920+
" selected sites",
921+
):
922+
tsinfer.VariantData(path, ancestral_state, site_mask=site_mask)
923+
826924
def test_empty_alleles_not_at_end(self, tmp_path):
827925
path = tmp_path / "data.zarr"
828926
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
@@ -854,3 +952,62 @@ def test_unimplemented_from_tree_sequence(self):
854952
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
855953
with pytest.raises(NotImplementedError):
856954
tsinfer.VariantData.from_tree_sequence(None)
955+
956+
def test_multiple_contigs(self, tmp_path):
957+
path = tmp_path / "data.zarr"
958+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
959+
ds["contig_id"] = (
960+
ds["contig_id"].dims,
961+
np.array(["c10", "c11"], dtype="<U3"),
962+
)
963+
ds["variant_contig"] = (
964+
ds["variant_contig"].dims,
965+
np.array([0, 0, 1], dtype=ds["variant_contig"].dtype),
966+
)
967+
sgkit.save_dataset(ds, path)
968+
with pytest.raises(
969+
ValueError, match=r'Sites belong to multiple contigs \("c10", "c11"\)'
970+
):
971+
tsinfer.VariantData(path, ds["variant_allele"][:, 0].astype(str))
972+
973+
def test_all_masked(self, tmp_path):
974+
path = tmp_path / "data.zarr"
975+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
976+
sgkit.save_dataset(ds, path)
977+
with pytest.raises(ValueError, match="All sites have been masked out"):
978+
tsinfer.VariantData(
979+
path, ds["variant_allele"][:, 0].astype(str), site_mask=np.ones(3, bool)
980+
)
981+
982+
def test_bad_contig_param(self, tmp_path):
983+
path = tmp_path / "data.zarr"
984+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
985+
sgkit.save_dataset(ds, path)
986+
with pytest.raises(ValueError, match='"XX" not found'):
987+
tsinfer.VariantData(
988+
path, ds["variant_allele"][:, 0].astype(str), contig_id="XX"
989+
)
990+
991+
def test_multiple_contig_param(self, tmp_path):
992+
path = tmp_path / "data.zarr"
993+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
994+
ds["contig_id"] = (
995+
ds["contig_id"].dims,
996+
np.array(["chr1", "chr1"], dtype="<U4"),
997+
)
998+
sgkit.save_dataset(ds, path)
999+
with pytest.raises(ValueError, match='Multiple contigs named "chr1"'):
1000+
tsinfer.VariantData(
1001+
path, ds["variant_allele"][:, 0].astype(str), contig_id="chr1"
1002+
)
1003+
1004+
def test_missing_sites_time(self, tmp_path):
1005+
path = tmp_path / "data.zarr"
1006+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
1007+
sgkit.save_dataset(ds, path)
1008+
with pytest.raises(
1009+
ValueError, match="The sites time array XX was not found in the dataset"
1010+
):
1011+
tsinfer.VariantData(
1012+
path, ds["variant_allele"][:, 0].astype(str), sites_time="XX"
1013+
)

tsinfer/formats.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,7 +2328,8 @@ class VariantData(SampleData):
23282328
:param Union(array, str) site_mask: A numpy array of booleans specifying which
23292329
sites to mask out (exclude) from the dataset. Alternatively, a string
23302330
can be provided, giving the name of an array in the input dataset which contains
2331-
the site mask. If ``None`` (default), all sites are included.
2331+
the site mask. If ``None`` (default), all sites are included (unless restricted
2332+
to a ``contig_id``, see below).
23322333
:param Union(array, str) sites_time: A numpy array of floats specifying the relative
23332334
time of occurrence of the mutation to the derived state at each site. This must
23342335
be of the same length as the number of unmasked sites. Alternatively, a
@@ -2338,6 +2339,11 @@ class VariantData(SampleData):
23382339
reasonable approximation to the relative order of ancestors used for inference.
23392340
Time values are ignored for sites not used in inference, such as singletons,
23402341
sites with more than two alleles, or sites with an unknown ancestral state.
2342+
:param str contig_id: The name of the contig to use (e.g. "chr1"), if the .vcz file
2343+
contains multiple contigs; contig names can be found in the `.contig_id array
2344+
of the input dataset. If provided, sites associated with any other contigs will
2345+
be added to the sites that are masked out. If ``None`` (default), do not mark
2346+
out sites on the basis of their contig ID.
23412347
"""
23422348

23432349
FORMAT_NAME = "tsinfer-variant-data"
@@ -2351,6 +2357,7 @@ def __init__(
23512357
sample_mask=None,
23522358
site_mask=None,
23532359
sites_time=None,
2360+
contig_id=None,
23542361
):
23552362
try:
23562363
if len(path_or_zarr.call_genotype.shape) == 3:
@@ -2384,8 +2391,24 @@ def __init__(
23842391
raise ValueError(
23852392
"Site mask array must be the same length as the number of unmasked sites"
23862393
)
2394+
if contig_id is not None:
2395+
contig_index = np.where(self.data.contig_id[:] == contig_id)[0]
2396+
if len(contig_index) == 0:
2397+
raise ValueError(
2398+
f'"{contig_id}" not found among the available contig IDs: '
2399+
+ ",".join(f"{n}" for n in self.data.contig_id[:])
2400+
)
2401+
elif len(contig_index) > 1:
2402+
raise ValueError(f'Multiple contigs named "{contig_id}"')
2403+
contig_index = contig_index[0]
2404+
site_mask = np.logical_or(
2405+
site_mask, self.data["variant_contig"][:] != contig_index
2406+
)
2407+
23872408
# We negate the mask as it is much easier in numpy to have True=keep
23882409
self.sites_select = ~site_mask.astype(bool)
2410+
if np.sum(self.sites_select) == 0:
2411+
raise ValueError("All sites have been masked out. Please unmask some")
23892412

23902413
if sample_mask is None:
23912414
sample_mask = np.full(self._num_individuals_before_mask, False, dtype=bool)
@@ -2415,6 +2438,20 @@ def __init__(
24152438
" zarr dataset, indicating that all the genotypes are"
24162439
" unphased"
24172440
)
2441+
2442+
used_contigs = self.data.variant_contig[:][self.sites_select]
2443+
self._contig_index = used_contigs[0]
2444+
self._contig_id = self.data.contig_id[self._contig_index]
2445+
2446+
if np.any(used_contigs != self._contig_index):
2447+
contig_names = ", ".join(
2448+
f'"{self.data.contig_id[c]}"' for c in np.unique(used_contigs)
2449+
)
2450+
raise ValueError(
2451+
f"Sites belong to multiple contigs ({contig_names}). Please restrict "
2452+
"sites to one contig e.g. via the `contig_id` argument."
2453+
)
2454+
24182455
if np.any(np.diff(self.sites_position) <= 0):
24192456
raise ValueError(
24202457
"Values taken from the variant_position array are not strictly "
@@ -2436,7 +2473,7 @@ def __init__(
24362473
self._sites_time = self.data[sites_time][:][self.sites_select]
24372474
except KeyError:
24382475
raise ValueError(
2439-
f"The sites time {sites_time} was not found" f" in the dataset."
2476+
f"The sites time array {sites_time} was not found in the dataset"
24402477
)
24412478

24422479
if isinstance(ancestral_state, np.ndarray):
@@ -2519,16 +2556,30 @@ def finalised(self):
25192556

25202557
@functools.cached_property
25212558
def sequence_length(self):
2559+
"""
2560+
The sequence length of the contig associated with sites used in the dataset.
2561+
If the dataset has a "sequence_length" attribute, this is always used, otherwise
2562+
if the dataset has recorded contig lengths, the appropriate length is taken,
2563+
otherwise the length is calculated from the maximum variant position plus one.
2564+
"""
25222565
try:
25232566
return self.data.attrs["sequence_length"]
25242567
except KeyError:
2525-
warnings.warn(
2526-
"`sequence_length` was not found as an attribute in the dataset, so"
2527-
" the largest position has been used. It can be set with"
2528-
" ds.attrs['sequence_length'] = 1337; ds.to_zarr('path/to/store',"
2529-
" mode='a')"
2530-
)
2531-
return int(np.max(self.data["variant_position"])) + 1
2568+
if self._contig_index is not None:
2569+
try:
2570+
if self._contig_index < len(self.data.contig_length):
2571+
return self.data.contig_length[self._contig_index]
2572+
except AttributeError:
2573+
pass # contig_length is optional, fall back to calculating length
2574+
return int(np.max(self.data["variant_position"])) + 1
2575+
2576+
@property
2577+
def contig_id(self):
2578+
"""
2579+
The contig ID (name) for all used sites, or None if no
2580+
contig IDs were provided
2581+
"""
2582+
return self._contig_id
25322583

25332584
@property
25342585
def num_sites(self):

0 commit comments

Comments
 (0)