Skip to content

Commit 444e4fa

Browse files
author
Ben Jeffery
committed
Flip sgkit mask polarity
1 parent e529ac0 commit 444e4fa

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

tests/test_sgkit.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -435,25 +435,25 @@ class TestSgkitMask:
435435
def test_sgkit_variant_mask(self, tmp_path, sites):
436436
ts, zarr_path = make_ts_and_zarr(tmp_path)
437437
ds = sgkit.load_dataset(zarr_path)
438-
sites_mask = np.zeros_like(ds["variant_position"], dtype=bool)
438+
sites_mask = np.ones_like(ds["variant_position"], dtype=bool)
439439
for i in sites:
440-
sites_mask[i] = True
440+
sites_mask[i] = False
441441
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
442442
samples = tsinfer.SgkitSampleData(zarr_path)
443443
assert samples.num_sites == len(sites)
444-
assert np.array_equal(samples.sites_mask, sites_mask)
444+
assert np.array_equal(samples.sites_mask, ~sites_mask)
445445
assert np.array_equal(
446-
samples.sites_position, ts.tables.sites.position[sites_mask]
446+
samples.sites_position, ts.tables.sites.position[~sites_mask]
447447
)
448448
inf_ts = tsinfer.infer(samples)
449449
assert np.array_equal(
450-
ts.genotype_matrix()[sites_mask], inf_ts.genotype_matrix()
450+
ts.genotype_matrix()[~sites_mask], inf_ts.genotype_matrix()
451451
)
452452
assert np.array_equal(
453-
ts.tables.sites.position[sites_mask], inf_ts.tables.sites.position
453+
ts.tables.sites.position[~sites_mask], inf_ts.tables.sites.position
454454
)
455455
assert np.array_equal(
456-
ts.tables.sites.ancestral_state[sites_mask],
456+
ts.tables.sites.ancestral_state[~sites_mask],
457457
inf_ts.tables.sites.ancestral_state,
458458
)
459459
# TODO - site metadata needs merging not replacing
@@ -464,7 +464,7 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
464464
def test_sgkit_variant_bad_mask_length(self, tmp_path):
465465
ts, zarr_path = make_ts_and_zarr(tmp_path)
466466
ds = sgkit.load_dataset(zarr_path)
467-
sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int)
467+
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
468468
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
469469
with pytest.raises(
470470
ValueError,
@@ -475,7 +475,7 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path):
475475
def test_bad_mask_length_at_iterator(self, tmp_path):
476476
ts, zarr_path = make_ts_and_zarr(tmp_path)
477477
ds = sgkit.load_dataset(zarr_path)
478-
sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int)
478+
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
479479
from tsinfer.formats import chunk_iterator
480480

481481
with pytest.raises(
@@ -488,33 +488,33 @@ def test_bad_mask_length_at_iterator(self, tmp_path):
488488
def test_sgkit_sample_mask(self, tmp_path, sample_list):
489489
ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True)
490490
ds = sgkit.load_dataset(zarr_path)
491-
samples_mask = np.zeros_like(ds["sample_id"], dtype=bool)
491+
samples_mask = np.ones_like(ds["sample_id"], dtype=bool)
492492
for i in sample_list:
493-
samples_mask[i] = True
493+
samples_mask[i] = False
494494
add_array_to_dataset("samples_mask", samples_mask, zarr_path)
495495
samples = tsinfer.SgkitSampleData(zarr_path)
496496
assert samples.ploidy == 3
497497
assert samples.num_individuals == len(sample_list)
498498
assert samples.num_samples == len(sample_list) * samples.ploidy
499-
assert np.array_equal(samples.individuals_mask, samples_mask)
500-
assert np.array_equal(samples.samples_mask, np.repeat(samples_mask, 3))
499+
assert np.array_equal(samples.individuals_mask, ~samples_mask)
500+
assert np.array_equal(samples.samples_mask, np.repeat(~samples_mask, 3))
501501
assert np.array_equal(
502-
samples.individuals_time, ds.individuals_time.values[samples_mask]
502+
samples.individuals_time, ds.individuals_time.values[~samples_mask]
503503
)
504504
assert np.array_equal(
505-
samples.individuals_location, ds.individuals_location.values[samples_mask]
505+
samples.individuals_location, ds.individuals_location.values[~samples_mask]
506506
)
507507
assert np.array_equal(
508508
samples.individuals_population,
509-
ds.individuals_population.values[samples_mask],
509+
ds.individuals_population.values[~samples_mask],
510510
)
511511
assert np.array_equal(
512-
samples.individuals_flags, ds.individuals_flags.values[samples_mask]
512+
samples.individuals_flags, ds.individuals_flags.values[~samples_mask]
513513
)
514514
assert np.array_equal(
515515
samples.samples_individual, np.repeat(np.arange(len(sample_list)), 3)
516516
)
517-
expected_gt = ds.call_genotype.values[:, samples_mask, :].reshape(
517+
expected_gt = ds.call_genotype.values[:, ~samples_mask, :].reshape(
518518
samples.num_sites, len(sample_list) * 3
519519
)
520520
assert np.array_equal(samples.sites_genotypes, expected_gt)

tsinfer/formats.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,7 +2338,8 @@ def num_sites(self):
23382338
@functools.cached_property
23392339
def individuals_mask(self):
23402340
try:
2341-
return self.data["samples_mask"][:].astype(bool)
2341+
# We negate the mask as it is much easier in numpy to have True=keep
2342+
return ~(self.data["samples_mask"][:].astype(bool))
23422343
except KeyError:
23432344
return np.full(self._num_unmasked_individuals, True, dtype=bool)
23442345

@@ -2393,8 +2394,8 @@ def sites_mask(self):
23932394
raise ValueError(
23942395
"Mask must be the same length as the number of unmasked sites"
23952396
)
2396-
2397-
return self.data["variant_mask"].astype(bool)
2397+
# We negate the mask as it is much easier in numpy to have True=keep
2398+
return ~(self.data["variant_mask"].astype(bool)[:])
23982399
except KeyError:
23992400
return np.full(self.data["variant_position"].shape, True, dtype=bool)
24002401

0 commit comments

Comments
 (0)