19
19
"""
20
20
Tests for the data files.
21
21
"""
22
+ import json
22
23
import sys
23
24
import tempfile
24
25
25
26
import msprime
27
+ import numcodecs
26
28
import numpy as np
27
29
import pytest
28
30
import sgkit
29
31
import tskit
30
32
import xarray as xr
33
+ import zarr
31
34
32
35
import tsinfer
33
36
from tsinfer import formats
@@ -262,8 +265,8 @@ def test_sgkit_dataset_roundtrip(tmp_path):
262
265
ds = sgkit .load_dataset (zarr_path )
263
266
264
267
assert ts .num_individuals == inf_ts .num_individuals == ds .dims ["samples" ]
265
- for ( i , ind ) in zip (inf_ts .individuals (), ds ["sample_id" ].values ):
266
- assert i .metadata ["sgkit_sample_id" ] == ind
268
+ for ts_ind , sample_id in zip (inf_ts .individuals (), ds ["sample_id" ].values ):
269
+ assert ts_ind .metadata ["sgkit_sample_id" ] == sample_id
267
270
268
271
assert (
269
272
ts .num_samples == inf_ts .num_samples == ds .dims ["samples" ] * ds .dims ["ploidy" ]
@@ -284,6 +287,35 @@ def test_sgkit_dataset_roundtrip(tmp_path):
284
287
assert inf_ts .num_edges > 200
285
288
286
289
290
+ @pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
291
+ def test_sgkit_individual_metadata_not_clobbered (tmp_path ):
292
+ ts , zarr_path = make_ts_and_zarr (tmp_path )
293
+ # Load the zarr to add metadata for testing
294
+ zarr_root = zarr .open (zarr_path )
295
+ empty_obj = json .dumps ({}).encode ()
296
+ indiv_metadata = np .array ([empty_obj ] * ts .num_individuals , dtype = object )
297
+ indiv_metadata [42 ] = json .dumps ({"sgkit_sample_id" : "foobar" }).encode ()
298
+ zarr_root .create_dataset (
299
+ "individuals_metadata" , data = indiv_metadata , object_codec = numcodecs .VLenBytes ()
300
+ )
301
+ zarr_root .attrs ["individuals_metadata_schema" ] = repr (
302
+ tskit .MetadataSchema .permissive_json ()
303
+ )
304
+
305
+ samples = tsinfer .SgkitSampleData (zarr_path )
306
+ inf_ts = tsinfer .infer (samples )
307
+ ds = sgkit .load_dataset (zarr_path )
308
+
309
+ assert ts .num_individuals == inf_ts .num_individuals == ds .dims ["samples" ]
310
+ for i , (ts_ind , sample_id ) in enumerate (
311
+ zip (inf_ts .individuals (), ds ["sample_id" ].values )
312
+ ):
313
+ if i != 42 :
314
+ assert ts_ind .metadata ["sgkit_sample_id" ] == sample_id
315
+ else :
316
+ assert ts_ind .metadata ["sgkit_sample_id" ] == "foobar"
317
+
318
+
287
319
@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
288
320
def test_sgkit_dataset_accessors (tmp_path ):
289
321
ts , zarr_path = make_ts_and_zarr (tmp_path , add_optional = True , shuffle_alleles = False )
0 commit comments