@@ -305,7 +305,7 @@ def zarr_summary(array):
305
305
return ret
306
306
307
307
308
- def chunk_iterator (array , indexes = None , mask = None , dimension = 0 ):
308
+ def chunk_iterator (array , indexes = None , mask = None , orthogonal_mask = None , dimension = 0 ):
309
309
"""
310
310
Utility to iterate over closely spaced rows in the specified array efficiently
311
311
by accessing one chunk at a time (normally used as an iterator over each row)
@@ -314,6 +314,8 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0):
314
314
assert dimension < 2
315
315
if mask is None :
316
316
mask = np .ones (array .shape [dimension ], dtype = bool )
317
+ if orthogonal_mask is None :
318
+ orthogonal_mask = np .ones (array .shape [int (not dimension )], dtype = bool )
317
319
if len (mask ) != array .shape [dimension ]:
318
320
raise ValueError ("Mask must be the same length as the array" )
319
321
@@ -339,14 +341,14 @@ def chunk_iterator(array, indexes=None, mask=None, dimension=0):
339
341
if chunk_id != prev_chunk_id :
340
342
chunk = array [chunk_id * chunk_size : (chunk_id + 1 ) * chunk_size ][:]
341
343
prev_chunk_id = chunk_id
342
- yield chunk [j % chunk_size ]
344
+ yield chunk [j % chunk_size , orthogonal_mask ]
343
345
elif dimension == 1 :
344
346
for j in indexes :
345
347
chunk_id = j // chunk_size
346
348
if chunk_id != prev_chunk_id :
347
349
chunk = array [:, chunk_id * chunk_size : (chunk_id + 1 ) * chunk_size ][:]
348
350
prev_chunk_id = chunk_id
349
- yield chunk [: , j % chunk_size ]
351
+ yield chunk [orthogonal_mask , j % chunk_size ]
350
352
351
353
352
354
def merge_variants (sd1 , sd2 ):
@@ -2297,9 +2299,9 @@ def __init__(self, path):
2297
2299
self .path = path
2298
2300
self .data = zarr .open (path , mode = "r" )
2299
2301
genotypes_arr = self .data ["call_genotype" ]
2300
- _ , self ._num_individuals , self .ploidy = genotypes_arr .shape
2302
+ _ , self ._num_unmasked_individuals , self .ploidy = genotypes_arr .shape
2301
2303
self ._num_sites = np .sum (self .sites_mask )
2302
- self ._num_samples = self ._num_individuals * self .ploidy
2304
+ self ._num_unmasked_samples = self ._num_unmasked_individuals * self .ploidy
2303
2305
2304
2306
assert self .ploidy == self .data ["call_genotype" ].chunks [2 ]
2305
2307
if self .ploidy > 1 :
@@ -2333,6 +2335,19 @@ def sequence_length(self):
2333
2335
def num_sites (self ):
2334
2336
return self ._num_sites
2335
2337
2338
+ @functools .cached_property
2339
+ def individuals_mask (self ):
2340
+ try :
2341
+ return self .data ["samples_mask" ][:].astype (bool )
2342
+ except KeyError :
2343
+ return np .full (self ._num_unmasked_individuals , True , dtype = bool )
2344
+
2345
+ @functools .cached_property
2346
+ def samples_mask (self ):
2347
+ # Samples in sgkit are individuals in tskit, so we need to expand
2348
+ # the mask to cover all the samples for each individual.
2349
+ return np .repeat (self .individuals_mask , self .ploidy )
2350
+
2336
2351
@functools .cached_property
2337
2352
def sites_metadata_schema (self ):
2338
2353
try :
@@ -2427,9 +2442,9 @@ def sites_genotypes(self):
2427
2442
gt = self .data ["call_genotype" ]
2428
2443
# This method is only used for test/debug so we retrieve and
2429
2444
# reshape the entire array.
2430
- return gt [...][self .sites_mask , :, :]. reshape (
2431
- gt . shape [ 0 ], gt . shape [ 1 ] * gt . shape [ 2 ]
2432
- )
2445
+ ret = gt [...][self .sites_mask , :, :]
2446
+ ret = ret [:, self . individuals_mask , : ]
2447
+ return ret . reshape ( ret . shape [ 0 ], ret . shape [ 1 ] * ret . shape [ 2 ] )
2433
2448
2434
2449
@functools .cached_property
2435
2450
def provenances_timestamp (self ):
@@ -2445,9 +2460,9 @@ def provenances_record(self):
2445
2460
except KeyError :
2446
2461
return np .array ([], dtype = object )
2447
2462
2448
- @property
2463
+ @functools . cached_property
2449
2464
def num_samples (self ):
2450
- return self ._num_samples
2465
+ return np . sum ( self .samples_mask )
2451
2466
2452
2467
@functools .cached_property
2453
2468
def samples_individual (self ):
@@ -2500,12 +2515,12 @@ def populations_metadata_schema(self):
2500
2515
2501
2516
@property
2502
2517
def num_individuals (self ):
2503
- return self ._num_individuals
2518
+ return np . sum ( self .individuals_mask )
2504
2519
2505
2520
@functools .cached_property
2506
2521
def individuals_time (self ):
2507
2522
try :
2508
- return self .data ["individuals_time" ]
2523
+ return self .data ["individuals_time" ][:][ self . individuals_mask ]
2509
2524
except KeyError :
2510
2525
return np .full (self .num_individuals , tskit .UNKNOWN_TIME )
2511
2526
@@ -2524,11 +2539,14 @@ def individuals_metadata(self):
2524
2539
# We set the sample_id in the individual metadata as this is often useful,
2525
2540
# however we silently don't overwrite if the key exists
2526
2541
if "individuals_metadata" in self .data :
2527
- assert len (self .data ["individuals_metadata" ]) == self .num_individuals
2528
- assert self .num_individuals == len (self .data ["sample_id" ])
2542
+ assert (
2543
+ len (self .data ["individuals_metadata" ]) == self ._num_unmasked_individuals
2544
+ )
2545
+ assert self ._num_unmasked_individuals == len (self .data ["sample_id" ])
2529
2546
md_list = []
2530
2547
for sample_id , r in zip (
2531
- self .data ["sample_id" ], self .data ["individuals_metadata" ][:]
2548
+ self .data ["sample_id" ][:][self .individuals_mask ],
2549
+ self .data ["individuals_metadata" ][:][self .individuals_mask ],
2532
2550
):
2533
2551
md = schema .decode_row (r )
2534
2552
if "sgkit_sample_id" not in md :
@@ -2537,27 +2555,28 @@ def individuals_metadata(self):
2537
2555
return md_list
2538
2556
else :
2539
2557
return [
2540
- {"sgkit_sample_id" : sample_id } for sample_id in self .data ["sample_id" ]
2558
+ {"sgkit_sample_id" : sample_id }
2559
+ for sample_id in self .data ["sample_id" ][:][self .individuals_mask ]
2541
2560
]
2542
2561
2543
2562
@functools .cached_property
2544
2563
def individuals_location (self ):
2545
2564
try :
2546
- return self .data ["individuals_location" ]
2565
+ return self .data ["individuals_location" ][:][ self . individuals_mask ]
2547
2566
except KeyError :
2548
2567
return np .array ([[]] * self .num_individuals , dtype = float )
2549
2568
2550
2569
@functools .cached_property
2551
2570
def individuals_population (self ):
2552
2571
try :
2553
- return self .data ["individuals_population" ]
2572
+ return self .data ["individuals_population" ][:][ self . individuals_mask ]
2554
2573
except KeyError :
2555
2574
return np .full ((self .num_individuals ), tskit .NULL , dtype = np .int32 )
2556
2575
2557
2576
@functools .cached_property
2558
2577
def individuals_flags (self ):
2559
2578
try :
2560
- return self .data ["individuals_flags" ]
2579
+ return self .data ["individuals_flags" ][:][ self . individuals_mask ]
2561
2580
except KeyError :
2562
2581
return np .full ((self .num_individuals ), 0 , dtype = np .int32 )
2563
2582
@@ -2585,7 +2604,10 @@ def variants(self, sites=None, recode_ancestral=None):
2585
2604
if recode_ancestral is None :
2586
2605
recode_ancestral = False
2587
2606
all_genotypes = chunk_iterator (
2588
- self .data ["call_genotype" ], indexes = sites , mask = self .sites_mask
2607
+ self .data ["call_genotype" ],
2608
+ indexes = sites ,
2609
+ mask = self .sites_mask ,
2610
+ orthogonal_mask = self .individuals_mask ,
2589
2611
)
2590
2612
assert MISSING_DATA < 0 # required for geno_map to remap MISSING_DATA
2591
2613
for genos , site in zip (all_genotypes , self .sites (ids = sites )):
@@ -2627,9 +2649,11 @@ def _all_haplotypes(self, sites=None, recode_ancestral=None):
2627
2649
aa_index [aa_index == MISSING_DATA ] = 0
2628
2650
gt = self .data ["call_genotype" ]
2629
2651
chunk_size = gt .chunks [1 ]
2630
- for j in range (self .num_individuals ):
2631
- if j % chunk_size == 0 :
2632
- chunk = gt [:, j : j + chunk_size , :]
2652
+ current_chunk = None
2653
+ for j in np .where (self .individuals_mask )[0 ]:
2654
+ if j // chunk_size != current_chunk :
2655
+ current_chunk = j // chunk_size
2656
+ chunk = gt [:, j // chunk_size : (j // chunk_size ) + chunk_size , :]
2633
2657
# Zarr doesn't support fancy indexing, so we have to do this after
2634
2658
chunk = chunk [self .sites_mask ]
2635
2659
indiv_gt = chunk [:, j % chunk_size , :]
0 commit comments