diff --git a/packages/treetime/src/alphabet/alphabet.rs b/packages/treetime/src/alphabet/alphabet.rs index 7d894def..bdc77d20 100644 --- a/packages/treetime/src/alphabet/alphabet.rs +++ b/packages/treetime/src/alphabet/alphabet.rs @@ -1,17 +1,17 @@ use crate::io::json::{json_write_str, JsonPretty}; -use crate::make_error; +use crate::representation::bitset128::BitSet128; +use crate::representation::state_set::StateSet; use crate::utils::string::quote; +use crate::{make_error, stateset}; use clap::ArgEnum; use color_eyre::{Section, SectionExt}; use eyre::{Report, WrapErr}; -use indexmap::{indexmap, IndexMap, IndexSet}; +use indexmap::{indexmap, IndexMap}; use itertools::{chain, Itertools}; -use maplit::btreeset; use ndarray::{stack, Array1, Array2, Axis}; use serde::{Deserialize, Serialize}; use smart_default::SmartDefault; use std::borrow::Borrow; -use std::collections::BTreeSet; use std::fmt::Display; use std::iter::once; use strum_macros::Display; @@ -32,13 +32,21 @@ pub type ProfileMap = IndexMap>; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Alphabet { - all: IndexSet, - canonical: IndexSet, + all: StateSet, + canonical: StateSet, ambiguous: IndexMap>, + ambiguous_keys: StateSet, + determined: StateSet, + undetermined: StateSet, unknown: char, gap: char, treat_gap_as_unknown: bool, profile_map: ProfileMap, + + #[serde(skip)] + char_to_index: Vec>, + #[serde(skip)] + index_to_char: Vec, } impl Default for Alphabet { @@ -95,26 +103,36 @@ impl Alphabet { treat_gap_as_unknown, } = cfg; - let canonical: IndexSet = canonical.iter().copied().collect(); + let canonical: StateSet = canonical.iter().copied().collect(); if canonical.is_empty() { return make_error!("When creating alphabet: canonical set of characters is empty. This is not allowed."); } let ambiguous: IndexMap> = ambiguous.to_owned(); + let ambiguous_keys = ambiguous.keys().collect(); - let all: IndexSet = chain!( - canonical.iter().copied(), - ambiguous.keys().copied(), - [*unknown, *gap].into_iter(), - ) - .collect(); + let undetermined = stateset! {*unknown, *gap}; + let determined = StateSet::from_union([canonical, ambiguous_keys]); + let all = StateSet::from_union([canonical, ambiguous_keys, undetermined]); let profile_map = cfg.create_profile_map()?; + let mut char_to_index = vec![None; 128]; + let mut index_to_char = Vec::with_capacity(canonical.len()); + for (i, c) in canonical.iter().enumerate() { + char_to_index[c as usize] = Some(i); + index_to_char.push(c); + } + Ok(Self { all, + char_to_index, + index_to_char, canonical, ambiguous, + ambiguous_keys, + determined, + undetermined, unknown: *unknown, gap: *gap, treat_gap_as_unknown: *treat_gap_as_unknown, @@ -123,7 +141,7 @@ impl Alphabet { } /// Resolve possible ambiguity of the given character to the set of canonical chars - pub fn disambiguate(&self, c: char) -> BTreeSet { + pub fn disambiguate(&self, c: char) -> StateSet { // If unknown then could be any canonical (e.g. N => { A, C, G, T }) if self.is_unknown(c) { self.canonical().collect() @@ -138,42 +156,43 @@ impl Alphabet { } } - /// Map a set of canonical characters back to the smallest set of ambiguous characters - /// - /// NOTE: Reverse of `disambiguate()` - pub fn ambiguate(&self, chars: &BTreeSet) -> BTreeSet { - let mut chars: IndexSet = chars.iter().flat_map(|c| self.disambiguate(*c).into_iter()).collect(); - assert!(chars.iter().all(|c| self.canonical.contains(c))); - - if self.canonical.is_subset(&chars) { - return once(self.unknown).collect(); - } - - // Attempt to cover the set using the least number of ambiguous characters - let ambiguous = self - .ambiguous - .iter() - .map(|(amb_char, amb_set)| { - let set: IndexSet<_> = amb_set.iter().copied().collect(); - (amb_char, set) - }) - .sorted_by_key(|(_, set)| -(set.len() as isize)); - - let mut result = btreeset! {}; - for (amb_char, amb_set) in ambiguous { - if chars.is_superset(&amb_set) { - result.insert(*amb_char); - chars = &chars - &amb_set; - if chars.is_empty() { - break; - } - } - } - - result.extend(&chars); - - result - } + // /// Map a set of canonical characters back to the smallest set of ambiguous characters + // /// + // /// NOTE: Reverse of `disambiguate()` + // pub fn ambiguate(&self, mut chars: StateSet) -> StateSet { + // // assert!( + // // chars.iter().all(|c| self.canonical.contains(c)), + // // "Expected only canonical characters ({}), but found: {chars}", + // // self.canonical + // // ); + // + // if self.canonical.is_subset(&chars) { + // return once(self.unknown).collect(); + // } + // + // // Attempt to cover the set using the least number of ambiguous characters + // let ambiguous = self + // .ambiguous + // .iter() + // .map(|(amb_char, amb_set)| { + // let set: StateSet = amb_set.iter().copied().collect(); + // (amb_char, set) + // }) + // .sorted_by_key(|(_, set)| -(set.len() as isize)); + // + // let mut result = stateset! {}; + // for (amb_char, amb_set) in ambiguous { + // if chars.is_superset(&amb_set) { + // result.insert(*amb_char); + // chars -= amb_set; + // if chars.is_empty() { + // break; + // } + // } + // } + // + // result + chars + // } #[inline] pub fn get_profile(&self, c: char) -> &Array1 { @@ -198,7 +217,7 @@ impl Alphabet { let mut profile = Array1::::zeros(self.n_canonical()); for c in chars { let chars = self.disambiguate(*c.borrow()); - for c in chars { + for c in chars.iter() { let index = self.index(c); profile[index] = 1.0; } @@ -216,16 +235,16 @@ impl Alphabet { .unwrap() } - pub fn sequence_to_indices<'a>(&'a self, chars: impl Iterator + 'a) -> impl Iterator + 'a { - chars.map(|c| self.index(c)) - } - - pub fn indices_to_sequence<'a>( - &'a self, - indices: impl Iterator + 'a, - ) -> impl Iterator + 'a { - indices.map(|i| self.char(i)) - } + // pub fn sequence_to_indices<'a>(&'a self, chars: impl Iterator + 'a) -> impl Iterator + 'a { + // chars.map(|c| self.index(c)) + // } + // + // pub fn indices_to_sequence<'a>( + // &'a self, + // indices: impl Iterator + 'a, + // ) -> impl Iterator + 'a { + // indices.map(|i| self.char(i)) + // } #[allow(single_use_lifetimes)] // TODO: remove when anonymous lifetimes in `impl Trait` are stabilized pub fn seq2prof<'a>(&self, chars: impl IntoIterator) -> Result, Report> { @@ -238,22 +257,22 @@ impl Alphabet { /// All existing characters (including 'unknown' and 'gap') pub fn chars(&self) -> impl Iterator + '_ { - self.all.iter().copied() + self.all.iter() } /// Get char by index (indexed in the same order as given by `.chars()`) pub fn char(&self, index: usize) -> char { - self.all[index] + self.index_to_char[index] } /// Get index of a character (indexed in the same order as given by `.chars()`) pub fn index(&self, c: char) -> usize { - self.all.get_index_of(&c).unwrap() + self.char_to_index[c as usize].unwrap() } /// Check if character is in alphabet (including 'unknown' and 'gap') pub fn contains(&self, c: char) -> bool { - self.all.contains(&c) + self.all.contains(c) } pub fn n_chars(&self) -> usize { @@ -262,12 +281,12 @@ impl Alphabet { /// Canonical (unambiguous) characters (e.g. 'A', 'C', 'G', 'T' in nuc alphabet) pub fn canonical(&self) -> impl Iterator + '_ { - self.canonical.iter().copied() + self.canonical.iter() } /// Check is character is canonical pub fn is_canonical(&self, c: char) -> bool { - self.canonical().contains(&c) + self.canonical.contains(c) } pub fn n_canonical(&self) -> usize { @@ -276,12 +295,12 @@ impl Alphabet { /// Ambiguous characters (e.g. 'R', 'S' etc. in nuc alphabet) pub fn ambiguous(&self) -> impl Iterator + '_ { - self.ambiguous.keys().copied() + self.ambiguous_keys.iter() } /// Check if character is ambiguous (e.g. 'R', 'S' etc. in nuc alphabet) pub fn is_ambiguous(&self, c: char) -> bool { - self.ambiguous().contains(&c) + self.ambiguous_keys.contains(c) } pub fn n_ambiguous(&self) -> usize { @@ -290,28 +309,28 @@ impl Alphabet { /// Determined characters: canonical or ambiguous pub fn determined(&self) -> impl Iterator + '_ { - chain!(self.canonical(), self.ambiguous()) + self.determined.iter() } pub fn is_determined(&self, c: char) -> bool { - self.determined().contains(&c) + self.determined.contains(c) } pub fn n_determined(&self) -> usize { - self.determined().count() + self.determined.len() } /// Undetermined characters: gap or unknown pub fn undetermined(&self) -> impl Iterator + '_ { - [self.gap(), self.unknown()].into_iter() + self.undetermined.iter() } pub fn is_undetermined(&self, c: char) -> bool { - self.undetermined().contains(&c) + self.undetermined.contains(c) } pub fn n_undetermined(&self) -> usize { - self.undetermined().count() + self.undetermined.len() } /// Get 'unknown' character @@ -435,36 +454,36 @@ impl AlphabetConfig { } } - let canonical: IndexSet<_> = canonical.iter().copied().collect(); - let ambiguous_keys: IndexSet<_> = ambiguous.keys().copied().collect(); - let ambiguous_set_map: IndexMap> = ambiguous + let canonical: StateSet = canonical.iter().copied().collect(); + let ambiguous_keys: StateSet = ambiguous.keys().copied().collect(); + let ambiguous_set_map: IndexMap = ambiguous .iter() .map(|(key, vals)| (*key, vals.iter().copied().collect())) .collect(); { - let canonical_inter_ambig: IndexSet<_> = canonical.intersection(&ambiguous_keys).copied().collect(); + let canonical_inter_ambig: StateSet = canonical.intersection(&ambiguous_keys); if !canonical_inter_ambig.is_empty() { - let msg = canonical_inter_ambig.into_iter().join(", "); + let msg = canonical_inter_ambig.iter().join(", "); return make_error!("Canonical and ambiguous sets must be disjoint, but these characters are shared: {msg}"); } } - if canonical.contains(gap) { + if canonical.contains(*gap) { let msg = canonical.iter().map(quote).join(", "); return make_error!("Canonical set contains 'gap' character: {msg}"); } - if canonical.contains(unknown) { + if canonical.contains(*unknown) { let msg = canonical.iter().map(quote).join(", "); return make_error!("Canonical set contains 'unknown' character: {msg}"); } - if ambiguous.keys().contains(&gap) { + if ambiguous_keys.contains(*gap) { let msg = ambiguous.keys().map(quote).join(", "); return make_error!("Ambiguous set contains 'gap' character: {msg}"); } - if ambiguous.keys().contains(&gap) { + if ambiguous_keys.contains(*gap) { let msg = ambiguous.keys().map(quote).join(", "); return make_error!("Ambiguous set contains 'unknown' character: {msg}"); } @@ -472,9 +491,10 @@ impl AlphabetConfig { { let ambig_gaps = ambiguous_set_map .iter() - .map(|(key, vals)| (key, vals.difference(&canonical).collect::>())) + .map(|(key, vals)| (key, vals.difference(&canonical))) .filter(|(key, extra)| !extra.is_empty()) .collect_vec(); + if !ambig_gaps.is_empty() { let msg = ambig_gaps .iter() @@ -493,88 +513,86 @@ mod tests { use super::*; use eyre::Report; use indoc::indoc; - use maplit::btreeset; - use ndarray::array; use pretty_assertions::assert_eq; - #[test] - fn test_alphabet_sequence_to_indices() -> Result<(), Report> { - let actual = Alphabet::new(AlphabetName::Nuc, false)? - .sequence_to_indices(array!['A', 'G', 'T', 'G', '-', 'G', 'N', 'G', 'C'].into_iter()) - .collect_vec(); - let expected = vec![0, 2, 3, 2, 15, 2, 14, 2, 1]; - assert_eq!(expected, actual); - Ok(()) - } - - #[test] - fn test_alphabet_indices_to_sequence() -> Result<(), Report> { - let actual = Alphabet::new(AlphabetName::Nuc, false)? - .indices_to_sequence(array![0, 2, 3, 2, 15, 2, 14, 2, 1].into_iter()) - .collect_vec(); - let expected = vec!['A', 'G', 'T', 'G', '-', 'G', 'N', 'G', 'C']; - assert_eq!(expected, actual); - Ok(()) - } + // #[test] + // fn test_alphabet_sequence_to_indices() -> Result<(), Report> { + // let actual = Alphabet::new(AlphabetName::Nuc, false)? + // .sequence_to_indices(array!['A', 'G', 'T', 'G', '-', 'G', 'N', 'G', 'C'].into_iter()) + // .collect_vec(); + // let expected = vec![0, 2, 3, 2, 15, 2, 14, 2, 1]; + // assert_eq!(expected, actual); + // Ok(()) + // } + + // #[test] + // fn test_alphabet_indices_to_sequence() -> Result<(), Report> { + // let actual = Alphabet::new(AlphabetName::Nuc, false)? + // .indices_to_sequence(array![0, 2, 3, 2, 15, 2, 14, 2, 1].into_iter()) + // .collect_vec(); + // let expected = vec!['A', 'G', 'T', 'G', '-', 'G', 'N', 'G', 'C']; + // assert_eq!(expected, actual); + // Ok(()) + // } #[test] fn test_disambiguate() -> Result<(), Report> { let alphabet = Alphabet::new(AlphabetName::Nuc, false)?; - assert_eq!(btreeset! {'A', 'G'}, alphabet.disambiguate('R')); - assert_eq!(btreeset! {'A', 'C', 'G', 'T'}, alphabet.disambiguate('N')); - assert_eq!(btreeset! {'C'}, alphabet.disambiguate('C')); - assert_eq!(btreeset! {alphabet.gap()}, alphabet.disambiguate(alphabet.gap())); + assert_eq!(stateset! {'A', 'G'}, alphabet.disambiguate('R')); + assert_eq!(stateset! {'A', 'C', 'G', 'T'}, alphabet.disambiguate('N')); + assert_eq!(stateset! {'C'}, alphabet.disambiguate('C')); + assert_eq!(stateset! {alphabet.gap()}, alphabet.disambiguate(alphabet.gap())); Ok(()) } - #[test] - fn test_ambiguate_empty_set() { - let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); - let empty_set = btreeset! {}; - let result = alphabet.ambiguate(&empty_set); - assert!(result.is_empty()); - } - - #[test] - fn test_ambiguate_single_canonical_char() { - let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); - let single_char_set = btreeset! {'A'}; - let result = alphabet.ambiguate(&single_char_set); - assert_eq!(result, btreeset! {'A'}); - } - - #[test] - fn test_ambiguate_all_canonical_chars() { - let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); - let all_canonical = btreeset! {'A', 'C', 'G', 'T'}; - let result = alphabet.ambiguate(&all_canonical); - assert_eq!(result, btreeset! {alphabet.unknown()}); - } - - #[test] - fn test_ambiguate_single_ambiguous_char() { - let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); - let ambiguous_set = btreeset! {'A', 'G'}; - let result = alphabet.ambiguate(&ambiguous_set); - assert_eq!(result, btreeset! {'R'}); - } - - #[test] - fn test_ambiguate_complex_case() { - let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); - let complex_set = btreeset! {'A', 'C', 'T', 'R', 'Y'}; - let result = alphabet.ambiguate(&complex_set); - assert_eq!(result, btreeset! {alphabet.unknown()}); - } - - #[test] - fn test_ambiguate_mixture_of_ambiguous_and_unambiguous_chars_aa() { - let alphabet = Alphabet::new(AlphabetName::Aa, false).unwrap(); - // B is N or D, partial overlap with canonicals N, D, E, Q, and explicit canonicals K, R - let mixed_set = btreeset! {'N', 'B', 'E', 'Q', 'K', 'R'}; - let result = alphabet.ambiguate(&mixed_set); - assert_eq!(result, btreeset! {'B', 'Z', 'K', 'R'}); - } + // #[test] + // fn test_ambiguate_empty_set() { + // let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); + // let empty_set = stateset! {}; + // let result = alphabet.ambiguate(empty_set); + // assert!(result.is_empty()); + // } + // + // #[test] + // fn test_ambiguate_single_canonical_char() { + // let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); + // let single_char_set = stateset! {'A'}; + // let result = alphabet.ambiguate(single_char_set); + // assert_eq!(result, stateset! {'A'}); + // } + // + // #[test] + // fn test_ambiguate_all_canonical_chars() { + // let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); + // let all_canonical = stateset! {'A', 'C', 'G', 'T'}; + // let result = alphabet.ambiguate(all_canonical); + // assert_eq!(result, stateset! {alphabet.unknown()}); + // } + // + // #[test] + // fn test_ambiguate_single_ambiguous_char() { + // let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); + // let ambiguous_set = stateset! {'A', 'G'}; + // let result = alphabet.ambiguate(ambiguous_set); + // assert_eq!(result, stateset! {'R'}); + // } + // + // #[test] + // fn test_ambiguate_complex_case() { + // let alphabet = Alphabet::new(AlphabetName::Nuc, false).unwrap(); + // let complex_set = stateset! {'A', 'C', 'T', 'R', 'Y'}; + // let result = alphabet.ambiguate(complex_set); + // assert_eq!(result, stateset! {alphabet.unknown()}); + // } + // + // #[test] + // fn test_ambiguate_mixture_of_ambiguous_and_unambiguous_chars_aa() { + // let alphabet = Alphabet::new(AlphabetName::Aa, false).unwrap(); + // // B is N or D, partial overlap with canonicals N, D, E, Q, and explicit canonicals K, R + // let mixed_set = stateset! {'N', 'B', 'E', 'Q', 'K', 'R'}; + // let result = alphabet.ambiguate(mixed_set); + // assert_eq!(result, stateset! {'B', 'Z', 'K', 'R'}); + // } #[test] fn test_alphabet_nuc() -> Result<(), Report> { @@ -582,22 +600,22 @@ mod tests { let actual = json_write_str(&alphabet, JsonPretty(true))?; let expected = indoc! { /* language=json */ r#"{ "all": [ + "-", "A", + "B", "C", + "D", "G", - "T", - "R", - "Y", - "S", - "W", + "H", "K", "M", - "D", - "H", - "B", - "V", "N", - "-" + "R", + "S", + "T", + "V", + "W", + "Y" ], "canonical": [ "A", @@ -651,6 +669,38 @@ mod tests { "G" ] }, + "ambiguous_keys": [ + "B", + "D", + "H", + "K", + "M", + "R", + "S", + "V", + "W", + "Y" + ], + "determined": [ + "A", + "B", + "C", + "D", + "G", + "H", + "K", + "M", + "R", + "S", + "T", + "V", + "W", + "Y" + ], + "undetermined": [ + "-", + "N" + ], "unknown": "N", "gap": "-", "treat_gap_as_unknown": false, @@ -847,7 +897,10 @@ mod tests { let actual = json_write_str(&alphabet, JsonPretty(true))?; let expected = indoc! { /* language=json */ r#"{ "all": [ + "*", + "-", "A", + "B", "C", "D", "E", @@ -855,6 +908,7 @@ mod tests { "G", "H", "I", + "J", "K", "L", "M", @@ -866,15 +920,12 @@ mod tests { "T", "V", "W", - "Y", - "*", - "B", - "Z", - "J", "X", - "-" + "Y", + "Z" ], "canonical": [ + "*", "A", "C", "D", @@ -894,8 +945,7 @@ mod tests { "T", "V", "W", - "Y", - "*" + "Y" ], "ambiguous": { "B": [ @@ -911,6 +961,41 @@ mod tests { "I" ] }, + "ambiguous_keys": [ + "B", + "J", + "Z" + ], + "determined": [ + "*", + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "P", + "Q", + "R", + "S", + "T", + "V", + "W", + "Y", + "Z" + ], + "undetermined": [ + "-", + "X" + ], "unknown": "X", "gap": "-", "treat_gap_as_unknown": false, diff --git a/packages/treetime/src/commands/ancestral/fitch.rs b/packages/treetime/src/commands/ancestral/fitch.rs index 3dd0d83b..eaba7e29 100644 --- a/packages/treetime/src/commands/ancestral/fitch.rs +++ b/packages/treetime/src/commands/ancestral/fitch.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] + use crate::alphabet::alphabet::{FILL_CHAR, NON_CHAR, VARIABLE_CHAR}; use crate::graph::breadth_first::GraphTraversalContinuation; use crate::io::fasta::FastaRecord; @@ -14,10 +15,9 @@ use crate::utils::interval::range::range_contains; use crate::utils::interval::range_complement::range_complement; use crate::utils::interval::range_difference::range_difference; use crate::utils::interval::range_intersection::{range_intersection, range_intersection_iter}; -use crate::utils::manyzip::Manyzip; use crate::{make_error, make_internal_report, make_report}; use eyre::{Report, WrapErr}; -use itertools::{izip, Itertools}; +use itertools::Itertools; use maplit::btreemap; use ndarray::AssignElem; @@ -135,7 +135,7 @@ fn fitch_backwards(graph: &SparseGraph, sparse_partitions: &[PartitionParsimony] return None; // this position does not have character state information } let state = match child.fitch.variable.get(&pos) { - Some(var_pos) => var_pos.clone(), + Some(var_pos) => *var_pos, None => StateSet::from_char(child.sequence[pos]), }; Some(state) @@ -166,26 +166,21 @@ fn fitch_backwards(graph: &SparseGraph, sparse_partitions: &[PartitionParsimony] } // Process all positions where the children are fixed or completely unknown in some children. - - // Gather state sets for each position across child sequences - // TODO(perf): avoid copying and allocations - let child_state_sets = Manyzip(children.iter().map(|(c, e)| c.sequence.iter().copied()).collect_vec()); - - // Zip these states with node sequence - let state_zip = izip!(sequence.iter_mut(), child_state_sets.into_iter()); - for (pos, (nuc, child_states)) in state_zip.enumerate() { - if *nuc != FILL_CHAR { + for (pos, parent_state) in sequence.iter_mut().enumerate() { + if *parent_state != FILL_CHAR { continue; } - let determined_states = child_states - .into_iter() - .filter(|&c| alphabet.is_canonical(c)) - .unique() - .collect_vec(); + let mut determined_states = Vec::with_capacity(alphabet.n_chars()); + for &(child, _) in &children { + let state = child.sequence[pos]; + if !determined_states.contains(&state) && alphabet.is_canonical(state) { + determined_states.push(state); + } + } // Find the state of the current node at this position - *nuc = match determined_states.as_slice() { + *parent_state = match determined_states.as_slice() { [state] => { // All children have the same state, that will be the state of the current node *state @@ -197,7 +192,7 @@ fn fitch_backwards(graph: &SparseGraph, sparse_partitions: &[PartitionParsimony] states => { // Child states differ. This is variable state. // Save child states and postpone the decision until forward pass. - let dis = StateSet::from_chars(states); + let dis = states.iter().collect(); seq_dis.variable.insert(pos, dis); VARIABLE_CHAR } @@ -513,8 +508,9 @@ pub fn ancestral_reconstruction_fitch( seq[r.0..r.1].fill(alphabet.unknown()); } - for (pos, states) in &mut node.fitch.variable { - seq[*pos] = alphabet.ambiguate(&states.inner()).first().copied().unwrap(); + for (&pos, &states) in &node.fitch.variable { + seq[pos] = states.first().unwrap(); + // seq[pos] = alphabet.ambiguate(states).first().unwrap(); } node.sequence = seq.clone(); @@ -981,12 +977,10 @@ mod tests { "sequence": "TCGGCCGTGTRTTG--", "fitch": { "variable": { - "10": { - "data": [ - "A", - "G" - ] - } + "10": [ + "A", + "G" + ] }, "variable_indel": {}, "composition": { diff --git a/packages/treetime/src/representation/bitset128.rs b/packages/treetime/src/representation/bitset128.rs new file mode 100644 index 00000000..c528adda --- /dev/null +++ b/packages/treetime/src/representation/bitset128.rs @@ -0,0 +1,554 @@ +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; +use std::ops::{Add, AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, Sub, SubAssign}; + +#[allow(variant_size_differences)] +#[derive(Clone, Debug)] +pub enum Bitset128Status { + Empty, + Unambiguous(char), + Ambiguous(BitSet128), +} + +#[must_use] +#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct BitSet128 { + bits: u128, +} + +impl BitSet128 { + pub fn new() -> Self { + Self { bits: 0 } + } + + pub fn from_char(c: char) -> Self { + Self { bits: 1 << (c as u32) } + } + + pub fn from_slice(chars: &[char]) -> Self { + chars.iter().copied().collect() + } + + pub fn is_empty(&self) -> bool { + self.bits == 0 + } + + pub fn len(&self) -> usize { + self.bits.count_ones() as usize + } + + pub fn clear(&mut self) { + self.bits = 0; + } + + pub fn contains(&self, c: char) -> bool { + (self.bits & (1 << (c as u32))) != 0 + } + + pub fn insert(&mut self, c: char) { + let mask = 1 << (c as u32); + self.bits |= mask; + } + + pub fn remove(&mut self, c: char) { + let mask = 1 << (c as u32); + self.bits &= !mask; + } + + pub fn union(&self, other: &Self) -> Self { + Self { + bits: self.bits | other.bits, + } + } + + pub fn intersection(&self, other: &Self) -> Self { + Self { + bits: self.bits & other.bits, + } + } + + pub fn difference(&self, other: &Self) -> Self { + Self { + bits: self.bits & !other.bits, + } + } + + pub fn symmetric_difference(set1: &Self, set2: &Self) -> Self { + Self { + bits: set1.bits ^ set2.bits, + } + } + + pub fn from_union(sets: I) -> Self + where + I: IntoIterator, + I::Item: Borrow, + { + let bits = sets.into_iter().fold(0, |acc, set| acc | set.borrow().bits); + Self { bits } + } + + pub fn from_intersection(sets: I) -> Self + where + I: IntoIterator, + I::Item: Borrow, + { + sets + .into_iter() + .map(|set| set.borrow().bits) + .reduce(|acc, set| acc & set) + .map_or_else(Self::new, |bits| Self { bits }) + } + + pub fn is_disjoint(&self, other: &Self) -> bool { + (self.bits & other.bits) == 0 + } + + pub fn is_subset(&self, other: &Self) -> bool { + (self.bits & other.bits) == self.bits + } + + pub fn is_superset(&self, other: &Self) -> bool { + other.is_subset(self) + } + + pub fn iter(&self) -> impl Iterator + '_ { + (0..128) + .filter(|&i| (self.bits & (1 << i)) != 0) + .map(|i| char::from_u32(i).unwrap()) + } + + pub fn chars(&self) -> impl Iterator + '_ { + self.iter() + } + + pub fn get(&self) -> Bitset128Status { + match self.bits.count_ones() { + 0 => Bitset128Status::Empty, + 1 => Bitset128Status::Unambiguous(self.get_one()), + _ => Bitset128Status::Ambiguous(*self), + } + } + + pub fn first(&self) -> Option { + (!self.is_empty()).then_some(char::from_u32(self.bits.trailing_zeros()).unwrap()) + } + + pub fn last(&self) -> Option { + (!self.is_empty()).then_some(char::from_u32(127 - self.bits.leading_zeros()).unwrap()) + } + + pub fn get_one_maybe(&self) -> Option { + self.first() + } + + pub fn get_one(&self) -> char { + self.get_one_maybe().expect("BitSet128 is empty") + } + + pub fn get_one_exactly(&self) -> char { + assert_eq!(self.len(), 1, "expected exactly one element"); + self.get_one() + } + + pub fn to_vec(&self) -> Vec { + self.iter().collect() + } + + pub fn from_vec(chars: Vec) -> Self { + Self::from_iter(chars) + } +} + +impl Add for BitSet128 { + type Output = Self; + + fn add(self, other: Self) -> Self::Output { + self.union(&other) + } +} + +impl Sub for BitSet128 { + type Output = Self; + + fn sub(self, other: Self) -> Self::Output { + self.difference(&other) + } +} + +impl AddAssign for BitSet128 { + #[allow(clippy::suspicious_op_assign_impl)] + fn add_assign(&mut self, other: Self) { + self.bits |= other.bits; + } +} + +impl SubAssign for BitSet128 { + fn sub_assign(&mut self, other: Self) { + self.bits &= !other.bits; + } +} + +impl BitAndAssign for BitSet128 { + fn bitand_assign(&mut self, other: Self) { + self.bits &= other.bits; + } +} + +impl BitOrAssign for BitSet128 { + fn bitor_assign(&mut self, other: Self) { + self.bits |= other.bits; + } +} + +impl BitXorAssign for BitSet128 { + fn bitxor_assign(&mut self, other: Self) { + self.bits ^= other.bits; + } +} + +impl> Extend for BitSet128 { + fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + for c in iter { + self.insert(*c.borrow()); + } + } +} + +impl> FromIterator for BitSet128 { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let bits = iter.into_iter().fold(0, |acc, c| acc | (1 << (*c.borrow() as u32))); + Self { bits } + } +} + +impl std::fmt::Display for BitSet128 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let chars: String = (0..128) + .filter(|&i| (self.bits & (1 << i)) != 0) + .map(|i| char::from_u32(i).unwrap()) + .join(", "); + write!(f, "{{{chars}}}") + } +} + +impl Serialize for BitSet128 { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let chars: Vec = self.to_vec(); + chars.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for BitSet128 { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let chars: Vec = Vec::deserialize(deserializer)?; + Ok(Self::from_vec(chars)) + } +} + +impl std::fmt::Debug for BitSet128 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + +#[macro_export] +macro_rules! bitset128 { + ($($char:expr),* $(,)?) => { + { + let chars = [$($char),*]; + BitSet128::from_slice(&chars) + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::{assert_eq, assert_ne}; + use std::hash::{DefaultHasher, Hash, Hasher}; + + #[test] + fn test_bitset128_new() { + let actual = BitSet128::new(); + let expected = bitset128! {}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_iter() { + let actual = BitSet128::from_iter(['a', 'b', 'c']); + let expected = bitset128! {'a', 'b', 'c'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_slice() { + let actual = BitSet128::from_slice(&['x', 'y', 'z']); + let expected = bitset128! {'x', 'y', 'z'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_is_empty() { + let a = BitSet128::new(); + assert!(a.is_empty()); + let a = bitset128! {'a'}; + assert!(!a.is_empty()); + } + + #[test] + fn test_bitset128_len() { + let actual = bitset128! {'a', 'b', 'c'}; + let expected_len = 3; + assert_eq!(actual.len(), expected_len); + + let actual = BitSet128::new(); + let expected_len = 0; + assert_eq!(actual.len(), expected_len); + } + + #[test] + fn test_bitset128_clear() { + let mut actual = bitset128! {'a', 'b', 'c'}; + actual.clear(); + let expected = BitSet128::new(); + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_insert() { + let mut actual = bitset128! {}; + actual.insert('a'); + actual.insert('a'); + actual.insert('b'); + let expected = bitset128! {'a', 'b'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_remove() { + let mut actual = bitset128! {'a', 'b', 'c'}; + actual.remove('b'); + actual.remove('b'); + let expected = bitset128! {'a', 'c'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_union() { + let a = bitset128! {'a', 'b', 'y'}; + let b = bitset128! {'b', 'z', 'x'}; + let actual = a.union(&b); + let expected = bitset128! {'a', 'b', 'x', 'y', 'z'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_union_with_empty() { + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {}; + let actual = a.union(&b); + let expected = bitset128! {'a', 'b', 'c'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_intersection() { + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {'b', 'c', 'x'}; + let actual = a.intersection(&b); + let expected = bitset128! {'b', 'c'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_intersection_with_empty() { + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {}; + let actual = a.intersection(&b); + let expected = bitset128! {}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_union() { + let a = bitset128! {'a', 'b', 'x', 'z'}; + let b = bitset128! {'y', 'x', 'a', 'z'}; + let c = bitset128! {'p', 'q', 'y', 'a'}; + let actual = BitSet128::from_union([a, b, c]); + let expected = bitset128! {'a', 'b', 'p', 'q', 'x', 'y', 'z'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_union_with_empty() { + let a = bitset128! {'a', 'b', 'x'}; + let b = bitset128! {}; + let c = bitset128! {'p', 'q', 'y'}; + let actual = BitSet128::from_union([a, b, c]); + let expected = bitset128! {'a', 'b', 'p', 'q', 'x', 'y'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_union_of_all_empty() { + let a = bitset128! {}; + let b = bitset128! {}; + let c = bitset128! {}; + let actual = BitSet128::from_union([a, b, c]); + let expected = bitset128! {}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_intersection() { + let a = bitset128! {'a', 'b', 'x', 'z'}; + let b = bitset128! {'y', 'x', 'a', 'z'}; + let c = bitset128! {'x', 'q', 'y', 'a'}; + let actual = BitSet128::from_intersection([a, b, c]); + let expected = bitset128! {'a', 'x'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_intersection_with_empty() { + let a = bitset128! {'a', 'b', 'x'}; + let b = bitset128! {}; + let c = bitset128! {'p', 'q', 'y'}; + let actual = BitSet128::from_intersection([a, b, c]); + let expected = bitset128! {}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_from_intersection_of_all_empty() { + let a = bitset128! {}; + let b = bitset128! {}; + let c = bitset128! {}; + let actual = BitSet128::from_intersection([a, b, c]); + let expected = bitset128! {}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_difference() { + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {'b', 'c', 'x'}; + let actual = a.difference(&b); + let expected = bitset128! {'a'}; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_is_disjoint() { + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {'x', 'y', 'z'}; + assert!(a.is_disjoint(&b)); + } + + #[test] + fn test_bitset128_is_subset() { + let a = bitset128! {'a', 'b'}; + let b = bitset128! {'a', 'b', 'c'}; + assert!(a.is_subset(&b)); + } + + #[test] + fn test_bitset128_is_superset() { + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {'a', 'b'}; + assert!(a.is_superset(&b)); + } + + #[test] + fn test_bitset128_display() { + let a = bitset128! {'a', 'b', 'c'}; + let actual = a.to_string(); + let expected = "{a, b, c}"; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_debug() { + let a = bitset128! {'a', 'b', 'c'}; + let actual = format!("{a:?}"); + let expected = "{a, b, c}"; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_eq() { + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {'a', 'b', 'c'}; + let c = bitset128! {'x', 'y', 'z'}; + assert_eq!(a, b); + assert_ne!(a, c); + } + + #[test] + fn test_bitset128_hash() { + fn calculate_hash(t: &T) -> u64 { + let mut hasher = DefaultHasher::new(); + t.hash(&mut hasher); + hasher.finish() + } + + let a = bitset128! {'a', 'b', 'c'}; + let b = bitset128! {'a', 'b', 'c'}; + let c = bitset128! {'x', 'y', 'z'}; + assert_eq!(calculate_hash(&a), calculate_hash(&b)); + assert_ne!(calculate_hash(&a), calculate_hash(&c)); + } + + #[test] + fn test_bitset128_get_empty() { + let set = BitSet128::new(); + assert!(matches!(set.get(), Bitset128Status::Empty)); + } + + #[test] + fn test_bitset128_get_unambiguous() { + let set = BitSet128::from_char('A'); + assert!(matches!(set.get(), Bitset128Status::Unambiguous('A'))); + } + + #[test] + fn test_bitset128_get_ambiguous() { + let set = BitSet128::from_iter(vec!['A', 'C']); + assert!(matches!(set.get(), Bitset128Status::Ambiguous(_))); + if let Bitset128Status::Ambiguous(actual) = set.get() { + let expected = bitset128! {'A', 'C'}; + assert_eq!(actual, expected); + } + } + + #[test] + fn test_bitset128_get_one() { + let set = BitSet128::from_iter(['T', 'A']); + let actual = set.get_one(); + let expected = 'A'; + assert_eq!(actual, expected); + } + + #[test] + fn test_bitset128_get_one_exactly() { + let set = BitSet128::from_iter(['T']); + let actual = set.get_one(); + let expected = 'T'; + assert_eq!(actual, expected); + } +} diff --git a/packages/treetime/src/representation/graph_sparse.rs b/packages/treetime/src/representation/graph_sparse.rs index 76bc2d79..f6a4d15a 100644 --- a/packages/treetime/src/representation/graph_sparse.rs +++ b/packages/treetime/src/representation/graph_sparse.rs @@ -94,7 +94,7 @@ impl SparseSeqNode { .iter() .enumerate() .filter(|(_, &c)| alphabet.is_ambiguous(c)) - .map(|(pos, &c)| (pos, StateSet::from_chars(alphabet.disambiguate(c)))) + .map(|(pos, &c)| (pos, alphabet.disambiguate(c))) .collect(); let seq_dis = ParsimonySeqDis { diff --git a/packages/treetime/src/representation/mod.rs b/packages/treetime/src/representation/mod.rs index aefa033c..58e3eb20 100644 --- a/packages/treetime/src/representation/mod.rs +++ b/packages/treetime/src/representation/mod.rs @@ -4,3 +4,4 @@ pub mod infer_dense; pub mod partitions_likelihood; pub mod partitions_parsimony; pub mod state_set; +pub mod bitset128; diff --git a/packages/treetime/src/representation/state_set.rs b/packages/treetime/src/representation/state_set.rs index cfb7d9a2..27378928 100644 --- a/packages/treetime/src/representation/state_set.rs +++ b/packages/treetime/src/representation/state_set.rs @@ -1,221 +1,12 @@ -use serde::{Deserialize, Serialize}; -use std::borrow::Borrow; -use std::collections::btree_set::Iter; -use std::collections::BTreeSet; +pub use crate::representation::bitset128::BitSet128; +pub use crate::representation::bitset128::Bitset128Status; -#[derive(Clone, Debug)] -pub enum StateSetStatus<'a> { - Empty, - Unambiguous(char), - Ambiguous(Iter<'a, char>), -} - -#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub struct StateSet { - data: BTreeSet, -} - -impl StateSet { - /// Create an empty set - pub fn new() -> Self { - Self { data: BTreeSet::new() } - } - - /// Create a set containing a given chars - pub fn from_chars(chars: I) -> Self - where - I: IntoIterator, - T: Borrow, - { - Self { - data: chars.into_iter().map(|c| *c.borrow()).collect(), - } - } - - /// Create a set containing a single given char - pub fn from_char(c: char) -> Self { - Self::from_chars([c]) - } - - /// Create a set from intersection of given sets - pub fn from_intersection(sets: &[StateSet]) -> StateSet { - let intersection = sets - .iter() - .map(|set| &set.data) - .fold(None, |acc: Option>, set| { - if let Some(a) = acc { - Some(a.intersection(set).copied().collect()) - } else { - Some(set.clone()) - } - }) - .unwrap_or_else(BTreeSet::new); - - StateSet { data: intersection } - } - - /// Create a set from union of given sets - pub fn from_union(sets: &[StateSet]) -> StateSet { - let data = sets - .iter() - .flat_map(|set| set.data.iter()) - .copied() - .collect::>(); - StateSet { data } - } - - /// Get contents - empty, unique element or multiple ambiguous elements - pub fn get(&self) -> StateSetStatus { - match self.data.len() { - 0 => StateSetStatus::Empty, - 1 => StateSetStatus::Unambiguous(*self.data.iter().next().unwrap()), - _ => StateSetStatus::Ambiguous(self.data.iter()), - } - } - - /// Get one of the characters from the set, or return None. - /// Note: which of the multiple character gets retrieved is not specified and is not to be relied on. - pub fn get_one_maybe(&self) -> Option { - self.data.iter().next().copied() - } - - /// Get one of the characters from the set, or panic if the set is empty. - /// Note: which of the multiple character gets retrieved is not specified and is not to be relied on. - pub fn get_one(&self) -> char { - assert!(self.data.len() > 0); - self.get_one_maybe().unwrap() - } - - /// Get exactly one element from the set, or panic if there's not exactly one element. - /// Note: which of the multiple character gets retrieved is not specified and is not to be relied on. - pub fn get_one_exactly(&self) -> char { - assert_eq!(self.data.len(), 1); - self.get_one_maybe().unwrap() - } - - /// Check if set contains a character - pub fn contains(&self, c: char) -> bool { - self.data.contains(&c) - } - - /// Access characters - #[allow(clippy::needless_lifetimes)] - pub fn chars<'a>(&'a self) -> impl Iterator + 'a { - self.data.iter().copied() - } - - /// Access internal set implementation - pub fn inner(&self) -> BTreeSet { - self.data.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use maplit::btreeset; - use pretty_assertions::assert_eq; - - #[test] - fn test_stateset_new() { - let set = StateSet::new(); - let actual = set.inner(); - let expected: BTreeSet = BTreeSet::new(); - assert_eq!(expected, actual); - } - - #[test] - fn test_stateset_from_chars() { - let chars = btreeset!['A', 'C', 'G', 'T']; - let set = StateSet::from_chars(&chars); - let actual = &set.inner(); - let expected = &chars; - assert_eq!(expected, actual); - } - - #[test] - fn test_stateset_from_char() { - let set = StateSet::from_char('A'); - let actual = set.inner(); - let expected = btreeset! {'A'}; - assert_eq!(expected, actual); - } - - #[test] - fn test_stateset_from_intersection() { - let set1 = StateSet::from_chars(vec!['A', 'C', 'G']); - let set2 = StateSet::from_chars(vec!['C', 'G', 'T']); - let result = StateSet::from_intersection(&[set1, set2]); - let actual = result.inner(); - let expected = btreeset! {'C', 'G'}; - assert_eq!(expected, actual); - - let empty_intersection = StateSet::from_intersection(&[]); - let actual = empty_intersection.inner(); - let expected: BTreeSet = BTreeSet::new(); - assert_eq!(expected, actual); - } - - #[test] - fn test_stateset_from_union() { - let set1 = StateSet::from_chars(vec!['A', 'C']); - let set2 = StateSet::from_chars(vec!['G', 'T']); - let result = StateSet::from_union(&[set1, set2]); - let actual = result.inner(); - let expected = btreeset! {'A', 'C', 'G', 'T'}; - assert_eq!(expected, actual); - - let empty_union = StateSet::from_union(&[]); - let actual = empty_union.inner(); - let expected: BTreeSet = BTreeSet::new(); - assert_eq!(expected, actual); - } - - #[test] - fn test_stateset_get_empty() { - let set = StateSet::new(); - assert!(matches!(set.get(), StateSetStatus::Empty)); - } - - #[test] - fn test_stateset_get_unambiguous() { - let set = StateSet::from_char('A'); - assert!(matches!(set.get(), StateSetStatus::Unambiguous('A'))); - } - - #[test] - fn test_stateset_get_ambiguous() { - let set = StateSet::from_chars(vec!['A', 'C']); - assert!(matches!(set.get(), StateSetStatus::Ambiguous(_))); - if let StateSetStatus::Ambiguous(iter) = set.get() { - let actual: BTreeSet = iter.copied().collect(); - let expected = btreeset! {'A', 'C'}; - assert_eq!(expected, actual); - } - } - - #[test] - fn test_stateset_get_one() { - let set = StateSet::from_chars(['T', 'A']); - let actual = set.get_one(); - let expected = 'A'; - assert_eq!(expected, actual); - } - - #[test] - fn test_stateset_get_one_exactly() { - let set = StateSet::from_chars(['T']); - let actual = set.get_one(); - let expected = 'T'; - assert_eq!(expected, actual); - } +pub type StateSetStatus = Bitset128Status; +pub type StateSet = BitSet128; - #[test] - fn test_stateset_inner() { - let chars = vec!['A', 'G', 'T']; - let set = StateSet::from_chars(chars.clone()); - let actual = set.inner(); - let expected = chars.into_iter().collect::>(); - assert_eq!(expected, actual); - } +#[macro_export] +macro_rules! stateset { + ($($args:tt)*) => { + $crate::bitset128!($($args)*) + }; }