Skip to content

Commit fa78171

Browse files
Some refactoring
1 parent c72c494 commit fa78171

File tree

1 file changed

+83
-38
lines changed

1 file changed

+83
-38
lines changed

python/tests/test_haplotype_matching.py

+83-38
Original file line numberDiff line numberDiff line change
@@ -436,13 +436,13 @@ def update_probabilities(self, site, haplotype_state):
436436
def process_site(self, site, haplotype_state):
437437
self.update_probabilities(site, haplotype_state)
438438
# d1 = self.node_values()
439-
print("PRE")
440-
self.print_state()
439+
# print("PRE")
440+
# self.print_state()
441441
self.compress()
442442
# d2 = self.node_values()
443443
# assert d1 == d2
444-
print("AFTER COMPRESS")
445-
self.print_state()
444+
# print("AFTER COMPRESS")
445+
# self.print_state()
446446
s = self.compute_normalisation_factor()
447447
for st in self.T:
448448
assert st.tree_node != tskit.NULL
@@ -493,13 +493,13 @@ def run(self, h):
493493
self.initialise(1 / n)
494494
while self.tree.next():
495495
self.update_tree()
496-
if self.tree.index != 0:
497-
print("AFTER UPDATE TREE")
498-
self.print_state()
496+
# if self.tree.index != 0:
497+
# print("AFTER UPDATE TREE")
498+
# self.print_state()
499499
for site in self.tree.sites():
500500
self.process_site(site, h[site.id])
501-
print("BEFORE UPDATE TREE")
502-
self.print_state()
501+
# print("BEFORE UPDATE TREE")
502+
# self.print_state()
503503
return self.output
504504

505505
def compute_normalisation_factor(self):
@@ -1197,6 +1197,7 @@ def check_viterbi(
11971197
recombination=None,
11981198
mutation=None,
11991199
match_all_nodes=False,
1200+
compare_fm_ll=True,
12001201
compare_lib=True,
12011202
compare_lshmm=None,
12021203
):
@@ -1220,12 +1221,28 @@ def check_viterbi(
12201221
cm = ls_viterbi_tree(
12211222
h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes
12221223
)
1223-
cm.print_state()
1224+
# cm.print_state()
12241225
path_tree = cm.traceback(match_all_nodes=match_all_nodes)
12251226
ll_tree = np.sum(np.log10(cm.normalisation_factor))
12261227
assert np.isscalar(ll_tree)
12271228
# print("path tree = ", path_tree)
12281229

1230+
if compare_fm_ll:
1231+
# Compare the log-likelihood of the Viterbi path (ll_tree)
1232+
# with the log-likelihood of the most likely path from
1233+
# the forward matrix.
1234+
fm = ls_forward_tree(
1235+
h,
1236+
ts,
1237+
recombination,
1238+
mutation,
1239+
scale_mutation_based_on_n_alleles=False,
1240+
match_all_nodes=match_all_nodes,
1241+
)
1242+
ll_fm = np.sum(np.log10(fm.normalisation_factor))
1243+
print("FMLL", ll_tree, ll_fm)
1244+
# np.testing.assert_allclose(ll_tree, ll_fm)
1245+
12291246
if compare_lshmm:
12301247
# Check that the likelihood of the preferred path is
12311248
# the same as ll_tree (and ll).
@@ -1239,6 +1256,8 @@ def check_viterbi(
12391256
scale_mutation_based_on_n_alleles=False,
12401257
)
12411258
assert np.isscalar(ll)
1259+
# This is the log likelihood returned by viterbi alg
1260+
nt.assert_allclose(ll_tree, ll)
12421261
# print()
12431262
# print("ls path = ", path)
12441263
ll_check = ls.path_ll(
@@ -1249,7 +1268,9 @@ def check_viterbi(
12491268
p_mutation=mutation,
12501269
scale_mutation_based_on_n_alleles=False,
12511270
)
1252-
nt.assert_allclose(ll_tree, ll)
1271+
# This is the log-likelihood of the path itself, computed
1272+
# different way
1273+
nt.assert_allclose(ll_tree, ll_check)
12531274

12541275
if compare_lib:
12551276
nt.assert_allclose(ll_check, ll)
@@ -1267,7 +1288,6 @@ def check_viterbi(
12671288
return path_tree
12681289

12691290

1270-
# TODO add params to run the various checks
12711291
def check_forward_matrix(
12721292
ts,
12731293
h,
@@ -1319,8 +1339,9 @@ def check_forward_matrix(
13191339
assert c.shape == (m,)
13201340
assert np.isscalar(ll)
13211341

1322-
# print(F)
1323-
# print(F2)
1342+
print(ll_tree)
1343+
print(F)
1344+
print(F2)
13241345
nt.assert_allclose(F, F2)
13251346
nt.assert_allclose(c, cm.normalisation_factor)
13261347
nt.assert_allclose(ll_tree, ll)
@@ -1447,8 +1468,7 @@ def test_match_sample(self, j):
14471468
h[j] = 1
14481469
path = check_viterbi(ts, h)
14491470
nt.assert_array_equal([j, j, j, j], path)
1450-
cm = check_forward_matrix(ts, h)
1451-
check_backward_matrix(ts, h, cm)
1471+
check_fb_matrices(ts, h)
14521472

14531473
@pytest.mark.parametrize("j", [1, 2])
14541474
def test_match_sample_missing_flanks(self, j):
@@ -1459,16 +1479,14 @@ def test_match_sample_missing_flanks(self, j):
14591479
h[j] = 1
14601480
path = check_viterbi(ts, h)
14611481
nt.assert_array_equal([j, j, j, j], path)
1462-
cm = check_forward_matrix(ts, h)
1463-
check_backward_matrix(ts, h, cm)
1482+
check_fb_matrices(ts, h)
14641483

14651484
def test_switch_each_sample(self):
14661485
ts = self.ts()
14671486
h = np.ones(4)
14681487
path = check_viterbi(ts, h)
14691488
nt.assert_array_equal([0, 1, 2, 3], path)
1470-
cm = check_forward_matrix(ts, h)
1471-
check_backward_matrix(ts, h, cm)
1489+
check_fb_matrices(ts, h)
14721490

14731491
def test_switch_each_sample_missing_flanks(self):
14741492
ts = self.ts()
@@ -1477,8 +1495,7 @@ def test_switch_each_sample_missing_flanks(self):
14771495
h[-1] = -1
14781496
path = check_viterbi(ts, h)
14791497
nt.assert_array_equal([1, 1, 2, 2], path)
1480-
cm = check_forward_matrix(ts, h)
1481-
check_backward_matrix(ts, h, cm)
1498+
check_fb_matrices(ts, h)
14821499

14831500
def test_switch_each_sample_missing_middle(self):
14841501
ts = self.ts()
@@ -1487,8 +1504,7 @@ def test_switch_each_sample_missing_middle(self):
14871504
path = check_viterbi(ts, h)
14881505
# Implementation of Viterbi switches at right-most position
14891506
nt.assert_array_equal([0, 0, 0, 3], path)
1490-
cm = check_forward_matrix(ts, h)
1491-
check_backward_matrix(ts, h, cm)
1507+
check_fb_matrices(ts, h)
14921508

14931509

14941510
class TestSingleBalancedTreeAllSamplesExample:
@@ -1525,25 +1541,54 @@ def test_match_sample(self, u, h):
15251541
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True
15261542
)
15271543
nt.assert_array_equal([u] * 7, path)
1528-
cm = check_forward_matrix(
1544+
fm = check_forward_matrix(
15291545
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True
15301546
)
1531-
check_backward_matrix(
1532-
ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=True
1547+
bm = check_backward_matrix(
1548+
ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=True
15331549
)
1550+
check_fb_matrix_integrity(fm, bm)
1551+
1552+
1553+
def check_fb_matrix_integrity(fm, bm):
1554+
"""
1555+
Validate properties of the forward and backward matrices.
1556+
"""
1557+
F = fm.decode()
1558+
B = bm.decode()
1559+
assert F.shape == B.shape
1560+
for j in range(len(F)):
1561+
s = np.sum(B[j] * F[j])
1562+
np.testing.assert_allclose(s, 1)
1563+
1564+
1565+
def check_fb_matrices(ts, h):
1566+
fm = check_forward_matrix(ts, h)
1567+
bm = check_backward_matrix(ts, h, fm)
1568+
check_fb_matrix_integrity(fm, bm)
15341569

15351570

15361571
def validate_match_all_nodes(ts, h, expected_path):
1537-
path = check_viterbi(
1538-
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1539-
)
1540-
nt.assert_array_equal(expected_path, path)
1541-
cm = check_forward_matrix(
1572+
# path = check_viterbi(
1573+
# ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1574+
# )
1575+
# nt.assert_array_equal(expected_path, path)
1576+
fm = check_forward_matrix(
15421577
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
15431578
)
1579+
F = fm.decode()
1580+
# print(cm.decode())
1581+
# cm.print_state()
15441582
bm = check_backward_matrix(
1545-
ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1583+
ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=False
15461584
)
1585+
print("sites = ", ts.num_sites)
1586+
B = bm.decode()
1587+
print(F)
1588+
for j in range(ts.num_sites):
1589+
print(j, np.sum(B[j] * F[j]))
1590+
1591+
# sum(B[variant,:] * F[variant,:]) = 1
15471592

15481593

15491594
class TestSingleBalancedTreeAllNodesExample:
@@ -1640,11 +1685,11 @@ def ts():
16401685
[
16411686
# Just samples
16421687
([1, 0, 0, 0, 0, 1, 1], [0] * 7),
1643-
([0, 1, 0, 0, 1, 1, 0], [1] * 7),
1644-
([0, 0, 1, 0, 1, 1, 0], [2] * 7),
1645-
([0, 0, 0, 1, 0, 0, 1], [3] * 7),
1646-
# Match root
1647-
([0, 0, 0, 0, 0, 0, 0], [7] * 7),
1688+
# ([0, 1, 0, 0, 1, 1, 0], [1] * 7),
1689+
# ([0, 0, 1, 0, 1, 1, 0], [2] * 7),
1690+
# ([0, 0, 0, 1, 0, 0, 1], [3] * 7),
1691+
# # Match root
1692+
# ([0, 0, 0, 0, 0, 0, 0], [7] * 7),
16481693
],
16491694
)
16501695
def test_match_all_nodes(self, h, expected_path):

0 commit comments

Comments
 (0)