diff --git a/pyproject.toml b/pyproject.toml index 17c5237..409331a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,9 @@ authors = [ ] requires-python = ">=3.9" dependencies = [ - "tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201 + # "tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201 + # FIXME + "tsinfer @ git+https://github.com/jeromekelleher/tsinfer.git@experimental-hmm", "pyfaidx", "tskit>=0.5.3", "tszip", diff --git a/sc2ts/inference.py b/sc2ts/inference.py index f534297..faf4b0c 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -402,18 +402,18 @@ def match_samples( show_progress=False, num_threads=None, ): - # First pass, compute the matches at precision=0. run_batch = samples - # Values based on https://github.com/jeromekelleher/sc2ts/issues/242, - # but somewhat arbitrary. - for precision, cost_threshold in [(0, 1), (1, 2), (2, 3)]: - logger.info(f"Running batch of {len(run_batch)} at p={precision}") + mu = 0.125 ## FIXME + for k in range(num_mismatches): + # To catch k mismatches we need a likelihood threshold of mu**k + likelihood_threshold = mu**k - 1e-15 + logger.info(f"Running match={k} batch of {len(run_batch)} at threshold={likelihood_threshold}") match_tsinfer( samples=run_batch, ts=base_ts, num_mismatches=num_mismatches, - precision=precision, + likelihood_threshold=likelihood_threshold, num_threads=num_threads, show_progress=show_progress, ) @@ -421,27 +421,26 @@ def match_samples( exceeding_threshold = [] for sample in run_batch: cost = sample.get_hmm_cost(num_mismatches) - logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}") - if cost > cost_threshold: + logger.debug(f"HMM@k={k}: hmm_cost={cost} {sample.summary()}") + if cost > k + 1: sample.path.clear() sample.mutations.clear() exceeding_threshold.append(sample) num_matches_found = len(run_batch) - len(exceeding_threshold) logger.info( - f"{num_matches_found} final matches for found p={precision}; " + f"{num_matches_found} final matches found at k={k}; " f"{len(exceeding_threshold)} remain" ) run_batch = exceeding_threshold - precision = 6 - logger.info(f"Running final batch of {len(run_batch)} at p={precision}") + logger.info(f"Running final batch of {len(run_batch)} at full precision") match_tsinfer( samples=run_batch, ts=base_ts, num_mismatches=num_mismatches, - precision=precision, num_threads=num_threads, + likelihood_threshold=1e-200, show_progress=show_progress, ) for sample in run_batch: @@ -798,7 +797,7 @@ def add_matching_results( return ts # , excluded_samples, added_samples -def solve_num_mismatches(ts, k): +def solve_num_mismatches(k, num_sites, mu=0.125): """ Return the low-level LS parameters corresponding to accepting k mismatches in favour of a single recombination. @@ -806,28 +805,18 @@ def solve_num_mismatches(ts, k): NOTE! This is NOT taking into account the spatial distance along the genome, and so is not a very good model in some ways. """ - # We can match against any node in tsinfer - m = ts.num_sites - n = ts.num_nodes # values of k <= 1 are not relevant for SC2 and lead to awkward corner cases assert k > 1 - # NOTE: the magnitude of mu matters because it puts a limit - # on how low we can push the HMM precision. We should be able to solve - # for the optimal value of this parameter such that the magnitude of the - # values within the HMM are as large as possible (so that we can truncate - # usefully). - # mu = 1e-2 - mu = 0.125 - denom = (1 - mu) ** k + (n - 1) * mu**k - r = n * mu**k / denom + denom = (1 - mu) ** k + r = mu**k / denom # Add a little bit of extra mass for recombination so that we deterministically # chose to recombine over k mutations # NOTE: the magnitude of this value will depend also on mu, see above. - r += r * 0.01 - ls_recomb = np.full(m - 1, r) - ls_mismatch = np.full(m, mu) + r += r * 0.125 + ls_recomb = np.full(num_sites - 1, r) + ls_mismatch = np.full(num_sites, mu) return ls_recomb, ls_mismatch @@ -1268,7 +1257,7 @@ def match_tsinfer( ts, *, num_mismatches, - precision=None, + likelihood_threshold=None, num_threads=0, show_progress=False, mirror_coordinates=False, @@ -1284,7 +1273,7 @@ def match_tsinfer( sd = convert_tsinfer_sample_data(ts, genotypes) L = int(ts.sequence_length) - ls_recomb, ls_mismatch = solve_num_mismatches(ts, num_mismatches) + ls_recomb, ls_mismatch = solve_num_mismatches(num_mismatches, ts.num_sites) pm = tsinfer.inference._get_progress_monitor( show_progress, generate_ancestors=False, @@ -1309,7 +1298,7 @@ def match_tsinfer( mismatch=ls_mismatch, progress_monitor=pm, num_threads=num_threads, - precision=precision, + likelihood_threshold=likelihood_threshold ) results = manager.run_match(np.arange(sd.num_samples)) diff --git a/tests/test_inference.py b/tests/test_inference.py index 9e4550d..ba09948 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,4 +1,5 @@ import numpy as np +import numpy.testing as nt import pytest import tsinfer import tskit @@ -8,6 +9,18 @@ import util +class TestSolveNumMismatches: + + @pytest.mark.parametrize( + ["k", "expected_rho"], + [(2, 0.02295918), (3, 0.00327988), (4, 0.00046855), (1000, 0)], + ) + def test_examples(self, k, expected_rho): + rho, mu = sc2ts.solve_num_mismatches(k, num_sites=2) + assert mu[0] == 0.125 + nt.assert_almost_equal(rho[0], expected_rho) + + class TestInitialTs: def test_reference_sequence(self): ts = sc2ts.initial_ts() @@ -612,13 +625,13 @@ def test_node_mutation_counts(self, fx_ts_map, date): "2020-02-03": {"nodes": 36, "mutations": 42}, "2020-02-04": {"nodes": 41, "mutations": 48}, "2020-02-05": {"nodes": 42, "mutations": 48}, - "2020-02-06": {"nodes": 49, "mutations": 51}, - "2020-02-07": {"nodes": 51, "mutations": 57}, - "2020-02-08": {"nodes": 57, "mutations": 58}, - "2020-02-09": {"nodes": 59, "mutations": 61}, - "2020-02-10": {"nodes": 60, "mutations": 65}, - "2020-02-11": {"nodes": 62, "mutations": 66}, - "2020-02-13": {"nodes": 66, "mutations": 68}, + "2020-02-06": {"nodes": 48, "mutations": 51}, + "2020-02-07": {"nodes": 50, "mutations": 57}, + "2020-02-08": {"nodes": 56, "mutations": 58}, + "2020-02-09": {"nodes": 58, "mutations": 61}, + "2020-02-10": {"nodes": 59, "mutations": 65}, + "2020-02-11": {"nodes": 61, "mutations": 66}, + "2020-02-13": {"nodes": 65, "mutations": 68}, } assert ts.num_nodes == expected[date]["nodes"] assert ts.num_mutations == expected[date]["mutations"] @@ -631,9 +644,9 @@ def test_node_mutation_counts(self, fx_ts_map, date): (13, "SRR11597132", 10), (16, "SRR11597177", 10), (41, "SRR11597156", 10), - (57, "SRR11597216", 1), - (60, "SRR11597207", 40), - (62, "ERR4205570", 58), + (56, "SRR11597216", 1), + (59, "SRR11597207", 40), + (61, "ERR4205570", 57), ], ) def test_exact_matches(self, fx_ts_map, node, strain, parent): @@ -693,10 +706,9 @@ class TestMatchingDetails: # assert s.path[0].parent == 37 @pytest.mark.parametrize( - ("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 58)] + ("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 57)] ) @pytest.mark.parametrize("num_mismatches", [2, 3, 4]) - @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_exact_matches( self, fx_ts_map, @@ -705,17 +717,18 @@ def test_exact_matches( strain, parent, num_mismatches, - precision, ): ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) + # FIXME + mu = 0.125 sc2ts.match_tsinfer( samples=samples, ts=ts, num_mismatches=num_mismatches, - precision=precision, + likelihood_threshold = mu**num_mismatches - 1e-12, num_threads=0, ) s = samples[0] @@ -725,10 +738,10 @@ def test_exact_matches( @pytest.mark.parametrize( ("strain", "parent", "position", "derived_state"), - [("SRR11597218", 10, 289, "T"), ("ERR4206593", 58, 26994, "T")], + [("SRR11597218", 10, 289, "T"), ("ERR4206593", 57, 26994, "T")], ) @pytest.mark.parametrize("num_mismatches", [2, 3, 4]) - @pytest.mark.parametrize("precision", [0, 1, 2, 12]) + # @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_one_mismatch( self, fx_ts_map, @@ -739,7 +752,6 @@ def test_one_mismatch( position, derived_state, num_mismatches, - precision, ): ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( @@ -749,7 +761,8 @@ def test_one_mismatch( samples=samples, ts=ts, num_mismatches=num_mismatches, - precision=precision, + # FIXME + likelihood_threshold=0.12499999, num_threads=0, ) s = samples[0] @@ -760,30 +773,27 @@ def test_one_mismatch( assert s.path[0].parent == parent @pytest.mark.parametrize("num_mismatches", [2, 3, 4]) - @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_two_mismatches( self, fx_ts_map, fx_alignment_store, fx_metadata_db, num_mismatches, - precision, ): strain = "ERR4204459" ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) + mu = 0.125 sc2ts.match_tsinfer( samples=samples, ts=ts, num_mismatches=num_mismatches, - precision=precision, + likelihood_threshold=mu**2 - 1e-12, num_threads=0, ) s = samples[0] assert len(s.path) == 1 assert s.path[0].parent == 5 assert len(s.mutations) == 2 - # assert s.mutations[0].site_position == position - # assert s.mutations[0].derived_state == derived_state