Skip to content

Commit 2eeaa25

Browse files
authored
Merge pull request #144 from marbl/overhaul-coordinates
Overhaul coordinates
2 parents 1b5b48b + 2f81014 commit 2eeaa25

File tree

5 files changed

+135
-168
lines changed

5 files changed

+135
-168
lines changed

extend.py

+53-38
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,15 @@
22
from Bio.Seq import Seq
33
from Bio.SeqIO import SeqRecord
44
from Bio.Align import MultipleSeqAlignment
5-
from glob import glob
6-
import tempfile
5+
import logging
76
from pathlib import Path
87
import re
9-
import subprocess
108
from collections import namedtuple, defaultdict, Counter
11-
import os
12-
from Bio.Align import substitution_matrices
13-
from itertools import product, combinations
9+
import bisect
1410
import numpy as np
15-
from Bio.AlignIO.MafIO import MafWriter, MafIterator
16-
from Bio.AlignIO.MauveIO import MauveWriter, MauveIterator
17-
from logger import logger
18-
import time
11+
from tqdm import tqdm
12+
from logger import logger, TqdmToLogger, MIN_TQDM_INTERVAL
13+
import spoa
1914
#%%
2015

2116

@@ -46,29 +41,38 @@ def parse_xmfa_header(xmfa_file):
4641
return index_to_gid, gid_to_index
4742

4843

49-
def index_input_sequences(xmfa_file, input_dir):
44+
def index_input_sequences(xmfa_file, file_list):
45+
basename_to_path = {}
46+
for f in file_list:
47+
basename = str(Path(f).stem)
48+
basename_to_path[basename] = f
5049
gid_to_records = {}
5150
gid_to_cid_to_index = {}
51+
gid_to_index_to_cid = {}
5252
with open(xmfa_file) as parsnp_fd:
5353
for line in (line.strip() for line in parsnp_fd):
5454
if line[:2] == "##":
5555
if line.startswith("##SequenceFile"):
56-
p = Path(os.path.join(input_dir + line.split(' ')[1]))
57-
gid_to_records[p.stem] = {record.id: record for record in SeqIO.parse(str(p), "fasta")}
58-
gid_to_cid_to_index[p.stem] = {idx+1: rec.id for (idx, rec) in enumerate(SeqIO.parse(str(p), "fasta"))}
59-
return gid_to_records, gid_to_cid_to_index
56+
basename = Path(line.split(' ')[1]).stem
57+
p = Path(basename_to_path[basename])
58+
gid_to_records[p.stem] = {}
59+
gid_to_cid_to_index[p.stem] = {}
60+
gid_to_index_to_cid[p.stem] = {}
61+
for idx, rec in enumerate(SeqIO.parse(str(p), "fasta")):
62+
gid_to_records[p.stem][rec.id] = rec
63+
gid_to_cid_to_index[p.stem][rec.id] = idx + 1
64+
gid_to_index_to_cid[p.stem][idx + 1] = rec.id
65+
return gid_to_records, gid_to_cid_to_index, gid_to_index_to_cid
6066

6167

62-
63-
def xmfa_to_covered(xmfa_file, index_to_gid, gid_to_cid_to_index):
68+
def xmfa_to_covered(xmfa_file, index_to_gid, gid_to_index_to_cid):
6469
seqid_parser = re.compile(r'^cluster(\d+) s(\d+):p(\d+)/.*')
6570
idpair_to_segments = defaultdict(list)
66-
idpair_to_tree = defaultdict(IntervalTree)
6771
cluster_to_named_segments = defaultdict(list)
68-
for aln in tqdm(AlignIO.parse(xmfa_file, "mauve")):
72+
for aln in AlignIO.parse(xmfa_file, "mauve"):
6973
for seq in aln:
7074
# Skip reference for now...
71-
aln_len = seq.annotations["end"] - seq.annotations["start"] + 1
75+
aln_len = seq.annotations["end"] - seq.annotations["start"]
7276
cluster_idx, contig_idx, startpos = [int(x) for x in seqid_parser.match(seq.id).groups()]
7377

7478
gid = index_to_gid[seq.name]
@@ -78,29 +82,29 @@ def xmfa_to_covered(xmfa_file, index_to_gid, gid_to_cid_to_index):
7882
else:
7983
endpos = startpos + aln_len
8084

81-
idp = IdPair(gid, gid_to_cid_to_index[gid][contig_idx])
85+
idp = IdPair(gid, gid_to_index_to_cid[gid][contig_idx])
8286
seg = Segment(idp, startpos, startpos + aln_len, seq.annotations["strand"])
8387
idpair_to_segments[idp].append(seg)
84-
idpair_to_tree[idp].addi(seg.start, seg.stop)
8588
cluster_to_named_segments[cluster_idx].append(seg)
8689

8790
for idp in idpair_to_segments:
8891
idpair_to_segments[idp] = sorted(idpair_to_segments[idp])
89-
idpair_to_tree[idp].merge_overlaps()
90-
return idpair_to_segments, idpair_to_tree, cluster_to_named_segments
92+
return idpair_to_segments, cluster_to_named_segments
9193

9294

9395
def run_msa(downstream_segs_to_align, gid_to_records):
9496
keep_extending = True
9597
iteration = 0
96-
seq_len_desc = stats.describe([seg.stop - seg.start for seg in downstream_segs_to_align])
97-
longest_seq = seq_len_desc.minmax[1]
98-
if sum(
99-
seq_len_desc.mean*(1 - length_window) <= (seg.stop - seg.start) <= seq_len_desc.mean*(1 + length_window) for seg in downstream_segs_to_align) > len(downstream_segs_to_align)*window_prop:
100-
base_length = int(seq_len_desc.mean*(1 + length_window))
101-
else:
102-
base_length = BASE_LENGTH
98+
seq_lens = [seg.stop - seg.start for seg in downstream_segs_to_align]
99+
longest_seq = max(seq_lens)
100+
mean_seq_len = np.mean(seq_lens)
101+
# if sum(
102+
# mean_seq_len*(1 - length_window) <= (seg.stop - seg.start) <= mean_seq_len*(1 + length_window) for seg in downstream_segs_to_align) > len(downstream_segs_to_align)*window_prop:
103+
# base_length = int(mean_seq_len*(1 + length_window))
104+
# else:
105+
# base_length = BASE_LENGTH
103106

107+
base_length = BASE_LENGTH
104108
while keep_extending:
105109
seqs_to_align = ["A" + (str(
106110
gid_to_records[seg.idp.gid][seg.idp.cid].seq[seg.start:seg.stop] if seg.strand == 1
@@ -131,11 +135,15 @@ def run_msa(downstream_segs_to_align, gid_to_records):
131135
return aligned_msa_seqs
132136

133137

134-
def extend_clusters(xmfa_file, index_to_gid, gid_to_cid_to_index, idpair_to_segments, idpair_to_tree, cluster_to_named_segments, gid_to_records):
138+
def extend_clusters(xmfa_file, gid_to_index, gid_to_cid_to_index, idpair_to_segments, cluster_to_named_segments, gid_to_records):
135139
ret_lcbs = []
136140
seqid_parser = re.compile(r'^cluster(\d+) s(\d+):p(\d+)/.*')
137141

138-
for aln_idx, aln in enumerate(tqdm(AlignIO.parse(xmfa_file, "mauve"), total=len(cluster_to_named_segments))):
142+
for aln_idx, aln in enumerate(tqdm(
143+
AlignIO.parse(xmfa_file, "mauve"),
144+
total=len(cluster_to_named_segments),
145+
file=TqdmToLogger(logger, level=logging.INFO),
146+
mininterval=MIN_TQDM_INTERVAL)):
139147
# validate_lcb(aln, gid_to_records, parsnp_header=True)
140148
seq = aln[0]
141149
cluster_idx, contig_idx, startpos = [int(x) for x in seqid_parser.match(seq.id).groups()]
@@ -167,29 +175,36 @@ def extend_clusters(xmfa_file, index_to_gid, gid_to_cid_to_index, idpair_to_segm
167175
new_lcb = MultipleSeqAlignment([])
168176
# Assumes alignments are always in the same order
169177
new_bp = []
170-
for seq_idx, (covered_seg, uncovered_seg, aln_str) in enumerate(zip(segs, downstream_segs_to_align, aligned_msa_seqs)):
178+
for seg_idx, (covered_seg, uncovered_seg, aln_str) in enumerate(zip(segs, downstream_segs_to_align, aligned_msa_seqs)):
171179
# Update segment in idpair_to_segments
180+
if len(aln_str) < MIN_LEN:
181+
continue
172182
new_bp_covered = len(aln_str) - aln_str.count("-")
173183
# print(f"Extending {covered_seg} by {new_bp_covered}")
174184
new_bp.append(new_bp_covered)
175185
new_seq = aln_str
176186
if covered_seg.strand == 1:
177187
new_seg = Segment(covered_seg.idp, uncovered_seg.start, uncovered_seg.start + new_bp_covered, covered_seg.strand)
188+
if new_bp_covered > 0:
189+
segs[seg_idx] = Segment(covered_seg.idp, covered_seg.start, new_seg.stop, covered_seg.strand)
178190
else:
179191
aln_str = Seq(aln_str).reverse_complement()
180192
new_seg = Segment(covered_seg.idp, covered_seg.start - new_bp_covered, covered_seg.start, covered_seg.strand)
193+
if new_bp_covered > 0:
194+
segs[seg_idx] = Segment(covered_seg.idp, new_seg.start, covered_seg.stop, covered_seg.strand)
181195

182196
new_record = SeqRecord(
183197
seq=new_seq,
184-
id=f"{covered_seg.idp.gid}#{covered_seg.idp.cid}",
198+
id=f"cluster{cluster_idx} s{gid_to_cid_to_index[covered_seg.idp.gid][covered_seg.idp.cid]}:p{new_seg.start if new_seg.strand == 1 else new_seg.stop}",
199+
name=gid_to_index[covered_seg.idp.gid],
185200
annotations={"start": new_seg.start, "end": new_seg.stop, "strand": new_seg.strand}
186201
)
202+
187203
# if covered_seg.strand == 1:
188204
new_lcb.append(new_record)
189-
if new_bp_covered > 0:
190-
idpair_to_tree[covered_seg.idp].addi(new_seg.start, new_seg.stop)
191205

192-
ret_lcbs.append(new_lcb)
206+
if len(new_lcb) > 0:
207+
ret_lcbs.append(new_lcb)
193208
return ret_lcbs
194209

195210

logger.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import io
23
############################################# Logging ##############################################
34
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
45
#These are the sequences need to get colored ouput
@@ -14,6 +15,28 @@
1415
COLOR_SEQ = "\033[1;%dm"
1516
BOLD_SEQ = "\033[1m"
1617

18+
MIN_TQDM_INTERVAL=30
19+
20+
21+
# Logging redirect copied from https://stackoverflow.com/questions/14897756/python-progress-bar-through-logging-module
22+
class TqdmToLogger(io.StringIO):
23+
"""
24+
Output stream for TQDM which will output to logger module instead of
25+
the StdOut.
26+
"""
27+
logger = None
28+
level = None
29+
buf = ''
30+
def __init__(self,logger,level=None):
31+
super(TqdmToLogger, self).__init__()
32+
self.logger = logger
33+
self.level = level or logging.INFO
34+
def write(self,buf):
35+
self.buf = buf.strip('\r\n\t ')
36+
def flush(self):
37+
self.logger.log(self.level, self.buf)
38+
39+
1740
def formatter_message(message, use_color = True):
1841
if use_color:
1942
message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ)

parsnp

+42-54
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,16 @@
55
66
'''
77

8-
import os, sys, string, getopt, random,subprocess, time,operator, math, datetime,numpy #pysam
8+
import os, sys, string, random, subprocess, time, operator, math, datetime, numpy #pysam
99
from collections import defaultdict
10-
import csv
1110
import shutil
1211
import shlex
1312
from tempfile import TemporaryDirectory
1413
import re
1514
import logging
16-
from logger import logger
17-
import multiprocessing
15+
from logger import logger, TqdmToLogger, MIN_TQDM_INTERVAL
1816
import argparse
1917
import signal
20-
import inspect
2118
from multiprocessing import Pool
2219
from Bio import SeqIO
2320
from glob import glob
@@ -27,7 +24,7 @@ from pathlib import Path
2724
import extend as ext
2825
from tqdm import tqdm
2926

30-
__version__ = "2.0.1"
27+
__version__ = "2.0.2"
3128
reroot_tree = True #use --midpoint-reroot
3229
random_seeded = random.Random(42)
3330

@@ -149,7 +146,7 @@ def run_phipack(query,seqlen,workingdir):
149146
currdir = os.getcwd()
150147
os.chdir(workingdir)
151148
command = "Profile -o -v -n %d -w 100 -m 100 -f %s > %s.out"%(seqlen,query,query)
152-
run_command(command,1, prepend_time=True)
149+
run_command(command, 1)
153150
os.chdir(currdir)
154151

155152
def run_fasttree(query,workingdir,recombination_sites):
@@ -685,15 +682,20 @@ def create_output_directory(output_dir):
685682

686683
if os.path.exists(output_dir):
687684
logger.warning(f"Output directory {output_dir} exists, all results will be overwritten")
688-
shutil.rmtree(output_dir)
685+
if os.path.exists(output_dir + "/partition"):
686+
shutil.rmtree(output_dir + "/partition/")
687+
if os.path.exists(output_dir + "/config/"):
688+
shutil.rmtree(output_dir + "/config/")
689+
if os.path.exists(output_dir + "/log/"):
690+
shutil.rmtree(output_dir + "/log/")
689691
elif output_dir == "[P_CURRDATE_CURRTIME]":
690692
today = datetime.datetime.now()
691693
timestamp = "P_" + today.isoformat().replace("-", "_").replace(".", "").replace(":", "").replace("T", "_")
692694
output_dir = os.getcwd() + os.sep + timestamp
693-
os.makedirs(output_dir)
694-
os.makedirs(os.path.join(output_dir, "tmp"))
695-
os.makedirs(os.path.join(output_dir, "log"))
696-
os.makedirs(os.path.join(output_dir, "config"))
695+
os.makedirs(output_dir, exist_ok=True)
696+
os.makedirs(os.path.join(output_dir, "tmp"), exist_ok=True)
697+
os.makedirs(os.path.join(output_dir, "log"), exist_ok=True)
698+
os.makedirs(os.path.join(output_dir, "config"), exist_ok=True)
697699
return output_dir
698700

699701

@@ -1645,7 +1647,11 @@ SETTINGS:
16451647
logger.info("Running partitions...")
16461648
good_chunks = set(chunk_labels)
16471649
with Pool(args.threads) as pool:
1648-
return_codes = tqdm(pool.imap(run_parsnp_aligner, chunk_output_dirs, chunksize=1), total=len(chunk_output_dirs))
1650+
return_codes = tqdm(
1651+
pool.imap(run_parsnp_aligner, chunk_output_dirs, chunksize=1),
1652+
total=len(chunk_output_dirs),
1653+
file=TqdmToLogger(logger,level=logging.INFO),
1654+
mininterval=MIN_TQDM_INTERVAL)
16491655
for cl, rc in zip(chunk_labels, return_codes):
16501656
if rc != 0:
16511657
logger.error(f"Partition {cl} failed...")
@@ -1666,51 +1672,33 @@ SETTINGS:
16661672
partition.merge_xmfas(partition_output_dir, chunk_labels, xmfa_out_f, num_clusters, args.threads)
16671673

16681674

1669-
1670-
run_lcb_trees = 0
1675+
parsnp_output = f"{outputDir}/parsnp.xmfa"
16711676

16721677
# This is the stuff for LCB extension:
1673-
annotation_dict = {}
1674-
#TODO always using xtrafast?
1675-
parsnp_output = f"{outputDir}/parsnp.xmfa"
16761678
if args.extend_lcbs:
1677-
xmfa_file = f"{outputDir}/parsnp.xmfa"
1678-
with TemporaryDirectory() as temp_directory:
1679-
original_maf_file = f"{outputDir}/parsnp-original.maf"
1680-
extended_xmfa_file = f"{outputDir}/parsnp-extended.xmfa"
1681-
fname_contigid_to_length, fname_contigidx_to_header, fname_to_seqrecord = ext.get_sequence_data(
1682-
ref,
1683-
finalfiles,
1684-
index_files=False)
1685-
fname_to_contigid_to_coords, fname_header_to_gcontigidx = ext.xmfa_to_maf(
1686-
xmfa_file,
1687-
original_maf_file,
1688-
fname_contigidx_to_header,
1689-
fname_contigid_to_length)
1690-
packed_write_result = ext.write_intercluster_regions(finalfiles + [ref], temp_directory, fname_to_contigid_to_coords)
1691-
fname_contigid_to_cluster_dir_to_length, fname_contigid_to_cluster_dir_to_adjacent_cluster = packed_write_result
1692-
cluster_files = glob(f"{temp_directory}/*.fasta")
1693-
clusterdir_expand, clusterdir_len = ext.get_new_extensions(
1694-
cluster_files,
1695-
args.match_score,
1696-
args.mismatch_penalty,
1697-
args.gap_penalty)
1698-
ext.write_extended_xmfa(
1699-
original_maf_file,
1700-
extended_xmfa_file,
1701-
temp_directory,
1702-
clusterdir_expand,
1703-
clusterdir_len,
1704-
fname_contigid_to_cluster_dir_to_length,
1705-
fname_contigid_to_cluster_dir_to_adjacent_cluster,
1706-
fname_header_to_gcontigidx,
1707-
fname_contigid_to_length,
1708-
args.extend_ani_cutoff,
1709-
args.extend_indel_cutoff,
1710-
threads)
1711-
parsnp_output = extended_xmfa_file
1712-
os.remove(original_maf_file)
1679+
logger.warning("The LCB extension module is experimental. Runtime may be significantly increased and extended alignments may not be as high quality as the original core-genome. Extensions off of existing LCBs are in a separate xmfa file.")
1680+
import partition
1681+
import extend as ext
1682+
1683+
orig_parsnp_xmfa = parsnp_output
1684+
extended_parsnp_xmfa = orig_parsnp_xmfa + ".extended"
1685+
1686+
# Index input fasta files and original xmfa
1687+
index_to_gid, gid_to_index = ext.parse_xmfa_header(orig_parsnp_xmfa)
1688+
gid_to_records, gid_to_cid_to_index, gid_to_index_to_cid = ext.index_input_sequences(orig_parsnp_xmfa, finalfiles + [ref])
1689+
1690+
# Get covered regions of xmfa file
1691+
idpair_to_segments, cluster_to_named_segments = ext.xmfa_to_covered(orig_parsnp_xmfa, index_to_gid, gid_to_index_to_cid)
1692+
1693+
# Extend clusters
1694+
logger.info(f"Extending LCBs with SPOA...")
1695+
new_lcbs = ext.extend_clusters(orig_parsnp_xmfa, gid_to_index, gid_to_cid_to_index, idpair_to_segments, cluster_to_named_segments, gid_to_records)
17131696

1697+
# Write output
1698+
partition.copy_header(orig_parsnp_xmfa, extended_parsnp_xmfa)
1699+
with open(extended_parsnp_xmfa, 'a') as out_f:
1700+
for lcb in new_lcbs:
1701+
partition.write_aln_to_xmfa(lcb, out_f)
17141702

17151703
#add genbank here, if present
17161704
if len(genbank_ref) != 0:

0 commit comments

Comments
 (0)