@@ -133,146 +133,141 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path):
133
133
134
134
135
135
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
136
- def test_sgkit_dataset_accessors (tmp_path ):
137
- ts , zarr_path = tsutil .make_ts_and_zarr (
138
- tmp_path , add_optional = True , shuffle_alleles = False
139
- )
140
- samples = tsinfer .VariantData (
141
- zarr_path , "variant_ancestral_allele" , sites_time = "sites_time"
142
- )
143
- ds = sgkit .load_dataset (zarr_path )
144
-
145
- assert samples .format_name == "tsinfer-variant-data"
146
- assert samples .format_version == (0 , 1 )
147
- assert samples .finalised
148
- assert samples .sequence_length == ts .sequence_length + 1337
149
- assert samples .num_sites == ts .num_sites
150
- assert samples .sites_metadata_schema == ts .tables .sites .metadata_schema .schema
151
- assert samples .sites_metadata == [site .metadata for site in ts .sites ()]
152
- assert np .array_equal (samples .sites_time , np .arange (ts .num_sites ) / ts .num_sites )
153
- assert np .array_equal (samples .sites_position , ts .tables .sites .position )
154
- for alleles , v in zip (samples .sites_alleles , ts .variants ()):
136
+ @pytest .mark .parametrize ("in_mem" , [True , False ])
137
+ def test_variantdata_accessors (tmp_path , in_mem ):
138
+ path = None if in_mem else tmp_path
139
+ ts , data = tsutil .make_ts_and_zarr (path , add_optional = True , shuffle_alleles = False )
140
+ vd = tsinfer .VariantData (data , "variant_ancestral_allele" , sites_time = "sites_time" )
141
+ ds = data if in_mem else sgkit .load_dataset (data )
142
+
143
+ assert vd .format_name == "tsinfer-variant-data"
144
+ assert vd .format_version == (0 , 1 )
145
+ assert vd .finalised
146
+ assert vd .sequence_length == ts .sequence_length + 1337
147
+ assert vd .num_sites == ts .num_sites
148
+ assert vd .sites_metadata_schema == ts .tables .sites .metadata_schema .schema
149
+ assert vd .sites_metadata == [site .metadata for site in ts .sites ()]
150
+ assert np .array_equal (vd .sites_time , np .arange (ts .num_sites ) / ts .num_sites )
151
+ assert np .array_equal (vd .sites_position , ts .tables .sites .position )
152
+ for alleles , v in zip (vd .sites_alleles , ts .variants ()):
155
153
# sgkit alleles are padded to be rectangular
156
154
assert np .all (alleles [: len (v .alleles )] == v .alleles )
157
155
assert np .all (alleles [len (v .alleles ) :] == "" )
158
- assert np .array_equal (samples .sites_select , np .ones (ts .num_sites , dtype = bool ))
156
+ assert np .array_equal (vd .sites_select , np .ones (ts .num_sites , dtype = bool ))
159
157
assert np .array_equal (
160
- samples .sites_ancestral_allele , np .zeros (ts .num_sites , dtype = np .int8 )
158
+ vd .sites_ancestral_allele , np .zeros (ts .num_sites , dtype = np .int8 )
161
159
)
162
- assert np .array_equal (samples .sites_genotypes , ts .genotype_matrix ())
160
+ assert np .array_equal (vd .sites_genotypes , ts .genotype_matrix ())
163
161
assert np .array_equal (
164
- samples .provenances_timestamp , ["2021-01-01T00:00:00" , "2021-01-02T00:00:00" ]
162
+ vd .provenances_timestamp , ["2021-01-01T00:00:00" , "2021-01-02T00:00:00" ]
165
163
)
166
- assert samples .provenances_record == [{"foo" : 1 }, {"foo" : 2 }]
167
- assert samples .num_samples == ts .num_samples
164
+ assert vd .provenances_record == [{"foo" : 1 }, {"foo" : 2 }]
165
+ assert vd .num_samples == ts .num_samples
168
166
assert np .array_equal (
169
- samples .samples_individual , np .repeat (np .arange (ts .num_samples // 3 ), 3 )
167
+ vd .samples_individual , np .repeat (np .arange (ts .num_samples // 3 ), 3 )
170
168
)
171
- assert samples .metadata_schema == tsutil .example_schema ("example" ).schema
172
- assert samples .metadata == ts .tables .metadata
169
+ assert vd .metadata_schema == tsutil .example_schema ("example" ).schema
170
+ assert vd .metadata == ts .tables .metadata
173
171
assert (
174
- samples .populations_metadata_schema
175
- == ts .tables .populations .metadata_schema .schema
172
+ vd .populations_metadata_schema == ts .tables .populations .metadata_schema .schema
176
173
)
177
- assert samples .populations_metadata == [pop .metadata for pop in ts .populations ()]
178
- assert samples .num_individuals == ts .num_individuals
174
+ assert vd .populations_metadata == [pop .metadata for pop in ts .populations ()]
175
+ assert vd .num_individuals == ts .num_individuals
179
176
assert np .array_equal (
180
- samples .individuals_time , np .arange (ts .num_individuals , dtype = np .float32 )
177
+ vd .individuals_time , np .arange (ts .num_individuals , dtype = np .float32 )
181
178
)
182
179
assert (
183
- samples .individuals_metadata_schema
184
- == ts .tables .individuals .metadata_schema .schema
180
+ vd .individuals_metadata_schema == ts .tables .individuals .metadata_schema .schema
185
181
)
186
- assert samples .individuals_metadata == [
182
+ assert vd .individuals_metadata == [
187
183
{"variant_data_sample_id" : sample_id , ** ind .metadata }
188
- for ind , sample_id in zip (ts .individuals (), ds [ " sample_id" ]. values )
184
+ for ind , sample_id in zip (ts .individuals (), ds . sample_id [:] )
189
185
]
190
186
assert np .array_equal (
191
- samples .individuals_location ,
187
+ vd .individuals_location ,
192
188
np .tile (np .array ([["0" , "1" ]], dtype = "float32" ), (ts .num_individuals , 1 )),
193
189
)
194
190
assert np .array_equal (
195
- samples .individuals_population , np .zeros (ts .num_individuals , dtype = "int32" )
191
+ vd .individuals_population , np .zeros (ts .num_individuals , dtype = "int32" )
196
192
)
197
193
assert np .array_equal (
198
- samples .individuals_flags ,
194
+ vd .individuals_flags ,
199
195
np .random .RandomState (42 ).randint (
200
196
0 , 2_000_000 , ts .num_individuals , dtype = "int32"
201
197
),
202
198
)
203
199
204
200
# Need to shuffle for the ancestral allele test
205
- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path , add_optional = True )
206
- samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
201
+ ts , data = tsutil .make_ts_and_zarr (path , add_optional = True )
202
+ vd = tsinfer .VariantData (data , "variant_ancestral_allele" )
207
203
for i in range (ts .num_sites ):
208
204
assert (
209
- samples .sites_alleles [i ][samples .sites_ancestral_allele [i ]]
205
+ vd .sites_alleles [i ][vd .sites_ancestral_allele [i ]]
210
206
== ts .site (i ).ancestral_state
211
207
)
212
208
213
209
214
210
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
215
- def test_sgkit_accessors_defaults (tmp_path ):
216
- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
217
- samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
218
- ds = sgkit .load_dataset (zarr_path )
211
+ @pytest .mark .parametrize ("in_mem" , [True , False ])
212
+ def test_variantdata_accessors_defaults (tmp_path , in_mem ):
213
+ path = None if in_mem else tmp_path
214
+ ts , data = tsutil .make_ts_and_zarr (path )
215
+ vdata = tsinfer .VariantData (data , "variant_ancestral_allele" )
216
+ ds = data if in_mem else sgkit .load_dataset (data )
219
217
220
218
default_schema = tskit .MetadataSchema .permissive_json ().schema
221
- assert samples .sequence_length == ts .sequence_length
222
- assert samples .sites_metadata_schema == default_schema
223
- assert samples .sites_metadata == [{} for _ in range (ts .num_sites )]
224
- for time in samples .sites_time :
219
+ assert vdata .sequence_length == ts .sequence_length
220
+ assert vdata .sites_metadata_schema == default_schema
221
+ assert vdata .sites_metadata == [{} for _ in range (ts .num_sites )]
222
+ for time in vdata .sites_time :
225
223
assert tskit .is_unknown_time (time )
226
- assert np .array_equal (samples .sites_select , np .ones (ts .num_sites , dtype = bool ))
227
- assert np .array_equal (samples .provenances_timestamp , [])
228
- assert np .array_equal (samples .provenances_record , [])
229
- assert samples .metadata_schema == default_schema
230
- assert samples .metadata == {}
231
- assert samples .populations_metadata_schema == default_schema
232
- assert samples .populations_metadata == []
233
- assert samples .individuals_metadata_schema == default_schema
234
- assert samples .individuals_metadata == [
235
- {"variant_data_sample_id" : sample_id } for sample_id in ds [ " sample_id" ]. values
224
+ assert np .array_equal (vdata .sites_select , np .ones (ts .num_sites , dtype = bool ))
225
+ assert np .array_equal (vdata .provenances_timestamp , [])
226
+ assert np .array_equal (vdata .provenances_record , [])
227
+ assert vdata .metadata_schema == default_schema
228
+ assert vdata .metadata == {}
229
+ assert vdata .populations_metadata_schema == default_schema
230
+ assert vdata .populations_metadata == []
231
+ assert vdata .individuals_metadata_schema == default_schema
232
+ assert vdata .individuals_metadata == [
233
+ {"variant_data_sample_id" : sample_id } for sample_id in ds . sample_id [:]
236
234
]
237
- for time in samples .individuals_time :
235
+ for time in vdata .individuals_time :
238
236
assert tskit .is_unknown_time (time )
239
237
assert np .array_equal (
240
- samples .individuals_location , np .array ([[]] * ts .num_individuals , dtype = float )
238
+ vdata .individuals_location , np .array ([[]] * ts .num_individuals , dtype = float )
241
239
)
242
240
assert np .array_equal (
243
- samples .individuals_population , np .full (ts .num_individuals , tskit .NULL )
241
+ vdata .individuals_population , np .full (ts .num_individuals , tskit .NULL )
244
242
)
245
243
assert np .array_equal (
246
- samples .individuals_flags , np .zeros (ts .num_individuals , dtype = int )
244
+ vdata .individuals_flags , np .zeros (ts .num_individuals , dtype = int )
247
245
)
248
246
249
247
250
248
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
251
- def test_variantdata_sites_time_default (tmp_path ):
252
- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
253
- samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
249
+ def test_variantdata_sites_time_default ():
250
+ ts , data = tsutil .make_ts_and_zarr ()
251
+ vdata = tsinfer .VariantData (data , "variant_ancestral_allele" )
254
252
255
253
assert (
256
- np .all (np .isnan (samples .sites_time ))
257
- and samples .sites_time .size == samples .num_sites
254
+ np .all (np .isnan (vdata .sites_time )) and vdata .sites_time .size == vdata .num_sites
258
255
)
259
256
260
257
261
258
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
262
- def test_variantdata_sites_time_array (tmp_path ):
263
- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
259
+ def test_variantdata_sites_time_array ():
260
+ ts , data = tsutil .make_ts_and_zarr ()
264
261
sites_time = np .arange (ts .num_sites )
265
- samples = tsinfer .VariantData (
266
- zarr_path , "variant_ancestral_allele" , sites_time = sites_time
267
- )
268
- assert np .array_equal (samples .sites_time , sites_time )
262
+ vdata = tsinfer .VariantData (data , "variant_ancestral_allele" , sites_time = sites_time )
263
+ assert np .array_equal (vdata .sites_time , sites_time )
269
264
wrong_length_sites_time = np .arange (ts .num_sites + 1 )
270
265
with pytest .raises (
271
266
ValueError ,
272
267
match = "Sites time array must be the same length as the number of selected sites" ,
273
268
):
274
269
tsinfer .VariantData (
275
- zarr_path ,
270
+ data ,
276
271
"variant_ancestral_allele" ,
277
272
sites_time = wrong_length_sites_time ,
278
273
)
@@ -302,17 +297,17 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
302
297
for i in sites :
303
298
sites_mask [i ] = False
304
299
tsutil .add_array_to_dataset ("variant_mask_42" , sites_mask , zarr_path )
305
- samples = tsinfer .VariantData (
300
+ vdata = tsinfer .VariantData (
306
301
zarr_path ,
307
302
"variant_ancestral_allele" ,
308
303
site_mask = "variant_mask_42" ,
309
304
)
310
- assert samples .num_sites == len (sites )
311
- assert np .array_equal (samples .sites_select , ~ sites_mask )
305
+ assert vdata .num_sites == len (sites )
306
+ assert np .array_equal (vdata .sites_select , ~ sites_mask )
312
307
assert np .array_equal (
313
- samples .sites_position , ts .tables .sites .position [~ sites_mask ]
308
+ vdata .sites_position , ts .tables .sites .position [~ sites_mask ]
314
309
)
315
- inf_ts = tsinfer .infer (samples )
310
+ inf_ts = tsinfer .infer (vdata )
316
311
assert np .array_equal (
317
312
ts .genotype_matrix ()[~ sites_mask ], inf_ts .genotype_matrix ()
318
313
)
@@ -675,6 +670,14 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
675
670
676
671
677
672
class TestVariantDataErrors :
673
+ def test_bad_zarr_spec (self ):
674
+ ds = zarr .group ()
675
+ ds ["call_genotype" ] = zarr .array (np .zeros (10 , dtype = np .int8 ))
676
+ with pytest .raises (
677
+ ValueError , match = "Expecting a VCF Zarr object with 3D call_genotype array"
678
+ ):
679
+ tsinfer .VariantData (ds , np .zeros (10 , dtype = "<U1" ))
680
+
678
681
def test_missing_phase (self , tmp_path ):
679
682
path = tmp_path / "data.zarr"
680
683
ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
0 commit comments