diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index a353c4c0..838dee4f 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -36,7 +36,7 @@ from tsinfer import formats -def ts_to_dataset(ts, chunks=None, samples=None): +def ts_to_dataset(ts, chunks=None, samples=None, contigs=None): """ # From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63 Convert the specified tskit tree sequence into an sgkit dataset. @@ -61,7 +61,7 @@ def ts_to_dataset(ts, chunks=None, samples=None): genotypes = np.expand_dims(genotypes, axis=2) ds = sgkit.create_genotype_call_dataset( - variant_contig_names=["1"], + variant_contig_names=["1"] if contigs is None else contigs, variant_contig=np.zeros(len(tables.sites), dtype=int), variant_position=tables.sites.position.astype(int), variant_allele=alleles, @@ -287,6 +287,78 @@ def test_simulate_genotype_call_dataset(tmp_path): assert np.all(v.genotypes == sd_v) +class TestMultiContig: + def make_two_ts_dataset(self, path): + # split ts into 2; put them as different contigs in the same dataset + ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123) + ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123) + split_at_site = 7 + assert ts.num_sites > 10 + site_break = ts.site(split_at_site).position + ts1 = ts.keep_intervals([(0, site_break)]).rtrim() + ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim() + ds = ts_to_dataset(ts, contigs=["chr1", "chr2"]) + ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]}) + variant_contig = ds["variant_contig"][:] + variant_contig[split_at_site:] = 1 + ds.update({"variant_contig": variant_contig}) + variant_position = ds["variant_position"].values + variant_position[split_at_site:] -= int(site_break) + ds.update({"variant_position": ds["variant_position"]}) + ds.update( + {"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])} + ) + ds.to_zarr(path, mode="w") + return ts1, ts2 + + def test_unmasked(self, tmp_path): + self.make_two_ts_dataset(tmp_path) + with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'): + tsinfer.VariantData(tmp_path, "variant_ancestral_allele") + + def test_mask(self, tmp_path): + ts1, ts2 = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, + "variant_ancestral_allele", + site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]), + ) + assert np.all(ts2.sites_position == vdata.sites_position) + assert vdata.contig_id == "chr2" + + @pytest.mark.parametrize("contig_id", ["chr1", "chr2"]) + def test_contig_id_param(self, contig_id, tmp_path): + tree_seqs = {} + tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, "variant_ancestral_allele", contig_id=contig_id + ) + assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position) + assert vdata.contig_id == contig_id + + def test_contig_id_param_and_mask(self, tmp_path): + ts1, ts2 = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, + "variant_ancestral_allele", + site_mask=np.array( + (ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False] + ), + contig_id="chr2", + ) + assert np.all(ts2.sites_position[1:] == vdata.sites_position) + assert vdata.contig_id == "chr2" + + @pytest.mark.parametrize("contig_id", ["chr1", "chr2"]) + def test_contig_length(self, contig_id, tmp_path): + tree_seqs = {} + tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, "variant_ancestral_allele", contig_id=contig_id + ) + assert vdata.sequence_length == tree_seqs[contig_id].sequence_length + + @pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows") class TestSgkitMask: @pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []]) @@ -747,3 +819,42 @@ def test_unimplemented_from_tree_sequence(self): # Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924 with pytest.raises(NotImplementedError): tsinfer.VariantData.from_tree_sequence(None) + + def test_multiple_contigs(self, tmp_path): + path = tmp_path / "data.zarr" + ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) + ds["contig_id"] = ( + ds["contig_id"].dims, + np.array(["c10", "c11"], dtype=" 1: + raise ValueError(f'Multiple contigs named "{contig_id}"') + contig_index = contig_index[0] + site_mask = np.logical_or( + site_mask, self.data["variant_contig"][:] != contig_index + ) + # We negate the mask as it is much easier in numpy to have True=keep self.sites_select = ~site_mask.astype(bool) @@ -2369,6 +2384,21 @@ def __init__( " zarr dataset, indicating that all the genotypes are" " unphased" ) + self._contig_index = None + self._contig_id = None + contig = self.data.variant_contig[:][self.sites_select] + try: + self._contig_index = contig[0] + self._contig_id = self.data.contig_id[self._contig_index] + except (IndexError, AttributeError): + pass + if self._contig_index is not None and np.any(contig != self._contig_index): + ctigs = ", ".join(f'"{self.data.contig_id[c]}"' for c in np.unique(contig)) + raise ValueError( + f"Sites belong to multiple contigs ({ctigs}). Please restrict " + "sites to one contig e.g. via the `contig_id` argument." + ) + if np.any(np.diff(self.sites_position) <= 0): raise ValueError( "Values taken from the variant_position array are not strictly " @@ -2472,7 +2502,20 @@ def sequence_length(self): try: return self.data.attrs["sequence_length"] except KeyError: - return int(np.max(self.data["variant_position"])) + 1 + if self._contig_index is not None: + try: + return self.data.contig_length[self._contig_index] + except AttributeError: + pass + return int(np.max(self.data["variant_position"])) + 1 + + @property + def contig_id(self): + """ + The contig ID (name) for all used sites, or None if no + contig IDs were provided + """ + return self._contig_id @property def num_sites(self):