Skip to content

Commit e529ac0

Browse files
committed
Add sample mask
1 parent 6b9099a commit e529ac0

File tree

3 files changed

+90
-24
lines changed

3 files changed

+90
-24
lines changed

tests/test_sgkit.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,50 @@ def test_bad_mask_length_at_iterator(self, tmp_path):
481481
with pytest.raises(
482482
ValueError, match="Mask must be the same length as the array"
483483
):
484-
for _ in chunk_iterator(ds.variant_position, mask=sites_mask):
484+
for _ in chunk_iterator(ds.call_genotype, mask=sites_mask):
485485
pass
486486

487+
@pytest.mark.parametrize("sample_list", [[1, 2, 3, 5, 9, 27], [0], []])
488+
def test_sgkit_sample_mask(self, tmp_path, sample_list):
489+
ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True)
490+
ds = sgkit.load_dataset(zarr_path)
491+
samples_mask = np.zeros_like(ds["sample_id"], dtype=bool)
492+
for i in sample_list:
493+
samples_mask[i] = True
494+
add_array_to_dataset("samples_mask", samples_mask, zarr_path)
495+
samples = tsinfer.SgkitSampleData(zarr_path)
496+
assert samples.ploidy == 3
497+
assert samples.num_individuals == len(sample_list)
498+
assert samples.num_samples == len(sample_list) * samples.ploidy
499+
assert np.array_equal(samples.individuals_mask, samples_mask)
500+
assert np.array_equal(samples.samples_mask, np.repeat(samples_mask, 3))
501+
assert np.array_equal(
502+
samples.individuals_time, ds.individuals_time.values[samples_mask]
503+
)
504+
assert np.array_equal(
505+
samples.individuals_location, ds.individuals_location.values[samples_mask]
506+
)
507+
assert np.array_equal(
508+
samples.individuals_population,
509+
ds.individuals_population.values[samples_mask],
510+
)
511+
assert np.array_equal(
512+
samples.individuals_flags, ds.individuals_flags.values[samples_mask]
513+
)
514+
assert np.array_equal(
515+
samples.samples_individual, np.repeat(np.arange(len(sample_list)), 3)
516+
)
517+
expected_gt = ds.call_genotype.values[:, samples_mask, :].reshape(
518+
samples.num_sites, len(sample_list) * 3
519+
)
520+
assert np.array_equal(samples.sites_genotypes, expected_gt)
521+
for v, gt in zip(samples.variants(), expected_gt):
522+
assert np.array_equal(v.genotypes, gt)
523+
524+
for i, (id, haplo) in enumerate(samples.haplotypes()):
525+
assert id == i
526+
assert np.array_equal(haplo, expected_gt[:, i])
527+
487528

488529
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
489530
def test_sgkit_ancestral_allele_same_ancestors(tmp_path):

tsinfer/formats.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def zarr_summary(array):
305305
return ret
306306

307307

308-
def chunk_iterator(array, indexes=None, mask=None, dimension=0):
308+
def chunk_iterator(array, indexes=None, mask=None, orthogonal_mask=None, dimension=0):
309309
"""
310310
Utility to iterate over closely spaced rows in the specified array efficiently
311311
by accessing one chunk at a time (normally used as an iterator over each row)
@@ -314,6 +314,8 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0):
314314
assert dimension < 2
315315
if mask is None:
316316
mask = np.ones(array.shape[dimension], dtype=bool)
317+
if orthogonal_mask is None:
318+
orthogonal_mask = np.ones(array.shape[int(not dimension)], dtype=bool)
317319
if len(mask) != array.shape[dimension]:
318320
raise ValueError("Mask must be the same length as the array")
319321

@@ -339,14 +341,14 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0):
339341
if chunk_id != prev_chunk_id:
340342
chunk = array[chunk_id * chunk_size : (chunk_id + 1) * chunk_size][:]
341343
prev_chunk_id = chunk_id
342-
yield chunk[j % chunk_size]
344+
yield chunk[j % chunk_size, orthogonal_mask]
343345
elif dimension == 1:
344346
for j in indexes:
345347
chunk_id = j // chunk_size
346348
if chunk_id != prev_chunk_id:
347349
chunk = array[:, chunk_id * chunk_size : (chunk_id + 1) * chunk_size][:]
348350
prev_chunk_id = chunk_id
349-
yield chunk[:, j % chunk_size]
351+
yield chunk[orthogonal_mask, j % chunk_size]
350352

351353

352354
def merge_variants(sd1, sd2):
@@ -2297,9 +2299,9 @@ def __init__(self, path):
22972299
self.path = path
22982300
self.data = zarr.open(path, mode="r")
22992301
genotypes_arr = self.data["call_genotype"]
2300-
_, self._num_individuals, self.ploidy = genotypes_arr.shape
2302+
_, self._num_unmasked_individuals, self.ploidy = genotypes_arr.shape
23012303
self._num_sites = np.sum(self.sites_mask)
2302-
self._num_samples = self._num_individuals * self.ploidy
2304+
self._num_unmasked_samples = self._num_unmasked_individuals * self.ploidy
23032305

23042306
assert self.ploidy == self.data["call_genotype"].chunks[2]
23052307
if self.ploidy > 1:
@@ -2333,6 +2335,19 @@ def sequence_length(self):
23332335
def num_sites(self):
23342336
return self._num_sites
23352337

2338+
@functools.cached_property
2339+
def individuals_mask(self):
2340+
try:
2341+
return self.data["samples_mask"][:].astype(bool)
2342+
except KeyError:
2343+
return np.full(self._num_unmasked_individuals, True, dtype=bool)
2344+
2345+
@functools.cached_property
2346+
def samples_mask(self):
2347+
# Samples in sgkit are individuals in tskit, so we need to expand
2348+
# the mask to cover all the samples for each individual.
2349+
return np.repeat(self.individuals_mask, self.ploidy)
2350+
23362351
@functools.cached_property
23372352
def sites_metadata_schema(self):
23382353
try:
@@ -2427,9 +2442,9 @@ def sites_genotypes(self):
24272442
gt = self.data["call_genotype"]
24282443
# This method is only used for test/debug so we retrieve and
24292444
# reshape the entire array.
2430-
return gt[...][self.sites_mask, :, :].reshape(
2431-
gt.shape[0], gt.shape[1] * gt.shape[2]
2432-
)
2445+
ret = gt[...][self.sites_mask, :, :]
2446+
ret = ret[:, self.individuals_mask, :]
2447+
return ret.reshape(ret.shape[0], ret.shape[1] * ret.shape[2])
24332448

24342449
@functools.cached_property
24352450
def provenances_timestamp(self):
@@ -2445,9 +2460,9 @@ def provenances_record(self):
24452460
except KeyError:
24462461
return np.array([], dtype=object)
24472462

2448-
@property
2463+
@functools.cached_property
24492464
def num_samples(self):
2450-
return self._num_samples
2465+
return np.sum(self.samples_mask)
24512466

24522467
@functools.cached_property
24532468
def samples_individual(self):
@@ -2500,12 +2515,12 @@ def populations_metadata_schema(self):
25002515

25012516
@property
25022517
def num_individuals(self):
2503-
return self._num_individuals
2518+
return np.sum(self.individuals_mask)
25042519

25052520
@functools.cached_property
25062521
def individuals_time(self):
25072522
try:
2508-
return self.data["individuals_time"]
2523+
return self.data["individuals_time"][:][self.individuals_mask]
25092524
except KeyError:
25102525
return np.full(self.num_individuals, tskit.UNKNOWN_TIME)
25112526

@@ -2524,11 +2539,14 @@ def individuals_metadata(self):
25242539
# We set the sample_id in the individual metadata as this is often useful,
25252540
# however we silently don't overwrite if the key exists
25262541
if "individuals_metadata" in self.data:
2527-
assert len(self.data["individuals_metadata"]) == self.num_individuals
2528-
assert self.num_individuals == len(self.data["sample_id"])
2542+
assert (
2543+
len(self.data["individuals_metadata"]) == self._num_unmasked_individuals
2544+
)
2545+
assert self._num_unmasked_individuals == len(self.data["sample_id"])
25292546
md_list = []
25302547
for sample_id, r in zip(
2531-
self.data["sample_id"], self.data["individuals_metadata"][:]
2548+
self.data["sample_id"][:][self.individuals_mask],
2549+
self.data["individuals_metadata"][:][self.individuals_mask],
25322550
):
25332551
md = schema.decode_row(r)
25342552
if "sgkit_sample_id" not in md:
@@ -2537,27 +2555,28 @@ def individuals_metadata(self):
25372555
return md_list
25382556
else:
25392557
return [
2540-
{"sgkit_sample_id": sample_id} for sample_id in self.data["sample_id"]
2558+
{"sgkit_sample_id": sample_id}
2559+
for sample_id in self.data["sample_id"][:][self.individuals_mask]
25412560
]
25422561

25432562
@functools.cached_property
25442563
def individuals_location(self):
25452564
try:
2546-
return self.data["individuals_location"]
2565+
return self.data["individuals_location"][:][self.individuals_mask]
25472566
except KeyError:
25482567
return np.array([[]] * self.num_individuals, dtype=float)
25492568

25502569
@functools.cached_property
25512570
def individuals_population(self):
25522571
try:
2553-
return self.data["individuals_population"]
2572+
return self.data["individuals_population"][:][self.individuals_mask]
25542573
except KeyError:
25552574
return np.full((self.num_individuals), tskit.NULL, dtype=np.int32)
25562575

25572576
@functools.cached_property
25582577
def individuals_flags(self):
25592578
try:
2560-
return self.data["individuals_flags"]
2579+
return self.data["individuals_flags"][:][self.individuals_mask]
25612580
except KeyError:
25622581
return np.full((self.num_individuals), 0, dtype=np.int32)
25632582

@@ -2585,7 +2604,10 @@ def variants(self, sites=None, recode_ancestral=None):
25852604
if recode_ancestral is None:
25862605
recode_ancestral = False
25872606
all_genotypes = chunk_iterator(
2588-
self.data["call_genotype"], indexes=sites, mask=self.sites_mask
2607+
self.data["call_genotype"],
2608+
indexes=sites,
2609+
mask=self.sites_mask,
2610+
orthogonal_mask=self.individuals_mask,
25892611
)
25902612
assert MISSING_DATA < 0 # required for geno_map to remap MISSING_DATA
25912613
for genos, site in zip(all_genotypes, self.sites(ids=sites)):
@@ -2627,9 +2649,11 @@ def _all_haplotypes(self, sites=None, recode_ancestral=None):
26272649
aa_index[aa_index == MISSING_DATA] = 0
26282650
gt = self.data["call_genotype"]
26292651
chunk_size = gt.chunks[1]
2630-
for j in range(self.num_individuals):
2631-
if j % chunk_size == 0:
2632-
chunk = gt[:, j : j + chunk_size, :]
2652+
current_chunk = None
2653+
for j in np.where(self.individuals_mask)[0]:
2654+
if j // chunk_size != current_chunk:
2655+
current_chunk = j // chunk_size
2656+
chunk = gt[:, j // chunk_size : (j // chunk_size) + chunk_size, :]
26332657
# Zarr doesn't support fancy indexing, so we have to do this after
26342658
chunk = chunk[self.sites_mask]
26352659
indiv_gt = chunk[:, j % chunk_size, :]

tsinfer/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,7 @@ def group_by_linesweep(self):
16441644
epoch_end = np.hstack([breaks + 1, [self.num_ancestors]])
16451645
time_slices = np.vstack([epoch_start, epoch_end]).T
16461646
epoch_sizes = time_slices[:, 1] - time_slices[:, 0]
1647+
16471648
median_size = np.median(epoch_sizes)
16481649
cutoff = 500 * median_size
16491650
# Zero out the first half so that an initial large epoch doesn't

0 commit comments

Comments
 (0)