Skip to content

Commit 2d6347c

Browse files
committed
Add a convenience function to get data from the toy data
1 parent 10543bd commit 2d6347c

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

python/tests/test_imputation.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,35 @@ def get_forward_backward_matrices():
219219
fwd_matrix_2 = convert_to_numpy(_fwd_matrix_text_2)
220220
bwd_matrix_2 = convert_to_numpy(_bwd_matrix_text_2)
221221
return [fwd_matrix_1, bwd_matrix_1, fwd_matrix_2, bwd_matrix_2]
222+
223+
224+
def get_test_data(matrix_text, par):
225+
"""Extracts data for checking forward or backward probability matrix calculations."""
226+
x = convert_to_numpy(matrix_text)
227+
if par == "switch":
228+
# Switch probability, one per site
229+
return x[:, 2].reshape((4, 4))[:, 2]
230+
elif par == "mismatch":
231+
# Mismatch probability
232+
return x[:, 2].reshape((4, 4))[:, 4]
233+
elif par == "ref_hap_allele":
234+
# Allele in haplotype in reference panel
235+
# 0 = ref allele, 1 = alt allele
236+
return x[:, 2].reshape((4, 4))[:, 6]
237+
elif par == "query_hap_allele":
238+
# Allele in haplotype in query
239+
# 0 = ref allele, 1 = alt allele
240+
return x[:, 2].reshape((4, 4))[:, 7]
241+
elif par == "shift":
242+
# Shift factor
243+
# TODO
244+
pass
245+
elif par == "scale":
246+
# Scale factor
247+
# TODO
248+
pass
249+
elif par == "sum":
250+
# Sum of values over haplotypes
251+
return x[:, 2].reshape((4, 4))[:, 10]
252+
else:
253+
raise ValueError(f"Unknown parameter: {par}")

0 commit comments

Comments
 (0)