diff --git a/packages/nextclade/src/tree/tree.rs b/packages/nextclade/src/tree/tree.rs index fc6676c7e..ba45891ff 100644 --- a/packages/nextclade/src/tree/tree.rs +++ b/packages/nextclade/src/tree/tree.rs @@ -18,6 +18,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::collections::BTreeMap; use std::path::Path; use std::slice::Iter; +use serde_json::json; use traversal::{Bft, DftPost, DftPre}; use validator::Validate; @@ -282,6 +283,11 @@ impl AuspiceGraphNodePayload { .and_then(|val| val.as_str()) } + /// Sets clade-like node attribute + pub fn set_clade_node_attr(&mut self, key: impl AsRef, value: impl AsRef) { + self.node_attrs.other[key.as_ref()] = json!({ "value": value.as_ref() }); + } + /// Extracts clade-like node attributes, given a list of key descriptions pub fn get_clade_node_attrs(&self, clade_node_attr_descs: &[CladeNodeAttrKeyDesc]) -> BTreeMap { clade_node_attr_descs @@ -292,6 +298,13 @@ impl AuspiceGraphNodePayload { }) .collect() } + + /// Sets clade-like node attributes. Inserts if key does not exist and overwrites existing. + pub fn set_clade_node_attrs(&mut self, attrs: BTreeMap) { + for (key, val) in attrs { + self.set_clade_node_attr(key, val); + } + } } impl GraphNode for AuspiceGraphNodePayload {} diff --git a/packages/nextclade/src/tree/tree_builder.rs b/packages/nextclade/src/tree/tree_builder.rs index 3b3cf6799..6c8e0179a 100644 --- a/packages/nextclade/src/tree/tree_builder.rs +++ b/packages/nextclade/src/tree/tree_builder.rs @@ -8,13 +8,16 @@ use crate::coord::range::NucRefGlobalRange; use crate::graph::node::{GraphNodeKey, Node}; use crate::tree::params::TreeBuilderParams; use crate::tree::split_muts::{difference_of_muts, split_muts, union_of_muts, SplitMutsResult}; -use crate::tree::tree::{AuspiceGraph, AuspiceGraphEdgePayload, AuspiceGraphNodePayload, TreeBranchAttrsLabels}; +use crate::tree::tree::{ + AuspiceGraph, AuspiceGraphEdgePayload, AuspiceGraphNodePayload, TreeBranchAttrsLabels, TreeNodeAttr, +}; use crate::tree::tree_attach_new_nodes::create_new_auspice_node; use crate::tree::tree_preprocess::add_auspice_metadata_in_place; use crate::types::outputs::NextcladeOutputs; use crate::utils::collections::concat_to_vec; +use crate::utils::stats::mode; use eyre::{Report, WrapErr}; -use itertools::Itertools; +use itertools::{chain, Itertools}; use std::collections::BTreeMap; pub fn graph_attach_new_nodes_in_place( @@ -361,7 +364,7 @@ pub fn knit_into_graph( // generate new internal node // add private mutations, divergence, name and branch attrs to new internal node let new_internal_node = { - let mut new_internal_node: AuspiceGraphNodePayload = target_node_auspice.clone(); + let mut new_internal_node: AuspiceGraphNodePayload = target_node_auspice.to_owned(); new_internal_node.tmp.private_mutations = muts_common_branch; new_internal_node.node_attrs.div = Some(divergence_middle_node); new_internal_node.branch_attrs.mutations = @@ -378,6 +381,27 @@ pub fn knit_into_graph( format!("nextclade__copy_of_{target_name}_for_placement_of_{qry_name}_#{qry_index}") }; + // Vote for the most plausible clade + let (clade, should_relabel) = vote_for_clade(graph, target_node, result); + new_internal_node.node_attrs.clade_membership = clade.as_deref().map(TreeNodeAttr::new); + + // Vote for the most plausible clade-like attrs + let clade_attrs = vote_for_clade_like_attrs(graph, target_node, result); + new_internal_node.set_clade_node_attrs(clade_attrs); + + // If decided, then move the clade label from target node to the internal node + if should_relabel { + let target_node = graph.get_node_mut(target_key)?; + if let Some(target_labels) = &mut target_node.payload_mut().branch_attrs.labels { + target_labels.clade = None; + new_internal_node + .branch_attrs + .labels + .get_or_insert(TreeBranchAttrsLabels::default()) + .clade = clade; + } + } + new_internal_node }; @@ -433,3 +457,54 @@ fn set_branch_attrs_aa_labels(node: &mut AuspiceGraphNodePayload) { }); } } + +// Vote for the most plausible clade for the new internal node +fn vote_for_clade( + graph: &AuspiceGraph, + target_node: &Node, + result: &NextcladeOutputs, +) -> (Option, bool) { + let query_clade = &result.clade; + + let parent_node = &graph.parent_of(target_node); + let parent_clade = &parent_node.and_then(|node| node.payload().clade()); + + let target_clade = &target_node.payload().clade(); + + let possible_clades = [parent_clade, query_clade, target_clade].into_iter().flatten(); // exclude None + let clade = mode(possible_clades).cloned(); + + // We will need to change branch label if both: + // - clade transition happens from parent to the new node + // AND + // - when the target node's clade wins the vote + let should_relabel = (parent_clade != &clade) && (target_clade.is_some() && target_clade == &clade); + + (clade, should_relabel) +} + +// Vote for the most plausible clade-like attribute values, for the new internal node +fn vote_for_clade_like_attrs( + graph: &AuspiceGraph, + target_node: &Node, + result: &NextcladeOutputs, +) -> BTreeMap { + let attr_descs = graph.data.meta.clade_node_attr_descs(); + + let query_attrs: &BTreeMap = &result.custom_node_attributes; + + let parent_node = &graph.parent_of(target_node); + let parent_attrs: &BTreeMap = &parent_node + .map(|node| node.payload().get_clade_node_attrs(attr_descs)) + .unwrap_or_default(); + + let target_attrs: &BTreeMap = &target_node.payload().get_clade_node_attrs(attr_descs); + + chain!(query_attrs.iter(), parent_attrs.iter(), target_attrs.iter()) + .into_group_map() + .into_iter() + .filter_map(|(key, grouped_values)| { + mode(grouped_values.into_iter().cloned()).map(|most_common| (key.clone(), most_common)) + }) + .collect() +} diff --git a/packages/nextclade/src/utils/mod.rs b/packages/nextclade/src/utils/mod.rs index 1b6f01f18..276f73aa8 100644 --- a/packages/nextclade/src/utils/mod.rs +++ b/packages/nextclade/src/utils/mod.rs @@ -10,6 +10,7 @@ pub mod info; pub mod map; pub mod num; pub mod option; +pub mod stats; pub mod string; pub mod vec2d; pub mod wraparound; diff --git a/packages/nextclade/src/utils/stats.rs b/packages/nextclade/src/utils/stats.rs new file mode 100644 index 000000000..35b82db0c --- /dev/null +++ b/packages/nextclade/src/utils/stats.rs @@ -0,0 +1,13 @@ +use itertools::Itertools; +use std::hash::Hash; + +/// Calculate mode (the most frequently occurring element) of an iterator. +/// In case of a tie, the first occurrence is returned. Returns `None` if the iterator is empty. +pub fn mode(items: impl IntoIterator) -> Option { + items + .into_iter() + .counts() + .into_iter() + .max_by_key(|&(_, count)| count) + .map(|(item, _)| item) +}