Skip to content

Commit 15267b7

Browse files
committed
Add mask names
1 parent 444e4fa commit 15267b7

File tree

2 files changed

+56
-21
lines changed

2 files changed

+56
-21
lines changed

tests/test_sgkit.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,8 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
438438
sites_mask = np.ones_like(ds["variant_position"], dtype=bool)
439439
for i in sites:
440440
sites_mask[i] = False
441-
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
442-
samples = tsinfer.SgkitSampleData(zarr_path)
441+
add_array_to_dataset("variant_mask_42", sites_mask, zarr_path)
442+
samples = tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_42")
443443
assert samples.num_sites == len(sites)
444444
assert np.array_equal(samples.sites_mask, ~sites_mask)
445445
assert np.array_equal(
@@ -465,12 +465,13 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path):
465465
ts, zarr_path = make_ts_and_zarr(tmp_path)
466466
ds = sgkit.load_dataset(zarr_path)
467467
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
468-
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
468+
add_array_to_dataset("variant_mask_foobar", sites_mask, zarr_path)
469+
tsinfer.SgkitSampleData(zarr_path)
469470
with pytest.raises(
470471
ValueError,
471472
match="Mask must be the same length as the number of unmasked sites",
472473
):
473-
tsinfer.SgkitSampleData(zarr_path)
474+
tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_foobar")
474475

475476
def test_bad_mask_length_at_iterator(self, tmp_path):
476477
ts, zarr_path = make_ts_and_zarr(tmp_path)
@@ -491,8 +492,10 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list):
491492
samples_mask = np.ones_like(ds["sample_id"], dtype=bool)
492493
for i in sample_list:
493494
samples_mask[i] = False
494-
add_array_to_dataset("samples_mask", samples_mask, zarr_path)
495-
samples = tsinfer.SgkitSampleData(zarr_path)
495+
add_array_to_dataset("samples_mask_69", samples_mask, zarr_path)
496+
samples = tsinfer.SgkitSampleData(
497+
zarr_path, sgkit_samples_mask_name="samples_mask_69"
498+
)
496499
assert samples.ploidy == 3
497500
assert samples.num_individuals == len(sample_list)
498501
assert samples.num_samples == len(sample_list) * samples.ploidy
@@ -525,6 +528,24 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list):
525528
assert id == i
526529
assert np.array_equal(haplo, expected_gt[:, i])
527530

531+
def test_sgkit_missing_masks(self, tmp_path):
532+
ts, zarr_path = make_ts_and_zarr(tmp_path)
533+
samples = tsinfer.SgkitSampleData(zarr_path)
534+
samples.individuals_mask
535+
samples.sites_mask
536+
with pytest.raises(
537+
ValueError, match="The sites mask foobar was not found in the dataset."
538+
):
539+
tsinfer.SgkitSampleData(zarr_path, sites_mask_name="foobar")
540+
with pytest.raises(
541+
ValueError,
542+
match="The sgkit samples mask foobar2 was not found in the dataset.",
543+
):
544+
samples = tsinfer.SgkitSampleData(
545+
zarr_path, sgkit_samples_mask_name="foobar2"
546+
)
547+
samples.individuals_mask
548+
528549

529550
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
530551
def test_sgkit_ancestral_allele_same_ancestors(tmp_path):

tsinfer/formats.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -2295,9 +2295,11 @@ class SgkitSampleData(SampleData):
22952295
FORMAT_NAME = "tsinfer-sgkit-sample-data"
22962296
FORMAT_VERSION = (0, 1)
22972297

2298-
def __init__(self, path):
2298+
def __init__(self, path, sgkit_samples_mask_name=None, sites_mask_name=None):
22992299
self.path = path
23002300
self.data = zarr.open(path, mode="r")
2301+
self._sgkit_samples_mask_name = sgkit_samples_mask_name
2302+
self._sites_mask_name = sites_mask_name
23012303
genotypes_arr = self.data["call_genotype"]
23022304
_, self._num_unmasked_individuals, self.ploidy = genotypes_arr.shape
23032305
self._num_sites = np.sum(self.sites_mask)
@@ -2337,11 +2339,17 @@ def num_sites(self):
23372339

23382340
@functools.cached_property
23392341
def individuals_mask(self):
2340-
try:
2341-
# We negate the mask as it is much easier in numpy to have True=keep
2342-
return ~(self.data["samples_mask"][:].astype(bool))
2343-
except KeyError:
2342+
if self._sgkit_samples_mask_name is None:
23442343
return np.full(self._num_unmasked_individuals, True, dtype=bool)
2344+
else:
2345+
try:
2346+
# We negate the mask as it is much easier in numpy to have True=keep
2347+
return ~(self.data[self._sgkit_samples_mask_name][:].astype(bool))
2348+
except KeyError:
2349+
raise ValueError(
2350+
f"The sgkit samples mask {self._sgkit_samples_mask_name} was not"
2351+
f" found in the dataset."
2352+
)
23452353

23462354
@functools.cached_property
23472355
def samples_mask(self):
@@ -2386,18 +2394,24 @@ def sites_alleles(self):
23862394

23872395
@functools.cached_property
23882396
def sites_mask(self):
2389-
try:
2390-
if (
2391-
self.data["variant_mask"].shape[0]
2392-
!= self.data["variant_position"].shape[0]
2393-
):
2397+
if self._sites_mask_name is None:
2398+
return np.full(self.data["variant_position"].shape, True, dtype=bool)
2399+
else:
2400+
try:
2401+
if (
2402+
self.data[self._sites_mask_name].shape[0]
2403+
!= self.data["variant_position"].shape[0]
2404+
):
2405+
raise ValueError(
2406+
"Mask must be the same length as the number of unmasked sites"
2407+
)
2408+
# We negate the mask as it is much easier in numpy to have True=keep
2409+
return ~(self.data[self._sites_mask_name].astype(bool)[:])
2410+
except KeyError:
23942411
raise ValueError(
2395-
"Mask must be the same length as the number of unmasked sites"
2412+
f"The sites mask {self._sites_mask_name} was not found"
2413+
f" in the dataset."
23962414
)
2397-
# We negate the mask as it is much easier in numpy to have True=keep
2398-
return ~(self.data["variant_mask"].astype(bool)[:])
2399-
except KeyError:
2400-
return np.full(self.data["variant_position"].shape, True, dtype=bool)
24012415

24022416
@functools.cached_property
24032417
def sites_ancestral_allele(self):

0 commit comments

Comments
 (0)