20
20
Tests for the data files.
21
21
"""
22
22
import json
23
+ import logging
23
24
import sys
24
25
import tempfile
26
+ import warnings
25
27
26
28
import msprime
27
29
import numcodecs
@@ -627,14 +629,12 @@ def test_missing_ancestral_allele(tmp_path):
627
629
628
630
629
631
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
630
- def test_ancestral_missingness (tmp_path ):
632
+ def test_deliberate_ancestral_missingness (tmp_path ):
631
633
ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
632
634
ds = sgkit .load_dataset (zarr_path )
633
635
ancestral_allele = ds .variant_ancestral_allele .values
634
636
ancestral_allele [0 ] = "N"
635
- ancestral_allele [11 ] = "-"
636
- ancestral_allele [12 ] = "💩"
637
- ancestral_allele [15 ] = "💩"
637
+ ancestral_allele [1 ] = "n"
638
638
ds = ds .drop_vars (["variant_ancestral_allele" ])
639
639
sgkit .save_dataset (ds , str (zarr_path ) + ".tmp" )
640
640
tsutil .add_array_to_dataset (
@@ -644,15 +644,57 @@ def test_ancestral_missingness(tmp_path):
644
644
["variants" ],
645
645
)
646
646
ds = sgkit .load_dataset (str (zarr_path ) + ".tmp" )
647
+ with warnings .catch_warnings ():
648
+ warnings .simplefilter ("error" ) # No warning raised if AA deliberately missing
649
+ sd = tsinfer .VariantData (str (zarr_path ) + ".tmp" , "variant_ancestral_allele" )
650
+ inf_ts = tsinfer .infer (sd )
651
+ for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
652
+ if i in [0 , 1 ]:
653
+ assert inf_var .site .metadata == {"inference_type" : "parsimony" }
654
+ else :
655
+ assert inf_var .site .ancestral_state == var .site .ancestral_state
656
+
657
+
658
+ @pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
659
+ def test_ancestral_missing_warning (tmp_path ):
660
+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
661
+ ds = sgkit .load_dataset (zarr_path )
662
+ anc_state = ds .variant_ancestral_allele .values
663
+ anc_state [0 ] = "N"
664
+ anc_state [11 ] = "-"
665
+ anc_state [12 ] = "💩"
666
+ anc_state [15 ] = "💩"
647
667
with pytest .warns (
648
668
UserWarning ,
649
669
match = r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2" ,
650
670
):
651
- sd = tsinfer .VariantData (str (zarr_path ) + ".tmp" , "variant_ancestral_allele" )
652
- inf_ts = tsinfer .infer (sd )
671
+ vdata = tsinfer .VariantData (zarr_path , anc_state )
672
+ inf_ts = tsinfer .infer (vdata )
673
+ for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
674
+ if i in [0 , 11 , 12 , 15 ]:
675
+ assert inf_var .site .metadata == {"inference_type" : "parsimony" }
676
+ assert inf_var .site .ancestral_state in var .site .alleles
677
+ else :
678
+ assert inf_var .site .ancestral_state == var .site .ancestral_state
679
+
680
+
681
+ @pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
682
+ def test_ancestral_missing_info (tmp_path , caplog ):
683
+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
684
+ ds = sgkit .load_dataset (zarr_path )
685
+ anc_state = ds .variant_ancestral_allele .values
686
+ anc_state [0 ] = "N"
687
+ anc_state [11 ] = "N"
688
+ anc_state [12 ] = "n"
689
+ anc_state [15 ] = "n"
690
+ with caplog .at_level (logging .INFO ):
691
+ vdata = tsinfer .VariantData (zarr_path , anc_state )
692
+ assert f"4 sites ({ 4 / ts .num_sites * 100 :.2f} %) were deliberately " in caplog .text
693
+ inf_ts = tsinfer .infer (vdata )
653
694
for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
654
695
if i in [0 , 11 , 12 , 15 ]:
655
696
assert inf_var .site .metadata == {"inference_type" : "parsimony" }
697
+ assert inf_var .site .ancestral_state in var .site .alleles
656
698
else :
657
699
assert inf_var .site .ancestral_state == var .site .ancestral_state
658
700
@@ -670,6 +712,25 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
670
712
671
713
672
714
class TestVariantDataErrors :
715
+ @staticmethod
716
+ def simulate_genotype_call_dataset (* args , ** kwargs ):
717
+ # roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
718
+ # duplicate alleles are created. Doesn't need to be efficient: just for testing
719
+ if "seed" not in kwargs :
720
+ kwargs ["seed" ] = 123
721
+ ds = sgkit .simulate_genotype_call_dataset (* args , ** kwargs )
722
+ variant_alleles = ds ["variant_allele" ].values
723
+ allowed_alleles = np .array (
724
+ ["A" , "T" , "C" , "G" , "N" ], dtype = variant_alleles .dtype
725
+ )
726
+ for row in range (len (variant_alleles )):
727
+ alleles = variant_alleles [row ]
728
+ if len (set (alleles )) != len (alleles ):
729
+ # Just use a set that we know is unique
730
+ variant_alleles [row ] = allowed_alleles [0 : len (alleles )]
731
+ ds ["variant_allele" ] = ds ["variant_allele" ].dims , variant_alleles
732
+ return ds
733
+
673
734
def test_bad_zarr_spec (self ):
674
735
ds = zarr .group ()
675
736
ds ["call_genotype" ] = zarr .array (np .zeros (10 , dtype = np .int8 ))
@@ -680,7 +741,7 @@ def test_bad_zarr_spec(self):
680
741
681
742
def test_missing_phase (self , tmp_path ):
682
743
path = tmp_path / "data.zarr"
683
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
744
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
684
745
sgkit .save_dataset (ds , path )
685
746
with pytest .raises (
686
747
ValueError , match = "The call_genotype_phased array is missing"
@@ -689,7 +750,7 @@ def test_missing_phase(self, tmp_path):
689
750
690
751
def test_phased (self , tmp_path ):
691
752
path = tmp_path / "data.zarr"
692
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
753
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
693
754
ds ["call_genotype_phased" ] = (
694
755
ds ["call_genotype" ].dims ,
695
756
np .ones (ds ["call_genotype" ].shape , dtype = bool ),
@@ -700,13 +761,13 @@ def test_phased(self, tmp_path):
700
761
def test_ploidy1_missing_phase (self , tmp_path ):
701
762
path = tmp_path / "data.zarr"
702
763
# Ploidy==1 is always ok
703
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
764
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
704
765
sgkit .save_dataset (ds , path )
705
766
tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
706
767
707
768
def test_ploidy1_unphased (self , tmp_path ):
708
769
path = tmp_path / "data.zarr"
709
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
770
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
710
771
ds ["call_genotype_phased" ] = (
711
772
ds ["call_genotype" ].dims ,
712
773
np .zeros (ds ["call_genotype" ].shape , dtype = bool ),
@@ -716,31 +777,54 @@ def test_ploidy1_unphased(self, tmp_path):
716
777
717
778
def test_duplicate_positions (self , tmp_path ):
718
779
path = tmp_path / "data.zarr"
719
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
780
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
720
781
ds ["variant_position" ][2 ] = ds ["variant_position" ][1 ]
721
782
sgkit .save_dataset (ds , path )
722
783
with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
723
784
tsinfer .VariantData (path , "variant_ancestral_allele" )
724
785
725
786
def test_bad_order_positions (self , tmp_path ):
726
787
path = tmp_path / "data.zarr"
727
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
788
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
728
789
ds ["variant_position" ][0 ] = ds ["variant_position" ][2 ] - 0.5
729
790
sgkit .save_dataset (ds , path )
730
791
with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
731
792
tsinfer .VariantData (path , "variant_ancestral_allele" )
732
793
794
+ def test_bad_ancestral_state (self , tmp_path ):
795
+ path = tmp_path / "data.zarr"
796
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
797
+ ancestral_state = ds ["variant_allele" ][:, 0 ].values .astype (str )
798
+ ancestral_state [1 ] = ""
799
+ sgkit .save_dataset (ds , path )
800
+ with pytest .raises (ValueError , match = "cannot contain empty strings" ):
801
+ tsinfer .VariantData (path , ancestral_state )
802
+
733
803
def test_empty_alleles_not_at_end (self , tmp_path ):
734
804
path = tmp_path / "data.zarr"
735
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
805
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
736
806
ds ["variant_allele" ] = (
737
807
ds ["variant_allele" ].dims ,
738
- np .array ([["" , "A " , "C" ], ["A" , "C" , "" ], ["A" , "C" , "" ]], dtype = "S1" ),
808
+ np .array ([["A " , "" , "C" ], ["A" , "C" , "" ], ["A" , "C" , "" ]], dtype = "S1" ),
739
809
)
740
810
sgkit .save_dataset (ds , path )
741
- vdata = tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
742
- with pytest .raises (ValueError , match = "Empty alleles must be at the end" ):
743
- tsinfer .infer (vdata )
811
+ with pytest .raises (
812
+ ValueError , match = 'Bad alleles: fill value "" in middle of list'
813
+ ):
814
+ tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
815
+
816
+ def test_unique_alleles (self , tmp_path ):
817
+ path = tmp_path / "data.zarr"
818
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
819
+ ds ["variant_allele" ] = (
820
+ ds ["variant_allele" ].dims ,
821
+ np .array ([["A" , "C" , "T" ], ["A" , "C" , "" ], ["A" , "A" , "" ]], dtype = "S1" ),
822
+ )
823
+ sgkit .save_dataset (ds , path )
824
+ with pytest .raises (
825
+ ValueError , match = "Duplicate allele values provided at site 2"
826
+ ):
827
+ tsinfer .VariantData (path , np .array (["A" , "A" , "A" ], dtype = "S1" ))
744
828
745
829
def test_unimplemented_from_tree_sequence (self ):
746
830
# NB we should reimplement something like this functionality.
0 commit comments