2020Tests for the data files.
2121"""
2222import json
23+ import logging
2324import sys
2425import tempfile
26+ import warnings
2527
2628import msprime
2729import numcodecs
@@ -627,14 +629,12 @@ def test_missing_ancestral_allele(tmp_path):
627629
628630
629631@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
630- def test_ancestral_missingness (tmp_path ):
632+ def test_deliberate_ancestral_missingness (tmp_path ):
631633 ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
632634 ds = sgkit .load_dataset (zarr_path )
633635 ancestral_allele = ds .variant_ancestral_allele .values
634636 ancestral_allele [0 ] = "N"
635- ancestral_allele [11 ] = "-"
636- ancestral_allele [12 ] = "💩"
637- ancestral_allele [15 ] = "💩"
637+ ancestral_allele [1 ] = "n"
638638 ds = ds .drop_vars (["variant_ancestral_allele" ])
639639 sgkit .save_dataset (ds , str (zarr_path ) + ".tmp" )
640640 tsutil .add_array_to_dataset (
@@ -644,15 +644,57 @@ def test_ancestral_missingness(tmp_path):
644644 ["variants" ],
645645 )
646646 ds = sgkit .load_dataset (str (zarr_path ) + ".tmp" )
647+ with warnings .catch_warnings ():
648+ warnings .simplefilter ("error" ) # No warning raised if AA deliberately missing
649+ sd = tsinfer .VariantData (str (zarr_path ) + ".tmp" , "variant_ancestral_allele" )
650+ inf_ts = tsinfer .infer (sd )
651+ for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
652+ if i in [0 , 1 ]:
653+ assert inf_var .site .metadata == {"inference_type" : "parsimony" }
654+ else :
655+ assert inf_var .site .ancestral_state == var .site .ancestral_state
656+
657+
658+ @pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
659+ def test_ancestral_missing_warning (tmp_path ):
660+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
661+ ds = sgkit .load_dataset (zarr_path )
662+ anc_state = ds .variant_ancestral_allele .values
663+ anc_state [0 ] = "N"
664+ anc_state [11 ] = "-"
665+ anc_state [12 ] = "💩"
666+ anc_state [15 ] = "💩"
647667 with pytest .warns (
648668 UserWarning ,
649669 match = r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2" ,
650670 ):
651- sd = tsinfer .VariantData (str (zarr_path ) + ".tmp" , "variant_ancestral_allele" )
652- inf_ts = tsinfer .infer (sd )
671+ vdata = tsinfer .VariantData (zarr_path , anc_state )
672+ inf_ts = tsinfer .infer (vdata )
673+ for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
674+ if i in [0 , 11 , 12 , 15 ]:
675+ assert inf_var .site .metadata == {"inference_type" : "parsimony" }
676+ assert inf_var .site .ancestral_state in var .site .alleles
677+ else :
678+ assert inf_var .site .ancestral_state == var .site .ancestral_state
679+
680+
681+ @pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on Windows" )
682+ def test_ancestral_missing_info (tmp_path , caplog ):
683+ ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
684+ ds = sgkit .load_dataset (zarr_path )
685+ anc_state = ds .variant_ancestral_allele .values
686+ anc_state [0 ] = "N"
687+ anc_state [11 ] = "N"
688+ anc_state [12 ] = "n"
689+ anc_state [15 ] = "n"
690+ with caplog .at_level (logging .INFO ):
691+ vdata = tsinfer .VariantData (zarr_path , anc_state )
692+ assert f"4 sites ({ 4 / ts .num_sites * 100 :.2f} %) were deliberately " in caplog .text
693+ inf_ts = tsinfer .infer (vdata )
653694 for i , (inf_var , var ) in enumerate (zip (inf_ts .variants (), ts .variants ())):
654695 if i in [0 , 11 , 12 , 15 ]:
655696 assert inf_var .site .metadata == {"inference_type" : "parsimony" }
697+ assert inf_var .site .ancestral_state in var .site .alleles
656698 else :
657699 assert inf_var .site .ancestral_state == var .site .ancestral_state
658700
@@ -670,6 +712,25 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
670712
671713
672714class TestVariantDataErrors :
715+ @staticmethod
716+ def simulate_genotype_call_dataset (* args , ** kwargs ):
717+ # roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
718+ # duplicate alleles are created. Doesn't need to be efficient: just for testing
719+ if "seed" not in kwargs :
720+ kwargs ["seed" ] = 123
721+ ds = sgkit .simulate_genotype_call_dataset (* args , ** kwargs )
722+ variant_alleles = ds ["variant_allele" ].values
723+ allowed_alleles = np .array (
724+ ["A" , "T" , "C" , "G" , "N" ], dtype = variant_alleles .dtype
725+ )
726+ for row in range (len (variant_alleles )):
727+ alleles = variant_alleles [row ]
728+ if len (set (alleles )) != len (alleles ):
729+ # Just use a set that we know is unique
730+ variant_alleles [row ] = allowed_alleles [0 : len (alleles )]
731+ ds ["variant_allele" ] = ds ["variant_allele" ].dims , variant_alleles
732+ return ds
733+
673734 def test_bad_zarr_spec (self ):
674735 ds = zarr .group ()
675736 ds ["call_genotype" ] = zarr .array (np .zeros (10 , dtype = np .int8 ))
@@ -680,7 +741,7 @@ def test_bad_zarr_spec(self):
680741
681742 def test_missing_phase (self , tmp_path ):
682743 path = tmp_path / "data.zarr"
683- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
744+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
684745 sgkit .save_dataset (ds , path )
685746 with pytest .raises (
686747 ValueError , match = "The call_genotype_phased array is missing"
@@ -689,7 +750,7 @@ def test_missing_phase(self, tmp_path):
689750
690751 def test_phased (self , tmp_path ):
691752 path = tmp_path / "data.zarr"
692- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
753+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
693754 ds ["call_genotype_phased" ] = (
694755 ds ["call_genotype" ].dims ,
695756 np .ones (ds ["call_genotype" ].shape , dtype = bool ),
@@ -700,13 +761,13 @@ def test_phased(self, tmp_path):
700761 def test_ploidy1_missing_phase (self , tmp_path ):
701762 path = tmp_path / "data.zarr"
702763 # Ploidy==1 is always ok
703- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
764+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
704765 sgkit .save_dataset (ds , path )
705766 tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
706767
707768 def test_ploidy1_unphased (self , tmp_path ):
708769 path = tmp_path / "data.zarr"
709- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
770+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
710771 ds ["call_genotype_phased" ] = (
711772 ds ["call_genotype" ].dims ,
712773 np .zeros (ds ["call_genotype" ].shape , dtype = bool ),
@@ -716,31 +777,54 @@ def test_ploidy1_unphased(self, tmp_path):
716777
717778 def test_duplicate_positions (self , tmp_path ):
718779 path = tmp_path / "data.zarr"
719- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
780+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
720781 ds ["variant_position" ][2 ] = ds ["variant_position" ][1 ]
721782 sgkit .save_dataset (ds , path )
722783 with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
723784 tsinfer .VariantData (path , "variant_ancestral_allele" )
724785
725786 def test_bad_order_positions (self , tmp_path ):
726787 path = tmp_path / "data.zarr"
727- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
788+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
728789 ds ["variant_position" ][0 ] = ds ["variant_position" ][2 ] - 0.5
729790 sgkit .save_dataset (ds , path )
730791 with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
731792 tsinfer .VariantData (path , "variant_ancestral_allele" )
732793
794+ def test_bad_ancestral_state (self , tmp_path ):
795+ path = tmp_path / "data.zarr"
796+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
797+ ancestral_state = ds ["variant_allele" ][:, 0 ].values .astype (str )
798+ ancestral_state [1 ] = ""
799+ sgkit .save_dataset (ds , path )
800+ with pytest .raises (ValueError , match = "cannot contain empty strings" ):
801+ tsinfer .VariantData (path , ancestral_state )
802+
733803 def test_empty_alleles_not_at_end (self , tmp_path ):
734804 path = tmp_path / "data.zarr"
735- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
805+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
736806 ds ["variant_allele" ] = (
737807 ds ["variant_allele" ].dims ,
738- np .array ([["" , "A " , "C" ], ["A" , "C" , "" ], ["A" , "C" , "" ]], dtype = "S1" ),
808+ np .array ([["A " , "" , "C" ], ["A" , "C" , "" ], ["A" , "C" , "" ]], dtype = "S1" ),
739809 )
740810 sgkit .save_dataset (ds , path )
741- vdata = tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
742- with pytest .raises (ValueError , match = "Empty alleles must be at the end" ):
743- tsinfer .infer (vdata )
811+ with pytest .raises (
812+ ValueError , match = 'Bad alleles: fill value "" in middle of list'
813+ ):
814+ tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
815+
816+ def test_unique_alleles (self , tmp_path ):
817+ path = tmp_path / "data.zarr"
818+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
819+ ds ["variant_allele" ] = (
820+ ds ["variant_allele" ].dims ,
821+ np .array ([["A" , "C" , "T" ], ["A" , "C" , "" ], ["A" , "A" , "" ]], dtype = "S1" ),
822+ )
823+ sgkit .save_dataset (ds , path )
824+ with pytest .raises (
825+ ValueError , match = "Duplicate allele values provided at site 2"
826+ ):
827+ tsinfer .VariantData (path , np .array (["A" , "A" , "A" ], dtype = "S1" ))
744828
745829 def test_unimplemented_from_tree_sequence (self ):
746830 # NB we should reimplement something like this functionality.
0 commit comments