Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: speedup parsimony sequence reconstruction and fasta writer #295

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 28 additions & 41 deletions packages/treetime/src/commands/ancestral/fitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ pub fn ancestral_reconstruction_fitch(
graph: &SparseGraph,
include_leaves: bool,
partitions: &[PartitionParsimony],
mut visitor: impl FnMut(&SparseNode, Vec<char>),
mut visitor: impl FnMut(&SparseNode, &[char]),
) -> Result<(), Report> {
let n_partitions = partitions.len();

Expand All @@ -471,54 +471,41 @@ pub fn ancestral_reconstruction_fitch(
return;
}

let seq = (0..n_partitions)
.flat_map(|si| {
let PartitionParsimony { alphabet, .. } = &partitions[si];
for si in 0..n_partitions {
let PartitionParsimony { alphabet, .. } = &partitions[si];

let mut seq = if node.is_root {
node.payload.sparse_partitions[si].seq.sequence.clone()
} else {
let (parent, edge) = node.get_exactly_one_parent().unwrap();
let parent = &parent.read_arc().sparse_partitions[si];
let edge = &edge.read_arc().sparse_partitions[si];
if !node.is_root {
let (parent, edge) = node.get_exactly_one_parent().unwrap();
let parent_seq = &parent.read_arc().sparse_partitions[si].seq.sequence;
let edge_part = &edge.read_arc().sparse_partitions[si];

let mut seq = parent.seq.sequence.clone();
node.payload.sparse_partitions[si].seq.sequence = parent_seq.clone();

// Implant mutations
for sub in &edge.subs {
seq[sub.pos] = sub.qry;
}
for sub in &edge_part.subs {
node.payload.sparse_partitions[si].seq.sequence[sub.pos] = sub.qry;
}

// Implant indels
for indel in &edge.indels {
if indel.deletion {
seq[indel.range.0..indel.range.1].fill(alphabet.gap());
} else {
seq[indel.range.0..indel.range.1].copy_from_slice(&indel.seq);
}
for indel in &edge_part.indels {
if indel.deletion {
node.payload.sparse_partitions[si].seq.sequence[indel.range.0..indel.range.1].fill(alphabet.gap());
} else {
node.payload.sparse_partitions[si].seq.sequence[indel.range.0..indel.range.1].copy_from_slice(&indel.seq);
}

seq
};

let node = &mut node.payload.sparse_partitions[si].seq;

// At the node itself, mask whatever is unknown in the node.
for r in &node.unknown {
seq[r.0..r.1].fill(alphabet.unknown());
}
}

for (&pos, &states) in &node.fitch.variable {
seq[pos] = alphabet.set_to_char(states);
}
let seq = &mut node.payload.sparse_partitions[si].seq;

node.sequence = seq.clone();
for r in &mut seq.unknown {
seq.sequence[r.0..r.1].fill(alphabet.unknown());
}

seq
})
.collect();
for (pos, states) in &mut seq.fitch.variable {
seq.sequence[*pos] = alphabet.set_to_char(*states);
}

visitor(&node.payload, seq);
visitor(&node.payload, &node.payload.sparse_partitions[si].seq.sequence);
}
});

Ok(())
Expand Down Expand Up @@ -610,7 +597,7 @@ mod tests {

let mut actual = BTreeMap::new();
ancestral_reconstruction_fitch(&graph, false, &partitions, |node, seq| {
actual.insert(node.name.clone(), vec_to_string(seq));
actual.insert(node.name.clone(), vec_to_string(seq.to_owned()));
})?;

assert_eq!(
Expand Down Expand Up @@ -670,7 +657,7 @@ mod tests {

let mut actual = BTreeMap::new();
ancestral_reconstruction_fitch(&graph, true, &partitions, |node, seq| {
actual.insert(node.name.clone(), vec_to_string(seq));
actual.insert(node.name.clone(), vec_to_string(seq.to_owned()));
})?;

assert_eq!(
Expand Down
10 changes: 3 additions & 7 deletions packages/treetime/src/commands/ancestral/run_ancestral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::representation::infer_dense::infer_dense;
use crate::representation::partitions_likelihood::{PartitionLikelihood, PartitionLikelihoodWithAln};
use crate::representation::partitions_parsimony::PartitionParsimonyWithAln;
use crate::utils::random::get_random_number_generator;
use crate::utils::string::vec_to_string;
use eyre::Report;
use itertools::Itertools;
use serde::Serialize;
Expand Down Expand Up @@ -89,8 +88,7 @@ pub fn run_ancestral_reconstruction(ancestral_args: &TreetimeAncestralArgs) -> R
ancestral_reconstruction_fitch(&graph, *reconstruct_tip_states, &partitions, |node, seq| {
let name = node.name.as_deref().unwrap_or("");
let desc = &node.desc;
// TODO: avoid converting vec to string, write vec chars directly
output_fasta.write(name, desc, vec_to_string(seq)).unwrap();
output_fasta.write(name, desc, seq).unwrap();
})?;

write_graph(outdir, &graph)?;
Expand All @@ -112,8 +110,7 @@ pub fn run_ancestral_reconstruction(ancestral_args: &TreetimeAncestralArgs) -> R
ancestral_reconstruction_marginal_sparse(&graph, *reconstruct_tip_states, &partitions, |node, seq| {
let name = node.name.as_deref().unwrap_or("");
let desc = &node.desc;
// TODO: avoid converting vec to string, write vec chars directly
output_fasta.write(name, desc, vec_to_string(seq)).unwrap();
output_fasta.write(name, desc, &seq).unwrap();
})?;

write_graph(outdir, &graph)?;
Expand All @@ -127,8 +124,7 @@ pub fn run_ancestral_reconstruction(ancestral_args: &TreetimeAncestralArgs) -> R
ancestral_reconstruction_marginal_dense(&graph, *reconstruct_tip_states, |node, seq| {
let name = node.name.as_deref().unwrap_or("");
let desc = &node.desc;
// TODO: avoid converting vec to string, write vec chars directly
output_fasta.write(name, desc, vec_to_string(seq)).unwrap();
output_fasta.write(name, desc, &seq).unwrap();
})?;

write_graph(outdir, &graph)?;
Expand Down
6 changes: 4 additions & 2 deletions packages/treetime/src/graph/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ where
}
}

pub fn get_exactly_one_parent(&self) -> Result<&NodeEdgePayloadPair<N, E>, Report> {
get_exactly_one(&self.parents).wrap_err("Nodes with multiple parents are not yet supported")
pub fn get_exactly_one_parent(&self) -> Result<NodeEdgePayloadPair<N, E>, Report> {
get_exactly_one(&self.parents)
.cloned()
.wrap_err("Nodes with multiple parents are not yet supported")
}
}

Expand Down
32 changes: 24 additions & 8 deletions packages/treetime/src/io/fasta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,7 @@ impl FastaWriter {
Ok(Self::new(create_file_or_stdout(filepath)?))
}

pub fn write(
&mut self,
seq_name: impl AsRef<str>,
desc: &Option<String>,
seq: impl AsRef<str>,
) -> Result<(), Report> {
pub fn write(&mut self, seq_name: impl AsRef<str>, desc: &Option<String>, seq: &[char]) -> Result<(), Report> {
self.writer.write_all(b">")?;
self.writer.write_all(seq_name.as_ref().as_bytes())?;

Expand All @@ -268,7 +263,7 @@ impl FastaWriter {
}

self.writer.write_all(b"\n")?;
self.writer.write_all(seq.as_ref().as_bytes())?;
write_chars_chunked(&mut self.writer, seq)?;
self.writer.write_all(b"\n")?;
Ok(())
}
Expand All @@ -279,11 +274,32 @@ impl FastaWriter {
}
}

fn write_chars_chunked<W>(writer: &mut W, chars: &[char]) -> Result<(), Report>
where
W: std::io::Write,
{
const CHUNK_SIZE: usize = 1024;
let mut buffer = [0_u8; CHUNK_SIZE];
let mut buffer_index = 0;
for c in chars {
if buffer_index >= CHUNK_SIZE {
writer.write_all(&buffer)?;
buffer_index = 0;
}
buffer[buffer_index] = *c as u8;
buffer_index += 1;
}
if buffer_index > 0 {
writer.write_all(&buffer[..buffer_index])?;
}
Ok(())
}

pub fn write_one_fasta(
filepath: impl AsRef<Path>,
seq_name: impl AsRef<str>,
desc: &Option<String>,
seq: impl AsRef<str>,
seq: &[char],
) -> Result<(), Report> {
let mut writer = FastaWriter::from_path(&filepath)?;
writer.write(seq_name, desc, seq)
Expand Down