Skip to content

Commit

Permalink
Slower but better results from a priority queue
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorHansen committed Jan 5, 2024
1 parent 80d9c63 commit 132a538
Showing 1 changed file with 39 additions and 51 deletions.
90 changes: 39 additions & 51 deletions src/extract/faster_greedy_dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// For example (+ (* x x ) (* x x )) has one mulitplication
// included in the cost.

use std::cmp::Reverse;
use std::collections::BinaryHeap;

use super::*;
use rustc_hash::{FxHashMap, FxHashSet};

Expand Down Expand Up @@ -92,7 +95,7 @@ impl Extractor for FasterGreedyDagExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::with_capacity(egraph.classes().len());
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
let mut analysis_pending = UniqueQueue::default();
let mut analysis_pending = MostlyUniquePriorityQueue::default();

for class in egraph.classes().values() {
parents.insert(class.id.clone(), Vec::new());
Expand All @@ -107,7 +110,7 @@ impl Extractor for FasterGreedyDagExtractor {

// start the analysis from leaves
if egraph[node].is_leaf() {
analysis_pending.insert(node.clone());
analysis_pending.insert(node.clone(), egraph[node].cost);
}
}
}
Expand All @@ -120,18 +123,20 @@ impl Extractor for FasterGreedyDagExtractor {

while let Some(node_id) = analysis_pending.pop() {
let class_id = n2c(&node_id);
let node = &egraph[&node_id];
if node.children.iter().all(|c| costs.contains_key(n2c(c))) {
let lookup = costs.get(class_id);
let mut prev_cost = INFINITY;
if lookup.is_some() {
prev_cost = lookup.unwrap().total;
}

let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs, prev_cost);
if cost_set.total < prev_cost {
costs.insert(class_id.clone(), cost_set);
analysis_pending.extend(parents[class_id].iter().cloned());
let lookup = costs.get(class_id);
let prev_cost = lookup.map_or(INFINITY, |v| v.total);

let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs, prev_cost);
if cost_set.total < prev_cost {
costs.insert(class_id.clone(), cost_set);
for e in &parents[class_id] {
if egraph[e]
.children
.iter()
.all(|c| costs.contains_key(n2c(c)))
{
analysis_pending.insert(e.clone(), egraph[e].cost);
}
}
}
}
Expand All @@ -150,57 +155,40 @@ Notably, insert/pop operations have O(1) expected amortized runtime complexity.
Thanks @Bastacyclop for the implementation!
*/

#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub(crate) struct UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
set: FxHashSet<T>, // hashbrown::
queue: std::collections::VecDeque<T>,
pub(crate) struct MostlyUniquePriorityQueue {
set: HashMap<NodeId, Cost>,
queue: BinaryHeap<Reverse<(Cost, NodeId)>>,
}

impl<T> Default for UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
impl Default for MostlyUniquePriorityQueue {
fn default() -> Self {
UniqueQueue {
MostlyUniquePriorityQueue {
set: Default::default(),
queue: std::collections::VecDeque::new(),
queue: BinaryHeap::new(),
}
}
}

impl<T> UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn insert(&mut self, t: T) {
if self.set.insert(t.clone()) {
self.queue.push_back(t);
impl MostlyUniquePriorityQueue {

// Note there can be duplicates innserted, but that's fine.
pub fn insert(&mut self, node_id: NodeId, cost: Cost) {
let old = self.set.get(&node_id);
if old.is_some() && *old.unwrap() <= cost {
// Skip if the existing cost is lower.
return;
}
}

pub fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = T>,
{
for t in iter.into_iter() {
self.insert(t);
}
self.set.insert(node_id.clone(), cost.clone());
self.queue.push(Reverse((cost, node_id.clone())));
}

pub fn pop(&mut self) -> Option<T> {
let res = self.queue.pop_front();
res.as_ref().map(|t| self.set.remove(t));
pub fn pop(&mut self) -> Option<NodeId> {
let res = self.queue.pop().map(|Reverse(t)| t.1);
res.as_ref().map(|node_id| self.set.remove(&node_id));
res
}

#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
let r = self.queue.is_empty();
debug_assert_eq!(r, self.set.is_empty());
r
}
}

0 comments on commit 132a538

Please sign in to comment.