Skip to content

Commit dfb4c54

Browse files
committed
Add tests to check LS HMM of tskit compared to BEAGLE
1 parent b6f9872 commit dfb4c54

File tree

1 file changed

+251
-0
lines changed

1 file changed

+251
-0
lines changed

python/tests/test_imputation.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
Tests for genotype imputation (forward and Baum-Welsh algorithms).
3+
"""
4+
import io
5+
6+
import numpy as np
7+
8+
import tskit
9+
10+
11+
# A tree sequence containing 3 diploid individuals with 5 sites and 5 mutations
12+
# (one per site). The first 2 individuals are used as reference panel,
13+
# the last one is the target individual.
14+
15+
_toy_ts_nodes_text = """\
16+
id is_sample time population individual metadata
17+
0 1 0.000000 0 0
18+
1 1 0.000000 0 0
19+
2 1 0.000000 0 1
20+
3 1 0.000000 0 1
21+
4 1 0.000000 0 2
22+
5 1 0.000000 0 2
23+
6 0 0.029768 0 -1
24+
7 0 0.133017 0 -1
25+
8 0 0.223233 0 -1
26+
9 0 0.651586 0 -1
27+
10 0 0.698831 0 -1
28+
11 0 2.114867 0 -1
29+
12 0 4.322031 0 -1
30+
13 0 7.432311 0 -1
31+
"""
32+
33+
_toy_ts_edges_text = """\
34+
left right parent child metadata
35+
0.000000 1000000.000000 6 0
36+
0.000000 1000000.000000 6 3
37+
0.000000 1000000.000000 7 2
38+
0.000000 1000000.000000 7 5
39+
0.000000 1000000.000000 8 1
40+
0.000000 1000000.000000 8 4
41+
0.000000 781157.000000 9 6
42+
0.000000 781157.000000 9 7
43+
0.000000 505438.000000 10 8
44+
0.000000 505438.000000 10 9
45+
505438.000000 549484.000000 11 8
46+
505438.000000 549484.000000 11 9
47+
781157.000000 1000000.000000 12 6
48+
781157.000000 1000000.000000 12 7
49+
549484.000000 1000000.000000 13 8
50+
549484.000000 781157.000000 13 9
51+
781157.000000 1000000.000000 13 12
52+
"""
53+
54+
_toy_ts_sites_text = """\
55+
position ancestral_state metadata
56+
200000.000000 A
57+
300000.000000 C
58+
520000.000000 G
59+
600000.000000 T
60+
900000.000000 A
61+
"""
62+
63+
_toy_ts_mutations_text = """\
64+
site node time derived_state parent metadata
65+
0 9 unknown G -1
66+
1 8 unknown A -1
67+
2 9 unknown T -1
68+
3 9 unknown C -1
69+
4 12 unknown C -1
70+
"""
71+
72+
_toy_ts_individuals_text = """\
73+
flags
74+
0
75+
0
76+
0
77+
"""
78+
79+
80+
def get_toy_ts():
81+
"""
82+
Returns the toy tree sequence in text format above.
83+
"""
84+
ts = tskit.load_text(
85+
nodes=io.StringIO(_toy_ts_nodes_text),
86+
edges=io.StringIO(_toy_ts_edges_text),
87+
sites=io.StringIO(_toy_ts_sites_text),
88+
mutations=io.StringIO(_toy_ts_mutations_text),
89+
individuals=io.StringIO(_toy_ts_individuals_text),
90+
strict=False,
91+
)
92+
return ts
93+
94+
95+
# BEAGLE 4.1 was run on the toy data set above using default parameters.
96+
# The following are the forward probability matrices and backward probability
97+
# matrices calculated when imputing into the third individual above. There are
98+
# two sets of matrices, one for each haplotype.
99+
#
100+
# Notes about calculations:
101+
# n = number of haplotypes in ref. panel
102+
# M = number of markers
103+
# m = index of marker (site)
104+
# h = index of haplotype in ref. panel
105+
#
106+
# In forward probability matrix,
107+
# fwd[m][h] = emission prob., if m = 0 (first marker)
108+
# fwd[m][h] = emission prob. * (scale * fwd[m - 1][h] + shift), otherwise
109+
# where scale = (1 - switch prob.)/sum of fwd[m - 1],
110+
# and shift = switch prob./n.
111+
#
112+
# In backward probability matrix,
113+
# bwd[m][h] = 1, if m = M - 1 (last marker) // DON'T SEE THIS IN BEAGLE
114+
# unadj. bwd[m][h] = emission prob. / n
115+
# bwd[m][h] = (unadj. bwd[m][h] + shift) * scale, otherwise
116+
# where scale = (1 - switch prob.)/sum of unadj. bwd[m],
117+
# and shift = switch prob./n.
118+
#
119+
# For each site, the sum of backward value over all haplotypes is calculated
120+
# before scaling and shifting.
121+
122+
_fwd_matrix_text_1 = """
123+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shift,scale,sum,val
124+
0,0,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,0.000100,0.000100
125+
0,1,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,1.000000,0.999900
126+
0,2,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000100,0.000100
127+
0,3,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000200,0.000100
128+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.000025,0.000025
129+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.250000,0.249975
130+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250025,0.000025
131+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250050,0.000025
132+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025
133+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975
134+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025
135+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025
136+
3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025
137+
3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975
138+
3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025
139+
3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025
140+
"""
141+
142+
_bwd_matrix_text_1 = """
143+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shift,scale,sum,val
144+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
145+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
146+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
147+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
148+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
149+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
150+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
151+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
152+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
153+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
154+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
155+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
156+
0,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
157+
0,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.250050,0.250000
158+
0,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
159+
0,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
160+
"""
161+
162+
_fwd_matrix_text_2 = """
163+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shift,scale,sum,val
164+
0,0,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,0.999900,0.999900
165+
0,1,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,1.000000,0.000100
166+
0,2,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,1.999900,0.999900
167+
0,3,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,2.999800,0.999900
168+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.249975,0.249975
169+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250000,0.000025
170+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.499975,0.249975
171+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.749950,0.249975
172+
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975
173+
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025
174+
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975
175+
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975
176+
3,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975
177+
3,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025
178+
3,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975
179+
3,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975
180+
"""
181+
182+
_bwd_matrix_text_2 = """
183+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shift,scale,sum,val
184+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
185+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
186+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
187+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
188+
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
189+
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
190+
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
191+
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
192+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
193+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
194+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
195+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
196+
0,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
197+
0,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.749950,0.250000
198+
0,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
199+
0,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
200+
"""
201+
202+
203+
def convert_to_numpy(matrix_text):
204+
"""
205+
Converts a matrix in text to numpy and returns it.
206+
"""
207+
x = np.loadtxt(io.StringIO(matrix_text), skiprows=1, delimiter=",")
208+
for i in np.arange(x.shape[0]):
209+
# Check that switch and non-switch probabilities sum to 1
210+
assert (x[i, 2] + x[i, 3]) == 1 or x[i, 2] == -1
211+
# Check that non-mismatch and mismatch probabilities sum to 1
212+
assert (x[i, 4] + x[i, 5]) == 1 or x[i, 4] == -1
213+
return x[:, -1].reshape((4, 4)) # size (m, h)
214+
215+
216+
def get_forward_backward_matrices():
217+
fwd_matrix_1 = convert_to_numpy(_fwd_matrix_text_1)
218+
bwd_matrix_1 = convert_to_numpy(_bwd_matrix_text_1)
219+
fwd_matrix_2 = convert_to_numpy(_fwd_matrix_text_2)
220+
bwd_matrix_2 = convert_to_numpy(_bwd_matrix_text_2)
221+
return [fwd_matrix_1, bwd_matrix_1, fwd_matrix_2, bwd_matrix_2]
222+
223+
224+
def get_test_data(matrix_text, par):
225+
"""Extracts data for checking forward or backward probability matrix calculations."""
226+
x = convert_to_numpy(matrix_text)
227+
if par == "switch":
228+
# Switch probability, one per site
229+
return x[:, 2].reshape((4, 4))[:, 0]
230+
elif par == "mismatch":
231+
# Mismatch probability, one per site
232+
return x[:, 4].reshape((4, 4))[:, 0]
233+
elif par == "ref_hap_allele":
234+
# Allele in haplotype in reference panel
235+
# 0 = ref allele, 1 = alt allele
236+
return x[:, 6].reshape((4, 4))
237+
elif par == "query_hap_allele":
238+
# Allele in haplotype in query
239+
# 0 = ref allele, 1 = alt allele
240+
return x[:, 7].reshape((4, 4))[:, 0]
241+
elif par == "shift":
242+
# Shift factor, one per site
243+
return x[:, 8].reshape((4, 4))[:, 0]
244+
elif par == "scale":
245+
# Scale factor, one per site
246+
return x[:, 9].reshape((4, 4))[:, 0]
247+
elif par == "sum":
248+
# Sum of values over haplotypes
249+
return x[:, 10].reshape((4, 4))[:, 0]
250+
else:
251+
raise ValueError(f"Unknown parameter: {par}")

0 commit comments

Comments
 (0)