Skip to content

Commit 9a48da1

Browse files
benjefferymergify[bot]
authored andcommitted
Return sgkit sample_id in individual metadata
1 parent aa33167 commit 9a48da1

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

tests/test_sgkit.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ def test_sgkit_dataset_roundtrip(tmp_path):
260260
samples = tsinfer.SgkitSampleData(zarr_path)
261261
inf_ts = tsinfer.infer(samples)
262262
ds = sgkit.load_dataset(zarr_path)
263+
264+
assert ts.num_individuals == inf_ts.num_individuals == ds.dims["samples"]
265+
for (i, ind) in zip(inf_ts.individuals(), ds["sample_id"].values):
266+
assert i.metadata["sgkit_sample_id"] == ind
267+
263268
assert (
264269
ts.num_samples == inf_ts.num_samples == ds.dims["samples"] * ds.dims["ploidy"]
265270
)
@@ -283,6 +288,7 @@ def test_sgkit_dataset_roundtrip(tmp_path):
283288
def test_sgkit_dataset_accessors(tmp_path):
284289
ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True, shuffle_alleles=False)
285290
samples = tsinfer.SgkitSampleData(zarr_path)
291+
ds = sgkit.load_dataset(zarr_path)
286292

287293
assert samples.format_name == "tsinfer-sgkit-sample-data"
288294
assert samples.format_version == (0, 1)
@@ -325,7 +331,10 @@ def test_sgkit_dataset_accessors(tmp_path):
325331
samples.individuals_metadata_schema
326332
== ts.tables.individuals.metadata_schema.schema
327333
)
328-
assert samples.individuals_metadata == [ind.metadata for ind in ts.individuals()]
334+
assert samples.individuals_metadata == [
335+
{"sgkit_sample_id": sample_id, **ind.metadata}
336+
for ind, sample_id in zip(ts.individuals(), ds["sample_id"].values)
337+
]
329338
assert np.array_equal(
330339
samples.individuals_location,
331340
np.tile(np.array([["0", "1"]], dtype="float32"), (ts.num_individuals, 1)),
@@ -354,6 +363,7 @@ def test_sgkit_dataset_accessors(tmp_path):
354363
def test_sgkit_accessors_defaults(tmp_path):
355364
ts, zarr_path = make_ts_and_zarr(tmp_path)
356365
samples = tsinfer.SgkitSampleData(zarr_path)
366+
ds = sgkit.load_dataset(zarr_path)
357367

358368
default_schema = tskit.MetadataSchema.permissive_json().schema
359369
assert samples.sequence_length == ts.sequence_length
@@ -369,7 +379,9 @@ def test_sgkit_accessors_defaults(tmp_path):
369379
assert samples.populations_metadata_schema == default_schema
370380
assert samples.populations_metadata == []
371381
assert samples.individuals_metadata_schema == default_schema
372-
assert samples.individuals_metadata == [{} for _ in range(ts.num_individuals)]
382+
assert samples.individuals_metadata == [
383+
{"sgkit_sample_id": sample_id} for sample_id in ds["sample_id"].values
384+
]
373385
for time in samples.individuals_time:
374386
assert tskit.is_unknown_time(time)
375387
assert np.array_equal(

tsinfer/formats.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,10 +2512,21 @@ def individuals_metadata_schema(self):
25122512
@functools.cached_property
25132513
def individuals_metadata(self):
25142514
schema = tskit.MetadataSchema(self.populations_metadata_schema)
2515-
try:
2516-
return [schema.decode_row(r) for r in self.data["individuals_metadata"][:]]
2517-
except KeyError:
2518-
return [{} for _ in range(self.num_individuals)]
2515+
if "individuals_metadata" in self.data:
2516+
assert len(self.data["individuals_metadata"]) == self.num_individuals
2517+
assert self.num_individuals == len(self.data["sample_id"])
2518+
md_list = []
2519+
for sample_id, r in zip(
2520+
self.data["sample_id"], self.data["individuals_metadata"][:]
2521+
):
2522+
md = schema.decode_row(r)
2523+
md["sgkit_sample_id"] = sample_id
2524+
md_list.append(md)
2525+
return md_list
2526+
else:
2527+
return [
2528+
{"sgkit_sample_id": sample_id} for sample_id in self.data["sample_id"]
2529+
]
25192530

25202531
@functools.cached_property
25212532
def individuals_location(self):

0 commit comments

Comments
 (0)