@@ -260,6 +260,11 @@ def test_sgkit_dataset_roundtrip(tmp_path):
260
260
samples = tsinfer .SgkitSampleData (zarr_path )
261
261
inf_ts = tsinfer .infer (samples )
262
262
ds = sgkit .load_dataset (zarr_path )
263
+
264
+ assert ts .num_individuals == inf_ts .num_individuals == ds .dims ["samples" ]
265
+ for (i , ind ) in zip (inf_ts .individuals (), ds ["sample_id" ].values ):
266
+ assert i .metadata ["sgkit_sample_id" ] == ind
267
+
263
268
assert (
264
269
ts .num_samples == inf_ts .num_samples == ds .dims ["samples" ] * ds .dims ["ploidy" ]
265
270
)
@@ -283,6 +288,7 @@ def test_sgkit_dataset_roundtrip(tmp_path):
283
288
def test_sgkit_dataset_accessors (tmp_path ):
284
289
ts , zarr_path = make_ts_and_zarr (tmp_path , add_optional = True , shuffle_alleles = False )
285
290
samples = tsinfer .SgkitSampleData (zarr_path )
291
+ ds = sgkit .load_dataset (zarr_path )
286
292
287
293
assert samples .format_name == "tsinfer-sgkit-sample-data"
288
294
assert samples .format_version == (0 , 1 )
@@ -325,7 +331,10 @@ def test_sgkit_dataset_accessors(tmp_path):
325
331
samples .individuals_metadata_schema
326
332
== ts .tables .individuals .metadata_schema .schema
327
333
)
328
- assert samples .individuals_metadata == [ind .metadata for ind in ts .individuals ()]
334
+ assert samples .individuals_metadata == [
335
+ {"sgkit_sample_id" : sample_id , ** ind .metadata }
336
+ for ind , sample_id in zip (ts .individuals (), ds ["sample_id" ].values )
337
+ ]
329
338
assert np .array_equal (
330
339
samples .individuals_location ,
331
340
np .tile (np .array ([["0" , "1" ]], dtype = "float32" ), (ts .num_individuals , 1 )),
@@ -354,6 +363,7 @@ def test_sgkit_dataset_accessors(tmp_path):
354
363
def test_sgkit_accessors_defaults (tmp_path ):
355
364
ts , zarr_path = make_ts_and_zarr (tmp_path )
356
365
samples = tsinfer .SgkitSampleData (zarr_path )
366
+ ds = sgkit .load_dataset (zarr_path )
357
367
358
368
default_schema = tskit .MetadataSchema .permissive_json ().schema
359
369
assert samples .sequence_length == ts .sequence_length
@@ -369,7 +379,9 @@ def test_sgkit_accessors_defaults(tmp_path):
369
379
assert samples .populations_metadata_schema == default_schema
370
380
assert samples .populations_metadata == []
371
381
assert samples .individuals_metadata_schema == default_schema
372
- assert samples .individuals_metadata == [{} for _ in range (ts .num_individuals )]
382
+ assert samples .individuals_metadata == [
383
+ {"sgkit_sample_id" : sample_id } for sample_id in ds ["sample_id" ].values
384
+ ]
373
385
for time in samples .individuals_time :
374
386
assert tskit .is_unknown_time (time )
375
387
assert np .array_equal (
0 commit comments