Skip to content

Commit d575342

Browse files
Python changes for precision->likelihood_thresold
Also update the Python implementation to match the C one
1 parent 2aa5e27 commit d575342

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

_tsinfermodule.c

+5-5
Original file line numberDiff line numberDiff line change
@@ -1283,21 +1283,21 @@ AncestorMatcher_init(AncestorMatcher *self, PyObject *args, PyObject *kwds)
12831283
int err;
12841284
int extended_checks = 0;
12851285
static char *kwlist[] = {"tree_sequence_builder", "recombination",
1286-
"mismatch", "precision", "extended_checks", NULL};
1286+
"mismatch", "likelihood_threshold", "extended_checks", NULL};
12871287
TreeSequenceBuilder *tree_sequence_builder = NULL;
12881288
PyObject *recombination = NULL;
12891289
PyObject *mismatch = NULL;
12901290
PyArrayObject *recombination_array = NULL;
12911291
PyArrayObject *mismatch_array = NULL;
12921292
npy_intp *shape;
1293-
unsigned int precision = 22;
1293+
double likelihood_threshold = DBL_MIN;
12941294
int flags = 0;
12951295

12961296
self->ancestor_matcher = NULL;
12971297
self->tree_sequence_builder = NULL;
1298-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|Ii", kwlist,
1298+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|di", kwlist,
12991299
&TreeSequenceBuilderType, &tree_sequence_builder,
1300-
&recombination, &mismatch, &precision,
1300+
&recombination, &mismatch, &likelihood_threshold,
13011301
&extended_checks)) {
13021302
goto out;
13031303
}
@@ -1343,7 +1343,7 @@ AncestorMatcher_init(AncestorMatcher *self, PyObject *args, PyObject *kwds)
13431343
self->tree_sequence_builder->tree_sequence_builder,
13441344
PyArray_DATA(recombination_array),
13451345
PyArray_DATA(mismatch_array),
1346-
precision, flags);
1346+
likelihood_threshold, flags);
13471347
if (err != 0) {
13481348
handle_library_error(err);
13491349
goto out;

tsinfer/algorithm.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -636,13 +636,13 @@ def __init__(
636636
tree_sequence_builder,
637637
recombination=None,
638638
mismatch=None,
639-
precision=None,
639+
likelihood_threshold=None,
640640
extended_checks=False,
641641
):
642642
self.tree_sequence_builder = tree_sequence_builder
643643
self.mismatch = mismatch
644644
self.recombination = recombination
645-
self.precision = precision
645+
self.likelihood_threshold = likelihood_threshold
646646
self.extended_checks = extended_checks
647647
self.num_sites = tree_sequence_builder.num_sites
648648
self.parent = None
@@ -705,7 +705,8 @@ def unset_allelic_state(self, site):
705705
assert np.all(self.allelic_state == -1)
706706

707707
def update_site(self, site, haplotype_state):
708-
n = self.tree_sequence_builder.num_match_nodes
708+
# n = self.tree_sequence_builder.num_match_nodes
709+
n = 1
709710
rho = self.recombination[site]
710711
mu = self.mismatch[site]
711712
num_alleles = self.tree_sequence_builder.num_alleles[site]
@@ -763,13 +764,13 @@ def update_site(self, site, haplotype_state):
763764
elif rho == 0:
764765
raise _tsinfer.MatchImpossible(
765766
"Matching failed with recombination=0, potentially due to "
766-
"rounding issues. Try increasing the precision value"
767+
"rounding issues. Try increasing the likelihood_threshold value"
767768
)
768769
raise AssertionError("Unexpected matching failure")
769770

770771
for u in self.likelihood_nodes:
771772
x = self.likelihood[u] / max_L
772-
self.likelihood[u] = round(x, self.precision)
773+
self.likelihood[u] = max(x, self.likelihood_threshold)
773774

774775
self.max_likelihood_node[site] = max_L_node
775776
self.unset_allelic_state(site)

tsinfer/inference.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def infer(
269269
num_threads=0,
270270
# Deliberately undocumented parameters below
271271
precision=None,
272+
likelihood_threshold=None,
272273
engine=constants.C_ENGINE,
273274
progress_monitor=None,
274275
time_units=None,
@@ -349,6 +350,7 @@ def infer(
349350
recombination_rate=recombination_rate,
350351
mismatch_ratio=mismatch_ratio,
351352
precision=precision,
353+
likelihood_threshold=likelihood_threshold,
352354
path_compression=path_compression,
353355
progress_monitor=progress_monitor,
354356
time_units=time_units,
@@ -362,6 +364,7 @@ def infer(
362364
recombination_rate=recombination_rate,
363365
mismatch_ratio=mismatch_ratio,
364366
precision=precision,
367+
likelihood_threshold=likelihood_threshold,
365368
post_process=post_process,
366369
path_compression=path_compression,
367370
progress_monitor=progress_monitor,
@@ -457,6 +460,7 @@ def match_ancestors(
457460
recombination=None, # See :class:`Matcher`
458461
mismatch=None, # See :class:`Matcher`
459462
precision=None,
463+
likelihood_threshold=None,
460464
engine=constants.C_ENGINE,
461465
progress_monitor=None,
462466
extended_checks=False,
@@ -514,6 +518,7 @@ def match_ancestors(
514518
path_compression=path_compression,
515519
num_threads=num_threads,
516520
precision=precision,
521+
likelihood_threshold=likelihood_threshold,
517522
extended_checks=extended_checks,
518523
engine=engine,
519524
progress_monitor=progress_monitor,
@@ -639,6 +644,7 @@ def match_samples(
639644
recombination=None, # See :class:`Matcher`
640645
mismatch=None, # See :class:`Matcher`
641646
precision=None,
647+
likelihood_threshold=None,
642648
extended_checks=False,
643649
engine=constants.C_ENGINE,
644650
progress_monitor=None,
@@ -723,6 +729,7 @@ def match_samples(
723729
path_compression=path_compression,
724730
num_threads=num_threads,
725731
precision=precision,
732+
likelihood_threshold=likelihood_threshold,
726733
extended_checks=extended_checks,
727734
engine=engine,
728735
progress_monitor=progress_monitor,
@@ -1141,6 +1148,7 @@ def __init__(
11411148
recombination=None,
11421149
mismatch=None,
11431150
precision=None,
1151+
likelihood_threshold=None,
11441152
extended_checks=False,
11451153
engine=constants.C_ENGINE,
11461154
progress_monitor=None,
@@ -1233,11 +1241,16 @@ def __init__(
12331241
if not (np.all(mismatch >= 0) and np.all(mismatch <= 1)):
12341242
raise ValueError("Underlying mismatch probabilities not between 0 & 1")
12351243

1236-
if precision is None:
1237-
precision = 13
1244+
if precision is not None and likelihood_threshold is not None:
1245+
raise ValueError("Cannot specify likelihood_threshold and precision")
1246+
if precision is not None:
1247+
likelihood_threshold = pow(10, -precision)
1248+
if likelihood_threshold is None:
1249+
likelihood_threshold = 1e-13 # ~Same as previous precision default.
1250+
12381251
self.recombination[1:] = recombination
12391252
self.mismatch[:] = mismatch
1240-
self.precision = precision
1253+
self.likelihood_threshold = likelihood_threshold
12411254

12421255
if len(recombination) == 0:
12431256
logger.info("Fewer than two inference sites: no recombination possible")
@@ -1261,7 +1274,7 @@ def __init__(
12611274
f"mean={np.mean(mismatch):.5g}"
12621275
)
12631276
logger.info(
1264-
f"Matching using {precision} digits of precision in likelihood calcs"
1277+
f"Matching using likelihood_threshold of {likelihood_threshold:.5g}"
12651278
)
12661279

12671280
if engine == constants.C_ENGINE:
@@ -1303,7 +1316,7 @@ def __init__(
13031316
self.tree_sequence_builder,
13041317
recombination=self.recombination,
13051318
mismatch=self.mismatch,
1306-
precision=precision,
1319+
likelihood_threshold=likelihood_threshold,
13071320
extended_checks=self.extended_checks,
13081321
)
13091322
for _ in range(num_threads)

0 commit comments

Comments
 (0)