Skip to content

Commit 0f7903c

Browse files
committed
add naive edge buffer
1 parent 5ed5b4b commit 0f7903c

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed

Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ rustdoc-args = ["--cfg", "doc_cfg"]
5858
# Not run during tests
5959
[[example]]
6060
name = "tree_traversals"
61+
62+
[[example]]
63+
name = "haploid_wright_fisher_simple_edge_buffer"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// This is a rust implementation of the example
2+
// found in tskit-c
3+
4+
use std::collections::HashMap;
5+
6+
use anyhow::Result;
7+
use clap::Parser;
8+
#[cfg(test)]
9+
use rand::distributions::Distribution;
10+
use rand::prelude::*;
11+
use rand::SeedableRng;
12+
13+
#[derive(Debug)]
14+
struct Edge {
15+
left: tskit::Position,
16+
right: tskit::Position,
17+
child: tskit::NodeId,
18+
previous: Option<usize>,
19+
}
20+
21+
#[derive(Default)]
22+
struct EdgeBuffer {
23+
parent: Vec<tskit::NodeId>,
24+
last: HashMap<tskit::NodeId, usize>,
25+
edges: Vec<Edge>,
26+
}
27+
28+
impl EdgeBuffer {
29+
fn buffer_edge(
30+
&mut self,
31+
left: tskit::Position,
32+
right: tskit::Position,
33+
parent: tskit::NodeId,
34+
child: tskit::NodeId,
35+
) -> Result<()> {
36+
if let Some(last) = self.last.get_mut(&parent) {
37+
let pchild = self.edges[*last].child;
38+
assert!(child >= pchild);
39+
self.edges.push(Edge {
40+
left,
41+
right,
42+
child,
43+
previous: Some(*last),
44+
});
45+
*last = self.edges.len() - 1;
46+
} else {
47+
self.edges.push(Edge {
48+
left,
49+
right,
50+
child,
51+
previous: None,
52+
});
53+
self.last.insert(parent, self.edges.len() - 1);
54+
self.parent.push(parent);
55+
}
56+
Ok(())
57+
}
58+
59+
fn clear(&mut self) {
60+
self.parent.clear();
61+
self.last.clear();
62+
self.edges.clear();
63+
}
64+
}
65+
66+
fn rotate_edges(bookmark: &tskit::types::Bookmark, tables: &mut tskit::TableCollection) {
67+
let num_edges = tables.edges().num_rows().as_usize();
68+
let left =
69+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) };
70+
let right =
71+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) };
72+
let parent =
73+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) };
74+
let child =
75+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) };
76+
let mid = bookmark.edges().as_usize();
77+
left.rotate_left(mid);
78+
right.rotate_left(mid);
79+
parent.rotate_left(mid);
80+
child.rotate_left(mid);
81+
}
82+
83+
// ANCHOR: haploid_wright_fisher
84+
fn simulate(
85+
seed: u64,
86+
popsize: usize,
87+
num_generations: i32,
88+
simplify_interval: i32,
89+
) -> Result<tskit::TreeSequence> {
90+
if popsize == 0 {
91+
return Err(anyhow::Error::msg("popsize must be > 0"));
92+
}
93+
if num_generations == 0 {
94+
return Err(anyhow::Error::msg("num_generations must be > 0"));
95+
}
96+
if simplify_interval == 0 {
97+
return Err(anyhow::Error::msg("simplify_interval must be > 0"));
98+
}
99+
let mut tables = tskit::TableCollection::new(1.0)?;
100+
101+
// create parental nodes
102+
let mut parents_and_children = {
103+
let mut temp = vec![];
104+
let parental_time = f64::from(num_generations);
105+
for _ in 0..popsize {
106+
let node = tables.add_node(0, parental_time, -1, -1)?;
107+
temp.push(node);
108+
}
109+
temp
110+
};
111+
112+
// allocate space for offspring nodes
113+
parents_and_children.resize(2 * parents_and_children.len(), tskit::NodeId::NULL);
114+
115+
// Construct non-overlapping mutable slices into our vector.
116+
let (mut parents, mut children) = parents_and_children.split_at_mut(popsize);
117+
118+
let parent_picker = rand::distributions::Uniform::new(0, popsize);
119+
let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0);
120+
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
121+
let mut bookmark = tskit::types::Bookmark::default();
122+
123+
let mut buffer = EdgeBuffer::default();
124+
for birth_time in (0..num_generations).rev() {
125+
for c in children.iter_mut() {
126+
let bt = f64::from(birth_time);
127+
let child = tables.add_node(0, bt, -1, -1)?;
128+
let left_parent = parents
129+
.get(parent_picker.sample(&mut rng))
130+
.ok_or_else(|| anyhow::Error::msg("invalid left_parent index"))?;
131+
let right_parent = parents
132+
.get(parent_picker.sample(&mut rng))
133+
.ok_or_else(|| anyhow::Error::msg("invalid right_parent index"))?;
134+
let breakpoint = breakpoint_generator.sample(&mut rng);
135+
buffer.buffer_edge(0_f64.into(), breakpoint.into(), *left_parent, child)?;
136+
buffer.buffer_edge(breakpoint.into(), 1_f64.into(), *right_parent, child)?;
137+
*c = child;
138+
}
139+
140+
if birth_time % simplify_interval == 0 {
141+
for &parent in buffer.parent.iter().rev() {
142+
let mut last = buffer.last.get(&parent).cloned();
143+
while let Some(previous) = last {
144+
let edge = &buffer.edges[previous];
145+
tables.add_edge(edge.left, edge.right, parent, edge.child)?;
146+
last = edge.previous;
147+
}
148+
}
149+
buffer.clear();
150+
rotate_edges(&bookmark, &mut tables);
151+
if let Some(idmap) =
152+
tables.simplify(children, tskit::SimplificationOptions::default(), true)?
153+
{
154+
// remap child nodes
155+
for o in children.iter_mut() {
156+
*o = idmap[usize::try_from(*o)?];
157+
}
158+
}
159+
bookmark.set_edges(tables.edges().num_rows());
160+
}
161+
std::mem::swap(&mut parents, &mut children);
162+
}
163+
164+
tables.build_index()?;
165+
let treeseq = tables.tree_sequence(tskit::TreeSequenceFlags::default())?;
166+
167+
Ok(treeseq)
168+
}
169+
// ANCHOR_END: haploid_wright_fisher
170+
171+
#[derive(Clone, clap::Parser)]
172+
struct SimParams {
173+
seed: u64,
174+
popsize: usize,
175+
num_generations: i32,
176+
simplify_interval: i32,
177+
treefile: Option<String>,
178+
#[clap(short, long, help = "Use bookmark to avoid sorting entire edge table.")]
179+
bookmark: bool,
180+
}
181+
182+
fn main() -> Result<()> {
183+
let params = SimParams::parse();
184+
let treeseq = simulate(
185+
params.seed,
186+
params.popsize,
187+
params.num_generations,
188+
params.simplify_interval,
189+
)?;
190+
191+
if let Some(treefile) = &params.treefile {
192+
treeseq.dump(treefile, 0)?;
193+
}
194+
195+
Ok(())
196+
}

0 commit comments

Comments
 (0)