|
20 | 20 | Tests for the data files.
|
21 | 21 | """
|
22 | 22 | import json
|
| 23 | +import os |
| 24 | +import pickle |
23 | 25 | import sys
|
24 | 26 | import tempfile
|
25 | 27 |
|
@@ -615,3 +617,58 @@ def test_empty_alleles_not_at_end(self, tmp_path):
|
615 | 617 | samples = tsinfer.SgkitSampleData(path)
|
616 | 618 | with pytest.raises(ValueError, match="Empty alleles must be at the end"):
|
617 | 619 | tsinfer.infer(samples)
|
| 620 | + |
| 621 | + |
| 622 | +class TestSgkitMatchSamplesToDisk: |
| 623 | + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") |
| 624 | + @pytest.mark.parametrize("slice", [(0, 5), (0, 0), (0, 1), (10, 15)]) |
| 625 | + def test_match_samples_to_disk_write( |
| 626 | + self, slice, small_sd_fixture, tmp_path, tmpdir |
| 627 | + ): |
| 628 | + ts, zarr_path = make_ts_and_zarr(tmp_path) |
| 629 | + samples = tsinfer.SgkitSampleData(zarr_path) |
| 630 | + ancestors = tsinfer.generate_ancestors(samples) |
| 631 | + anc_ts = tsinfer.match_ancestors(samples, ancestors) |
| 632 | + tsinfer.match_samples_slice_to_disk( |
| 633 | + samples, anc_ts, slice, tmpdir / "test.path" |
| 634 | + ) |
| 635 | + file_slice, matches = pickle.load(open(tmpdir / "test.path", "rb")) |
| 636 | + assert slice == file_slice |
| 637 | + assert len(matches) == slice[1] - slice[0] |
| 638 | + for m in matches: |
| 639 | + assert isinstance(m, tsinfer.inference.MatchResult) |
| 640 | + |
| 641 | + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") |
| 642 | + def test_match_samples_to_disk_full(self, small_sd_fixture, tmp_path, tmpdir): |
| 643 | + ts, zarr_path = make_ts_and_zarr(tmp_path) |
| 644 | + samples = tsinfer.SgkitSampleData(zarr_path) |
| 645 | + ancestors = tsinfer.generate_ancestors(samples) |
| 646 | + anc_ts = tsinfer.match_ancestors(samples, ancestors) |
| 647 | + ts = tsinfer.match_samples(samples, anc_ts) |
| 648 | + start_index = 0 |
| 649 | + while start_index < ts.num_samples: |
| 650 | + end_index = min(start_index + 5, ts.num_samples) |
| 651 | + tsinfer.match_samples_slice_to_disk( |
| 652 | + samples, |
| 653 | + anc_ts, |
| 654 | + (start_index, end_index), |
| 655 | + tmpdir / f"test-{start_index}.path", |
| 656 | + ) |
| 657 | + start_index = end_index |
| 658 | + batch_ts = tsinfer.match_samples( |
| 659 | + samples, anc_ts, match_file_pattern=str(tmpdir / "*.path") |
| 660 | + ) |
| 661 | + ts.tables.assert_equals(batch_ts.tables, ignore_provenance=True) |
| 662 | + |
| 663 | + tmpdir.join("test-5.path").copy(tmpdir.join("test-5-copy.path")) |
| 664 | + with pytest.raises(ValueError, match="Duplicate sample index 5"): |
| 665 | + tsinfer.match_samples( |
| 666 | + samples, anc_ts, match_file_pattern=str(tmpdir / "*.path") |
| 667 | + ) |
| 668 | + |
| 669 | + os.remove(tmpdir / "test-5.path") |
| 670 | + os.remove(tmpdir / "test-5-copy.path") |
| 671 | + with pytest.raises(ValueError, match="index 5 not found"): |
| 672 | + tsinfer.match_samples( |
| 673 | + samples, anc_ts, match_file_pattern=str(tmpdir / "*.path") |
| 674 | + ) |
0 commit comments