36
36
from tsinfer import formats
37
37
38
38
39
- def ts_to_dataset (ts , chunks = None , samples = None ):
39
+ def ts_to_dataset (ts , chunks = None , samples = None , contigs = None ):
40
40
"""
41
41
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
42
42
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -61,7 +61,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
61
61
genotypes = np .expand_dims (genotypes , axis = 2 )
62
62
63
63
ds = sgkit .create_genotype_call_dataset (
64
- variant_contig_names = ["1" ],
64
+ variant_contig_names = ["1" ] if contigs is None else contigs ,
65
65
variant_contig = np .zeros (len (tables .sites ), dtype = int ),
66
66
variant_position = tables .sites .position .astype (int ),
67
67
variant_allele = alleles ,
@@ -292,6 +292,78 @@ def test_simulate_genotype_call_dataset(tmp_path):
292
292
assert np .all (v .genotypes == sd_v )
293
293
294
294
295
+ class TestMultiContig :
296
+ def make_two_ts_dataset (self , path ):
297
+ # split ts into 2; put them as different contigs in the same dataset
298
+ ts = msprime .sim_ancestry (4 , sequence_length = 1000 , random_seed = 123 )
299
+ ts = msprime .sim_mutations (ts , rate = 2e-3 , random_seed = 123 )
300
+ split_at_site = 7
301
+ assert ts .num_sites > 10
302
+ site_break = ts .site (split_at_site ).position
303
+ ts1 = ts .keep_intervals ([(0 , site_break )]).rtrim ()
304
+ ts2 = ts .keep_intervals ([(site_break , ts .sequence_length )]).ltrim ()
305
+ ds = ts_to_dataset (ts , contigs = ["chr1" , "chr2" ])
306
+ ds .update ({"variant_ancestral_allele" : ds ["variant_allele" ][:, 0 ]})
307
+ variant_contig = ds ["variant_contig" ][:]
308
+ variant_contig [split_at_site :] = 1
309
+ ds .update ({"variant_contig" : variant_contig })
310
+ variant_position = ds ["variant_position" ].values
311
+ variant_position [split_at_site :] -= int (site_break )
312
+ ds .update ({"variant_position" : ds ["variant_position" ]})
313
+ ds .update (
314
+ {"contig_length" : np .array ([ts1 .sequence_length , ts2 .sequence_length ])}
315
+ )
316
+ ds .to_zarr (path , mode = "w" )
317
+ return ts1 , ts2
318
+
319
+ def test_unmasked (self , tmp_path ):
320
+ self .make_two_ts_dataset (tmp_path )
321
+ with pytest .raises (ValueError , match = r'multiple contigs \("chr1", "chr2"\)' ):
322
+ tsinfer .VariantData (tmp_path , "variant_ancestral_allele" )
323
+
324
+ def test_mask (self , tmp_path ):
325
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
326
+ vdata = tsinfer .VariantData (
327
+ tmp_path ,
328
+ "variant_ancestral_allele" ,
329
+ site_mask = np .array (ts1 .num_sites * [True ] + ts2 .num_sites * [False ]),
330
+ )
331
+ assert np .all (ts2 .sites_position == vdata .sites_position )
332
+ assert vdata .contig_id == "chr2"
333
+
334
+ @pytest .mark .parametrize ("contig_id" , ["chr1" , "chr2" ])
335
+ def test_contig_id_param (self , contig_id , tmp_path ):
336
+ tree_seqs = {}
337
+ tree_seqs ["chr1" ], tree_seqs ["chr2" ] = self .make_two_ts_dataset (tmp_path )
338
+ vdata = tsinfer .VariantData (
339
+ tmp_path , "variant_ancestral_allele" , contig_id = contig_id
340
+ )
341
+ assert np .all (tree_seqs [contig_id ].sites_position == vdata .sites_position )
342
+ assert vdata .contig_id == contig_id
343
+
344
+ def test_contig_id_param_and_mask (self , tmp_path ):
345
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
346
+ vdata = tsinfer .VariantData (
347
+ tmp_path ,
348
+ "variant_ancestral_allele" ,
349
+ site_mask = np .array (
350
+ (ts1 .num_sites + 1 ) * [True ] + (ts2 .num_sites - 1 ) * [False ]
351
+ ),
352
+ contig_id = "chr2" ,
353
+ )
354
+ assert np .all (ts2 .sites_position [1 :] == vdata .sites_position )
355
+ assert vdata .contig_id == "chr2"
356
+
357
+ @pytest .mark .parametrize ("contig_id" , ["chr1" , "chr2" ])
358
+ def test_contig_length (self , contig_id , tmp_path ):
359
+ tree_seqs = {}
360
+ tree_seqs ["chr1" ], tree_seqs ["chr2" ] = self .make_two_ts_dataset (tmp_path )
361
+ vdata = tsinfer .VariantData (
362
+ tmp_path , "variant_ancestral_allele" , contig_id = contig_id
363
+ )
364
+ assert vdata .sequence_length == tree_seqs [contig_id ].sequence_length
365
+
366
+
295
367
@pytest .mark .skipif (sys .platform == "win32" , reason = "File permission errors on Windows" )
296
368
class TestSgkitMask :
297
369
@pytest .mark .parametrize ("sites" , [[1 , 2 , 3 , 5 , 9 , 27 ], [0 ], []])
@@ -754,3 +826,42 @@ def test_empty_alleles_not_at_end(self, tmp_path):
754
826
samples = tsinfer .VariantData (path , "variant_ancestral_allele" )
755
827
with pytest .raises (ValueError , match = "Empty alleles must be at the end" ):
756
828
tsinfer .infer (samples )
829
+
830
+ def test_multiple_contigs (self , tmp_path ):
831
+ path = tmp_path / "data.zarr"
832
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
833
+ ds ["contig_id" ] = (
834
+ ds ["contig_id" ].dims ,
835
+ np .array (["c10" , "c11" ], dtype = "<U3" ),
836
+ )
837
+ ds ["variant_contig" ] = (
838
+ ds ["variant_contig" ].dims ,
839
+ np .array ([0 , 0 , 1 ], dtype = ds ["variant_contig" ].dtype ),
840
+ )
841
+ sgkit .save_dataset (ds , path )
842
+ with pytest .raises (
843
+ ValueError , match = r'Sites belong to multiple contigs \("c10", "c11"\)'
844
+ ):
845
+ tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].astype (str ))
846
+
847
+ def test_bad_contig_param (self , tmp_path ):
848
+ path = tmp_path / "data.zarr"
849
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
850
+ sgkit .save_dataset (ds , path )
851
+ with pytest .raises (ValueError , match = '"XX" not found' ):
852
+ tsinfer .VariantData (
853
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), contig_id = "XX"
854
+ )
855
+
856
+ def test_multiple_contig_param (self , tmp_path ):
857
+ path = tmp_path / "data.zarr"
858
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
859
+ ds ["contig_id" ] = (
860
+ ds ["contig_id" ].dims ,
861
+ np .array (["chr1" , "chr1" ], dtype = "<U4" ),
862
+ )
863
+ sgkit .save_dataset (ds , path )
864
+ with pytest .raises (ValueError , match = 'Multiple contigs named "chr1"' ):
865
+ tsinfer .VariantData (
866
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), contig_id = "chr1"
867
+ )
0 commit comments