Skip to content

Commit f15abee

Browse files
benjefferymergify[bot]
authored andcommitted
Add match_samples_to_disk function
1 parent 93e386e commit f15abee

File tree

3 files changed

+256
-73
lines changed

3 files changed

+256
-73
lines changed

tests/test_sgkit.py

+57
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
Tests for the data files.
2121
"""
2222
import json
23+
import os
24+
import pickle
2325
import sys
2426
import tempfile
2527

@@ -615,3 +617,58 @@ def test_empty_alleles_not_at_end(self, tmp_path):
615617
samples = tsinfer.SgkitSampleData(path)
616618
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
617619
tsinfer.infer(samples)
620+
621+
622+
class TestSgkitMatchSamplesToDisk:
623+
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
624+
@pytest.mark.parametrize("slice", [(0, 5), (0, 0), (0, 1), (10, 15)])
625+
def test_match_samples_to_disk_write(
626+
self, slice, small_sd_fixture, tmp_path, tmpdir
627+
):
628+
ts, zarr_path = make_ts_and_zarr(tmp_path)
629+
samples = tsinfer.SgkitSampleData(zarr_path)
630+
ancestors = tsinfer.generate_ancestors(samples)
631+
anc_ts = tsinfer.match_ancestors(samples, ancestors)
632+
tsinfer.match_samples_slice_to_disk(
633+
samples, anc_ts, slice, tmpdir / "test.path"
634+
)
635+
file_slice, matches = pickle.load(open(tmpdir / "test.path", "rb"))
636+
assert slice == file_slice
637+
assert len(matches) == slice[1] - slice[0]
638+
for m in matches:
639+
assert isinstance(m, tsinfer.inference.MatchResult)
640+
641+
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
642+
def test_match_samples_to_disk_full(self, small_sd_fixture, tmp_path, tmpdir):
643+
ts, zarr_path = make_ts_and_zarr(tmp_path)
644+
samples = tsinfer.SgkitSampleData(zarr_path)
645+
ancestors = tsinfer.generate_ancestors(samples)
646+
anc_ts = tsinfer.match_ancestors(samples, ancestors)
647+
ts = tsinfer.match_samples(samples, anc_ts)
648+
start_index = 0
649+
while start_index < ts.num_samples:
650+
end_index = min(start_index + 5, ts.num_samples)
651+
tsinfer.match_samples_slice_to_disk(
652+
samples,
653+
anc_ts,
654+
(start_index, end_index),
655+
tmpdir / f"test-{start_index}.path",
656+
)
657+
start_index = end_index
658+
batch_ts = tsinfer.match_samples(
659+
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
660+
)
661+
ts.tables.assert_equals(batch_ts.tables, ignore_provenance=True)
662+
663+
tmpdir.join("test-5.path").copy(tmpdir.join("test-5-copy.path"))
664+
with pytest.raises(ValueError, match="Duplicate sample index 5"):
665+
tsinfer.match_samples(
666+
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
667+
)
668+
669+
os.remove(tmpdir / "test-5.path")
670+
os.remove(tmpdir / "test-5-copy.path")
671+
with pytest.raises(ValueError, match="index 5 not found"):
672+
tsinfer.match_samples(
673+
samples, anc_ts, match_file_pattern=str(tmpdir / "*.path")
674+
)

0 commit comments

Comments
 (0)