38
38
from tsinfer import formats
39
39
40
40
41
- def ts_to_dataset (ts , chunks = None , samples = None ):
41
+ def ts_to_dataset (ts , chunks = None , samples = None , contigs = None ):
42
42
"""
43
43
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
44
44
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
63
63
genotypes = np .expand_dims (genotypes , axis = 2 )
64
64
65
65
ds = sgkit .create_genotype_call_dataset (
66
- variant_contig_names = ["1" ],
66
+ variant_contig_names = ["1" ] if contigs is None else contigs ,
67
67
variant_contig = np .zeros (len (tables .sites ), dtype = int ),
68
68
variant_position = tables .sites .position .astype (int ),
69
69
variant_allele = alleles ,
@@ -299,18 +299,102 @@ def test_simulate_genotype_call_dataset(tmp_path):
299
299
ts = msprime .sim_ancestry (4 , sequence_length = 1000 , random_seed = 123 )
300
300
ts = msprime .sim_mutations (ts , rate = 2e-3 , random_seed = 123 )
301
301
ds = ts_to_dataset (ts )
302
- ds .update ({"variant_ancestral_allele" : ds ["variant_allele" ][:, 0 ]})
303
302
ds .to_zarr (tmp_path , mode = "w" )
304
- sd = tsinfer .VariantData (tmp_path , "variant_ancestral_allele" )
305
- ts = tsinfer .infer (sd )
306
- for v , ds_v , sd_v in zip (ts .variants (), ds .call_genotype , sd .sites_genotypes ):
303
+ vdata = tsinfer .VariantData (tmp_path , ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) )
304
+ ts = tsinfer .infer (vdata )
305
+ for v , ds_v , vd_v in zip (ts .variants (), ds .call_genotype , vdata .sites_genotypes ):
307
306
assert np .all (v .genotypes == ds_v .values .flatten ())
308
- assert np .all (v .genotypes == sd_v )
307
+ assert np .all (v .genotypes == vd_v )
308
+
309
+
310
+ def test_simulate_genotype_call_dataset_length (tmp_path ):
311
+ # create_genotype_call_dataset does not save contig lengths
312
+ ts = msprime .sim_ancestry (4 , sequence_length = 1000 , random_seed = 123 )
313
+ ts = msprime .sim_mutations (ts , rate = 2e-3 , random_seed = 123 )
314
+ ds = ts_to_dataset (ts )
315
+ assert "contig_length" not in ds
316
+ ds .to_zarr (tmp_path , mode = "w" )
317
+ vdata = tsinfer .VariantData (tmp_path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
318
+ assert vdata .sequence_length == ts .sites_position [- 1 ] + 1
319
+
320
+
321
+ class TestMultiContig :
322
+ def make_two_ts_dataset (self , path ):
323
+ # split ts into 2; put them as different contigs in the same dataset
324
+ ts = msprime .sim_ancestry (4 , sequence_length = 1000 , random_seed = 123 )
325
+ ts = msprime .sim_mutations (ts , rate = 2e-3 , random_seed = 123 )
326
+ split_at_site = 7
327
+ assert ts .num_sites > 10
328
+ site_break = ts .site (split_at_site ).position
329
+ ts1 = ts .keep_intervals ([(0 , site_break )]).rtrim ()
330
+ ts2 = ts .keep_intervals ([(site_break , ts .sequence_length )]).ltrim ()
331
+ ds = ts_to_dataset (ts , contigs = ["chr1" , "chr2" ])
332
+ ds .update ({"variant_ancestral_allele" : ds ["variant_allele" ][:, 0 ]})
333
+ variant_contig = ds ["variant_contig" ][:]
334
+ variant_contig [split_at_site :] = 1
335
+ ds .update ({"variant_contig" : variant_contig })
336
+ variant_position = ds ["variant_position" ].values
337
+ variant_position [split_at_site :] -= int (site_break )
338
+ ds .update ({"variant_position" : ds ["variant_position" ]})
339
+ ds .update (
340
+ {"contig_length" : np .array ([ts1 .sequence_length , ts2 .sequence_length ])}
341
+ )
342
+ ds .to_zarr (path , mode = "w" )
343
+ return ts1 , ts2
344
+
345
+ def test_unmasked (self , tmp_path ):
346
+ self .make_two_ts_dataset (tmp_path )
347
+ with pytest .raises (ValueError , match = r'multiple contigs \("chr1", "chr2"\)' ):
348
+ tsinfer .VariantData (tmp_path , "variant_ancestral_allele" )
349
+
350
+ def test_mask (self , tmp_path ):
351
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
352
+ vdata = tsinfer .VariantData (
353
+ tmp_path ,
354
+ "variant_ancestral_allele" ,
355
+ site_mask = np .array (ts1 .num_sites * [True ] + ts2 .num_sites * [False ]),
356
+ )
357
+ assert np .all (ts2 .sites_position == vdata .sites_position )
358
+ assert vdata .contig_id == "chr2"
359
+ assert vdata .sequence_length == ts2 .sequence_length
360
+
361
+ @pytest .mark .parametrize ("contig_id" , ["chr1" , "chr2" ])
362
+ def test_contig_id_param (self , contig_id , tmp_path ):
363
+ tree_seqs = {}
364
+ tree_seqs ["chr1" ], tree_seqs ["chr2" ] = self .make_two_ts_dataset (tmp_path )
365
+ vdata = tsinfer .VariantData (
366
+ tmp_path , "variant_ancestral_allele" , contig_id = contig_id
367
+ )
368
+ assert np .all (tree_seqs [contig_id ].sites_position == vdata .sites_position )
369
+ assert vdata .contig_id == contig_id
370
+ assert vdata .sequence_length == tree_seqs [contig_id ].sequence_length
371
+
372
+ def test_contig_id_param_and_mask (self , tmp_path ):
373
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
374
+ vdata = tsinfer .VariantData (
375
+ tmp_path ,
376
+ "variant_ancestral_allele" ,
377
+ site_mask = np .array (
378
+ (ts1 .num_sites + 1 ) * [True ] + (ts2 .num_sites - 1 ) * [False ]
379
+ ),
380
+ contig_id = "chr2" ,
381
+ )
382
+ assert np .all (ts2 .sites_position [1 :] == vdata .sites_position )
383
+ assert vdata .contig_id == "chr2"
384
+
385
+ @pytest .mark .parametrize ("contig_id" , ["chr1" , "chr2" ])
386
+ def test_contig_length (self , contig_id , tmp_path ):
387
+ tree_seqs = {}
388
+ tree_seqs ["chr1" ], tree_seqs ["chr2" ] = self .make_two_ts_dataset (tmp_path )
389
+ vdata = tsinfer .VariantData (
390
+ tmp_path , "variant_ancestral_allele" , contig_id = contig_id
391
+ )
392
+ assert vdata .sequence_length == tree_seqs [contig_id ].sequence_length
309
393
310
394
311
395
@pytest .mark .skipif (sys .platform == "win32" , reason = "File permission errors on Windows" )
312
396
class TestSgkitMask :
313
- @pytest .mark .parametrize ("sites" , [[1 , 2 , 3 , 5 , 9 , 27 ], [0 ], [] ])
397
+ @pytest .mark .parametrize ("sites" , [[1 , 2 , 3 , 5 , 9 , 27 ], [0 ]])
314
398
def test_sgkit_variant_mask (self , tmp_path , sites ):
315
399
ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
316
400
ds = sgkit .load_dataset (zarr_path )
@@ -823,6 +907,20 @@ def test_bad_ancestral_state(self, tmp_path):
823
907
with pytest .raises (ValueError , match = "cannot contain empty strings" ):
824
908
tsinfer .VariantData (path , ancestral_state )
825
909
910
+ def test_ancestral_state_len_not_same_as_mask (self , tmp_path ):
911
+ path = tmp_path / "data.zarr"
912
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
913
+ sgkit .save_dataset (ds , path )
914
+ ancestral_state = ds ["variant_allele" ][:, 0 ].values .astype (str )
915
+ site_mask = np .zeros (ds .sizes ["variants" ], dtype = bool )
916
+ site_mask [0 ] = True
917
+ with pytest .raises (
918
+ ValueError ,
919
+ match = "Ancestral state array must be the same length as the number of"
920
+ " selected sites" ,
921
+ ):
922
+ tsinfer .VariantData (path , ancestral_state , site_mask = site_mask )
923
+
826
924
def test_empty_alleles_not_at_end (self , tmp_path ):
827
925
path = tmp_path / "data.zarr"
828
926
ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
@@ -854,3 +952,62 @@ def test_unimplemented_from_tree_sequence(self):
854
952
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
855
953
with pytest .raises (NotImplementedError ):
856
954
tsinfer .VariantData .from_tree_sequence (None )
955
+
956
+ def test_multiple_contigs (self , tmp_path ):
957
+ path = tmp_path / "data.zarr"
958
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
959
+ ds ["contig_id" ] = (
960
+ ds ["contig_id" ].dims ,
961
+ np .array (["c10" , "c11" ], dtype = "<U3" ),
962
+ )
963
+ ds ["variant_contig" ] = (
964
+ ds ["variant_contig" ].dims ,
965
+ np .array ([0 , 0 , 1 ], dtype = ds ["variant_contig" ].dtype ),
966
+ )
967
+ sgkit .save_dataset (ds , path )
968
+ with pytest .raises (
969
+ ValueError , match = r'Sites belong to multiple contigs \("c10", "c11"\)'
970
+ ):
971
+ tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].astype (str ))
972
+
973
+ def test_all_masked (self , tmp_path ):
974
+ path = tmp_path / "data.zarr"
975
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
976
+ sgkit .save_dataset (ds , path )
977
+ with pytest .raises (ValueError , match = "All sites have been masked out" ):
978
+ tsinfer .VariantData (
979
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), site_mask = np .ones (3 , bool )
980
+ )
981
+
982
+ def test_bad_contig_param (self , tmp_path ):
983
+ path = tmp_path / "data.zarr"
984
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
985
+ sgkit .save_dataset (ds , path )
986
+ with pytest .raises (ValueError , match = '"XX" not found' ):
987
+ tsinfer .VariantData (
988
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), contig_id = "XX"
989
+ )
990
+
991
+ def test_multiple_contig_param (self , tmp_path ):
992
+ path = tmp_path / "data.zarr"
993
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
994
+ ds ["contig_id" ] = (
995
+ ds ["contig_id" ].dims ,
996
+ np .array (["chr1" , "chr1" ], dtype = "<U4" ),
997
+ )
998
+ sgkit .save_dataset (ds , path )
999
+ with pytest .raises (ValueError , match = 'Multiple contigs named "chr1"' ):
1000
+ tsinfer .VariantData (
1001
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), contig_id = "chr1"
1002
+ )
1003
+
1004
+ def test_missing_sites_time (self , tmp_path ):
1005
+ path = tmp_path / "data.zarr"
1006
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
1007
+ sgkit .save_dataset (ds , path )
1008
+ with pytest .raises (
1009
+ ValueError , match = "The sites time array XX was not found in the dataset"
1010
+ ):
1011
+ tsinfer .VariantData (
1012
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), sites_time = "XX"
1013
+ )
0 commit comments