38
38
from tsinfer import formats
39
39
40
40
41
- def ts_to_dataset (ts , chunks = None , samples = None ):
41
+ def ts_to_dataset (ts , chunks = None , samples = None , contigs = None ):
42
42
"""
43
43
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
44
44
Convert the specified tskit tree sequence into an sgkit dataset.
@@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
63
63
genotypes = np .expand_dims (genotypes , axis = 2 )
64
64
65
65
ds = sgkit .create_genotype_call_dataset (
66
- variant_contig_names = ["1" ],
66
+ variant_contig_names = ["1" ] if contigs is None else contigs ,
67
67
variant_contig = np .zeros (len (tables .sites ), dtype = int ),
68
68
variant_position = tables .sites .position .astype (int ),
69
69
variant_allele = alleles ,
@@ -145,7 +145,7 @@ def test_variantdata_accessors(tmp_path, in_mem):
145
145
assert vd .format_name == "tsinfer-variant-data"
146
146
assert vd .format_version == (0 , 1 )
147
147
assert vd .finalised
148
- assert vd .sequence_length == ts .sequence_length + 1337
148
+ assert vd .sequence_length == ts .sequence_length
149
149
assert vd .num_sites == ts .num_sites
150
150
assert vd .sites_metadata_schema == ts .tables .sites .metadata_schema .schema
151
151
assert vd .sites_metadata == [site .metadata for site in ts .sites ()]
@@ -218,11 +218,7 @@ def test_variantdata_accessors_defaults(tmp_path, in_mem):
218
218
ds = data if in_mem else sgkit .load_dataset (data )
219
219
220
220
default_schema = tskit .MetadataSchema .permissive_json ().schema
221
- with pytest .warns (
222
- UserWarning ,
223
- match = "`sequence_length` was not found as an attribute in the dataset" ,
224
- ):
225
- assert vdata .sequence_length == ts .sequence_length
221
+ assert vdata .sequence_length == ts .sequence_length
226
222
assert vdata .sites_metadata_schema == default_schema
227
223
assert vdata .sites_metadata == [{} for _ in range (ts .num_sites )]
228
224
for time in vdata .sites_time :
@@ -299,18 +295,116 @@ def test_simulate_genotype_call_dataset(tmp_path):
299
295
ts = msprime .sim_ancestry (4 , sequence_length = 1000 , random_seed = 123 )
300
296
ts = msprime .sim_mutations (ts , rate = 2e-3 , random_seed = 123 )
301
297
ds = ts_to_dataset (ts )
302
- ds .update ({"variant_ancestral_allele" : ds ["variant_allele" ][:, 0 ]})
303
298
ds .to_zarr (tmp_path , mode = "w" )
304
- sd = tsinfer .VariantData (tmp_path , "variant_ancestral_allele" )
305
- ts = tsinfer .infer (sd )
306
- for v , ds_v , sd_v in zip (ts .variants (), ds .call_genotype , sd .sites_genotypes ):
299
+ vdata = tsinfer .VariantData (tmp_path , ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) )
300
+ ts = tsinfer .infer (vdata )
301
+ for v , ds_v , vd_v in zip (ts .variants (), ds .call_genotype , vdata .sites_genotypes ):
307
302
assert np .all (v .genotypes == ds_v .values .flatten ())
308
- assert np .all (v .genotypes == sd_v )
303
+ assert np .all (v .genotypes == vd_v )
304
+
305
+
306
+ def test_simulate_genotype_call_dataset_length (tmp_path ):
307
+ # create_genotype_call_dataset does not save contig lengths
308
+ ts = msprime .sim_ancestry (4 , sequence_length = 1000 , random_seed = 123 )
309
+ ts = msprime .sim_mutations (ts , rate = 2e-3 , random_seed = 123 )
310
+ ds = ts_to_dataset (ts )
311
+ assert "contig_length" not in ds
312
+ ds .to_zarr (tmp_path , mode = "w" )
313
+ vdata = tsinfer .VariantData (tmp_path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
314
+ assert vdata .sequence_length == ts .sites_position [- 1 ] + 1
315
+
316
+ vdata = tsinfer .VariantData (
317
+ tmp_path , ds ["variant_allele" ][:, 0 ].values .astype (str ), sequence_length = 1337
318
+ )
319
+ assert vdata .sequence_length == 1337
320
+
321
+
322
+ class TestMultiContig :
323
+ def make_two_ts_dataset (self , path ):
324
+ # split ts into 2; put them as different contigs in the same dataset
325
+ ts = msprime .sim_ancestry (4 , sequence_length = 1000 , random_seed = 123 )
326
+ ts = msprime .sim_mutations (ts , rate = 2e-3 , random_seed = 123 )
327
+ split_at_site = 7
328
+ assert ts .num_sites > 10
329
+ site_break = ts .site (split_at_site ).position
330
+ ts1 = ts .keep_intervals ([(0 , site_break )]).rtrim ()
331
+ ts2 = ts .keep_intervals ([(site_break , ts .sequence_length )]).ltrim ()
332
+ ds = ts_to_dataset (ts , contigs = ["chr1" , "chr2" ])
333
+ ds .update ({"variant_ancestral_allele" : ds ["variant_allele" ][:, 0 ]})
334
+ variant_contig = ds ["variant_contig" ][:]
335
+ variant_contig [split_at_site :] = 1
336
+ ds .update ({"variant_contig" : variant_contig })
337
+ variant_position = ds ["variant_position" ].values
338
+ variant_position [split_at_site :] -= int (site_break )
339
+ ds .update ({"variant_position" : ds ["variant_position" ]})
340
+ ds .update (
341
+ {"contig_length" : np .array ([ts1 .sequence_length , ts2 .sequence_length ])}
342
+ )
343
+ ds .to_zarr (path , mode = "w" )
344
+ return ts1 , ts2
345
+
346
+ def test_unmasked (self , tmp_path ):
347
+ self .make_two_ts_dataset (tmp_path )
348
+ with pytest .raises (ValueError , match = r'multiple contigs \("chr1", "chr2"\)' ):
349
+ tsinfer .VariantData (tmp_path , "variant_ancestral_allele" )
350
+
351
+ def test_mask (self , tmp_path ):
352
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
353
+ vdata = tsinfer .VariantData (
354
+ tmp_path ,
355
+ "variant_ancestral_allele" ,
356
+ site_mask = np .array (ts1 .num_sites * [True ] + ts2 .num_sites * [False ]),
357
+ )
358
+ assert np .all (ts2 .sites_position == vdata .sites_position )
359
+ assert vdata .contig_id == "chr2"
360
+ assert vdata .sequence_length == ts2 .sequence_length
361
+
362
+ @pytest .mark .parametrize ("contig_id" , ["chr1" , "chr2" ])
363
+ def test_multi_contig (self , contig_id , tmp_path ):
364
+ tree_seqs = {}
365
+ tree_seqs ["chr1" ], tree_seqs ["chr2" ] = self .make_two_ts_dataset (tmp_path )
366
+ with pytest .raises (ValueError , match = "multiple contigs" ):
367
+ vdata = tsinfer .VariantData (tmp_path , "variant_ancestral_allele" )
368
+ root = zarr .open (tmp_path )
369
+ mask = root ["variant_contig" ][:] == (1 if contig_id == "chr1" else 0 )
370
+ vdata = tsinfer .VariantData (
371
+ tmp_path , "variant_ancestral_allele" , site_mask = mask
372
+ )
373
+ assert np .all (tree_seqs [contig_id ].sites_position == vdata .sites_position )
374
+ assert vdata .contig_id == contig_id
375
+ assert vdata ._contig_index == (0 if contig_id == "chr1" else 1 )
376
+ assert vdata .sequence_length == tree_seqs [contig_id ].sequence_length
377
+
378
+ def test_mixed_contigs_error (self , tmp_path ):
379
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
380
+ mask = np .ones (ts1 .num_sites + ts2 .num_sites )
381
+ # Select two varaints, one from each contig
382
+ mask [0 ] = False
383
+ mask [- 1 ] = False
384
+ with pytest .raises (ValueError , match = "multiple contigs" ):
385
+ tsinfer .VariantData (
386
+ tmp_path ,
387
+ "variant_ancestral_allele" ,
388
+ site_mask = mask ,
389
+ )
390
+
391
+ def test_no_variant_contig (self , tmp_path ):
392
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
393
+ root = zarr .open (tmp_path )
394
+ del root ["variant_contig" ]
395
+ mask = np .ones (ts1 .num_sites + ts2 .num_sites )
396
+ mask [0 ] = False
397
+ vdata = tsinfer .VariantData (
398
+ tmp_path , "variant_ancestral_allele" , site_mask = mask
399
+ )
400
+ assert vdata .sequence_length == ts1 .sites_position [0 ] + 1
401
+ assert vdata .contig_id is None
402
+ assert vdata ._contig_index is None
309
403
310
404
311
405
@pytest .mark .skipif (sys .platform == "win32" , reason = "File permission errors on Windows" )
312
406
class TestSgkitMask :
313
- @pytest .mark .parametrize ("sites" , [[1 , 2 , 3 , 5 , 9 , 27 ], [0 ], [] ])
407
+ @pytest .mark .parametrize ("sites" , [[1 , 2 , 3 , 5 , 9 , 27 ], [0 ]])
314
408
def test_sgkit_variant_mask (self , tmp_path , sites ):
315
409
ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
316
410
ds = sgkit .load_dataset (zarr_path )
@@ -823,6 +917,20 @@ def test_bad_ancestral_state(self, tmp_path):
823
917
with pytest .raises (ValueError , match = "cannot contain empty strings" ):
824
918
tsinfer .VariantData (path , ancestral_state )
825
919
920
+ def test_ancestral_state_len_not_same_as_mask (self , tmp_path ):
921
+ path = tmp_path / "data.zarr"
922
+ ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
923
+ sgkit .save_dataset (ds , path )
924
+ ancestral_state = ds ["variant_allele" ][:, 0 ].values .astype (str )
925
+ site_mask = np .zeros (ds .sizes ["variants" ], dtype = bool )
926
+ site_mask [0 ] = True
927
+ with pytest .raises (
928
+ ValueError ,
929
+ match = "Ancestral state array must be the same length as the number of"
930
+ " selected sites" ,
931
+ ):
932
+ tsinfer .VariantData (path , ancestral_state , site_mask = site_mask )
933
+
826
934
def test_empty_alleles_not_at_end (self , tmp_path ):
827
935
path = tmp_path / "data.zarr"
828
936
ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
@@ -854,3 +962,23 @@ def test_unimplemented_from_tree_sequence(self):
854
962
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
855
963
with pytest .raises (NotImplementedError ):
856
964
tsinfer .VariantData .from_tree_sequence (None )
965
+
966
+ def test_all_masked (self , tmp_path ):
967
+ path = tmp_path / "data.zarr"
968
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
969
+ sgkit .save_dataset (ds , path )
970
+ with pytest .raises (ValueError , match = "All sites have been masked out" ):
971
+ tsinfer .VariantData (
972
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), site_mask = np .ones (3 , bool )
973
+ )
974
+
975
+ def test_missing_sites_time (self , tmp_path ):
976
+ path = tmp_path / "data.zarr"
977
+ ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
978
+ sgkit .save_dataset (ds , path )
979
+ with pytest .raises (
980
+ ValueError , match = "The sites time array XX was not found in the dataset"
981
+ ):
982
+ tsinfer .VariantData (
983
+ path , ds ["variant_allele" ][:, 0 ].astype (str ), sites_time = "XX"
984
+ )
0 commit comments