Skip to content

Commit ba629de

Browse files
committed
Deal with multiple contigs and sequence lengths
Introduces a `contig_id` parameter to variant_data, as described in #949. Fixes #249
1 parent a080e8f commit ba629de

File tree

2 files changed

+157
-3
lines changed

2 files changed

+157
-3
lines changed

tests/test_variantdata.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from tsinfer import formats
3737

3838

39-
def ts_to_dataset(ts, chunks=None, samples=None):
39+
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
4040
"""
4141
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
4242
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -61,7 +61,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
6161
genotypes = np.expand_dims(genotypes, axis=2)
6262

6363
ds = sgkit.create_genotype_call_dataset(
64-
variant_contig_names=["1"],
64+
variant_contig_names=["1"] if contigs is None else contigs,
6565
variant_contig=np.zeros(len(tables.sites), dtype=int),
6666
variant_position=tables.sites.position.astype(int),
6767
variant_allele=alleles,
@@ -292,6 +292,78 @@ def test_simulate_genotype_call_dataset(tmp_path):
292292
assert np.all(v.genotypes == sd_v)
293293

294294

295+
class TestMultiContig:
296+
def make_two_ts_dataset(self, path):
297+
# split ts into 2; put them as different contigs in the same dataset
298+
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
299+
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
300+
split_at_site = 7
301+
assert ts.num_sites > 10
302+
site_break = ts.site(split_at_site).position
303+
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
304+
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
305+
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
306+
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
307+
variant_contig = ds["variant_contig"][:]
308+
variant_contig[split_at_site:] = 1
309+
ds.update({"variant_contig": variant_contig})
310+
variant_position = ds["variant_position"].values
311+
variant_position[split_at_site:] -= int(site_break)
312+
ds.update({"variant_position": ds["variant_position"]})
313+
ds.update(
314+
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
315+
)
316+
ds.to_zarr(path, mode="w")
317+
return ts1, ts2
318+
319+
def test_unmasked(self, tmp_path):
320+
self.make_two_ts_dataset(tmp_path)
321+
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
322+
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
323+
324+
def test_mask(self, tmp_path):
325+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
326+
vdata = tsinfer.VariantData(
327+
tmp_path,
328+
"variant_ancestral_allele",
329+
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
330+
)
331+
assert np.all(ts2.sites_position == vdata.sites_position)
332+
assert vdata.contig_id == "chr2"
333+
334+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
335+
def test_contig_id_param(self, contig_id, tmp_path):
336+
tree_seqs = {}
337+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
338+
vdata = tsinfer.VariantData(
339+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
340+
)
341+
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
342+
assert vdata.contig_id == contig_id
343+
344+
def test_contig_id_param_and_mask(self, tmp_path):
345+
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
346+
vdata = tsinfer.VariantData(
347+
tmp_path,
348+
"variant_ancestral_allele",
349+
site_mask=np.array(
350+
(ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False]
351+
),
352+
contig_id="chr2",
353+
)
354+
assert np.all(ts2.sites_position[1:] == vdata.sites_position)
355+
assert vdata.contig_id == "chr2"
356+
357+
@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
358+
def test_contig_length(self, contig_id, tmp_path):
359+
tree_seqs = {}
360+
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
361+
vdata = tsinfer.VariantData(
362+
tmp_path, "variant_ancestral_allele", contig_id=contig_id
363+
)
364+
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length
365+
366+
295367
@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
296368
class TestSgkitMask:
297369
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
@@ -754,3 +826,42 @@ def test_empty_alleles_not_at_end(self, tmp_path):
754826
samples = tsinfer.VariantData(path, "variant_ancestral_allele")
755827
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
756828
tsinfer.infer(samples)
829+
830+
def test_multiple_contigs(self, tmp_path):
831+
path = tmp_path / "data.zarr"
832+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
833+
ds["contig_id"] = (
834+
ds["contig_id"].dims,
835+
np.array(["c10", "c11"], dtype="<U3"),
836+
)
837+
ds["variant_contig"] = (
838+
ds["variant_contig"].dims,
839+
np.array([0, 0, 1], dtype=ds["variant_contig"].dtype),
840+
)
841+
sgkit.save_dataset(ds, path)
842+
with pytest.raises(
843+
ValueError, match=r'Sites belong to multiple contigs \("c10", "c11"\)'
844+
):
845+
tsinfer.VariantData(path, ds["variant_allele"][:, 0].astype(str))
846+
847+
def test_bad_contig_param(self, tmp_path):
848+
path = tmp_path / "data.zarr"
849+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
850+
sgkit.save_dataset(ds, path)
851+
with pytest.raises(ValueError, match='"XX" not found'):
852+
tsinfer.VariantData(
853+
path, ds["variant_allele"][:, 0].astype(str), contig_id="XX"
854+
)
855+
856+
def test_multiple_contig_param(self, tmp_path):
857+
path = tmp_path / "data.zarr"
858+
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
859+
ds["contig_id"] = (
860+
ds["contig_id"].dims,
861+
np.array(["chr1", "chr1"], dtype="<U4"),
862+
)
863+
sgkit.save_dataset(ds, path)
864+
with pytest.raises(ValueError, match='Multiple contigs named "chr1"'):
865+
tsinfer.VariantData(
866+
path, ds["variant_allele"][:, 0].astype(str), contig_id="chr1"
867+
)

tsinfer/formats.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2305,6 +2305,7 @@ def __init__(
23052305
sample_mask=None,
23062306
site_mask=None,
23072307
sites_time=None,
2308+
contig_id=None,
23082309
):
23092310
self.path = path
23102311
self.data = zarr.open(path, mode="r")
@@ -2326,6 +2327,20 @@ def __init__(
23262327
raise ValueError(
23272328
"Site mask array must be the same length as the number of unmasked sites"
23282329
)
2330+
if contig_id is not None:
2331+
contig_index = np.where(self.data.contig_id[:] == contig_id)[0]
2332+
if len(contig_index) == 0:
2333+
raise ValueError(
2334+
f'"{contig_id}" not found among the available contig IDs: '
2335+
+ ",".join(f"{n}" for n in self.data.contig_id[:])
2336+
)
2337+
elif len(contig_index) > 1:
2338+
raise ValueError(f'Multiple contigs named "{contig_id}"')
2339+
contig_index = contig_index[0]
2340+
site_mask = np.logical_or(
2341+
site_mask, self.data["variant_contig"][:] != contig_index
2342+
)
2343+
23292344
# We negate the mask as it is much easier in numpy to have True=keep
23302345
self.sites_select = ~site_mask.astype(bool)
23312346

@@ -2357,6 +2372,21 @@ def __init__(
23572372
" sgkit dataset, indicating that all the genotypes are"
23582373
" unphased"
23592374
)
2375+
self._contig_index = None
2376+
self._contig_id = None
2377+
contig = self.data.variant_contig[:][self.sites_select]
2378+
try:
2379+
self._contig_index = contig[0]
2380+
self._contig_id = self.data.contig_id[self._contig_index]
2381+
except (IndexError, AttributeError):
2382+
pass
2383+
if self._contig_index is not None and np.any(contig != self._contig_index):
2384+
ctigs = ", ".join(f'"{self.data.contig_id[c]}"' for c in np.unique(contig))
2385+
raise ValueError(
2386+
f"Sites belong to multiple contigs ({ctigs}). Please restrict "
2387+
"sites to one contig e.g. via the `contig_id` argument."
2388+
)
2389+
23602390
if np.any(np.diff(self.sites_position) <= 0):
23612391
raise ValueError(
23622392
"Values taken from the variant_position array are not strictly "
@@ -2460,7 +2490,20 @@ def sequence_length(self):
24602490
try:
24612491
return self.data.attrs["sequence_length"]
24622492
except KeyError:
2463-
return int(np.max(self.data["variant_position"])) + 1
2493+
if self._contig_index is not None:
2494+
try:
2495+
return self.data.contig_length[self._contig_index]
2496+
except AttributeError:
2497+
pass
2498+
return int(np.max(self.data["variant_position"])) + 1
2499+
2500+
@property
2501+
def contig_id(self):
2502+
"""
2503+
The contig ID (name) for all used sites, or None if no
2504+
contig IDs were provided
2505+
"""
2506+
return self._contig_id
24642507

24652508
@property
24662509
def num_sites(self):

0 commit comments

Comments
 (0)