@@ -435,25 +435,25 @@ class TestSgkitMask:
435
435
def test_sgkit_variant_mask (self , tmp_path , sites ):
436
436
ts , zarr_path = make_ts_and_zarr (tmp_path )
437
437
ds = sgkit .load_dataset (zarr_path )
438
- sites_mask = np .zeros_like (ds ["variant_position" ], dtype = bool )
438
+ sites_mask = np .ones_like (ds ["variant_position" ], dtype = bool )
439
439
for i in sites :
440
- sites_mask [i ] = True
440
+ sites_mask [i ] = False
441
441
add_array_to_dataset ("variant_mask" , sites_mask , zarr_path )
442
442
samples = tsinfer .SgkitSampleData (zarr_path )
443
443
assert samples .num_sites == len (sites )
444
- assert np .array_equal (samples .sites_mask , sites_mask )
444
+ assert np .array_equal (samples .sites_mask , ~ sites_mask )
445
445
assert np .array_equal (
446
- samples .sites_position , ts .tables .sites .position [sites_mask ]
446
+ samples .sites_position , ts .tables .sites .position [~ sites_mask ]
447
447
)
448
448
inf_ts = tsinfer .infer (samples )
449
449
assert np .array_equal (
450
- ts .genotype_matrix ()[sites_mask ], inf_ts .genotype_matrix ()
450
+ ts .genotype_matrix ()[~ sites_mask ], inf_ts .genotype_matrix ()
451
451
)
452
452
assert np .array_equal (
453
- ts .tables .sites .position [sites_mask ], inf_ts .tables .sites .position
453
+ ts .tables .sites .position [~ sites_mask ], inf_ts .tables .sites .position
454
454
)
455
455
assert np .array_equal (
456
- ts .tables .sites .ancestral_state [sites_mask ],
456
+ ts .tables .sites .ancestral_state [~ sites_mask ],
457
457
inf_ts .tables .sites .ancestral_state ,
458
458
)
459
459
# TODO - site metadata needs merging not replacing
@@ -464,7 +464,7 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
464
464
def test_sgkit_variant_bad_mask_length (self , tmp_path ):
465
465
ts , zarr_path = make_ts_and_zarr (tmp_path )
466
466
ds = sgkit .load_dataset (zarr_path )
467
- sites_mask = np .ones (ds .sizes ["variants" ] + 1 , dtype = int )
467
+ sites_mask = np .zeros (ds .sizes ["variants" ] + 1 , dtype = int )
468
468
add_array_to_dataset ("variant_mask" , sites_mask , zarr_path )
469
469
with pytest .raises (
470
470
ValueError ,
@@ -475,7 +475,7 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path):
475
475
def test_bad_mask_length_at_iterator (self , tmp_path ):
476
476
ts , zarr_path = make_ts_and_zarr (tmp_path )
477
477
ds = sgkit .load_dataset (zarr_path )
478
- sites_mask = np .ones (ds .sizes ["variants" ] + 1 , dtype = int )
478
+ sites_mask = np .zeros (ds .sizes ["variants" ] + 1 , dtype = int )
479
479
from tsinfer .formats import chunk_iterator
480
480
481
481
with pytest .raises (
@@ -488,33 +488,33 @@ def test_bad_mask_length_at_iterator(self, tmp_path):
488
488
def test_sgkit_sample_mask (self , tmp_path , sample_list ):
489
489
ts , zarr_path = make_ts_and_zarr (tmp_path , add_optional = True )
490
490
ds = sgkit .load_dataset (zarr_path )
491
- samples_mask = np .zeros_like (ds ["sample_id" ], dtype = bool )
491
+ samples_mask = np .ones_like (ds ["sample_id" ], dtype = bool )
492
492
for i in sample_list :
493
- samples_mask [i ] = True
493
+ samples_mask [i ] = False
494
494
add_array_to_dataset ("samples_mask" , samples_mask , zarr_path )
495
495
samples = tsinfer .SgkitSampleData (zarr_path )
496
496
assert samples .ploidy == 3
497
497
assert samples .num_individuals == len (sample_list )
498
498
assert samples .num_samples == len (sample_list ) * samples .ploidy
499
- assert np .array_equal (samples .individuals_mask , samples_mask )
500
- assert np .array_equal (samples .samples_mask , np .repeat (samples_mask , 3 ))
499
+ assert np .array_equal (samples .individuals_mask , ~ samples_mask )
500
+ assert np .array_equal (samples .samples_mask , np .repeat (~ samples_mask , 3 ))
501
501
assert np .array_equal (
502
- samples .individuals_time , ds .individuals_time .values [samples_mask ]
502
+ samples .individuals_time , ds .individuals_time .values [~ samples_mask ]
503
503
)
504
504
assert np .array_equal (
505
- samples .individuals_location , ds .individuals_location .values [samples_mask ]
505
+ samples .individuals_location , ds .individuals_location .values [~ samples_mask ]
506
506
)
507
507
assert np .array_equal (
508
508
samples .individuals_population ,
509
- ds .individuals_population .values [samples_mask ],
509
+ ds .individuals_population .values [~ samples_mask ],
510
510
)
511
511
assert np .array_equal (
512
- samples .individuals_flags , ds .individuals_flags .values [samples_mask ]
512
+ samples .individuals_flags , ds .individuals_flags .values [~ samples_mask ]
513
513
)
514
514
assert np .array_equal (
515
515
samples .samples_individual , np .repeat (np .arange (len (sample_list )), 3 )
516
516
)
517
- expected_gt = ds .call_genotype .values [:, samples_mask , :].reshape (
517
+ expected_gt = ds .call_genotype .values [:, ~ samples_mask , :].reshape (
518
518
samples .num_sites , len (sample_list ) * 3
519
519
)
520
520
assert np .array_equal (samples .sites_genotypes , expected_gt )
0 commit comments