From 285f0be77f2c30ca9270c2a6c01939fe850bb0c8 Mon Sep 17 00:00:00 2001 From: ivan-aksamentov Date: Tue, 10 Dec 2024 20:19:12 +0100 Subject: [PATCH] perf: speedup parsimony sequence reconstruction and fasta writer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This removes unnecessary heap allocations and copies in the reconstruction step (reconstruction from copressed representation - `ancestral_reconstruction_fitch()`) and in the fasta writer code. Speedup: ~16% on mpox-500 dataset. Command: ``` $ cargo -q build --release --target-dir=/workdir/.build/docker --bin=treetime $ hyperfine --warmup 1 --show-output '/workdir/.build/docker/release/treetime ancestral --method-anc=parsimony --dense=false --tree=data/mpox/clade-ii/500/tree.nwk --outdir=tmp/smoke-tests/ancestral/marginal/mpox/clade-ii/500 data/mpox/clade-ii/500/aln.fasta.xz' ``` Before (branch: rust, commit f676392): ``` Time (mean ± σ): 1.866 s ± 0.113 s [User: 4.547 s, System: 0.576 s] Range (min … max): 1.768 s … 2.041 s 10 runs ``` After: ``` Time (mean ± σ): 1.582 s ± 0.014 s [User: 4.358 s, System: 0.505 s] Range (min … max): 1.558 s … 1.599 s 10 runs ``` According to profiler, on the same data, `ancestral_reconstruction_fitch()` is now 6% of the running time, down from 16%. --- .../treetime/src/commands/ancestral/fitch.rs | 69 ++++++++----------- .../src/commands/ancestral/run_ancestral.rs | 10 +-- packages/treetime/src/graph/graph.rs | 6 +- packages/treetime/src/io/fasta.rs | 32 ++++++--- 4 files changed, 59 insertions(+), 58 deletions(-) diff --git a/packages/treetime/src/commands/ancestral/fitch.rs b/packages/treetime/src/commands/ancestral/fitch.rs index f15c9ed6..46ef7f59 100644 --- a/packages/treetime/src/commands/ancestral/fitch.rs +++ b/packages/treetime/src/commands/ancestral/fitch.rs @@ -462,7 +462,7 @@ pub fn ancestral_reconstruction_fitch( graph: &SparseGraph, include_leaves: bool, partitions: &[PartitionParsimony], - mut visitor: impl FnMut(&SparseNode, Vec), + mut visitor: impl FnMut(&SparseNode, &[char]), ) -> Result<(), Report> { let n_partitions = partitions.len(); @@ -471,54 +471,41 @@ pub fn ancestral_reconstruction_fitch( return; } - let seq = (0..n_partitions) - .flat_map(|si| { - let PartitionParsimony { alphabet, .. } = &partitions[si]; + for si in 0..n_partitions { + let PartitionParsimony { alphabet, .. } = &partitions[si]; - let mut seq = if node.is_root { - node.payload.sparse_partitions[si].seq.sequence.clone() - } else { - let (parent, edge) = node.get_exactly_one_parent().unwrap(); - let parent = &parent.read_arc().sparse_partitions[si]; - let edge = &edge.read_arc().sparse_partitions[si]; + if !node.is_root { + let (parent, edge) = node.get_exactly_one_parent().unwrap(); + let parent_seq = &parent.read_arc().sparse_partitions[si].seq.sequence; + let edge_part = &edge.read_arc().sparse_partitions[si]; - let mut seq = parent.seq.sequence.clone(); + node.payload.sparse_partitions[si].seq.sequence = parent_seq.clone(); - // Implant mutations - for sub in &edge.subs { - seq[sub.pos] = sub.qry; - } + for sub in &edge_part.subs { + node.payload.sparse_partitions[si].seq.sequence[sub.pos] = sub.qry; + } - // Implant indels - for indel in &edge.indels { - if indel.deletion { - seq[indel.range.0..indel.range.1].fill(alphabet.gap()); - } else { - seq[indel.range.0..indel.range.1].copy_from_slice(&indel.seq); - } + for indel in &edge_part.indels { + if indel.deletion { + node.payload.sparse_partitions[si].seq.sequence[indel.range.0..indel.range.1].fill(alphabet.gap()); + } else { + node.payload.sparse_partitions[si].seq.sequence[indel.range.0..indel.range.1].copy_from_slice(&indel.seq); } - - seq - }; - - let node = &mut node.payload.sparse_partitions[si].seq; - - // At the node itself, mask whatever is unknown in the node. - for r in &node.unknown { - seq[r.0..r.1].fill(alphabet.unknown()); } + } - for (&pos, &states) in &node.fitch.variable { - seq[pos] = alphabet.set_to_char(states); - } + let seq = &mut node.payload.sparse_partitions[si].seq; - node.sequence = seq.clone(); + for r in &mut seq.unknown { + seq.sequence[r.0..r.1].fill(alphabet.unknown()); + } - seq - }) - .collect(); + for (pos, states) in &mut seq.fitch.variable { + seq.sequence[*pos] = alphabet.set_to_char(*states); + } - visitor(&node.payload, seq); + visitor(&node.payload, &node.payload.sparse_partitions[si].seq.sequence); + } }); Ok(()) @@ -610,7 +597,7 @@ mod tests { let mut actual = BTreeMap::new(); ancestral_reconstruction_fitch(&graph, false, &partitions, |node, seq| { - actual.insert(node.name.clone(), vec_to_string(seq)); + actual.insert(node.name.clone(), vec_to_string(seq.to_owned())); })?; assert_eq!( @@ -670,7 +657,7 @@ mod tests { let mut actual = BTreeMap::new(); ancestral_reconstruction_fitch(&graph, true, &partitions, |node, seq| { - actual.insert(node.name.clone(), vec_to_string(seq)); + actual.insert(node.name.clone(), vec_to_string(seq.to_owned())); })?; assert_eq!( diff --git a/packages/treetime/src/commands/ancestral/run_ancestral.rs b/packages/treetime/src/commands/ancestral/run_ancestral.rs index 39384ced..3a5d7d42 100644 --- a/packages/treetime/src/commands/ancestral/run_ancestral.rs +++ b/packages/treetime/src/commands/ancestral/run_ancestral.rs @@ -17,7 +17,6 @@ use crate::representation::infer_dense::infer_dense; use crate::representation::partitions_likelihood::{PartitionLikelihood, PartitionLikelihoodWithAln}; use crate::representation::partitions_parsimony::PartitionParsimonyWithAln; use crate::utils::random::get_random_number_generator; -use crate::utils::string::vec_to_string; use eyre::Report; use itertools::Itertools; use serde::Serialize; @@ -89,8 +88,7 @@ pub fn run_ancestral_reconstruction(ancestral_args: &TreetimeAncestralArgs) -> R ancestral_reconstruction_fitch(&graph, *reconstruct_tip_states, &partitions, |node, seq| { let name = node.name.as_deref().unwrap_or(""); let desc = &node.desc; - // TODO: avoid converting vec to string, write vec chars directly - output_fasta.write(name, desc, vec_to_string(seq)).unwrap(); + output_fasta.write(name, desc, seq).unwrap(); })?; write_graph(outdir, &graph)?; @@ -112,8 +110,7 @@ pub fn run_ancestral_reconstruction(ancestral_args: &TreetimeAncestralArgs) -> R ancestral_reconstruction_marginal_sparse(&graph, *reconstruct_tip_states, &partitions, |node, seq| { let name = node.name.as_deref().unwrap_or(""); let desc = &node.desc; - // TODO: avoid converting vec to string, write vec chars directly - output_fasta.write(name, desc, vec_to_string(seq)).unwrap(); + output_fasta.write(name, desc, &seq).unwrap(); })?; write_graph(outdir, &graph)?; @@ -127,8 +124,7 @@ pub fn run_ancestral_reconstruction(ancestral_args: &TreetimeAncestralArgs) -> R ancestral_reconstruction_marginal_dense(&graph, *reconstruct_tip_states, |node, seq| { let name = node.name.as_deref().unwrap_or(""); let desc = &node.desc; - // TODO: avoid converting vec to string, write vec chars directly - output_fasta.write(name, desc, vec_to_string(seq)).unwrap(); + output_fasta.write(name, desc, &seq).unwrap(); })?; write_graph(outdir, &graph)?; diff --git a/packages/treetime/src/graph/graph.rs b/packages/treetime/src/graph/graph.rs index 9d2c3123..4698e8f1 100644 --- a/packages/treetime/src/graph/graph.rs +++ b/packages/treetime/src/graph/graph.rs @@ -87,8 +87,10 @@ where } } - pub fn get_exactly_one_parent(&self) -> Result<&NodeEdgePayloadPair, Report> { - get_exactly_one(&self.parents).wrap_err("Nodes with multiple parents are not yet supported") + pub fn get_exactly_one_parent(&self) -> Result, Report> { + get_exactly_one(&self.parents) + .cloned() + .wrap_err("Nodes with multiple parents are not yet supported") } } diff --git a/packages/treetime/src/io/fasta.rs b/packages/treetime/src/io/fasta.rs index 352b8a9f..20604ca9 100644 --- a/packages/treetime/src/io/fasta.rs +++ b/packages/treetime/src/io/fasta.rs @@ -253,12 +253,7 @@ impl FastaWriter { Ok(Self::new(create_file_or_stdout(filepath)?)) } - pub fn write( - &mut self, - seq_name: impl AsRef, - desc: &Option, - seq: impl AsRef, - ) -> Result<(), Report> { + pub fn write(&mut self, seq_name: impl AsRef, desc: &Option, seq: &[char]) -> Result<(), Report> { self.writer.write_all(b">")?; self.writer.write_all(seq_name.as_ref().as_bytes())?; @@ -268,7 +263,7 @@ impl FastaWriter { } self.writer.write_all(b"\n")?; - self.writer.write_all(seq.as_ref().as_bytes())?; + write_chars_chunked(&mut self.writer, seq)?; self.writer.write_all(b"\n")?; Ok(()) } @@ -279,11 +274,32 @@ impl FastaWriter { } } +fn write_chars_chunked(writer: &mut W, chars: &[char]) -> Result<(), Report> +where + W: std::io::Write, +{ + const CHUNK_SIZE: usize = 1024; + let mut buffer = [0_u8; CHUNK_SIZE]; + let mut buffer_index = 0; + for c in chars { + if buffer_index >= CHUNK_SIZE { + writer.write_all(&buffer)?; + buffer_index = 0; + } + buffer[buffer_index] = *c as u8; + buffer_index += 1; + } + if buffer_index > 0 { + writer.write_all(&buffer[..buffer_index])?; + } + Ok(()) +} + pub fn write_one_fasta( filepath: impl AsRef, seq_name: impl AsRef, desc: &Option, - seq: impl AsRef, + seq: &[char], ) -> Result<(), Report> { let mut writer = FastaWriter::from_path(&filepath)?; writer.write(seq_name, desc, seq)