Skip to content

Commit 788d4df

Browse files
hyanwongmergify[bot]
authored andcommitted
Allow VariantData to take either a path to a vcz or the zarr store itself
Add check for badly formatted vcz. Just looks for the dimensions of the call_genotype array, but that's most likely to be wrong (e.g. 2d by mistake)
1 parent 6c02f9d commit 788d4df

File tree

3 files changed

+120
-86
lines changed

3 files changed

+120
-86
lines changed

tests/test_variantdata.py

Lines changed: 85 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -133,146 +133,141 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path):
133133

134134

135135
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
136-
def test_sgkit_dataset_accessors(tmp_path):
137-
ts, zarr_path = tsutil.make_ts_and_zarr(
138-
tmp_path, add_optional=True, shuffle_alleles=False
139-
)
140-
samples = tsinfer.VariantData(
141-
zarr_path, "variant_ancestral_allele", sites_time="sites_time"
142-
)
143-
ds = sgkit.load_dataset(zarr_path)
144-
145-
assert samples.format_name == "tsinfer-variant-data"
146-
assert samples.format_version == (0, 1)
147-
assert samples.finalised
148-
assert samples.sequence_length == ts.sequence_length + 1337
149-
assert samples.num_sites == ts.num_sites
150-
assert samples.sites_metadata_schema == ts.tables.sites.metadata_schema.schema
151-
assert samples.sites_metadata == [site.metadata for site in ts.sites()]
152-
assert np.array_equal(samples.sites_time, np.arange(ts.num_sites) / ts.num_sites)
153-
assert np.array_equal(samples.sites_position, ts.tables.sites.position)
154-
for alleles, v in zip(samples.sites_alleles, ts.variants()):
136+
@pytest.mark.parametrize("in_mem", [True, False])
137+
def test_variantdata_accessors(tmp_path, in_mem):
138+
path = None if in_mem else tmp_path
139+
ts, data = tsutil.make_ts_and_zarr(path, add_optional=True, shuffle_alleles=False)
140+
vd = tsinfer.VariantData(data, "variant_ancestral_allele", sites_time="sites_time")
141+
ds = data if in_mem else sgkit.load_dataset(data)
142+
143+
assert vd.format_name == "tsinfer-variant-data"
144+
assert vd.format_version == (0, 1)
145+
assert vd.finalised
146+
assert vd.sequence_length == ts.sequence_length + 1337
147+
assert vd.num_sites == ts.num_sites
148+
assert vd.sites_metadata_schema == ts.tables.sites.metadata_schema.schema
149+
assert vd.sites_metadata == [site.metadata for site in ts.sites()]
150+
assert np.array_equal(vd.sites_time, np.arange(ts.num_sites) / ts.num_sites)
151+
assert np.array_equal(vd.sites_position, ts.tables.sites.position)
152+
for alleles, v in zip(vd.sites_alleles, ts.variants()):
155153
# sgkit alleles are padded to be rectangular
156154
assert np.all(alleles[: len(v.alleles)] == v.alleles)
157155
assert np.all(alleles[len(v.alleles) :] == "")
158-
assert np.array_equal(samples.sites_select, np.ones(ts.num_sites, dtype=bool))
156+
assert np.array_equal(vd.sites_select, np.ones(ts.num_sites, dtype=bool))
159157
assert np.array_equal(
160-
samples.sites_ancestral_allele, np.zeros(ts.num_sites, dtype=np.int8)
158+
vd.sites_ancestral_allele, np.zeros(ts.num_sites, dtype=np.int8)
161159
)
162-
assert np.array_equal(samples.sites_genotypes, ts.genotype_matrix())
160+
assert np.array_equal(vd.sites_genotypes, ts.genotype_matrix())
163161
assert np.array_equal(
164-
samples.provenances_timestamp, ["2021-01-01T00:00:00", "2021-01-02T00:00:00"]
162+
vd.provenances_timestamp, ["2021-01-01T00:00:00", "2021-01-02T00:00:00"]
165163
)
166-
assert samples.provenances_record == [{"foo": 1}, {"foo": 2}]
167-
assert samples.num_samples == ts.num_samples
164+
assert vd.provenances_record == [{"foo": 1}, {"foo": 2}]
165+
assert vd.num_samples == ts.num_samples
168166
assert np.array_equal(
169-
samples.samples_individual, np.repeat(np.arange(ts.num_samples // 3), 3)
167+
vd.samples_individual, np.repeat(np.arange(ts.num_samples // 3), 3)
170168
)
171-
assert samples.metadata_schema == tsutil.example_schema("example").schema
172-
assert samples.metadata == ts.tables.metadata
169+
assert vd.metadata_schema == tsutil.example_schema("example").schema
170+
assert vd.metadata == ts.tables.metadata
173171
assert (
174-
samples.populations_metadata_schema
175-
== ts.tables.populations.metadata_schema.schema
172+
vd.populations_metadata_schema == ts.tables.populations.metadata_schema.schema
176173
)
177-
assert samples.populations_metadata == [pop.metadata for pop in ts.populations()]
178-
assert samples.num_individuals == ts.num_individuals
174+
assert vd.populations_metadata == [pop.metadata for pop in ts.populations()]
175+
assert vd.num_individuals == ts.num_individuals
179176
assert np.array_equal(
180-
samples.individuals_time, np.arange(ts.num_individuals, dtype=np.float32)
177+
vd.individuals_time, np.arange(ts.num_individuals, dtype=np.float32)
181178
)
182179
assert (
183-
samples.individuals_metadata_schema
184-
== ts.tables.individuals.metadata_schema.schema
180+
vd.individuals_metadata_schema == ts.tables.individuals.metadata_schema.schema
185181
)
186-
assert samples.individuals_metadata == [
182+
assert vd.individuals_metadata == [
187183
{"variant_data_sample_id": sample_id, **ind.metadata}
188-
for ind, sample_id in zip(ts.individuals(), ds["sample_id"].values)
184+
for ind, sample_id in zip(ts.individuals(), ds.sample_id[:])
189185
]
190186
assert np.array_equal(
191-
samples.individuals_location,
187+
vd.individuals_location,
192188
np.tile(np.array([["0", "1"]], dtype="float32"), (ts.num_individuals, 1)),
193189
)
194190
assert np.array_equal(
195-
samples.individuals_population, np.zeros(ts.num_individuals, dtype="int32")
191+
vd.individuals_population, np.zeros(ts.num_individuals, dtype="int32")
196192
)
197193
assert np.array_equal(
198-
samples.individuals_flags,
194+
vd.individuals_flags,
199195
np.random.RandomState(42).randint(
200196
0, 2_000_000, ts.num_individuals, dtype="int32"
201197
),
202198
)
203199

204200
# Need to shuffle for the ancestral allele test
205-
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path, add_optional=True)
206-
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
201+
ts, data = tsutil.make_ts_and_zarr(path, add_optional=True)
202+
vd = tsinfer.VariantData(data, "variant_ancestral_allele")
207203
for i in range(ts.num_sites):
208204
assert (
209-
samples.sites_alleles[i][samples.sites_ancestral_allele[i]]
205+
vd.sites_alleles[i][vd.sites_ancestral_allele[i]]
210206
== ts.site(i).ancestral_state
211207
)
212208

213209

214210
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
215-
def test_sgkit_accessors_defaults(tmp_path):
216-
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
217-
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
218-
ds = sgkit.load_dataset(zarr_path)
211+
@pytest.mark.parametrize("in_mem", [True, False])
212+
def test_variantdata_accessors_defaults(tmp_path, in_mem):
213+
path = None if in_mem else tmp_path
214+
ts, data = tsutil.make_ts_and_zarr(path)
215+
vdata = tsinfer.VariantData(data, "variant_ancestral_allele")
216+
ds = data if in_mem else sgkit.load_dataset(data)
219217

220218
default_schema = tskit.MetadataSchema.permissive_json().schema
221-
assert samples.sequence_length == ts.sequence_length
222-
assert samples.sites_metadata_schema == default_schema
223-
assert samples.sites_metadata == [{} for _ in range(ts.num_sites)]
224-
for time in samples.sites_time:
219+
assert vdata.sequence_length == ts.sequence_length
220+
assert vdata.sites_metadata_schema == default_schema
221+
assert vdata.sites_metadata == [{} for _ in range(ts.num_sites)]
222+
for time in vdata.sites_time:
225223
assert tskit.is_unknown_time(time)
226-
assert np.array_equal(samples.sites_select, np.ones(ts.num_sites, dtype=bool))
227-
assert np.array_equal(samples.provenances_timestamp, [])
228-
assert np.array_equal(samples.provenances_record, [])
229-
assert samples.metadata_schema == default_schema
230-
assert samples.metadata == {}
231-
assert samples.populations_metadata_schema == default_schema
232-
assert samples.populations_metadata == []
233-
assert samples.individuals_metadata_schema == default_schema
234-
assert samples.individuals_metadata == [
235-
{"variant_data_sample_id": sample_id} for sample_id in ds["sample_id"].values
224+
assert np.array_equal(vdata.sites_select, np.ones(ts.num_sites, dtype=bool))
225+
assert np.array_equal(vdata.provenances_timestamp, [])
226+
assert np.array_equal(vdata.provenances_record, [])
227+
assert vdata.metadata_schema == default_schema
228+
assert vdata.metadata == {}
229+
assert vdata.populations_metadata_schema == default_schema
230+
assert vdata.populations_metadata == []
231+
assert vdata.individuals_metadata_schema == default_schema
232+
assert vdata.individuals_metadata == [
233+
{"variant_data_sample_id": sample_id} for sample_id in ds.sample_id[:]
236234
]
237-
for time in samples.individuals_time:
235+
for time in vdata.individuals_time:
238236
assert tskit.is_unknown_time(time)
239237
assert np.array_equal(
240-
samples.individuals_location, np.array([[]] * ts.num_individuals, dtype=float)
238+
vdata.individuals_location, np.array([[]] * ts.num_individuals, dtype=float)
241239
)
242240
assert np.array_equal(
243-
samples.individuals_population, np.full(ts.num_individuals, tskit.NULL)
241+
vdata.individuals_population, np.full(ts.num_individuals, tskit.NULL)
244242
)
245243
assert np.array_equal(
246-
samples.individuals_flags, np.zeros(ts.num_individuals, dtype=int)
244+
vdata.individuals_flags, np.zeros(ts.num_individuals, dtype=int)
247245
)
248246

249247

250248
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
251-
def test_variantdata_sites_time_default(tmp_path):
252-
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
253-
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
249+
def test_variantdata_sites_time_default():
250+
ts, data = tsutil.make_ts_and_zarr()
251+
vdata = tsinfer.VariantData(data, "variant_ancestral_allele")
254252

255253
assert (
256-
np.all(np.isnan(samples.sites_time))
257-
and samples.sites_time.size == samples.num_sites
254+
np.all(np.isnan(vdata.sites_time)) and vdata.sites_time.size == vdata.num_sites
258255
)
259256

260257

261258
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
262-
def test_variantdata_sites_time_array(tmp_path):
263-
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
259+
def test_variantdata_sites_time_array():
260+
ts, data = tsutil.make_ts_and_zarr()
264261
sites_time = np.arange(ts.num_sites)
265-
samples = tsinfer.VariantData(
266-
zarr_path, "variant_ancestral_allele", sites_time=sites_time
267-
)
268-
assert np.array_equal(samples.sites_time, sites_time)
262+
vdata = tsinfer.VariantData(data, "variant_ancestral_allele", sites_time=sites_time)
263+
assert np.array_equal(vdata.sites_time, sites_time)
269264
wrong_length_sites_time = np.arange(ts.num_sites + 1)
270265
with pytest.raises(
271266
ValueError,
272267
match="Sites time array must be the same length as the number of selected sites",
273268
):
274269
tsinfer.VariantData(
275-
zarr_path,
270+
data,
276271
"variant_ancestral_allele",
277272
sites_time=wrong_length_sites_time,
278273
)
@@ -302,17 +297,17 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
302297
for i in sites:
303298
sites_mask[i] = False
304299
tsutil.add_array_to_dataset("variant_mask_42", sites_mask, zarr_path)
305-
samples = tsinfer.VariantData(
300+
vdata = tsinfer.VariantData(
306301
zarr_path,
307302
"variant_ancestral_allele",
308303
site_mask="variant_mask_42",
309304
)
310-
assert samples.num_sites == len(sites)
311-
assert np.array_equal(samples.sites_select, ~sites_mask)
305+
assert vdata.num_sites == len(sites)
306+
assert np.array_equal(vdata.sites_select, ~sites_mask)
312307
assert np.array_equal(
313-
samples.sites_position, ts.tables.sites.position[~sites_mask]
308+
vdata.sites_position, ts.tables.sites.position[~sites_mask]
314309
)
315-
inf_ts = tsinfer.infer(samples)
310+
inf_ts = tsinfer.infer(vdata)
316311
assert np.array_equal(
317312
ts.genotype_matrix()[~sites_mask], inf_ts.genotype_matrix()
318313
)
@@ -675,6 +670,14 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
675670

676671

677672
class TestVariantDataErrors:
673+
def test_bad_zarr_spec(self):
674+
ds = zarr.group()
675+
ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
676+
with pytest.raises(
677+
ValueError, match="Expecting a VCF Zarr object with 3D call_genotype array"
678+
):
679+
tsinfer.VariantData(ds, np.zeros(10, dtype="<U1"))
680+
678681
def test_missing_phase(self, tmp_path):
679682
path = tmp_path / "data.zarr"
680683
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)

tests/tsutil.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
Extra utility functions used in several test files
2121
"""
2222
import json
23+
import tempfile
24+
from pathlib import Path
2325

2426
import msprime
2527
import numpy as np
2628
import sgkit
2729
import tskit
2830
import xarray as xr
31+
import zarr
2932

3033
import tsinfer
3134

@@ -219,7 +222,23 @@ def add_attribute_to_dataset(name, contents, zarr_path):
219222
sgkit.save_dataset(ds, zarr_path, mode="a")
220223

221224

222-
def make_ts_and_zarr(path, add_optional=False, shuffle_alleles=True):
225+
def make_ts_and_zarr(path=None, add_optional=False, shuffle_alleles=True):
226+
if path is None:
227+
in_mem_copy = zarr.group()
228+
with tempfile.TemporaryDirectory() as path:
229+
ts, zarr_path = _make_ts_and_zarr(
230+
Path(path), add_optional=add_optional, shuffle_alleles=shuffle_alleles
231+
)
232+
# For testing only, return an in-memory copy of the dataset we just made
233+
zarr.convenience.copy_all(zarr.open(zarr_path), in_mem_copy)
234+
return ts, in_mem_copy
235+
else:
236+
return _make_ts_and_zarr(
237+
path, add_optional=add_optional, shuffle_alleles=shuffle_alleles
238+
)
239+
240+
241+
def _make_ts_and_zarr(path, add_optional=False, shuffle_alleles=True):
223242
import sgkit.io.vcf
224243

225244
ts = msprime.sim_ancestry(

tsinfer/formats.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,15 +2299,27 @@ class VariantData(SampleData):
22992299

23002300
def __init__(
23012301
self,
2302-
path,
2302+
path_or_zarr,
23032303
ancestral_allele,
23042304
*,
23052305
sample_mask=None,
23062306
site_mask=None,
23072307
sites_time=None,
23082308
):
2309-
self.path = path
2310-
self.data = zarr.open(path, mode="r")
2309+
try:
2310+
if len(path_or_zarr.call_genotype.shape) == 3:
2311+
# Assumed to be a VCF Zarr hierarchy
2312+
self.path = None
2313+
self.data = path_or_zarr
2314+
else:
2315+
raise ValueError(
2316+
"Expecting a VCF Zarr object with 3D call_genotype array: "
2317+
"see https://github.com/sgkit-dev/vcf-zarr-spec/"
2318+
)
2319+
except AttributeError:
2320+
self.path = path_or_zarr
2321+
self.data = zarr.open(path_or_zarr, mode="r")
2322+
23112323
genotypes_arr = self.data["call_genotype"]
23122324
_, self._num_individuals_before_mask, self.ploidy = genotypes_arr.shape
23132325

0 commit comments

Comments
 (0)