Skip to content

Commit 93e386e

Browse files
benjefferymergify[bot]
authored andcommitted
Don't clobber existing
1 parent 9a48da1 commit 93e386e

File tree

3 files changed

+47
-3
lines changed

3 files changed

+47
-3
lines changed

CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44

55
In development
66

7+
**Features**
8+
9+
- `tsinfer` now supports inferring data from an `sgkit` dataset. This allows users to
10+
infer from VCFs via the optimised and parallel VCF parsing in `sgkit`.
11+
- The `variant_mask` boolean array in the `sgkit` dataset can be used mask sites
12+
not wanted for inference.
13+
- `sgkit` `sample_ids` are inserted into individual metadata as `sgkit_sample_id` if
14+
this key does not already exist.
15+
716
**Breaking Changes**
817

918
- Remove the `uuid` field from SampleData. SampleData equality is now purely based

tests/test_sgkit.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@
1919
"""
2020
Tests for the data files.
2121
"""
22+
import json
2223
import sys
2324
import tempfile
2425

2526
import msprime
27+
import numcodecs
2628
import numpy as np
2729
import pytest
2830
import sgkit
2931
import tskit
3032
import xarray as xr
33+
import zarr
3134

3235
import tsinfer
3336
from tsinfer import formats
@@ -262,8 +265,8 @@ def test_sgkit_dataset_roundtrip(tmp_path):
262265
ds = sgkit.load_dataset(zarr_path)
263266

264267
assert ts.num_individuals == inf_ts.num_individuals == ds.dims["samples"]
265-
for (i, ind) in zip(inf_ts.individuals(), ds["sample_id"].values):
266-
assert i.metadata["sgkit_sample_id"] == ind
268+
for ts_ind, sample_id in zip(inf_ts.individuals(), ds["sample_id"].values):
269+
assert ts_ind.metadata["sgkit_sample_id"] == sample_id
267270

268271
assert (
269272
ts.num_samples == inf_ts.num_samples == ds.dims["samples"] * ds.dims["ploidy"]
@@ -284,6 +287,35 @@ def test_sgkit_dataset_roundtrip(tmp_path):
284287
assert inf_ts.num_edges > 200
285288

286289

290+
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
291+
def test_sgkit_individual_metadata_not_clobbered(tmp_path):
292+
ts, zarr_path = make_ts_and_zarr(tmp_path)
293+
# Load the zarr to add metadata for testing
294+
zarr_root = zarr.open(zarr_path)
295+
empty_obj = json.dumps({}).encode()
296+
indiv_metadata = np.array([empty_obj] * ts.num_individuals, dtype=object)
297+
indiv_metadata[42] = json.dumps({"sgkit_sample_id": "foobar"}).encode()
298+
zarr_root.create_dataset(
299+
"individuals_metadata", data=indiv_metadata, object_codec=numcodecs.VLenBytes()
300+
)
301+
zarr_root.attrs["individuals_metadata_schema"] = repr(
302+
tskit.MetadataSchema.permissive_json()
303+
)
304+
305+
samples = tsinfer.SgkitSampleData(zarr_path)
306+
inf_ts = tsinfer.infer(samples)
307+
ds = sgkit.load_dataset(zarr_path)
308+
309+
assert ts.num_individuals == inf_ts.num_individuals == ds.dims["samples"]
310+
for i, (ts_ind, sample_id) in enumerate(
311+
zip(inf_ts.individuals(), ds["sample_id"].values)
312+
):
313+
if i != 42:
314+
assert ts_ind.metadata["sgkit_sample_id"] == sample_id
315+
else:
316+
assert ts_ind.metadata["sgkit_sample_id"] == "foobar"
317+
318+
287319
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
288320
def test_sgkit_dataset_accessors(tmp_path):
289321
ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True, shuffle_alleles=False)

tsinfer/formats.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,8 @@ def individuals_metadata_schema(self):
25122512
@functools.cached_property
25132513
def individuals_metadata(self):
25142514
schema = tskit.MetadataSchema(self.populations_metadata_schema)
2515+
# We set the sample_id in the individual metadata as this is often useful,
2516+
# however we silently don't overwrite if the key exists
25152517
if "individuals_metadata" in self.data:
25162518
assert len(self.data["individuals_metadata"]) == self.num_individuals
25172519
assert self.num_individuals == len(self.data["sample_id"])
@@ -2520,7 +2522,8 @@ def individuals_metadata(self):
25202522
self.data["sample_id"], self.data["individuals_metadata"][:]
25212523
):
25222524
md = schema.decode_row(r)
2523-
md["sgkit_sample_id"] = sample_id
2525+
if "sgkit_sample_id" not in md:
2526+
md["sgkit_sample_id"] = sample_id
25242527
md_list.append(md)
25252528
return md_list
25262529
else:

0 commit comments

Comments
 (0)