Skip to content

Commit 0238f04

Browse files
committed
Contig stitcher: simplify the concordance algorithm
1 parent ea50a6d commit 0238f04

File tree

3 files changed

+33
-52
lines changed

3 files changed

+33
-52
lines changed

micall/core/contig_stitcher.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from Bio import Seq
1111
import logging
1212
from contextvars import ContextVar, Context
13+
from fractions import Fraction
1314

1415
from micall.utils.cigar_tools import Cigar, connect_cigar_hits, CigarHit
1516
from micall.utils.consensus_aligner import CigarActions
@@ -350,15 +351,17 @@ def find_overlapping_contig(self, aligned_contigs):
350351
return max(every, key=lambda other: other.alignment.ref_length if other else 0, default=None)
351352

352353

353-
def calculate_concordance(left: str, right: str) -> List[float]:
354+
def calculate_concordance(left: str, right: str) -> List[Fraction]:
354355
"""
355-
Calculate concordance for two given sequences using a sliding window method.
356+
Calculate concordance for two given sequences using a sliding average.
356357
357-
The function compares the two strings from both left to right and then right to left,
358-
calculating for each position the ratio of matching characters in a window around the
359-
current position. So position holds a moving avarage score.
360-
361-
It's required that the input strings are of the same length.
358+
The function compares the two strings character by character, simultaneously from
359+
both left to right and right to left, calculating a score that represents a moving
360+
average of matches at each position. If characters match at a given position,
361+
a score of 1 is added; otherwise, a score of 0 is added. The score is then
362+
averaged with the previous scores using a weighted sliding average where the
363+
current score has a weight of 1/3 and the accumulated score has a weight of 2/3.
364+
This sliding average score is halved and then processed again, but in reverse direction.
362365
363366
:param left: string representing first sequence
364367
:param right: string representing second sequence
@@ -368,22 +371,18 @@ def calculate_concordance(left: str, right: str) -> List[float]:
368371
if len(left) != len(right):
369372
raise ValueError("Can only calculate concordance for same sized sequences")
370373

371-
result: List[float] = [0] * len(left)
374+
result: List[Fraction] = [Fraction(0)] * len(left)
372375

373376
def slide(start, end):
374-
window_size = 30
375-
scores = deque([0] * window_size, maxlen=window_size)
376-
scores_sum = 0
377+
scores_sum = Fraction(0)
377378
inputs = list(zip(left, right))
378379
increment = 1 if start <= end else -1
379380

380381
for i in range(start, end, increment):
381382
(a, b) = inputs[i]
382-
current = a == b
383-
scores_sum -= scores.popleft()
384-
scores_sum += current
385-
scores.append(current)
386-
result[i] += (scores_sum / window_size) / 2
383+
current = Fraction(1) if a == b else Fraction(0)
384+
scores_sum = (scores_sum * 2 / 3 + current * 1 / 3)
385+
result[i] += scores_sum / 2
387386

388387
# Slide forward, then in reverse, adding the scores at each position.
389388
slide(0, len(left))
@@ -392,22 +391,10 @@ def slide(start, end):
392391
return result
393392

394393

395-
def disambiguate_concordance(concordance: List[float]) -> Iterable[Tuple[float, int, int]]:
396-
def slide(concordance):
397-
count = 0
398-
for i, (prev, current, next) in enumerate(sliding_window(concordance)):
399-
if current == prev:
400-
count += 1
401-
yield count
402-
else:
403-
yield 0
404-
405-
forward = list(slide(concordance))
406-
reverse = list(reversed(list(slide(reversed(concordance)))))
407-
for i, (x, f, r) in enumerate(zip(concordance, forward, reverse)):
408-
local_rank = f * r
394+
def disambiguate_concordance(concordance: List[float]) -> Iterable[Tuple[float, int]]:
395+
for i, x in enumerate(concordance):
409396
global_rank = i if i < len(concordance) / 2 else len(concordance) - i - 1
410-
yield (x, local_rank, global_rank)
397+
yield (x, global_rank)
411398

412399

413400
def concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance):
@@ -467,8 +454,8 @@ def stitch_2_contigs(left, right):
467454
right_overlap_drop, right_overlap_take = right_overlap.cut_reference(aligned_right_cutpoint)
468455

469456
# Log it.
470-
average_concordance = sum(concordance) / (len(concordance) or 1)
471-
concordance_str = ', '.join(map(lambda x: str(round(x, 2)), concordance))
457+
average_concordance = Fraction(sum(concordance) / (len(concordance) or 1))
458+
concordance_str = ', '.join(map(lambda x: str(int(round(x * 100)) / 100), concordance))
472459
cut_point_location_scaled = max_concordance_index / (((len(concordance) or 1) - 1) or 1)
473460
logger.debug("Created overlap contigs %r at %s and %r at %s based on parts of %r and %r, with avg. concordance %s%%, cut point at %s%%, and full concordance [%s].",
474461
left_overlap_take.name, left_overlap.alignment, right_overlap_take.name, right_overlap_take.alignment,

micall/tests/test_contig_stitcher.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -978,18 +978,18 @@ def generate_random_string_pair(length):
978978

979979
@pytest.mark.parametrize(
980980
'left, right, expected',
981-
[("aaaaa", "aaaaa", [0.1] * 5),
982-
("abcdd", "abcdd", [0.1] * 5),
983-
("aaaaaaaa", "baaaaaab", [0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1]),
984-
("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]),
985-
("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]),
986-
("aaaaaaaa", "aaaaabbb", [0.1, 0.1, 0.1, 0.1, 0.1, 0.08, 0.08, 0.08]),
987-
("aaaaaaaa", "aaabbaaa", [0.12, 0.12, 0.12, 0.1, 0.1, 0.12, 0.12, 0.12]),
981+
[("aaaaa", "aaaaa", [0.6, 0.68, 0.7, 0.68, 0.6]),
982+
("abcdd", "abcdd", [0.6, 0.68, 0.7, 0.68, 0.6]),
983+
("aaaaaaaa", "baaaaaab", [0.3, 0.62, 0.71, 0.75, 0.75, 0.71, 0.62, 0.3]),
984+
("aaaaaaaa", "aaaaaaab", [0.64, 0.73, 0.79, 0.8, 0.79, 0.73, 0.64, 0.31]),
985+
("aaaaaaaa", "aaaaaaab", [0.64, 0.73, 0.79, 0.8, 0.79, 0.73, 0.64, 0.31]),
986+
("aaaaaaaa", "aaaaabbb", [0.6, 0.68, 0.7, 0.68, 0.6, 0.29, 0.19, 0.13]),
987+
("aaaaaaaa", "aaabbaaa", [0.56, 0.63, 0.62, 0.39, 0.39, 0.62, 0.63, 0.56]),
988988
("aaaaa", "bbbbb", [0] * 5),
989989
]
990990
)
991991
def test_concordance_simple(left, right, expected):
992-
result = [round(x, 2) for x in calculate_concordance(left, right)]
992+
result = [round(float(x), 2) for x in calculate_concordance(left, right)]
993993
assert result == expected
994994

995995

@@ -1000,7 +1000,7 @@ def test_concordance_simple(left, right, expected):
10001000
("a" * 128, "a" * 64 + "ba" * 32, 32),
10011001
("a" * 128, "a" * 54 + "b" * 20 + "a" * 54, 28), # two peaks
10021002
("a" * 128, "a" * 63 + "b" * 2 + "a" * 63, 32), # two peaks
1003-
("a" * 1280, "b" * 640 + "a" * 640, 640 + 30), # the window is too small to account for all of the context
1003+
("a" * 1280, "b" * 640 + "a" * 640, round(1280 * 3 / 4)),
10041004
]
10051005
)
10061006
def test_concordance_simple_index(left, right, expected):
@@ -1019,13 +1019,6 @@ def generate_test_cases(num_cases):
10191019
concordance_cases = generate_test_cases(num_cases=100)
10201020

10211021

1022-
@pytest.mark.parametrize('left, right', concordance_cases)
1023-
def test_concordance_output_is_list_of_floats(left, right):
1024-
result = calculate_concordance(left, right)
1025-
assert isinstance(result, list), "Result should be a list"
1026-
assert all(isinstance(n, float) for n in result), "All items in result should be float"
1027-
1028-
10291022
@pytest.mark.parametrize('left, right', concordance_cases)
10301023
def test_concordance_output_range(left, right):
10311024
result = calculate_concordance(left, right)

micall/utils/contig_stitcher_events.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Union, List
22
from dataclasses import dataclass
3+
from fractions import Fraction
34

45

56
@dataclass
@@ -89,10 +90,10 @@ class Overlap:
8990
right_remainder: 'AlignedContig'
9091
left_take: 'AlignedContig'
9192
right_take: 'AlignedContig'
92-
concordance: List[float]
93-
average: float
93+
concordance: List[Fraction]
94+
average: Fraction
9495
cut_point: int
95-
cut_point_scaled: float
96+
cut_point_scaled: Fraction
9697

9798

9899
@dataclass

0 commit comments

Comments
 (0)