@@ -438,8 +438,8 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
438
438
sites_mask = np .ones_like (ds ["variant_position" ], dtype = bool )
439
439
for i in sites :
440
440
sites_mask [i ] = False
441
- add_array_to_dataset ("variant_mask " , sites_mask , zarr_path )
442
- samples = tsinfer .SgkitSampleData (zarr_path )
441
+ add_array_to_dataset ("variant_mask_42 " , sites_mask , zarr_path )
442
+ samples = tsinfer .SgkitSampleData (zarr_path , sites_mask_name = "variant_mask_42" )
443
443
assert samples .num_sites == len (sites )
444
444
assert np .array_equal (samples .sites_mask , ~ sites_mask )
445
445
assert np .array_equal (
@@ -465,12 +465,13 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path):
465
465
ts , zarr_path = make_ts_and_zarr (tmp_path )
466
466
ds = sgkit .load_dataset (zarr_path )
467
467
sites_mask = np .zeros (ds .sizes ["variants" ] + 1 , dtype = int )
468
- add_array_to_dataset ("variant_mask" , sites_mask , zarr_path )
468
+ add_array_to_dataset ("variant_mask_foobar" , sites_mask , zarr_path )
469
+ tsinfer .SgkitSampleData (zarr_path )
469
470
with pytest .raises (
470
471
ValueError ,
471
472
match = "Mask must be the same length as the number of unmasked sites" ,
472
473
):
473
- tsinfer .SgkitSampleData (zarr_path )
474
+ tsinfer .SgkitSampleData (zarr_path , sites_mask_name = "variant_mask_foobar" )
474
475
475
476
def test_bad_mask_length_at_iterator (self , tmp_path ):
476
477
ts , zarr_path = make_ts_and_zarr (tmp_path )
@@ -491,8 +492,10 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list):
491
492
samples_mask = np .ones_like (ds ["sample_id" ], dtype = bool )
492
493
for i in sample_list :
493
494
samples_mask [i ] = False
494
- add_array_to_dataset ("samples_mask" , samples_mask , zarr_path )
495
- samples = tsinfer .SgkitSampleData (zarr_path )
495
+ add_array_to_dataset ("samples_mask_69" , samples_mask , zarr_path )
496
+ samples = tsinfer .SgkitSampleData (
497
+ zarr_path , sgkit_samples_mask_name = "samples_mask_69"
498
+ )
496
499
assert samples .ploidy == 3
497
500
assert samples .num_individuals == len (sample_list )
498
501
assert samples .num_samples == len (sample_list ) * samples .ploidy
@@ -525,6 +528,24 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list):
525
528
assert id == i
526
529
assert np .array_equal (haplo , expected_gt [:, i ])
527
530
531
+ def test_sgkit_missing_masks (self , tmp_path ):
532
+ ts , zarr_path = make_ts_and_zarr (tmp_path )
533
+ samples = tsinfer .SgkitSampleData (zarr_path )
534
+ samples .individuals_mask
535
+ samples .sites_mask
536
+ with pytest .raises (
537
+ ValueError , match = "The sites mask foobar was not found in the dataset."
538
+ ):
539
+ tsinfer .SgkitSampleData (zarr_path , sites_mask_name = "foobar" )
540
+ with pytest .raises (
541
+ ValueError ,
542
+ match = "The sgkit samples mask foobar2 was not found in the dataset." ,
543
+ ):
544
+ samples = tsinfer .SgkitSampleData (
545
+ zarr_path , sgkit_samples_mask_name = "foobar2"
546
+ )
547
+ samples .individuals_mask
548
+
528
549
529
550
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
530
551
def test_sgkit_ancestral_allele_same_ancestors (tmp_path ):
0 commit comments