@@ -145,7 +145,7 @@ def test_variantdata_accessors(tmp_path, in_mem):
145
145
assert vd .format_name == "tsinfer-variant-data"
146
146
assert vd .format_version == (0 , 1 )
147
147
assert vd .finalised
148
- assert vd .sequence_length == ts .sequence_length + 1337
148
+ assert vd .sequence_length == ts .sequence_length
149
149
assert vd .num_sites == ts .num_sites
150
150
assert vd .sites_metadata_schema == ts .tables .sites .metadata_schema .schema
151
151
assert vd .sites_metadata == [site .metadata for site in ts .sites ()]
@@ -317,6 +317,11 @@ def test_simulate_genotype_call_dataset_length(tmp_path):
317
317
vdata = tsinfer .VariantData (tmp_path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
318
318
assert vdata .sequence_length == ts .sites_position [- 1 ] + 1
319
319
320
+ vdata = tsinfer .VariantData (
321
+ tmp_path , ds ["variant_allele" ][:, 0 ].values .astype (str ), sequence_length = 1337
322
+ )
323
+ assert vdata .sequence_length == 1337
324
+
320
325
321
326
class TestMultiContig :
322
327
def make_two_ts_dataset (self , path ):
@@ -359,37 +364,44 @@ def test_mask(self, tmp_path):
359
364
assert vdata .sequence_length == ts2 .sequence_length
360
365
361
366
@pytest .mark .parametrize ("contig_id" , ["chr1" , "chr2" ])
362
- def test_contig_id_param (self , contig_id , tmp_path ):
367
+ def test_multi_contig (self , contig_id , tmp_path ):
363
368
tree_seqs = {}
364
369
tree_seqs ["chr1" ], tree_seqs ["chr2" ] = self .make_two_ts_dataset (tmp_path )
370
+ with pytest .raises (ValueError , match = "multiple contigs" ):
371
+ vdata = tsinfer .VariantData (tmp_path , "variant_ancestral_allele" )
372
+ root = zarr .open (tmp_path )
373
+ mask = root ["variant_contig" ][:] == (1 if contig_id == "chr1" else 0 )
365
374
vdata = tsinfer .VariantData (
366
- tmp_path , "variant_ancestral_allele" , contig_id = contig_id
375
+ tmp_path , "variant_ancestral_allele" , site_mask = mask
367
376
)
368
377
assert np .all (tree_seqs [contig_id ].sites_position == vdata .sites_position )
369
378
assert vdata .contig_id == contig_id
379
+ assert vdata ._contig_index == (0 if contig_id == "chr1" else 1 )
370
380
assert vdata .sequence_length == tree_seqs [contig_id ].sequence_length
371
381
372
- def test_contig_id_param_and_mask (self , tmp_path ):
382
+ def test_mixed_contigs_error (self , tmp_path ):
373
383
ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
374
- vdata = tsinfer . VariantData (
375
- tmp_path ,
376
- "variant_ancestral_allele" ,
377
- site_mask = np . array (
378
- ( ts1 . num_sites + 1 ) * [ True ] + ( ts2 . num_sites - 1 ) * [ False ]
379
- ),
380
- contig_id = "chr2" ,
381
- )
382
- assert np . all ( ts2 . sites_position [ 1 :] == vdata . sites_position )
383
- assert vdata . contig_id == "chr2"
384
+ mask = np . ones ( ts1 . num_sites + ts2 . num_sites )
385
+ # Select two varaints, one from each contig
386
+ mask [ 0 ] = False
387
+ mask [ - 1 ] = False
388
+ with pytest . raises ( ValueError , match = "multiple contigs" ):
389
+ tsinfer . VariantData (
390
+ tmp_path ,
391
+ "variant_ancestral_allele" ,
392
+ site_mask = mask ,
393
+ )
384
394
385
- @pytest .mark .parametrize ("contig_id" , ["chr1" , "chr2" ])
386
- def test_contig_length (self , contig_id , tmp_path ):
387
- tree_seqs = {}
388
- tree_seqs ["chr1" ], tree_seqs ["chr2" ] = self .make_two_ts_dataset (tmp_path )
395
+ def test_no_variant_contig (self , tmp_path ):
396
+ ts1 , ts2 = self .make_two_ts_dataset (tmp_path )
397
+ root = zarr .open (tmp_path )
398
+ del root ["variant_contig" ]
399
+ mask = np .ones (ts1 .num_sites + ts2 .num_sites )
400
+ mask [0 ] = False
389
401
vdata = tsinfer .VariantData (
390
- tmp_path , "variant_ancestral_allele" , contig_id = contig_id
402
+ tmp_path , "variant_ancestral_allele" , site_mask = mask
391
403
)
392
- assert vdata .sequence_length == tree_seqs [ contig_id ]. sequence_length
404
+ assert vdata .sequence_length == ts1 . sites_position [ 0 ] + 1
393
405
394
406
395
407
@pytest .mark .skipif (sys .platform == "win32" , reason = "File permission errors on Windows" )
@@ -953,23 +965,6 @@ def test_unimplemented_from_tree_sequence(self):
953
965
with pytest .raises (NotImplementedError ):
954
966
tsinfer .VariantData .from_tree_sequence (None )
955
967
956
- def test_multiple_contigs (self , tmp_path ):
957
- path = tmp_path / "data.zarr"
958
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
959
- ds ["contig_id" ] = (
960
- ds ["contig_id" ].dims ,
961
- np .array (["c10" , "c11" ], dtype = "<U3" ),
962
- )
963
- ds ["variant_contig" ] = (
964
- ds ["variant_contig" ].dims ,
965
- np .array ([0 , 0 , 1 ], dtype = ds ["variant_contig" ].dtype ),
966
- )
967
- sgkit .save_dataset (ds , path )
968
- with pytest .raises (
969
- ValueError , match = r'Sites belong to multiple contigs \("c10", "c11"\)'
970
- ):
971
- tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].astype (str ))
972
-
973
968
def test_all_masked (self , tmp_path ):
974
969
path = tmp_path / "data.zarr"
975
970
ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
@@ -979,28 +974,6 @@ def test_all_masked(self, tmp_path):
979
974
path , ds ["variant_allele" ][:, 0 ].astype (str ), site_mask = np .ones (3 , bool )
980
975
)
981
976
982
- def test_bad_contig_param (self , tmp_path ):
983
- path = tmp_path / "data.zarr"
984
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
985
- sgkit .save_dataset (ds , path )
986
- with pytest .raises (ValueError , match = '"XX" not found' ):
987
- tsinfer .VariantData (
988
- path , ds ["variant_allele" ][:, 0 ].astype (str ), contig_id = "XX"
989
- )
990
-
991
- def test_multiple_contig_param (self , tmp_path ):
992
- path = tmp_path / "data.zarr"
993
- ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
994
- ds ["contig_id" ] = (
995
- ds ["contig_id" ].dims ,
996
- np .array (["chr1" , "chr1" ], dtype = "<U4" ),
997
- )
998
- sgkit .save_dataset (ds , path )
999
- with pytest .raises (ValueError , match = 'Multiple contigs named "chr1"' ):
1000
- tsinfer .VariantData (
1001
- path , ds ["variant_allele" ][:, 0 ].astype (str ), contig_id = "chr1"
1002
- )
1003
-
1004
977
def test_missing_sites_time (self , tmp_path ):
1005
978
path = tmp_path / "data.zarr"
1006
979
ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
0 commit comments