Skip to content

Commit 8d68f6f

Browse files
committed
feat: Implement efficient edge buffering
* cargo feature edgebuffer
1 parent ce926cb commit 8d68f6f

File tree

5 files changed

+735
-0
lines changed

5 files changed

+735
-0
lines changed

Diff for: Cargo.toml

+5
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pkg-config = "0.3"
5050
[features]
5151
provenance = ["humantime"]
5252
derive = ["tskit-derive", "serde", "serde_json", "bincode"]
53+
edgebuffer = []
5354

5455
[package.metadata.docs.rs]
5556
all-features = true
@@ -58,3 +59,7 @@ rustdoc-args = ["--cfg", "doc_cfg"]
5859
# Not run during tests
5960
[[example]]
6061
name = "tree_traversals"
62+
63+
[[example]]
64+
name = "haploid_wright_fisher_edge_buffering"
65+
required-features = ["edgebuffer"]

Diff for: examples/haploid_wright_fisher_edge_buffering.rs

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// This is a rust implementation of the example
2+
// found in tskit-c
3+
4+
use anyhow::Result;
5+
use clap::Parser;
6+
#[cfg(test)]
7+
use proptest::prelude::*;
8+
use rand::distributions::Distribution;
9+
use rand::SeedableRng;
10+
11+
// ANCHOR: haploid_wright_fisher_edge_buffering
12+
fn simulate(
13+
seed: u64,
14+
popsize: usize,
15+
num_generations: i32,
16+
simplify_interval: i32,
17+
) -> Result<tskit::TreeSequence> {
18+
if popsize == 0 {
19+
return Err(anyhow::Error::msg("popsize must be > 0"));
20+
}
21+
if num_generations == 0 {
22+
return Err(anyhow::Error::msg("num_generations must be > 0"));
23+
}
24+
if simplify_interval == 0 {
25+
return Err(anyhow::Error::msg("simplify_interval must be > 0"));
26+
}
27+
let mut tables = tskit::TableCollection::new(1.0)?;
28+
29+
// create parental nodes
30+
let mut parents_and_children = {
31+
let mut temp = vec![];
32+
let parental_time = f64::from(num_generations);
33+
for _ in 0..popsize {
34+
let node = tables.add_node(0, parental_time, -1, -1)?;
35+
temp.push(node);
36+
}
37+
temp
38+
};
39+
40+
// allocate space for offspring nodes
41+
parents_and_children.resize(2 * parents_and_children.len(), tskit::NodeId::NULL);
42+
43+
// Construct non-overlapping mutable slices into our vector.
44+
let (mut parents, mut children) = parents_and_children.split_at_mut(popsize);
45+
46+
let parent_picker = rand::distributions::Uniform::new(0, popsize);
47+
let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0);
48+
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
49+
let mut buffer = tskit::EdgeBuffer::default();
50+
51+
for birth_time in (0..num_generations).rev() {
52+
for c in children.iter_mut() {
53+
let bt = f64::from(birth_time);
54+
let child = tables.add_node(0, bt, -1, -1)?;
55+
let left_parent = parents
56+
.get(parent_picker.sample(&mut rng))
57+
.ok_or_else(|| anyhow::Error::msg("invalid left_parent index"))?;
58+
let right_parent = parents
59+
.get(parent_picker.sample(&mut rng))
60+
.ok_or_else(|| anyhow::Error::msg("invalid right_parent index"))?;
61+
buffer.setup_births(&[*left_parent, *right_parent], &[child])?;
62+
let breakpoint = breakpoint_generator.sample(&mut rng);
63+
buffer.record_birth(*left_parent, child, 0., breakpoint)?;
64+
buffer.record_birth(*right_parent, child, breakpoint, 1.0)?;
65+
buffer.finalize_births();
66+
*c = child;
67+
}
68+
69+
if birth_time % simplify_interval == 0 {
70+
buffer.pre_simplification(&mut tables)?;
71+
//tables.full_sort(tskit::TableSortOptions::default())?;
72+
if let Some(idmap) =
73+
tables.simplify(children, tskit::SimplificationOptions::default(), true)?
74+
{
75+
// remap child nodes
76+
for o in children.iter_mut() {
77+
*o = idmap[usize::try_from(*o)?];
78+
}
79+
}
80+
buffer.post_simplification(children, &mut tables)?;
81+
}
82+
std::mem::swap(&mut parents, &mut children);
83+
}
84+
85+
tables.build_index()?;
86+
let treeseq = tables.tree_sequence(tskit::TreeSequenceFlags::default())?;
87+
88+
Ok(treeseq)
89+
}
90+
// ANCHOR_END: haploid_wright_fisher_edge_buffering
91+
92+
#[derive(Clone, clap::Parser)]
93+
struct SimParams {
94+
seed: u64,
95+
popsize: usize,
96+
num_generations: i32,
97+
simplify_interval: i32,
98+
treefile: Option<String>,
99+
}
100+
101+
fn main() -> Result<()> {
102+
let params = SimParams::parse();
103+
let treeseq = simulate(
104+
params.seed,
105+
params.popsize,
106+
params.num_generations,
107+
params.simplify_interval,
108+
)?;
109+
110+
if let Some(treefile) = &params.treefile {
111+
treeseq.dump(treefile, 0)?;
112+
}
113+
114+
Ok(())
115+
}
116+
117+
#[cfg(test)]
118+
proptest! {
119+
#[test]
120+
fn test_simulate_proptest(seed in any::<u64>(),
121+
num_generations in 50..100i32,
122+
simplify_interval in 1..100i32) {
123+
let ts = simulate(seed, 100, num_generations, simplify_interval).unwrap();
124+
125+
// stress test the branch length fn b/c it is not a trivial
126+
// wrapper around the C API.
127+
{
128+
use streaming_iterator::StreamingIterator;
129+
let mut x = f64::NAN;
130+
if let Ok(mut tree_iter) = ts.tree_iterator(0) {
131+
// We will only do the first tree to save time.
132+
if let Some(tree) = tree_iter.next() {
133+
let b = tree.total_branch_length(false).unwrap();
134+
let b2 = unsafe {
135+
tskit::bindings::tsk_tree_get_total_branch_length(tree.as_ptr(), -1, &mut x)
136+
};
137+
assert!(b2 >= 0, "{}", b2);
138+
assert!(f64::from(b) - x <= 1e-8);
139+
}
140+
}
141+
}
142+
}
143+
}
144+

0 commit comments

Comments
 (0)