Skip to content

Commit d24e1e8

Browse files
committed
Add tests to check LS HMM of tskit compared to BEAGLE
1 parent b6f9872 commit d24e1e8

File tree

1 file changed

+261
-0
lines changed

1 file changed

+261
-0
lines changed

python/tests/test_imputation.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""
2+
Tests for genotype imputation (forward and Baum-Welsh algorithms).
3+
"""
4+
import io
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
import tskit
10+
11+
12+
# A tree sequence containing 3 diploid individuals with 5 sites and 5 mutations
13+
# (one per site). The first 2 individuals are used as reference panel,
14+
# the last one is the target individual.
15+
16+
toy_ts_nodes_text = """\
17+
id is_sample time population individual metadata
18+
0 1 0.000000 0 0
19+
1 1 0.000000 0 0
20+
2 1 0.000000 0 1
21+
3 1 0.000000 0 1
22+
4 1 0.000000 0 2
23+
5 1 0.000000 0 2
24+
6 0 0.029768 0 -1
25+
7 0 0.133017 0 -1
26+
8 0 0.223233 0 -1
27+
9 0 0.651586 0 -1
28+
10 0 0.698831 0 -1
29+
11 0 2.114867 0 -1
30+
12 0 4.322031 0 -1
31+
13 0 7.432311 0 -1
32+
"""
33+
34+
toy_ts_edges_text = """\
35+
left right parent child metadata
36+
0.000000 1000000.000000 6 0
37+
0.000000 1000000.000000 6 3
38+
0.000000 1000000.000000 7 2
39+
0.000000 1000000.000000 7 5
40+
0.000000 1000000.000000 8 1
41+
0.000000 1000000.000000 8 4
42+
0.000000 781157.000000 9 6
43+
0.000000 781157.000000 9 7
44+
0.000000 505438.000000 10 8
45+
0.000000 505438.000000 10 9
46+
505438.000000 549484.000000 11 8
47+
505438.000000 549484.000000 11 9
48+
781157.000000 1000000.000000 12 6
49+
781157.000000 1000000.000000 12 7
50+
549484.000000 1000000.000000 13 8
51+
549484.000000 781157.000000 13 9
52+
781157.000000 1000000.000000 13 12
53+
"""
54+
55+
toy_ts_sites_text = """\
56+
position ancestral_state metadata
57+
200000.000000 A
58+
300000.000000 C
59+
520000.000000 G
60+
600000.000000 T
61+
900000.000000 A
62+
"""
63+
64+
toy_ts_mutations_text = """\
65+
site node time derived_state parent metadata
66+
0 9 unknown G -1
67+
1 8 unknown A -1
68+
2 9 unknown T -1
69+
3 9 unknown C -1
70+
4 12 unknown C -1
71+
"""
72+
73+
toy_ts_individuals_text = """\
74+
flags
75+
0
76+
0
77+
0
78+
"""
79+
80+
81+
def get_toy_data():
82+
"""
83+
Returns toy data contained in the toy tree sequence in text format above.
84+
85+
:param: None
86+
:return: Reference panel tree sequence and query haplotypes.
87+
:rtype: list
88+
"""
89+
ts = tskit.load_text(
90+
nodes=io.StringIO(toy_ts_nodes_text),
91+
edges=io.StringIO(toy_ts_edges_text),
92+
sites=io.StringIO(toy_ts_sites_text),
93+
mutations=io.StringIO(toy_ts_mutations_text),
94+
individuals=io.StringIO(toy_ts_individuals_text),
95+
strict=False,
96+
)
97+
ref_ts = ts.simplify(samples=np.arange(2 * 2), filter_sites=False)
98+
query_ts = ts.simplify(samples=[5, 6], filter_sites=False)
99+
query_h = query_ts.genotype_matrix().T
100+
return [ref_ts, query_h]
101+
102+
103+
# BEAGLE 4.1 was run on the toy data set above using default parameters.
104+
#
105+
# In the query VCF, the site at position 520,000 was redacted and then imputed.
106+
# Note that the ancestral allele in the simulated tree sequence is
107+
# treated as the REF in the VCFs.
108+
#
109+
# The following are the forward probability matrices and backward probability
110+
# matrices calculated when imputing into the third individual above. There are
111+
# two sets of matrices, one for each haplotype.
112+
#
113+
# Notes about calculations:
114+
# n = number of haplotypes in ref. panel
115+
# M = number of markers
116+
# m = index of marker (site)
117+
# h = index of haplotype in ref. panel
118+
#
119+
# In forward probability matrix,
120+
# fwd[m][h] = emission prob., if m = 0 (first marker)
121+
# fwd[m][h] = emission prob. * (scale * fwd[m - 1][h] + shift), otherwise
122+
# where scale = (1 - switch prob.)/sum of fwd[m - 1],
123+
# and shift = switch prob./n.
124+
#
125+
# In backward probability matrix,
126+
# bwd[m][h] = 1, if m = M - 1 (last marker) // DON'T SEE THIS IN BEAGLE
127+
# unadj. bwd[m][h] = emission prob. / n
128+
# bwd[m][h] = (unadj. bwd[m][h] + shift) * scale, otherwise
129+
# where scale = (1 - switch prob.)/sum of unadj. bwd[m],
130+
# and shift = switch prob./n.
131+
#
132+
# For each site, the sum of backward value over all haplotypes is calculated
133+
# before scaling and shifting.
134+
135+
fwd_matrix_text_1 = """
136+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
137+
0,0,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,0.000100,0.000100
138+
0,1,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,1.000000,0.999900
139+
0,2,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000100,0.000100
140+
0,3,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000200,0.000100
141+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.000025,0.000025
142+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.250000,0.249975
143+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250025,0.000025
144+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250050,0.000025
145+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025
146+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975
147+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025
148+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025
149+
3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025
150+
3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975
151+
3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025
152+
3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025
153+
"""
154+
155+
bwd_matrix_text_1 = """
156+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
157+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
158+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
159+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
160+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
161+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
162+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
163+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
164+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
165+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
166+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
167+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
168+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
169+
0,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
170+
0,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.250050,0.250000
171+
0,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
172+
0,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
173+
"""
174+
175+
fwd_matrix_text_2 = """
176+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
177+
0,0,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,0.999900,0.999900
178+
0,1,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,1.000000,0.000100
179+
0,2,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,1.999900,0.999900
180+
0,3,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,2.999800,0.999900
181+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.249975,0.249975
182+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250000,0.000025
183+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.499975,0.249975
184+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.749950,0.249975
185+
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975
186+
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025
187+
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975
188+
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975
189+
3,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975
190+
3,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025
191+
3,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975
192+
3,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975
193+
"""
194+
195+
bwd_matrix_text_2 = """
196+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
197+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
198+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
199+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
200+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
201+
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
202+
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
203+
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
204+
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
205+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
206+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
207+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
208+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
209+
0,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
210+
0,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.749950,0.250000
211+
0,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
212+
0,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
213+
"""
214+
215+
216+
def convert_to_numpy(matrix_text):
217+
"""Converts a forward or backward matrix in text format to numpy."""
218+
df = pd.read_csv(io.StringIO(matrix_text))
219+
# Check that switch and non-switch probabilities sum to 1
220+
assert np.all(np.isin(df.probRec + df.probNoRec, [1, -2]))
221+
# Check that non-mismatch and mismatch probabilities sum to 1
222+
assert np.all(np.isin(df.noErrProb + df.errProb, [1, -2]))
223+
return df.val.to_numpy().reshape((4, 4))
224+
225+
226+
def get_forward_backward_matrices():
227+
fwd_matrix_1 = convert_to_numpy(fwd_matrix_text_1)
228+
bwd_matrix_1 = convert_to_numpy(bwd_matrix_text_1)
229+
fwd_matrix_2 = convert_to_numpy(fwd_matrix_text_2)
230+
bwd_matrix_2 = convert_to_numpy(bwd_matrix_text_2)
231+
return [fwd_matrix_1, bwd_matrix_1, fwd_matrix_2, bwd_matrix_2]
232+
233+
234+
def get_test_data(matrix_text, par):
235+
"""Extracts data to check forward or backward probability matrix calculations."""
236+
df = pd.read_csv(io.StringIO(matrix_text))
237+
if par == "switch":
238+
# Switch probability, one per site
239+
return df.probRec.to_numpy().reshape((4, 4))[:, 0]
240+
elif par == "mismatch":
241+
# Mismatch probability, one per site
242+
return df.errProb.to_numpy().reshape((4, 4))[:, 0]
243+
elif par == "ref_hap_allele":
244+
# Allele in haplotype in reference panel
245+
# 0 = ref allele, 1 = alt allele
246+
return df.refAl.to_numpy().reshape((4, 4))
247+
elif par == "query_hap_allele":
248+
# Allele in haplotype in query
249+
# 0 = ref allele, 1 = alt allele
250+
return df.queryAl.to_numpy().reshape((4, 4))[:, 0]
251+
elif par == "shift":
252+
# Shift factor, one per site
253+
return df.shiftFac.to_numpy().reshape((4, 4))[:, 0]
254+
elif par == "scale":
255+
# Scale factor, one per site
256+
return df.scaleFac.to_numpy().reshape((4, 4))[:, 0]
257+
elif par == "sum":
258+
# Sum of values over haplotypes
259+
return df.sumSite.to_numpy().reshape((4, 4))[:, 0]
260+
else:
261+
raise ValueError(f"Unknown parameter: {par}")

0 commit comments

Comments
 (0)