Skip to content

Commit 4e0440b

Browse files
committed
Contig stitcher: improve concordance calculations
Also add more tests for it.
1 parent bf1390f commit 4e0440b

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

micall/core/contig_stitcher.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,29 @@ def slide(start, end):
384384
return result
385385

386386

387+
def disambiguate_concordance(concordance: List[float]) -> List[Tuple[float, int]]:
388+
def slide(concordance):
389+
count = 0
390+
for i, (prev, current, next) in enumerate(sliding_window(concordance)):
391+
if current == prev:
392+
count += 1
393+
yield count
394+
else:
395+
yield 0
396+
397+
forward = list(slide(concordance))
398+
reverse = list(reversed(list(slide(reversed(concordance)))))
399+
for i, (x, f, r) in enumerate(zip(concordance, forward, reverse)):
400+
local_rank = f * r
401+
global_rank = i if i < len(concordance) / 2 else len(concordance) - i - 1
402+
yield (x, local_rank, global_rank)
403+
404+
387405
def concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance):
388406
""" Determine optimal cut points for stitching based on sequence concordance in the overlap region. """
389407

390-
valuator = lambda i: (concordance[i], i if i < len(concordance) / 2 else len(concordance) - i - 1)
391-
sorted_concordance_indexes = sorted(range(len(concordance)), key=valuator)
408+
concordance_d = list(disambiguate_concordance(concordance))
409+
sorted_concordance_indexes = sorted(range(len(concordance)), key=lambda i: concordance_d[i])
392410
remove_dashes = lambda s: ''.join(c for c in s if c != '-')
393411

394412
for max_concordance_index in reversed(sorted_concordance_indexes):

micall/tests/test_contig_stitcher.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import pytest
66

7-
from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus, calculate_concordance, align_all_to_reference, main, AlignedContig
7+
from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus, calculate_concordance, align_all_to_reference, main, AlignedContig, disambiguate_concordance
88
from micall.core.plot_contigs import plot_stitcher_coverage
99
from micall.tests.utils import MockAligner, fixed_random_seed
1010
from micall.utils.structured_logger import add_structured_handler
@@ -974,6 +974,42 @@ def generate_random_string_pair(length):
974974
right = ''.join(random.choice('ACGT') for _ in range(length))
975975
return left, right
976976

977+
978+
@pytest.mark.parametrize(
979+
'left, right, expected',
980+
[("aaaaa", "aaaaa", [0.1] * 5),
981+
("abcdd", "abcdd", [0.1] * 5),
982+
("aaaaaaaa", "baaaaaab", [0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1]),
983+
("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]),
984+
("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]),
985+
("aaaaaaaa", "aaaaabbb", [0.1, 0.1, 0.1, 0.1, 0.1, 0.08, 0.08, 0.08]),
986+
("aaaaaaaa", "aaabbaaa", [0.12, 0.12, 0.12, 0.1, 0.1, 0.12, 0.12, 0.12]),
987+
("aaaaa", "bbbbb", [0] * 5),
988+
]
989+
)
990+
def test_concordance_simple(left, right, expected):
991+
result = [round(x, 2) for x in calculate_concordance(left, right)]
992+
assert result == expected
993+
994+
995+
@pytest.mark.parametrize(
996+
'left, right, expected',
997+
[("a" * 128, "a" * 128, 64),
998+
("a" * 128, "a" * 64 + "b" * 64, 32),
999+
("a" * 128, "a" * 64 + "ba" * 32, 32),
1000+
("a" * 128, "a" * 54 + "b" * 20 + "a" * 54, 28), # two peaks
1001+
("a" * 128, "a" * 63 + "b" * 2 + "a" * 63, 32), # two peaks
1002+
("a" * 1280, "b" * 640 + "a" * 640, 640 + 30), # the window is too small to account for all of the context
1003+
]
1004+
)
1005+
def test_concordance_simple_index(left, right, expected):
1006+
concordance = calculate_concordance(left, right)
1007+
concordance_d = list(disambiguate_concordance(concordance))
1008+
index = max(range(len(concordance)), key=lambda i: concordance_d[i])
1009+
if abs(index - expected) > 3:
1010+
assert index == expected
1011+
1012+
9771013
def generate_test_cases(num_cases):
9781014
with fixed_random_seed(42):
9791015
length = random.randint(1, 80)

0 commit comments

Comments
 (0)