Skip to content

Commit dbdae8b

Browse files
committed
Deal with multiple contigs and sequence lengths
Introduces a `contig_id` parameter to variant_data, as described in #949. Fixes #249
1 parent f9de549 commit dbdae8b

File tree

2 files changed

+157
-3
lines changed

2 files changed

+157
-3
lines changed

tests/test_variantdata.py

+113-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from tsinfer import formats
3737

3838

39-
def ts_to_dataset(ts, chunks=None, samples=None):
39+
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
4040
"""
4141
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
4242
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -61,7 +61,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
6161
genotypes = np.expand_dims(genotypes, axis=2)
6262

6363
ds = sgkit.create_genotype_call_dataset(
64-
variant_contig_names=["1"],
64+
variant_contig_names=["1"] if contigs is None else contigs,
6565
variant_contig=np.zeros(len(tables.sites), dtype=int),
6666
variant_position=tables.sites.position.astype(int),
6767
variant_allele=alleles,
@@ -287,6 +287,78 @@ def test_simulate_genotype_call_dataset(tmp_path):
287287
assert np.all(v.genotypes == sd_v)
288288

289289

290+
class TestMultiContig:
291+
def make_two_ts_dataset(self, path):
292+
# split ts into 2; put them as different contigs in the same dataset
293+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
294+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
295+
split_at_site = 7
296+
assert ts.num_sites > 10
297+
site_break = ts.site(split_at_site).position
298+
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
299+
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
300+
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
301+
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
302+
variant_contig = ds["variant_contig"][:]
303+
variant_contig[split_at_site:] = 1
304+
ds.update({"variant_contig": variant_contig})
305+
variant_position = ds["variant_position"].values
306+
variant_position[split_at_site:] -= int(site_break)
307+
ds.update({"variant_position": ds["variant_position"]})
308+
ds.update(
309+
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
310+
)
311+
ds.to_zarr(path, mode="w")
312+
return ts1, ts2
313+
314+
def test_unmasked(self, tmp_path):
315+
self.make_two_ts_dataset(tmp_path)
316+
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
317+
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
318+
319+
def test_mask(self, tmp_path):
320+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
321+
vdata = tsinfer.VariantData(
322+
tmp_path,
323+
"variant_ancestral_allele",
324+
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
325+
)
326+
assert np.all(ts2.sites_position == vdata.sites_position)
327+
assert vdata.contig_id == "chr2"
328+
329+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
330+
def test_contig_id_param(self, contig_id, tmp_path):
331+
tree_seqs = {}
332+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
333+
vdata = tsinfer.VariantData(
334+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
335+
)
336+
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
337+
assert vdata.contig_id == contig_id
338+
339+
def test_contig_id_param_and_mask(self, tmp_path):
340+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
341+
vdata = tsinfer.VariantData(
342+
tmp_path,
343+
"variant_ancestral_allele",
344+
site_mask=np.array(
345+
(ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False]
346+
),
347+
contig_id="chr2",
348+
)
349+
assert np.all(ts2.sites_position[1:] == vdata.sites_position)
350+
assert vdata.contig_id == "chr2"
351+
352+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
353+
def test_contig_length(self, contig_id, tmp_path):
354+
tree_seqs = {}
355+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
356+
vdata = tsinfer.VariantData(
357+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
358+
)
359+
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length
360+
361+
290362
@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
291363
class TestSgkitMask:
292364
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
@@ -747,3 +819,42 @@ def test_unimplemented_from_tree_sequence(self):
747819
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
748820
with pytest.raises(NotImplementedError):
749821
tsinfer.VariantData.from_tree_sequence(None)
822+
823+
def test_multiple_contigs(self, tmp_path):
824+
path = tmp_path / "data.zarr"
825+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
826+
ds["contig_id"] = (
827+
ds["contig_id"].dims,
828+
np.array(["c10", "c11"], dtype="<U3"),
829+
)
830+
ds["variant_contig"] = (
831+
ds["variant_contig"].dims,
832+
np.array([0, 0, 1], dtype=ds["variant_contig"].dtype),
833+
)
834+
sgkit.save_dataset(ds, path)
835+
with pytest.raises(
836+
ValueError, match=r'Sites belong to multiple contigs \("c10", "c11"\)'
837+
):
838+
tsinfer.VariantData(path, ds["variant_allele"][:, 0].astype(str))
839+
840+
def test_bad_contig_param(self, tmp_path):
841+
path = tmp_path / "data.zarr"
842+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
843+
sgkit.save_dataset(ds, path)
844+
with pytest.raises(ValueError, match='"XX" not found'):
845+
tsinfer.VariantData(
846+
path, ds["variant_allele"][:, 0].astype(str), contig_id="XX"
847+
)
848+
849+
def test_multiple_contig_param(self, tmp_path):
850+
path = tmp_path / "data.zarr"
851+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
852+
ds["contig_id"] = (
853+
ds["contig_id"].dims,
854+
np.array(["chr1", "chr1"], dtype="<U4"),
855+
)
856+
sgkit.save_dataset(ds, path)
857+
with pytest.raises(ValueError, match='Multiple contigs named "chr1"'):
858+
tsinfer.VariantData(
859+
path, ds["variant_allele"][:, 0].astype(str), contig_id="chr1"
860+
)

tsinfer/formats.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -2305,6 +2305,7 @@ def __init__(
23052305
sample_mask=None,
23062306
site_mask=None,
23072307
sites_time=None,
2308+
contig_id=None,
23082309
):
23092310
try:
23102311
if len(path_or_zarr.call_genotype.shape) == 3:
@@ -2338,6 +2339,20 @@ def __init__(
23382339
raise ValueError(
23392340
"Site mask array must be the same length as the number of unmasked sites"
23402341
)
2342+
if contig_id is not None:
2343+
contig_index = np.where(self.data.contig_id[:] == contig_id)[0]
2344+
if len(contig_index) == 0:
2345+
raise ValueError(
2346+
f'"{contig_id}" not found among the available contig IDs: '
2347+
+ ",".join(f"{n}" for n in self.data.contig_id[:])
2348+
)
2349+
elif len(contig_index) > 1:
2350+
raise ValueError(f'Multiple contigs named "{contig_id}"')
2351+
contig_index = contig_index[0]
2352+
site_mask = np.logical_or(
2353+
site_mask, self.data["variant_contig"][:] != contig_index
2354+
)
2355+
23412356
# We negate the mask as it is much easier in numpy to have True=keep
23422357
self.sites_select = ~site_mask.astype(bool)
23432358

@@ -2369,6 +2384,21 @@ def __init__(
23692384
" zarr dataset, indicating that all the genotypes are"
23702385
" unphased"
23712386
)
2387+
self._contig_index = None
2388+
self._contig_id = None
2389+
contig = self.data.variant_contig[:][self.sites_select]
2390+
try:
2391+
self._contig_index = contig[0]
2392+
self._contig_id = self.data.contig_id[self._contig_index]
2393+
except (IndexError, AttributeError):
2394+
pass
2395+
if self._contig_index is not None and np.any(contig != self._contig_index):
2396+
ctigs = ", ".join(f'"{self.data.contig_id[c]}"' for c in np.unique(contig))
2397+
raise ValueError(
2398+
f"Sites belong to multiple contigs ({ctigs}). Please restrict "
2399+
"sites to one contig e.g. via the `contig_id` argument."
2400+
)
2401+
23722402
if np.any(np.diff(self.sites_position) <= 0):
23732403
raise ValueError(
23742404
"Values taken from the variant_position array are not strictly "
@@ -2472,7 +2502,20 @@ def sequence_length(self):
24722502
try:
24732503
return self.data.attrs["sequence_length"]
24742504
except KeyError:
2475-
return int(np.max(self.data["variant_position"])) + 1
2505+
if self._contig_index is not None:
2506+
try:
2507+
return self.data.contig_length[self._contig_index]
2508+
except AttributeError:
2509+
pass
2510+
return int(np.max(self.data["variant_position"])) + 1
2511+
2512+
@property
2513+
def contig_id(self):
2514+
"""
2515+
The contig ID (name) for all used sites, or None if no
2516+
contig IDs were provided
2517+
"""
2518+
return self._contig_id
24762519

24772520
@property
24782521
def num_sites(self):

0 commit comments

Comments
 (0)