Skip to content

Commit d5dd3f4

Browse files
hyanwongbenjeffery
authored andcommitted
Deal with multiple contigs and sequence lengths
1 parent 1aa0233 commit d5dd3f4

File tree

5 files changed

+201
-34
lines changed

5 files changed

+201
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
- If a mismatch ratio is provided to the `infer` command, it only applies during the
1313
`match_samples` phase ({issue}`980`, {pr}`981`, {user}`hyanwong`)
1414

15+
- Get the `sequence_length` of the contig associated with the unmasked sites,
16+
if contig lengths are provided ({pr}`964`, {user}`hyanwong`, {user}`benjeffery`)
17+
1518
**Fixes**
1619

1720
- Properly account for "N" as an unknown ancestral state, and ban "" from being

docs/usage.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ onto branches by {meth}`parsimony<tskit.Tree.map_mutations>`.
107107
It is also possible to *completely* exclude sites and samples, by specifing a boolean
108108
`site_mask` and/or a `sample_mask` when creating the `VariantData` object. Sites or samples with
109109
a mask value of `True` will be completely omitted both from inference and the final tree sequence.
110-
This can be useful, for example, if your VCF file contains multiple chromosomes (in which case
111-
`tsinfer` will need to be run separately on each chromosome) or if you wish to select only a subset
112-
of the chromosome for inference (e.g. to reduce computational load). If a `site_mask` is provided,
113-
note that the ancestral alleles array only specifies alleles for the unmasked sites.
110+
This can be useful, for example, if you wish to select only a subset of the chromosome for
111+
inference, e.g. to reduce computational load. You can also use it to subset inference to a
112+
particular contig, if your dataset contains multiple contigs. Note that if a `site_mask` is provided,
113+
the ancestral states array should only specify alleles for the unmasked sites.
114114

115115
Below, for instance, is an example of including only sites up to position six in the contig
116116
labelled "chr1" in the `example_data.vcz` file:

tests/test_variantdata.py

Lines changed: 142 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from tsinfer import formats
3939

4040

41-
def ts_to_dataset(ts, chunks=None, samples=None):
41+
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
4242
"""
4343
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
4444
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
6363
genotypes = np.expand_dims(genotypes, axis=2)
6464

6565
ds = sgkit.create_genotype_call_dataset(
66-
variant_contig_names=["1"],
66+
variant_contig_names=["1"] if contigs is None else contigs,
6767
variant_contig=np.zeros(len(tables.sites), dtype=int),
6868
variant_position=tables.sites.position.astype(int),
6969
variant_allele=alleles,
@@ -145,7 +145,7 @@ def test_variantdata_accessors(tmp_path, in_mem):
145145
assert vd.format_name == "tsinfer-variant-data"
146146
assert vd.format_version == (0, 1)
147147
assert vd.finalised
148-
assert vd.sequence_length == ts.sequence_length + 1337
148+
assert vd.sequence_length == ts.sequence_length
149149
assert vd.num_sites == ts.num_sites
150150
assert vd.sites_metadata_schema == ts.tables.sites.metadata_schema.schema
151151
assert vd.sites_metadata == [site.metadata for site in ts.sites()]
@@ -218,11 +218,7 @@ def test_variantdata_accessors_defaults(tmp_path, in_mem):
218218
ds = data if in_mem else sgkit.load_dataset(data)
219219

220220
default_schema = tskit.MetadataSchema.permissive_json().schema
221-
with pytest.warns(
222-
UserWarning,
223-
match="`sequence_length` was not found as an attribute in the dataset",
224-
):
225-
assert vdata.sequence_length == ts.sequence_length
221+
assert vdata.sequence_length == ts.sequence_length
226222
assert vdata.sites_metadata_schema == default_schema
227223
assert vdata.sites_metadata == [{} for _ in range(ts.num_sites)]
228224
for time in vdata.sites_time:
@@ -299,18 +295,116 @@ def test_simulate_genotype_call_dataset(tmp_path):
299295
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
300296
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
301297
ds = ts_to_dataset(ts)
302-
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
303298
ds.to_zarr(tmp_path, mode="w")
304-
sd = tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
305-
ts = tsinfer.infer(sd)
306-
for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes):
299+
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
300+
ts = tsinfer.infer(vdata)
301+
for v, ds_v, vd_v in zip(ts.variants(), ds.call_genotype, vdata.sites_genotypes):
307302
assert np.all(v.genotypes == ds_v.values.flatten())
308-
assert np.all(v.genotypes == sd_v)
303+
assert np.all(v.genotypes == vd_v)
304+
305+
306+
def test_simulate_genotype_call_dataset_length(tmp_path):
307+
# create_genotype_call_dataset does not save contig lengths
308+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
309+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
310+
ds = ts_to_dataset(ts)
311+
assert "contig_length" not in ds
312+
ds.to_zarr(tmp_path, mode="w")
313+
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
314+
assert vdata.sequence_length == ts.sites_position[-1] + 1
315+
316+
vdata = tsinfer.VariantData(
317+
tmp_path, ds["variant_allele"][:, 0].values.astype(str), sequence_length=1337
318+
)
319+
assert vdata.sequence_length == 1337
320+
321+
322+
class TestMultiContig:
323+
def make_two_ts_dataset(self, path):
324+
# split ts into 2; put them as different contigs in the same dataset
325+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
326+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
327+
split_at_site = 7
328+
assert ts.num_sites > 10
329+
site_break = ts.site(split_at_site).position
330+
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
331+
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
332+
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
333+
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
334+
variant_contig = ds["variant_contig"][:]
335+
variant_contig[split_at_site:] = 1
336+
ds.update({"variant_contig": variant_contig})
337+
variant_position = ds["variant_position"].values
338+
variant_position[split_at_site:] -= int(site_break)
339+
ds.update({"variant_position": ds["variant_position"]})
340+
ds.update(
341+
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
342+
)
343+
ds.to_zarr(path, mode="w")
344+
return ts1, ts2
345+
346+
def test_unmasked(self, tmp_path):
347+
self.make_two_ts_dataset(tmp_path)
348+
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
349+
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
350+
351+
def test_mask(self, tmp_path):
352+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
353+
vdata = tsinfer.VariantData(
354+
tmp_path,
355+
"variant_ancestral_allele",
356+
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
357+
)
358+
assert np.all(ts2.sites_position == vdata.sites_position)
359+
assert vdata.contig_id == "chr2"
360+
assert vdata.sequence_length == ts2.sequence_length
361+
362+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
363+
def test_multi_contig(self, contig_id, tmp_path):
364+
tree_seqs = {}
365+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
366+
with pytest.raises(ValueError, match="multiple contigs"):
367+
vdata = tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
368+
root = zarr.open(tmp_path)
369+
mask = root["variant_contig"][:] == (1 if contig_id == "chr1" else 0)
370+
vdata = tsinfer.VariantData(
371+
tmp_path, "variant_ancestral_allele", site_mask=mask
372+
)
373+
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
374+
assert vdata.contig_id == contig_id
375+
assert vdata._contig_index == (0 if contig_id == "chr1" else 1)
376+
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length
377+
378+
def test_mixed_contigs_error(self, tmp_path):
379+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
380+
mask = np.ones(ts1.num_sites + ts2.num_sites)
381+
# Select two varaints, one from each contig
382+
mask[0] = False
383+
mask[-1] = False
384+
with pytest.raises(ValueError, match="multiple contigs"):
385+
tsinfer.VariantData(
386+
tmp_path,
387+
"variant_ancestral_allele",
388+
site_mask=mask,
389+
)
390+
391+
def test_no_variant_contig(self, tmp_path):
392+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
393+
root = zarr.open(tmp_path)
394+
del root["variant_contig"]
395+
mask = np.ones(ts1.num_sites + ts2.num_sites)
396+
mask[0] = False
397+
vdata = tsinfer.VariantData(
398+
tmp_path, "variant_ancestral_allele", site_mask=mask
399+
)
400+
assert vdata.sequence_length == ts1.sites_position[0] + 1
401+
assert vdata.contig_id is None
402+
assert vdata._contig_index is None
309403

310404

311405
@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
312406
class TestSgkitMask:
313-
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
407+
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0]])
314408
def test_sgkit_variant_mask(self, tmp_path, sites):
315409
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
316410
ds = sgkit.load_dataset(zarr_path)
@@ -823,6 +917,20 @@ def test_bad_ancestral_state(self, tmp_path):
823917
with pytest.raises(ValueError, match="cannot contain empty strings"):
824918
tsinfer.VariantData(path, ancestral_state)
825919

920+
def test_ancestral_state_len_not_same_as_mask(self, tmp_path):
921+
path = tmp_path / "data.zarr"
922+
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
923+
sgkit.save_dataset(ds, path)
924+
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
925+
site_mask = np.zeros(ds.sizes["variants"], dtype=bool)
926+
site_mask[0] = True
927+
with pytest.raises(
928+
ValueError,
929+
match="Ancestral state array must be the same length as the number of"
930+
" selected sites",
931+
):
932+
tsinfer.VariantData(path, ancestral_state, site_mask=site_mask)
933+
826934
def test_empty_alleles_not_at_end(self, tmp_path):
827935
path = tmp_path / "data.zarr"
828936
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
@@ -854,3 +962,23 @@ def test_unimplemented_from_tree_sequence(self):
854962
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
855963
with pytest.raises(NotImplementedError):
856964
tsinfer.VariantData.from_tree_sequence(None)
965+
966+
def test_all_masked(self, tmp_path):
967+
path = tmp_path / "data.zarr"
968+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
969+
sgkit.save_dataset(ds, path)
970+
with pytest.raises(ValueError, match="All sites have been masked out"):
971+
tsinfer.VariantData(
972+
path, ds["variant_allele"][:, 0].astype(str), site_mask=np.ones(3, bool)
973+
)
974+
975+
def test_missing_sites_time(self, tmp_path):
976+
path = tmp_path / "data.zarr"
977+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
978+
sgkit.save_dataset(ds, path)
979+
with pytest.raises(
980+
ValueError, match="The sites time array XX was not found in the dataset"
981+
):
982+
tsinfer.VariantData(
983+
path, ds["variant_allele"][:, 0].astype(str), sites_time="XX"
984+
)

tests/tsutil.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,6 @@ def _make_ts_and_zarr(path, add_optional=False, shuffle_alleles=True):
339339
)
340340

341341
if add_optional:
342-
add_attribute_to_dataset(
343-
"sequence_length",
344-
ts.sequence_length + 1337,
345-
path / "data.zarr",
346-
)
347342
sites_md = tables.sites.metadata
348343
sites_md_offset = tables.sites.metadata_offset
349344
add_array_to_dataset(

tsinfer/formats.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,6 +2338,11 @@ class VariantData(SampleData):
23382338
reasonable approximation to the relative order of ancestors used for inference.
23392339
Time values are ignored for sites not used in inference, such as singletons,
23402340
sites with more than two alleles, or sites with an unknown ancestral state.
2341+
:param int sequence_length: An integer specifying the resulting `sequence_length`
2342+
attribute of the output tree sequence. If not specified the `contig_length`
2343+
attribute from the undelying zarr store for the contig of the selected variants.
2344+
is used. If that is not present then the maximum position plus one of the used
2345+
variants is used.
23412346
"""
23422347

23432348
FORMAT_NAME = "tsinfer-variant-data"
@@ -2351,7 +2356,11 @@ def __init__(
23512356
sample_mask=None,
23522357
site_mask=None,
23532358
sites_time=None,
2359+
sequence_length=None,
23542360
):
2361+
self._sequence_length = sequence_length
2362+
self._contig_index = None
2363+
self._contig_id = None
23552364
try:
23562365
if len(path_or_zarr.call_genotype.shape) == 3:
23572366
# Assumed to be a VCF Zarr hierarchy
@@ -2384,9 +2393,16 @@ def __init__(
23842393
raise ValueError(
23852394
"Site mask array must be the same length as the number of unmasked sites"
23862395
)
2396+
23872397
# We negate the mask as it is much easier in numpy to have True=keep
23882398
self.sites_select = ~site_mask.astype(bool)
23892399

2400+
if np.sum(self.sites_select) == 0:
2401+
raise ValueError(
2402+
"All sites have been masked out, at least one value"
2403+
"must be 'False' in the site mask"
2404+
)
2405+
23902406
if sample_mask is None:
23912407
sample_mask = np.full(self._num_individuals_before_mask, False, dtype=bool)
23922408
elif isinstance(sample_mask, np.ndarray):
@@ -2415,6 +2431,22 @@ def __init__(
24152431
" zarr dataset, indicating that all the genotypes are"
24162432
" unphased"
24172433
)
2434+
2435+
if "variant_contig" in self.data:
2436+
used_contigs = self.data.variant_contig[:][self.sites_select]
2437+
self._contig_index = used_contigs[0]
2438+
self._contig_id = self.data.contig_id[self._contig_index]
2439+
2440+
if np.any(used_contigs != self._contig_index):
2441+
contig_names = ", ".join(
2442+
f'"{self.data.contig_id[c]}"' for c in np.unique(used_contigs)
2443+
)
2444+
raise ValueError(
2445+
f"Sites belong to multiple contigs ({contig_names}). "
2446+
"Please restrict sites to one contig using the sites_mask argument."
2447+
"e.g. `mask=zarr_group['variant_contig'] != wanted_index`"
2448+
)
2449+
24182450
if np.any(np.diff(self.sites_position) <= 0):
24192451
raise ValueError(
24202452
"Values taken from the variant_position array are not strictly "
@@ -2436,7 +2468,7 @@ def __init__(
24362468
self._sites_time = self.data[sites_time][:][self.sites_select]
24372469
except KeyError:
24382470
raise ValueError(
2439-
f"The sites time {sites_time} was not found" f" in the dataset."
2471+
f"The sites time array {sites_time} was not found in the dataset"
24402472
)
24412473

24422474
if isinstance(ancestral_state, np.ndarray):
@@ -2519,16 +2551,25 @@ def finalised(self):
25192551

25202552
@functools.cached_property
25212553
def sequence_length(self):
2522-
try:
2523-
return self.data.attrs["sequence_length"]
2524-
except KeyError:
2525-
warnings.warn(
2526-
"`sequence_length` was not found as an attribute in the dataset, so"
2527-
" the largest position has been used. It can be set with"
2528-
" ds.attrs['sequence_length'] = 1337; ds.to_zarr('path/to/store',"
2529-
" mode='a')"
2530-
)
2531-
return int(np.max(self.data["variant_position"])) + 1
2554+
"""
2555+
The sequence length of the contig associated with sites used in the dataset.
2556+
If set manually then that value is used else if the dataset has recorded
2557+
contig lengths use that else the length is calculated from the maximum
2558+
variant position plus one.
2559+
"""
2560+
if self._sequence_length is not None:
2561+
return self._sequence_length
2562+
if self._contig_index is not None and "contig_length" in self.data:
2563+
return self.data.contig_length[self._contig_index]
2564+
return int(np.max(self.sites_position)) + 1
2565+
2566+
@property
2567+
def contig_id(self):
2568+
"""
2569+
The contig ID (name) for all used sites, or None if no
2570+
contig IDs were present in the zarr dataset
2571+
"""
2572+
return self._contig_id
25322573

25332574
@property
25342575
def num_sites(self):

0 commit comments

Comments
 (0)