diff --git a/Cargo.toml b/Cargo.toml index 2544018b3..66be29a09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tskit" -version = "0.12.0" +version = "0.13.0-alpha.0" authors = ["tskit developers "] build = "build.rs" edition = "2021" @@ -50,6 +50,7 @@ pkg-config = "0.3" [features] provenance = ["humantime"] derive = ["tskit-derive", "serde", "serde_json", "bincode"] +edgebuffer = [] [package.metadata.docs.rs] all-features = true @@ -58,3 +59,7 @@ rustdoc-args = ["--cfg", "doc_cfg"] # Not run during tests [[example]] name = "tree_traversals" + +[[example]] +name = "haploid_wright_fisher_edge_buffering" +required-features = ["edgebuffer"] diff --git a/examples/haploid_wright_fisher.rs b/examples/haploid_wright_fisher.rs index 1d10b7fa8..27053a20d 100644 --- a/examples/haploid_wright_fisher.rs +++ b/examples/haploid_wright_fisher.rs @@ -8,12 +8,30 @@ use proptest::prelude::*; use rand::distributions::Distribution; use rand::SeedableRng; +fn rotate_edges(bookmark: &tskit::types::Bookmark, tables: &mut tskit::TableCollection) { + let num_edges = tables.edges().num_rows().as_usize(); + let left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let right = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) }; + let parent = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) }; + let child = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) }; + let mid = bookmark.edges().as_usize(); + left.rotate_left(mid); + right.rotate_left(mid); + parent.rotate_left(mid); + child.rotate_left(mid); +} + // ANCHOR: haploid_wright_fisher fn simulate( seed: u64, popsize: usize, num_generations: i32, simplify_interval: i32, + update_bookmark: bool, ) -> Result { if popsize == 0 { return Err(anyhow::Error::msg("popsize must be > 0")); @@ -46,6 +64,7 @@ fn simulate( let parent_picker = rand::distributions::Uniform::new(0, popsize); let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0); let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut bookmark = tskit::types::Bookmark::new(); for birth_time in (0..num_generations).rev() { for c in children.iter_mut() { @@ -64,7 +83,10 @@ fn simulate( } if birth_time % simplify_interval == 0 { - tables.full_sort(tskit::TableSortOptions::default())?; + tables.sort(&bookmark, tskit::TableSortOptions::default())?; + if update_bookmark { + rotate_edges(&bookmark, &mut tables); + } if let Some(idmap) = tables.simplify(children, tskit::SimplificationOptions::default(), true)? { @@ -73,6 +95,9 @@ fn simulate( *o = idmap[usize::try_from(*o)?]; } } + if update_bookmark { + bookmark.set_edges(tables.edges().num_rows()); + } } std::mem::swap(&mut parents, &mut children); } @@ -91,6 +116,8 @@ struct SimParams { num_generations: i32, simplify_interval: i32, treefile: Option, + #[clap(short, long, help = "Use bookmark to avoid sorting entire edge table.")] + bookmark: bool, } fn main() -> Result<()> { @@ -100,6 +127,7 @@ fn main() -> Result<()> { params.popsize, params.num_generations, params.simplify_interval, + params.bookmark, )?; if let Some(treefile) = ¶ms.treefile { @@ -114,8 +142,9 @@ proptest! { #[test] fn test_simulate_proptest(seed in any::(), num_generations in 50..100i32, - simplify_interval in 1..100i32) { - let ts = simulate(seed, 100, num_generations, simplify_interval).unwrap(); + simplify_interval in 1..100i32, + bookmark in proptest::bool::ANY) { + let ts = simulate(seed, 100, num_generations, simplify_interval, bookmark).unwrap(); // stress test the branch length fn b/c it is not a trivial // wrapper around the C API. diff --git a/examples/haploid_wright_fisher_edge_buffering.rs b/examples/haploid_wright_fisher_edge_buffering.rs new file mode 100644 index 000000000..b718ae271 --- /dev/null +++ b/examples/haploid_wright_fisher_edge_buffering.rs @@ -0,0 +1,158 @@ +// This is a rust implementation of the example +// found in tskit-c + +use anyhow::Result; +use clap::Parser; +#[cfg(test)] +use proptest::prelude::*; +use rand::distributions::Distribution; +use rand::SeedableRng; +use tskit::NodeId; + +// ANCHOR: haploid_wright_fisher_edge_buffering +fn simulate( + seed: u64, + popsize: usize, + num_generations: i32, + simplify_interval: i32, +) -> Result { + if popsize == 0 { + return Err(anyhow::Error::msg("popsize must be > 0")); + } + if num_generations == 0 { + return Err(anyhow::Error::msg("num_generations must be > 0")); + } + if simplify_interval == 0 { + return Err(anyhow::Error::msg("simplify_interval must be > 0")); + } + let mut tables = tskit::TableCollection::new(1.0)?; + + // create parental nodes + let mut parents_and_children = { + let mut temp = vec![]; + let parental_time = f64::from(num_generations); + for _ in 0..popsize { + let node = tables.add_node(0, parental_time, -1, -1)?; + temp.push(node); + } + temp + }; + + // allocate space for offspring nodes + parents_and_children.resize(2 * parents_and_children.len(), tskit::NodeId::NULL); + + // Construct non-overlapping mutable slices into our vector. + let (mut parents, mut children) = parents_and_children.split_at_mut(popsize); + + let parent_picker = rand::distributions::Uniform::new(0, popsize); + let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0); + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut buffer = tskit::EdgeBuffer::default(); + let mut node_map: Vec = vec![]; + + for birth_time in (0..num_generations).rev() { + for c in children.iter_mut() { + let bt = f64::from(birth_time); + let child = tables.add_node(0, bt, -1, -1)?; + let left_parent = parents + .get(parent_picker.sample(&mut rng)) + .ok_or_else(|| anyhow::Error::msg("invalid left_parent index"))?; + let right_parent = parents + .get(parent_picker.sample(&mut rng)) + .ok_or_else(|| anyhow::Error::msg("invalid right_parent index"))?; + //buffer.setup_births(&[*left_parent, *right_parent], &[child])?; + let breakpoint = breakpoint_generator.sample(&mut rng); + buffer.buffer_birth(*left_parent, child, 0., breakpoint)?; + buffer.buffer_birth(*right_parent, child, breakpoint, 1.0)?; + //buffer.finalize_births(); + *c = child; + } + + if birth_time % simplify_interval == 0 { + //buffer.pre_simplification(&mut tables)?; + //tables.full_sort(tskit::TableSortOptions::default())?; + node_map.resize(tables.nodes().num_rows().as_usize(), tskit::NodeId::NULL); + tskit::simplfify_from_buffer( + children, + tskit::SimplificationOptions::default(), + &mut tables, + &mut buffer, + Some(&mut node_map), + )?; + for o in children.iter_mut() { + assert!(o.as_usize() < node_map.len()); + *o = node_map[usize::try_from(*o)?]; + assert!(!o.is_null()); + } + //if let Some(idmap) = + // tables.simplify(children, tskit::SimplificationOptions::default(), true)? + //{ + // // remap child nodes + // for o in children.iter_mut() { + // *o = idmap[usize::try_from(*o)?]; + // } + //} + buffer.post_simplification(children, &mut tables)?; + } + std::mem::swap(&mut parents, &mut children); + } + + tables.build_index()?; + let treeseq = tables.tree_sequence(tskit::TreeSequenceFlags::default())?; + + Ok(treeseq) +} +// ANCHOR_END: haploid_wright_fisher_edge_buffering + +#[derive(Clone, clap::Parser)] +struct SimParams { + seed: u64, + popsize: usize, + num_generations: i32, + simplify_interval: i32, + treefile: Option, +} + +fn main() -> Result<()> { + let params = SimParams::parse(); + let treeseq = simulate( + params.seed, + params.popsize, + params.num_generations, + params.simplify_interval, + )?; + + if let Some(treefile) = ¶ms.treefile { + treeseq.dump(treefile, 0)?; + } + + Ok(()) +} + +#[cfg(test)] +proptest! { +#[test] + fn test_simulate_proptest(seed in any::(), + num_generations in 50..100i32, + simplify_interval in 1..100i32) { + let ts = simulate(seed, 100, num_generations, simplify_interval).unwrap(); + + // stress test the branch length fn b/c it is not a trivial + // wrapper around the C API. + { + use streaming_iterator::StreamingIterator; + let mut x = f64::NAN; + if let Ok(mut tree_iter) = ts.tree_iterator(0) { + // We will only do the first tree to save time. + if let Some(tree) = tree_iter.next() { + let b = tree.total_branch_length(false).unwrap(); + let b2 = unsafe { + tskit::bindings::tsk_tree_get_total_branch_length(tree.as_ptr(), -1, &mut x) + }; + assert!(b2 >= 0, "{}", b2); + assert!(f64::from(b) - x <= 1e-8); + } + } + } + } +} diff --git a/src/edgebuffer.rs b/src/edgebuffer.rs new file mode 100644 index 000000000..200a2a62f --- /dev/null +++ b/src/edgebuffer.rs @@ -0,0 +1,791 @@ +use crate::NodeId; +use crate::Position; +use crate::TableCollection; +use crate::TskitError; + +// Design considerations: +// +// We should be able to do better than +// the fwdpp implementation by taking a +// time-sorted list of alive nodes and inserting +// their edges. +// After insertion, we can truncate the input +// edge table, eliminating all edges corresponding +// to the set of alive nodes. +// This procedure would only be done AFTER +// simplification, such that the copied +// edges are guaranteed correct. +// We'd need to hash the existence of these alive nodes. +// Then, when going over the edge buffer, we can ask +// if an edge parent is in the hashed set. +// We would also keep track of the smallest +// edge id, and that (maybe minus 1?) is our truncation point. + +fn swap_with_empty(vec: &mut Vec) { + let mut t = vec![]; + std::mem::swap(&mut t, vec); +} + +#[derive(Copy, Clone)] +struct AliveNodeTimes { + min: f64, + max: f64, +} + +impl AliveNodeTimes { + fn new(min: f64, max: f64) -> Self { + Self { min, max } + } + + fn non_overlapping(&self) -> bool { + self.min == self.max + } +} + +#[derive(Debug)] +struct PreExistingEdge { + first: usize, + last: usize, +} + +impl PreExistingEdge { + fn new(first: usize, last: usize) -> Self { + assert!(last > first); + Self { first, last } + } +} + +#[derive(Debug)] +struct Segment { + left: Position, + right: Position, +} + +type ChildSegments = std::collections::HashMap>; + +#[derive(Default, Debug)] +struct BufferedBirths { + children: Vec, + segments: std::collections::HashMap, +} + +impl BufferedBirths { + fn initialize(&mut self, parents: &[NodeId], children: &[NodeId]) -> Result<(), TskitError> { + self.children = children.to_vec(); + self.children.sort(); + self.children.dedup(); + self.segments.clear(); + // FIXME: don't do this work if the parent already exists + for p in parents { + let mut segments = ChildSegments::default(); + for c in &self.children { + if segments.insert(*c, vec![]).is_some() { + return Err(TskitError::LibraryError("redundant child ids".to_owned())); + } + } + self.segments.insert(*p, segments); + } + Ok(()) + } +} + +#[derive(Default, Debug)] +pub struct EdgeBuffer { + left: Vec, + right: Vec, + child: Vec, + // TODO: this should be + // an option so that we can use take. + buffered_births: BufferedBirths, + // NOTE: these vectors are wasteful: + // 1. usize is more than we need, + // but it is more convenient. + // 2. Worse, these vectors will + // contain N elements, where + // N is the total number of nodes, + // but likely many fewer nodes than that + // have actually had offspring. + // It is hard to fix this -- we cannot + // guarantee that parents are entered + // in any specific order. + // 3. Performance IMPROVES MEASURABLY + // if we use u32 here. But tsk_size_t + // is u64. + head: Vec, + tail: Vec, + next: Vec, +} + +impl EdgeBuffer { + fn insert_new_parent(&mut self, parent: usize, child: NodeId, left: Position, right: Position) { + self.left.push(left); + self.right.push(right); + self.child.push(child); + self.head[parent] = self.left.len() - 1; + self.tail[parent] = self.head[parent]; + self.next.push(usize::MAX); + } + + fn extend_parent(&mut self, parent: usize, child: NodeId, left: Position, right: Position) { + self.left.push(left); + self.right.push(right); + self.child.push(child); + let t = self.tail[parent]; + self.tail[parent] = self.left.len() - 1; + self.next[t] = self.left.len() - 1; + self.next.push(usize::MAX); + } + + fn clear(&mut self) { + self.left.clear(); + self.right.clear(); + self.child.clear(); + self.head.clear(); + self.tail.clear(); + self.next.clear(); + } + + fn release_memory(&mut self) { + swap_with_empty(&mut self.head); + swap_with_empty(&mut self.next); + swap_with_empty(&mut self.left); + swap_with_empty(&mut self.right); + swap_with_empty(&mut self.child); + swap_with_empty(&mut self.tail); + } + + fn extract_buffered_births(&mut self) -> BufferedBirths { + let mut b = BufferedBirths::default(); + std::mem::swap(&mut self.buffered_births, &mut b); + b + } + + // Should Err if prents/children not unique + pub fn setup_births( + &mut self, + parents: &[NodeId], + children: &[NodeId], + ) -> Result<(), TskitError> { + self.buffered_births.initialize(parents, children) + } + + pub fn finalize_births(&mut self) { + let buffered_births = self.extract_buffered_births(); + for (p, children) in buffered_births.segments.iter() { + for c in buffered_births.children.iter() { + if let Some(segs) = children.get(c) { + for s in segs { + self.buffer_birth(*p, *c, s.left, s.right).unwrap(); + } + } else { + // should be error + panic!(); + } + } + } + } + + pub fn record_birth( + &mut self, + parent: P, + child: C, + left: L, + right: R, + ) -> Result<(), TskitError> + where + P: Into, + C: Into, + L: Into, + R: Into, + { + let parent = parent.into(); + + let child = child.into(); + if let Some(parent_buffer) = self.buffered_births.segments.get_mut(&parent) { + if let Some(v) = parent_buffer.get_mut(&child) { + let left = left.into(); + let right = right.into(); + v.push(Segment { left, right }); + } else { + // should be an error + panic!(); + } + } else { + // should be an error + panic!(); + } + + Ok(()) + } + + // NOTE: tskit is overly strict during simplification, + // enforcing sorting requirements on the edge table + // that are not strictly necessary. + pub fn buffer_birth( + &mut self, + parent: P, + child: C, + left: L, + right: R, + ) -> Result<(), TskitError> + where + P: Into, + C: Into, + L: Into, + R: Into, + { + let parent = parent.into(); + if parent < 0 { + return Err(TskitError::IndexError); + } + + let parent = parent.as_usize(); + + if parent >= self.head.len() { + self.head.resize(parent + 1, usize::MAX); + self.tail.resize(parent + 1, usize::MAX); + } + + if self.head[parent] == usize::MAX { + self.insert_new_parent(parent, child.into(), left.into(), right.into()); + } else { + self.extend_parent(parent, child.into(), left.into(), right.into()); + } + Ok(()) + } + + // NOTE: we can probably have this function not error: + // the head array is populated by i32 converted to usize, + // so if things are getting out of range, we should be + // in trouble before this point. + // NOTE: we need a bitflags here for other options, like sorting the head + // contents based on birth time. + pub fn pre_simplification(&mut self, tables: &mut TableCollection) -> Result<(), TskitError> { + let num_input_edges = tables.edges().num_rows().as_usize(); + let mut head_index: Vec = self + .head + .iter() + .enumerate() + .filter(|(_, j)| **j != usize::MAX) + .map(|(i, _)| i) + .collect(); + + let node_time = tables.nodes().time_slice(); + head_index.sort_by(|a, b| node_time[*a].partial_cmp(&node_time[*b]).unwrap()); + //for (i, h) in self.head.iter().rev().enumerate() { + for h in head_index.into_iter() { + let parent = match i32::try_from(h) { + Ok(value) => value, + Err(_) => { + return Err(TskitError::RangeError( + "usize to i32 conversion failed".to_owned(), + )) + } + }; + tables.add_edge( + self.left[self.head[h]], + self.right[self.head[h]], + parent, + self.child[self.head[h]], + )?; + + let mut next = self.next[self.head[h]]; + while next != usize::MAX { + tables.add_edge(self.left[next], self.right[next], parent, self.child[next])?; + next = self.next[next]; + } + } + + self.release_memory(); + + // This assert is redundant b/c TableCollection + // works via MBox/NonNull. + assert!(!tables.as_ptr().is_null()); + // SAFETY: table collection pointer is not null and num_edges + // is the right length. + let num_edges = tables.edges().num_rows().as_usize(); + let edge_left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let edge_right = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) + }; + let edge_parent = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) + }; + let edge_child = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) + }; + edge_left.rotate_left(num_input_edges); + edge_right.rotate_left(num_input_edges); + edge_parent.rotate_left(num_input_edges); + edge_child.rotate_left(num_input_edges); + Ok(()) + } + + fn alive_node_times(&self, alive: &[NodeId], tables: &mut TableCollection) -> AliveNodeTimes { + let node_times = tables.nodes().time_slice_raw(); + let mut max_alive_node_time = 0.0; + let mut min_alive_node_time = f64::MAX; + + for a in alive { + let time = node_times[a.as_usize()]; + max_alive_node_time = if time > max_alive_node_time { + time + } else { + max_alive_node_time + }; + min_alive_node_time = if time < min_alive_node_time { + time + } else { + min_alive_node_time + }; + } + AliveNodeTimes::new(min_alive_node_time, max_alive_node_time) + } + + // The method here ends up creating a problem: + // we are buffering nodes with increasing node id + // that are also more ancient. This is the opposite + // order from what happens during a forward-time simulation. + // NOTE: the mechanics of this fn differ if we use + // "regular" simplification or streaming! + // For the former case, we have to do the setup/finalize + // business. For the latter, WE DO NOT. + // This differences suggests there are actually two types/impls + // being discussed here. + fn buffer_existing_edges( + &mut self, + pre_existing_edges: Vec, + tables: &mut TableCollection, + ) -> Result { + let parent = tables.edges().parent_slice(); + let child = tables.edges().child_slice(); + let left = tables.edges().left_slice(); + let right = tables.edges().right_slice(); + let mut rv = 0; + for pre in pre_existing_edges.iter() { + self.setup_births(&[parent[pre.first]], &child[pre.first..pre.last])?; + for e in pre.first..pre.last { + assert_eq!(parent[e], parent[pre.first]); + self.record_birth(parent[e], child[e], left[e], right[e])?; + rv += 1; + } + self.finalize_births(); + } + + Ok(rv) + } + + // FIXME: clean up commented-out code + // if we decide we don't need it. + fn collect_pre_existing_edges( + &self, + alive_node_times: AliveNodeTimes, + tables: &mut TableCollection, + ) -> Vec { + let mut edges = vec![]; + let mut i = 0; + let parent = tables.edges().parent_slice(); + //let child = tables.edges().child_slice(); + let node_time = tables.nodes().time_slice(); + while i < parent.len() { + let p = parent[i]; + // let c = child[i]; + if node_time[p.as_usize()] <= alive_node_times.max + //|| (node_time[c.as_usize()] < alive_node_times.max + // && node_time[p.as_usize()] > alive_node_times.max) + { + let mut j = 0_usize; + while i + j < parent.len() && parent[i + j] == p { + j += 1; + } + edges.push(PreExistingEdge::new(i, i + j)); + i += j; + } else { + break; + } + } + edges + } + + // FIXME: + // + // 1. If min/max parent alive times are equal, return. + // DONE + // 2. Else, we need to do a rotation at min_edge + // before truncation. + // DONE + // 3. However, we also have to respect our API + // and process each parent carefully, + // setting up the birth/death epochs. + // We need to use setup_births and finalize_births + // to get this right. + // DONE + // 4. We are doing this in the wrong temporal order. + // We need to pre-process all existing edge intervals, + // cache them, then go backwards through them, + // so that we buffer them present-to-past. + // DONE + // 5. This step should be EARLY in a recording epoch, + // so that we avoid the gotcha of stealing edges + // from the last generation of a simulation. + pub fn post_simplification( + &mut self, + alive: &[NodeId], + tables: &mut TableCollection, + ) -> Result<(), TskitError> { + self.clear(); + + let alive_node_times = self.alive_node_times(alive, tables); + if alive_node_times.non_overlapping() { + // There can be no overlap between current + // edges and births that are about to happen, + // so we get out. + return Ok(()); + } + + let pre_existing_edges = self.collect_pre_existing_edges(alive_node_times, tables); + let min_edge = self.buffer_existing_edges(pre_existing_edges, tables)?; + let num_edges = tables.edges().num_rows().as_usize(); + let edge_left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let edge_right = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) + }; + let edge_parent = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) + }; + let edge_child = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) + }; + edge_left.rotate_left(min_edge); + edge_right.rotate_left(min_edge); + edge_parent.rotate_left(min_edge); + edge_child.rotate_left(min_edge); + // SAFETY: ????? + let rv = unsafe { + crate::bindings::tsk_edge_table_truncate( + &mut (*tables.as_mut_ptr()).edges, + (num_edges - min_edge) as crate::bindings::tsk_size_t, + ) + }; + handle_tsk_return_value!(rv, ()) + } +} + +struct StreamingSimplifier { + simplifier: crate::bindings::tsk_streaming_simplifier_t, +} + +impl StreamingSimplifier { + fn new>( + samples: &[NodeId], + options: O, + tables: &mut TableCollection, + ) -> Result { + let mut simplifier = + std::mem::MaybeUninit::::uninit(); + let num_samples = samples.len() as crate::bindings::tsk_size_t; + match unsafe { + crate::bindings::tsk_streaming_simplifier_init( + simplifier.as_mut_ptr(), + tables.as_mut_ptr(), + samples.as_ptr().cast::(), + num_samples, + options.into().bits(), + ) + } { + code if code < 0 => Err(TskitError::ErrorCode { code }), + _ => Ok(Self { + simplifier: unsafe { simplifier.assume_init() }, + }), + } + } + + fn add_edge( + &mut self, + left: Position, + right: Position, + parent: NodeId, // FIXME: shouldn't be here + child: NodeId, + ) -> Result<(), TskitError> { + let code = unsafe { + crate::bindings::tsk_streaming_simplifier_add_edge( + &mut self.simplifier, + left.into(), + right.into(), + parent.into(), + child.into(), + ) + }; + handle_tsk_return_value!(code, ()) + } + + fn merge_ancestors(&mut self, parent: NodeId) -> Result<(), TskitError> { + let code = unsafe { + crate::bindings::tsk_streaming_simplifier_merge_ancestors( + &mut self.simplifier, + parent.into(), + ) + }; + handle_tsk_return_value!(code, ()) + } + + // FIXME: need to be able to validate that node_map is correct length! + fn finalise(&mut self, node_map: Option<&mut [NodeId]>) -> Result<(), TskitError> { + let n = match node_map { + Some(x) => x.as_mut_ptr().cast::(), + None => std::ptr::null_mut(), + }; + let code = + unsafe { crate::bindings::tsk_streaming_simplifier_finalise(&mut self.simplifier, n) }; + handle_tsk_return_value!(code, ()) + } + + fn input_num_edges(&self) -> usize { + unsafe { + crate::bindings::tsk_streaming_simplifier_get_num_input_edges(&self.simplifier) as usize + } + } + + fn input_left(&self) -> &[Position] { + unsafe { + std::slice::from_raw_parts( + crate::bindings::tsk_streaming_simplifier_get_input_left(&self.simplifier) + .cast::(), + self.input_num_edges(), + ) + } + } + + fn input_right(&self) -> &[Position] { + unsafe { + std::slice::from_raw_parts( + crate::bindings::tsk_streaming_simplifier_get_input_right(&self.simplifier) + .cast::(), + self.input_num_edges(), + ) + } + } + + fn input_parent(&self) -> &[NodeId] { + unsafe { + std::slice::from_raw_parts( + crate::bindings::tsk_streaming_simplifier_get_input_parent(&self.simplifier) + .cast::(), + self.input_num_edges(), + ) + } + } + + fn input_child(&self) -> &[NodeId] { + unsafe { + std::slice::from_raw_parts( + crate::bindings::tsk_streaming_simplifier_get_input_child(&self.simplifier) + .cast::(), + self.input_num_edges(), + ) + } + } + + fn get_input_parent(&self, u: usize) -> NodeId { + assert!(u < self.input_num_edges()); + self.input_parent()[u] + } + fn get_input_child(&self, u: usize) -> NodeId { + assert!(u < self.input_num_edges()); + self.input_child()[u] + } + fn get_input_left(&self, u: usize) -> Position { + assert!(u < self.input_num_edges()); + self.input_left()[u] + } + fn get_input_right(&self, u: usize) -> Position { + assert!(u < self.input_num_edges()); + self.input_right()[u] + } +} + +impl Drop for StreamingSimplifier { + fn drop(&mut self) { + let code = unsafe { crate::bindings::tsk_streaming_simplifier_free(&mut self.simplifier) }; + assert_eq!(code, 0); + } +} + +// TODO: +// 1. The edge buffer API is wrong here. +// We need to encapsulate the existing type, +// and make one whose public API does what we need. +// 2. If this works out, it means we need to extract +// the core buffer ops out to a private type +// and make public newtypes using it. +// FIXME: this function is unsafe b/c of how tskit-c +// messes w/pointers behind the scenes. +// Solution is to take ownership of the tables? +pub fn simplfify_from_buffer>( + samples: &[NodeId], + options: O, + tables: &mut TableCollection, + buffer: &mut EdgeBuffer, + node_map: Option<&mut [NodeId]>, +) -> Result<(), TskitError> { + // have to take copies of the current members of + // the edge table. + let mut last_parent_time = -1.0; + let mut head_index: Vec = buffer + .head + .iter() + .enumerate() + .filter(|(_, j)| **j != usize::MAX) + .map(|(i, _)| i) + .collect(); + + let node_time = tables.nodes().time_slice(); + head_index.sort_by(|a, b| node_time[*a].partial_cmp(&node_time[*b]).unwrap()); + let mut simplifier = StreamingSimplifier::new(samples, options, tables)?; + // Simplify the most recent births + //for (i, h) in buffer.head.iter().rev().enumerate() { + for h in head_index.into_iter() { + let parent = i32::try_from(h).unwrap(); + simplifier.add_edge( + buffer.left[buffer.head[h]], + buffer.right[buffer.head[h]], + parent.into(), + buffer.child[buffer.head[h]], + )?; + let mut next = buffer.next[buffer.head[h]]; + assert!(parent >= 0); + while next != usize::MAX { + assert!(next < buffer.left.len()); + simplifier.add_edge( + buffer.left[next], + buffer.right[next], + parent.into(), + buffer.child[next], + )?; + next = buffer.next[next]; + } + simplifier.merge_ancestors(parent.into())?; + + // major stress-test -- delete later + //{ + // let l = tables.edges().left_slice(); + // let p = tables.edges().parent_slice(); + // let c = tables.edges().child_slice(); + // let mut i = 0; + // while i < l.len() { + // let pi = p[i]; + // while i < l.len() && p[i] == pi { + // if i > 0 && c[i] == c[i - 1] { + // assert_ne!( + // l[i], + // l[i - 1], + // "{:?},{:?} | {:?},{:?} | {:?},{:?} => {:?}", + // p[i], + // p[i - 1], + // c[i], + // c[i - 1], + // l[i], + // l[i - 1], + // edge_check + // ); + // } + // i += 1; + // } + // } + //} + } + buffer.release_memory(); + + // Simplify pre-existing edges. + //let mut i = 0; + //let num_input_edges = simplifier.input_num_edges(); + //while i < num_input_edges { + // let p = simplifier.get_input_parent(i); + // //let mut edge_check: Vec<(NodeId, Position)> = vec![]; + // while i < num_input_edges && simplifier.get_input_parent(i) == p { + // //assert!(!edge_check.iter().any(|x| *x == (child[i], left[i]))); + // simplifier.add_edge( + // simplifier.get_input_left(i), + // simplifier.get_input_right(i), + // simplifier.get_input_parent(i), + // simplifier.get_input_child(i), + // )?; + // //edge_check.push((child[i], left[i])); + // i += 1; + // } + // simplifier.merge_ancestors(p)?; + // // major stress-test -- delete later + // //{ + // // let l = tables.edges().left_slice(); + // // let p = tables.edges().parent_slice(); + // // let c = tables.edges().child_slice(); + // // let mut i = 0; + // // while i < l.len() { + // // let pi = p[i]; + // // while i < l.len() && p[i] == pi { + // // if i > 0 && c[i] == c[i - 1] { + // // assert_ne!( + // // l[i], + // // l[i - 1], + // // "{:?},{:?} | {:?},{:?} | {:?},{:?} => {:?}", + // // p[i], + // // p[i - 1], + // // c[i], + // // c[i - 1], + // // l[i], + // // l[i - 1], + // // edge_check + // // ); + // // } + // // i += 1; + // // } + // // } + // //} + //} + + simplifier.finalise(node_map)?; + Ok(()) +} + +#[test] +fn test_pre_simplification() { + let mut tables = TableCollection::new(10.).unwrap(); + let mut buffer = EdgeBuffer::default(); + let p0 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let p1 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let c0 = tables.add_node(0, 0.0, -1, -1).unwrap(); + let c1 = tables.add_node(0, 0.0, -1, -1).unwrap(); + buffer.setup_births(&[p0, p1], &[c0, c1]).unwrap(); + + // Record data in a way that intentionally + // breaks what tskit wants: + // * children are not sorted in increading order + // of id. + buffer.record_birth(0, 3, 5.0, 10.0).unwrap(); + buffer.record_birth(0, 2, 0.0, 5.0).unwrap(); + buffer.record_birth(1, 3, 0.0, 5.0).unwrap(); + buffer.record_birth(1, 2, 5.0, 10.0).unwrap(); + buffer.finalize_births(); + buffer.pre_simplification(&mut tables).unwrap(); + assert_eq!(tables.edges().num_rows(), 4); + tables.simplify(&[2, 3], 0, false).unwrap(); + assert_eq!(tables.edges().num_rows(), 0); +} + +#[test] +fn test_post_simplification() { + let mut tables = TableCollection::new(10.).unwrap(); + let p0 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let p1 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let c0 = tables.add_node(0, 0.0, -1, -1).unwrap(); + let c1 = tables.add_node(0, 0.0, -1, -1).unwrap(); + let _e0 = tables.add_edge(0.0, 10.0, p0, c0).unwrap(); + let _e1 = tables.add_edge(0.0, 10.0, p1, c1).unwrap(); + assert_eq!(tables.edges().num_rows(), 2); + let alive = vec![c0, c1]; // the children have replaced the parents + let mut buffer = EdgeBuffer::default(); + buffer.post_simplification(&alive, &mut tables).unwrap(); + assert_eq!(tables.edges().num_rows(), 2); +} diff --git a/src/lib.rs b/src/lib.rs index 81e5b1228..6ac5bfb85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,6 +140,17 @@ pub use trees::{Tree, TreeSequence}; #[cfg_attr(doc_cfg, doc(cfg(feature = "provenance")))] pub mod provenance; +#[cfg(feature = "edgebuffer")] +mod edgebuffer; + +#[cfg(feature = "edgebuffer")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "edgebuffer")))] +pub use edgebuffer::EdgeBuffer; + +#[cfg(feature = "edgebuffer")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "edgebuffer")))] +pub use edgebuffer::simplfify_from_buffer; + /// Handles return codes from low-level tskit functions. /// /// When an error from the tskit C API is detected, diff --git a/subprojects/tskit/tskit/tables.c b/subprojects/tskit/tskit/tables.c index ee628d24f..52c84e333 100644 --- a/subprojects/tskit/tskit/tables.c +++ b/subprojects/tskit/tskit/tables.c @@ -23,6 +23,7 @@ * SOFTWARE. */ +#include "tables.h" #include #include #include @@ -12870,3 +12871,106 @@ tsk_squash_edges(tsk_edge_t *edges, tsk_size_t num_edges, tsk_size_t *num_output out: return ret; } + +// KRT's latest madness + +typedef struct __tsk_streaming_simplifier_impl_t { + simplifier_t simplifier; +} tsk_streaming_simplifier_impl_t; + +int tsk_streaming_simplifier_init(tsk_streaming_simplifier_t * self, + tsk_table_collection_t *tables, const tsk_id_t *samples, + tsk_size_t num_samples, tsk_flags_t options) { + int ret = 0; + self->pimpl = tsk_malloc(sizeof(tsk_streaming_simplifier_impl_t)); + if (self->pimpl == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = simplifier_init(&self->pimpl->simplifier, samples, num_samples, tables, options); + if (ret != 0) { + goto out; + } + +out: + return ret; +} + +// FIXME: parent not used +int tsk_streaming_simplifier_add_edge(tsk_streaming_simplifier_t * self, + double left, double right, tsk_id_t parent, tsk_id_t child) { + int ret = simplifier_extract_ancestry(&self->pimpl->simplifier, left, right, child); + return ret; +} + +int tsk_streaming_simplifier_merge_ancestors(tsk_streaming_simplifier_t * self, tsk_id_t parent) { + int ret = simplifier_merge_ancestors(&self->pimpl->simplifier, parent); + self->pimpl->simplifier.segment_queue_size = 0; + return ret; +} + +int tsk_streaming_simplifier_finalise(tsk_streaming_simplifier_t * self, tsk_id_t * node_map) { + int ret = 0; + simplifier_t * simplifier = &self->pimpl->simplifier; + ret = simplifier_run(simplifier, node_map); + //if (simplifier->options & TSK_SIMPLIFY_KEEP_INPUT_ROOTS) { + // ret = simplifier_insert_input_roots(simplifier); + // if (ret != 0) { + // goto out; + // } + //} + //ret = simplifier_output_sites(simplifier); + //if (ret != 0) { + // goto out; + //} + //ret = simplifier_finalise_references(simplifier); + //if (ret != 0) { + // goto out; + //} + //if (node_map != NULL) { + // /* Finally, output the new IDs for the nodes, if required. */ + // tsk_memcpy(node_map, simplifier->node_id_map, + // simplifier->input_tables.nodes.num_rows * sizeof(tsk_id_t)); + //} + //if (simplifier->edge_sort_offset != TSK_NULL) { + // tsk_bug_assert(simplifier->options & TSK_SIMPLIFY_KEEP_INPUT_ROOTS); + // ret = simplifier_sort_edges(simplifier); + // if (ret != 0) { + // goto out; + // } + //} +//out: + return ret; +} + +const tsk_id_t * tsk_streaming_simplifier_get_input_parent(const tsk_streaming_simplifier_t * self) +{ + return self->pimpl->simplifier.input_tables.edges.parent; +} +const tsk_id_t * tsk_streaming_simplifier_get_input_child(const tsk_streaming_simplifier_t * self) +{ + return self->pimpl->simplifier.input_tables.edges.child; +} +const double * tsk_streaming_simplifier_get_input_left(const tsk_streaming_simplifier_t * self) +{ + return self->pimpl->simplifier.input_tables.edges.left; +} +const double * tsk_streaming_simplifier_get_input_right(const tsk_streaming_simplifier_t * self) +{ + return self->pimpl->simplifier.input_tables.edges.right; +} + +tsk_size_t tsk_streaming_simplifier_get_num_input_edges(const tsk_streaming_simplifier_t * self) { + return self->pimpl->simplifier.input_tables.edges.num_rows; +} + +int tsk_streaming_simplifier_free(tsk_streaming_simplifier_t * self) { + int ret = 0; + ret = simplifier_free(&self->pimpl->simplifier); + if (ret != 0) { + goto out; + } + tsk_safe_free(self->pimpl); +out: + return ret; +} diff --git a/subprojects/tskit/tskit/tables.h b/subprojects/tskit/tskit/tables.h index bab354622..397c8c8ac 100644 --- a/subprojects/tskit/tskit/tables.h +++ b/subprojects/tskit/tskit/tables.h @@ -670,6 +670,32 @@ typedef struct { bool store_pairs; } tsk_identity_segments_t; +// KRT's latest insanity +typedef struct { + /* don't leak private types into public API */ + struct __tsk_streaming_simplifier_impl_t * pimpl; +} tsk_streaming_simplifier_t; + +int tsk_streaming_simplifier_init(tsk_streaming_simplifier_t * self, + tsk_table_collection_t *tables, const tsk_id_t *samples, + tsk_size_t num_samples, tsk_flags_t options); +int tsk_streaming_simplifier_free(tsk_streaming_simplifier_t * self); +// metadata... +int tsk_streaming_simplifier_add_edge(tsk_streaming_simplifier_t * self, + double left, double right, tsk_id_t parent, tsk_id_t child); +int tsk_streaming_simplifier_merge_ancestors(tsk_streaming_simplifier_t * self, tsk_id_t parent); + +// runs the simplifier, thus processing ancient edges +// present in the input edge table. +int tsk_streaming_simplifier_finalise(tsk_streaming_simplifier_t * self, tsk_id_t *node_map); + +// None of this is needed anymore. +const tsk_id_t * tsk_streaming_simplifier_get_input_parent(const tsk_streaming_simplifier_t * self); +const tsk_id_t * tsk_streaming_simplifier_get_input_child(const tsk_streaming_simplifier_t * self); +const double * tsk_streaming_simplifier_get_input_left(const tsk_streaming_simplifier_t * self); +const double * tsk_streaming_simplifier_get_input_right(const tsk_streaming_simplifier_t * self); +tsk_size_t tsk_streaming_simplifier_get_num_input_edges(const tsk_streaming_simplifier_t * self); + /****************************************************************************/ /* Common function options */ /****************************************************************************/ diff --git a/tests/test_edge_buffer.rs b/tests/test_edge_buffer.rs new file mode 100644 index 000000000..ab27673dd --- /dev/null +++ b/tests/test_edge_buffer.rs @@ -0,0 +1,314 @@ +#![cfg(feature = "edgebuffer")] + +use proptest::prelude::*; +use rand::distributions::Distribution; +use rand::SeedableRng; + +use tskit::EdgeBuffer; +use tskit::NodeId; +use tskit::TableCollection; +use tskit::TreeSequence; +use tskit::TskitError; + +trait Recording { + fn add_node(&mut self, flags: u32, time: f64) -> Result; + fn add_edge( + &mut self, + left: f64, + right: f64, + parent: NodeId, + child: NodeId, + ) -> Result<(), TskitError>; + + fn simplify(&mut self, samples: &mut [NodeId]) -> Result<(), TskitError>; + fn post_simplify(&mut self, _samples: &mut [NodeId]) -> Result<(), TskitError> { + Ok(()) + } + fn start_recording(&mut self, _parents: &[NodeId], _child: &[NodeId]) {} + fn end_recording(&mut self) {} +} + +struct TableCollectionWithBuffer { + tables: TableCollection, + buffer: EdgeBuffer, +} + +impl TableCollectionWithBuffer { + fn new() -> Self { + Self { + tables: TableCollection::new(1.0).unwrap(), + buffer: EdgeBuffer::default(), + } + } +} + +impl Recording for TableCollectionWithBuffer { + fn add_node(&mut self, flags: u32, time: f64) -> Result { + self.tables.add_node(flags, time, -1, -1) + } + + fn add_edge( + &mut self, + left: f64, + right: f64, + parent: NodeId, + child: NodeId, + ) -> Result<(), TskitError> { + self.buffer.record_birth(parent, child, left, right) + } + + fn start_recording(&mut self, parents: &[NodeId], children: &[NodeId]) { + self.buffer.setup_births(parents, children).unwrap() + } + + fn end_recording(&mut self) { + self.buffer.finalize_births() + } + + fn simplify(&mut self, samples: &mut [NodeId]) -> Result<(), TskitError> { + self.buffer.pre_simplification(&mut self.tables).unwrap(); + match self.tables.simplify(samples, 0, true) { + Ok(Some(idmap)) => { + for s in samples.iter_mut() { + *s = idmap[s.as_usize()]; + } + Ok(()) + } + Ok(None) => panic!(), + Err(e) => Err(e), + } + } + + fn post_simplify(&mut self, samples: &mut [NodeId]) -> Result<(), TskitError> { + self.buffer.post_simplification(samples, &mut self.tables) + } +} + +impl From for TreeSequence { + fn from(value: TableCollectionWithBuffer) -> Self { + value + .tables + .tree_sequence(tskit::TreeSequenceFlags::BUILD_INDEXES) + .unwrap() + } +} + +#[repr(transparent)] +struct StandardTableCollection(TableCollection); + +impl StandardTableCollection { + fn new() -> Self { + Self(TableCollection::new(1.0).unwrap()) + } +} + +struct TableCollectionWithBufferForStreaming { + tables: TableCollection, + buffer: EdgeBuffer, + node_map: Vec, +} + +impl TableCollectionWithBufferForStreaming { + fn new() -> Self { + Self { + tables: TableCollection::new(1.0).unwrap(), + buffer: EdgeBuffer::default(), + node_map: vec![], + } + } +} + +impl Recording for TableCollectionWithBufferForStreaming { + fn add_node(&mut self, flags: u32, time: f64) -> Result { + self.tables.add_node(flags, time, -1, -1) + } + + fn add_edge( + &mut self, + left: f64, + right: f64, + parent: NodeId, + child: NodeId, + ) -> Result<(), TskitError> { + self.buffer.buffer_birth(parent, child, left, right) + } + + fn simplify(&mut self, samples: &mut [NodeId]) -> Result<(), TskitError> { + self.node_map.resize( + self.tables.nodes().num_rows().as_usize(), + tskit::NodeId::NULL, + ); + tskit::simplfify_from_buffer( + samples, + tskit::SimplificationOptions::default(), + &mut self.tables, + &mut self.buffer, + Some(&mut self.node_map), + ) + .unwrap(); + for o in samples.iter_mut() { + assert!(o.as_usize() < self.node_map.len()); + *o = self.node_map[usize::try_from(*o).unwrap()]; + assert!(!o.is_null()); + } + Ok(()) + } + + fn post_simplify(&mut self, samples: &mut [NodeId]) -> Result<(), TskitError> { + self.buffer.post_simplification(samples, &mut self.tables) + } +} + +impl From for TreeSequence { + fn from(value: TableCollectionWithBufferForStreaming) -> Self { + value + .tables + .tree_sequence(tskit::TreeSequenceFlags::BUILD_INDEXES) + .unwrap() + } +} + +impl Recording for StandardTableCollection { + fn add_node(&mut self, flags: u32, time: f64) -> Result { + self.0.add_node(flags, time, -1, -1) + } + fn add_edge( + &mut self, + left: f64, + right: f64, + parent: NodeId, + child: NodeId, + ) -> Result<(), TskitError> { + match self.0.add_edge(left, right, parent, child) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + fn simplify(&mut self, samples: &mut [NodeId]) -> Result<(), TskitError> { + self.0.full_sort(0).unwrap(); + match self.0.simplify(samples, 0, true) { + Ok(Some(idmap)) => { + for s in samples { + *s = idmap[s.as_usize()]; + } + Ok(()) + } + Ok(None) => panic!("need to remap input sample nodes"), + Err(e) => Err(e), + } + } +} + +impl From for TreeSequence { + fn from(value: StandardTableCollection) -> Self { + let mut value = value; + value.0.build_index().unwrap(); + value.0.tree_sequence(0.into()).unwrap() + } +} + +fn overlapping_generations(seed: u64, pdeath: f64, simplify: i32, recorder: T) -> TreeSequence +where + T: Into + Recording, +{ + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + + let popsize = 10; + let nsteps = 10; + + let mut parents = vec![]; + + let mut recorder = recorder; + + for _ in 0..popsize { + let node = recorder.add_node(0, nsteps as f64).unwrap(); + parents.push(node); + } + + let death = rand::distributions::Uniform::new(0., 1.0); + let parent_picker = rand::distributions::Uniform::new(0, popsize); + let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0); + + for birth_time in (0..nsteps).rev() { + let mut replacements = vec![]; + for i in 0..parents.len() { + if death.sample(&mut rng) <= pdeath { + replacements.push(i); + } + } + let mut births = vec![]; + + for _ in 0..replacements.len() { + let parent_index = parent_picker.sample(&mut rng); + let parent = parents[parent_index]; + let parent_index = parent_picker.sample(&mut rng); + let parent2 = parents[parent_index]; + let child = recorder.add_node(0, birth_time as f64).unwrap(); + births.push(child); + let breakpoint = breakpoint_generator.sample(&mut rng); + recorder.start_recording(&[parent, parent2], &[child]); + recorder.add_edge(0., breakpoint, parent, child).unwrap(); + recorder.add_edge(breakpoint, 1., parent2, child).unwrap(); + recorder.end_recording(); + } + for (r, b) in replacements.iter().zip(births.iter()) { + assert!(*r < parents.len()); + parents[*r] = *b; + } + if birth_time % simplify == 0 { + recorder.simplify(&mut parents).unwrap(); + if birth_time > 0 { + recorder.post_simplify(&mut parents).unwrap(); + } + } + } + recorder.into() +} + +fn compare_treeseqs(a: &TreeSequence, b: &TreeSequence) { + use streaming_iterator::StreamingIterator; + assert_eq!(a.edges().num_rows(), b.edges().num_rows()); + assert_eq!(a.nodes().num_rows(), b.nodes().num_rows()); + assert_eq!(a.num_trees(), b.num_trees()); + + let mut trees_a = a.tree_iterator(0).unwrap(); + let mut trees_b = b.tree_iterator(0).unwrap(); + + while let Some(tree_a) = trees_a.next() { + let tree_b = trees_b.next().unwrap(); + assert_eq!(tree_a.interval(), tree_b.interval()); + assert_eq!( + tree_a.total_branch_length(true).unwrap(), + tree_b.total_branch_length(true).unwrap() + ); + } +} + +fn run_overlapping_generations_test(seed: u64, pdeath: f64, simplify_interval: i32) { + let standard = StandardTableCollection::new(); + let standard_treeseq = overlapping_generations(seed, pdeath, simplify_interval, standard); + let with_buffer = TableCollectionWithBuffer::new(); + let standard_with_buffer = + overlapping_generations(seed, pdeath, simplify_interval, with_buffer); + let with_buffer_streaming = TableCollectionWithBufferForStreaming::new(); + let standard_with_buffer_streaming = + overlapping_generations(seed, pdeath, simplify_interval, with_buffer_streaming); + + compare_treeseqs(&standard_treeseq, &standard_with_buffer); + compare_treeseqs(&standard_treeseq, &standard_with_buffer_streaming); +} + +#[test] +fn failing_test_params() { + run_overlapping_generations_test(3491384373429438832, 0.49766542321295254, 1); +} + +#[cfg(test)] +proptest! { + #[test] + fn test_edge_buffer_overlapping_generations(seed in any::(), + pdeath in 0.05..1.0, + simplify_interval in 1..100i32) { + run_overlapping_generations_test(seed, pdeath, simplify_interval) + } +}