Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: clade mismatch against nearest node #1526

Merged
merged 9 commits into from
Oct 29, 2024
49 changes: 47 additions & 2 deletions packages/nextclade/src/tree/tree_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ use crate::coord::range::NucRefGlobalRange;
use crate::graph::node::{GraphNodeKey, Node};
use crate::tree::params::TreeBuilderParams;
use crate::tree::split_muts::{difference_of_muts, split_muts, union_of_muts, SplitMutsResult};
use crate::tree::tree::{AuspiceGraph, AuspiceGraphEdgePayload, AuspiceGraphNodePayload, TreeBranchAttrsLabels};
use crate::tree::tree::{
AuspiceGraph, AuspiceGraphEdgePayload, AuspiceGraphNodePayload, TreeBranchAttrsLabels, TreeNodeAttr,
};
use crate::tree::tree_attach_new_nodes::create_new_auspice_node;
use crate::tree::tree_preprocess::add_auspice_metadata_in_place;
use crate::types::outputs::NextcladeOutputs;
use crate::utils::collections::concat_to_vec;
use crate::utils::stats::mode;
use eyre::{Report, WrapErr};
use itertools::Itertools;
use std::collections::BTreeMap;
Expand Down Expand Up @@ -361,7 +364,7 @@ pub fn knit_into_graph(
// generate new internal node
// add private mutations, divergence, name and branch attrs to new internal node
let new_internal_node = {
let mut new_internal_node: AuspiceGraphNodePayload = target_node_auspice.clone();
let mut new_internal_node: AuspiceGraphNodePayload = target_node_auspice.to_owned();
new_internal_node.tmp.private_mutations = muts_common_branch;
new_internal_node.node_attrs.div = Some(divergence_middle_node);
new_internal_node.branch_attrs.mutations =
Expand All @@ -378,6 +381,23 @@ pub fn knit_into_graph(
format!("nextclade__copy_of_{target_name}_for_placement_of_{qry_name}_#{qry_index}")
};

// Vote for the most plausible clade
let (clade, should_relabel) = vote_for_clade(graph, target_node, result);
new_internal_node.node_attrs.clade_membership = clade.as_deref().map(TreeNodeAttr::new);

// If decided, then move the clade label from target node to the internal node
if should_relabel {
let target_node = graph.get_node_mut(target_key)?;
if let Some(target_labels) = &mut target_node.payload_mut().branch_attrs.labels {
target_labels.clade = None;
new_internal_node
.branch_attrs
.labels
.get_or_insert(TreeBranchAttrsLabels::default())
.clade = clade;
}
}

new_internal_node
};

Expand Down Expand Up @@ -433,3 +453,28 @@ fn set_branch_attrs_aa_labels(node: &mut AuspiceGraphNodePayload) {
});
}
}

// Vote for the most plausible clade for the new internal node
fn vote_for_clade(
graph: &AuspiceGraph,
target_node: &Node<AuspiceGraphNodePayload>,
result: &NextcladeOutputs,
) -> (Option<String>, bool) {
let query_clade = &result.clade;

let parent_node = &graph.parent_of(target_node);
let parent_clade = &parent_node.and_then(|node| node.payload().clade());

let target_clade = &target_node.payload().clade();

let possible_clades = [parent_clade, query_clade, target_clade].into_iter().flatten(); // exclude None
let clade = mode(possible_clades).cloned();

// We will need to change branch label if both:
// - clade transition happens from parent to the new node
// AND
// - when the target node's clade wins the vote
let should_relabel = (parent_clade != &clade) && (target_clade.is_some() && target_clade == &clade);

(clade, should_relabel)
}
1 change: 1 addition & 0 deletions packages/nextclade/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod info;
pub mod map;
pub mod num;
pub mod option;
pub mod stats;
pub mod string;
pub mod vec2d;
pub mod wraparound;
Expand Down
13 changes: 13 additions & 0 deletions packages/nextclade/src/utils/stats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use itertools::Itertools;
use std::hash::Hash;

/// Calculate mode (the most frequently occurring element) of an iterator.
/// In case of a tie, the first occurrence is returned. Returns `None` if the iterator is empty.
pub fn mode<T: Hash + Eq + Clone>(items: impl IntoIterator<Item = T>) -> Option<T> {
items
.into_iter()
.counts()
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(item, _)| item)
}