diff --git a/Cargo.toml b/Cargo.toml index e2b3af6b..f87e9493 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,23 +2,24 @@ authors = ["Max Willsey "] categories = ["data-structures"] description = "An implementation of egraphs" -edition = "2018" +edition = "2021" keywords = ["e-graphs"] license = "MIT" name = "egg" readme = "README.md" repository = "https://github.com/egraphs-good/egg" version = "0.9.5" +rust-version = "1.63.0" [dependencies] env_logger = { version = "0.9.0", default-features = false } fxhash = "0.2.1" -hashbrown = "0.12.1" +hashbrown = { version = "0.14.3", default-features = false, features = ["inline-more"] } indexmap = "1.8.1" instant = "0.1.12" log = "0.4.17" smallvec = { version = "1.8.0", features = ["union", "const_generics"] } -symbol_table = { version = "0.2.0", features = ["global"] } +symbol_table = { version = "0.3.0", features = ["global"] } symbolic_expressions = "5.0.3" thiserror = "1.0.31" @@ -49,9 +50,11 @@ serde-1 = [ "vectorize", ] wasm-bindgen = ["instant/wasm-bindgen"] +push-pop-alt = [] # private features for testing test-explanations = [] +test-push-pop = ["deterministic"] [package.metadata.docs.rs] all-features = true diff --git a/Makefile b/Makefile index 229977bf..279518cb 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,8 @@ test: cargo test --release --features=lp # don't run examples in proof-production mode cargo test --release --features "test-explanations" + cargo test --release --features "test-push-pop" --features "test-explanations" + cargo test --release --features "test-push-pop" --features "push-pop-alt" .PHONY: nits diff --git a/rust-toolchain b/rust-toolchain index 2fef84a8..6cb4a6fe 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.60 \ No newline at end of file +1.63 \ No newline at end of file diff --git a/src/dot.rs b/src/dot.rs index cefaf440..b68028ce 100644 --- a/src/dot.rs +++ b/src/dot.rs @@ -1,7 +1,7 @@ /*! EGraph visualization with [GraphViz] -Use the [`Dot`] struct to visualize an [`EGraph`] +Use the [`Dot`] struct to visualize an [`EGraph`](crate::EGraph) [GraphViz]: https://graphviz.gitlab.io/ !*/ @@ -11,13 +11,13 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::io::{Error, ErrorKind, Result, Write}; use std::path::Path; -use crate::{egraph::EGraph, Analysis, Language}; +use crate::{raw::EGraphResidual, Language}; /** -A wrapper for an [`EGraph`] that can output [GraphViz] for +A wrapper for an [`EGraphResidual`] that can output [GraphViz] for visualization. -The [`EGraph::dot`](EGraph::dot()) method creates `Dot`s. +The [`EGraphResidual::dot`] method creates `Dot`s. # Example @@ -50,8 +50,8 @@ instead of to its own eclass. [GraphViz]: https://graphviz.gitlab.io/ **/ -pub struct Dot<'a, L: Language, N: Analysis> { - pub(crate) egraph: &'a EGraph, +pub struct Dot<'a, L: Language> { + pub(crate) egraph: &'a EGraphResidual, /// A list of strings to be output top part of the dot file. pub config: Vec, /// Whether or not to anchor the edges in the output. @@ -59,10 +59,9 @@ pub struct Dot<'a, L: Language, N: Analysis> { pub use_anchors: bool, } -impl<'a, L, N> Dot<'a, L, N> +impl<'a, L> Dot<'a, L> where L: Language + Display, - N: Analysis, { /// Writes the `Dot` to a .dot file with the given filename. /// Does _not_ require a `dot` binary. @@ -170,16 +169,15 @@ where } } -impl<'a, L: Language, N: Analysis> Debug for Dot<'a, L, N> { +impl<'a, L: Language> Debug for Dot<'a, L> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_tuple("Dot").field(self.egraph).finish() } } -impl<'a, L, N> Display for Dot<'a, L, N> +impl<'a, L> Display for Dot<'a, L> where L: Language + Display, - N: Analysis, { fn fmt(&self, f: &mut Formatter) -> fmt::Result { writeln!(f, "digraph egraph {{")?; @@ -192,17 +190,19 @@ where writeln!(f, " {}", line)?; } + let classes = self.egraph.generate_class_nodes(); + // define all the nodes, clustered by eclass - for class in self.egraph.classes() { - writeln!(f, " subgraph cluster_{} {{", class.id)?; + for (&id, class) in &classes { + writeln!(f, " subgraph cluster_{} {{", id)?; writeln!(f, " style=dotted")?; for (i, node) in class.iter().enumerate() { - writeln!(f, " {}.{}[label = \"{}\"]", class.id, i, node)?; + writeln!(f, " {}.{}[label = \"{}\"]", id, i, node)?; } writeln!(f, " }}")?; } - for class in self.egraph.classes() { + for (&id, class) in &classes { for (i_in_class, node) in class.iter().enumerate() { let mut arg_i = 0; node.try_for_each(|child| { @@ -210,19 +210,19 @@ where let (anchor, label) = self.edge(arg_i, node.len()); let child_leader = self.egraph.find(child); - if child_leader == class.id { + if child_leader == id { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.{}:n [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, class.id, i_in_class, class.id, label + id, i_in_class, anchor, id, i_in_class, id, label )?; } else { writeln!( f, // {}.0 to pick an arbitrary node in the cluster " {}.{}{} -> {}.0 [lhead = cluster_{}, {}]", - class.id, i_in_class, anchor, child, child_leader, label + id, i_in_class, anchor, child, child_leader, label )?; } arg_i += 1; diff --git a/src/eclass.rs b/src/eclass.rs index 5f74b2c2..fdb47f0b 100644 --- a/src/eclass.rs +++ b/src/eclass.rs @@ -1,15 +1,13 @@ -use std::fmt::Debug; +use std::fmt::{Debug, Formatter}; use std::iter::ExactSizeIterator; use crate::*; -/// An equivalence class of enodes. +/// The additional data required to turn a [`raw::RawEClass`] into a [`EClass`] #[non_exhaustive] -#[derive(Debug, Clone)] +#[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -pub struct EClass { - /// This eclass's id. - pub id: Id, +pub struct EClassData { /// The equivalent enodes in this equivalence class. pub nodes: Vec, /// The analysis data associated with this eclass. @@ -17,10 +15,19 @@ pub struct EClass { /// Modifying this field will _not_ cause changes to propagate through the e-graph. /// Prefer [`EGraph::set_analysis_data`] instead. pub data: D, - /// The parent enodes and their original Ids. - pub(crate) parents: Vec<(L, Id)>, } +impl Debug for EClassData { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut nodes = self.nodes.clone(); + nodes.sort(); + write!(f, "({:?}): {:?}", self.data, nodes) + } +} + +/// An equivalence class of enodes +pub type EClass = raw::RawEClass>; + impl EClass { /// Returns `true` if the `eclass` is empty. pub fn is_empty(&self) -> bool { @@ -28,6 +35,7 @@ impl EClass { } /// Returns the number of enodes in this eclass. + #[allow(clippy::len_without_is_empty)] // https://github.com/rust-lang/rust-clippy/issues/11165 pub fn len(&self) -> usize { self.nodes.len() } @@ -36,11 +44,6 @@ impl EClass { pub fn iter(&self) -> impl ExactSizeIterator { self.nodes.iter() } - - /// Iterates over the parent enodes of this eclass. - pub fn parents(&self) -> impl ExactSizeIterator { - self.parents.iter().map(|(node, id)| (node, *id)) - } } impl EClass { diff --git a/src/egraph.rs b/src/egraph.rs index 6af452b2..4e67c0de 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1,14 +1,23 @@ use crate::*; -use std::{ - borrow::BorrowMut, - fmt::{self, Debug, Display}, -}; +use std::fmt::{self, Debug, Display}; +use std::mem; +use std::ops::Deref; #[cfg(feature = "serde-1")] use serde::{Deserialize, Serialize}; +use crate::eclass::EClassData; +use crate::raw::{EGraphResidual, RawEGraph}; use log::*; +#[cfg(feature = "push-pop-alt")] +use raw::semi_persistent1 as sp; + +#[cfg(not(feature = "push-pop-alt"))] +use raw::semi_persistent2 as sp; + +use sp::UndoLog; +type PushInfo = (sp::PushInfo, explain::PushInfo, usize); /** A data structure to keep track of equalities between expressions. Check out the [background tutorial](crate::tutorials::_01_background) @@ -48,7 +57,7 @@ You must call [`EGraph::rebuild`] after deserializing an e-graph! [dot]: Dot [extract]: Extractor [sound]: https://itinerarium.github.io/phoneme-synthesis/?w=/'igraf/ -**/ + **/ #[derive(Clone)] #[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] pub struct EGraph> { @@ -56,16 +65,7 @@ pub struct EGraph> { pub analysis: N, /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, - unionfind: UnionFind, - /// Stores each enode's `Id`, not the `Id` of the eclass. - /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new - /// unions can cause them to become out of date. - #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] - memo: HashMap, - /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, - /// not the canonical id of the eclass. - pending: Vec<(L, Id)>, - analysis_pending: UniqueQueue<(L, Id)>, + analysis_pending: UniqueQueue, #[cfg_attr( feature = "serde-1", serde(bound( @@ -73,7 +73,7 @@ pub struct EGraph> { deserialize = "N::Data: for<'a> Deserialize<'a>", )) )] - pub(crate) classes: HashMap>, + pub(crate) inner: RawEGraph, Option>, #[cfg_attr(feature = "serde-1", serde(skip))] #[cfg_attr(feature = "serde-1", serde(default = "default_classes_by_op"))] pub(crate) classes_by_op: HashMap>, @@ -84,6 +84,8 @@ pub struct EGraph> { /// Only manually set it if you know what you're doing. #[cfg_attr(feature = "serde-1", serde(skip))] pub clean: bool, + push_log: Vec, + data_history: Vec<(Id, N::Data)>, } #[cfg(feature = "serde-1")] @@ -100,10 +102,16 @@ impl + Default> Default for EGraph { // manual debug impl to avoid L: Language bound on EGraph defn impl> Debug for EGraph { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("EGraph") - .field("memo", &self.memo) - .field("classes", &self.classes) - .finish() + self.inner.fmt(f) + } +} + +impl> Deref for EGraph { + type Target = EGraphResidual; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner } } @@ -112,38 +120,24 @@ impl> EGraph { pub fn new(analysis: N) -> Self { Self { analysis, - classes: Default::default(), - unionfind: Default::default(), clean: false, explain: None, - pending: Default::default(), - memo: Default::default(), + inner: Default::default(), analysis_pending: Default::default(), classes_by_op: Default::default(), + push_log: Default::default(), + data_history: Default::default(), } } /// Returns an iterator over the eclasses in the egraph. pub fn classes(&self) -> impl ExactSizeIterator> { - self.classes.values() + self.inner.classes() } /// Returns an mutating iterator over the eclasses in the egraph. pub fn classes_mut(&mut self) -> impl ExactSizeIterator> { - self.classes.values_mut() - } - - /// Returns `true` if the egraph is empty - /// # Example - /// ``` - /// use egg::{*, SymbolLang as S}; - /// let mut egraph = EGraph::::default(); - /// assert!(egraph.is_empty()); - /// egraph.add(S::leaf("foo")); - /// assert!(!egraph.is_empty()); - /// ``` - pub fn is_empty(&self) -> bool { - self.memo.is_empty() + self.inner.classes_mut().0 } /// Returns the number of enodes in the `EGraph`. @@ -163,7 +157,7 @@ impl> EGraph { /// assert_eq!(egraph.number_of_classes(), 1); /// ``` pub fn total_size(&self) -> usize { - self.memo.len() + self.inner.total_size() } /// Iterates over the classes, returning the total number of nodes. @@ -173,7 +167,7 @@ impl> EGraph { /// Returns the number of eclasses in the egraph. pub fn number_of_classes(&self) -> usize { - self.classes.len() + self.classes().len() } /// Enable explanations for this `EGraph`. @@ -186,7 +180,11 @@ impl> EGraph { if self.total_size() > 0 { panic!("Need to set explanations enabled before adding any expressions to the egraph."); } - self.explain = Some(Explain::new()); + let mut explain = Explain::new(); + if self.inner.has_undo_log() { + explain.enable_undo_log() + } + self.explain = Some(explain); self } @@ -212,14 +210,38 @@ impl> EGraph { } } + /// Enable [`push`](EGraph::push) and [`pop`](EGraph::pop) for this `EGraph`. + /// This allows the egraph to revert to an earlier state + pub fn with_push_pop_enabled(mut self) -> Self { + if self.inner.has_undo_log() { + return self; + } + self.inner.set_undo_log(Some(UndoLog::default())); + if let Some(explain) = &mut self.explain { + explain.enable_undo_log() + } + self + } + + /// Disable [`push`](EGraph::push) and [`pop`](EGraph::pop) for this `EGraph`. + pub fn with_push_pop_disabled(mut self) -> Self { + self.inner.set_undo_log(None); + if let Some(explain) = &mut self.explain { + explain.disable_undo_log() + } + self + } + /// Make a copy of the egraph with the same nodes, but no unions between them. pub fn copy_without_unions(&self, analysis: N) -> Self { - if let Some(explain) = &self.explain { - let egraph = Self::new(analysis); - explain.populate_enodes(egraph) - } else { + if self.explain.is_none() { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); } + let mut egraph = Self::new(analysis); + for (_, node) in self.uncanonical_nodes() { + egraph.add(node.clone()); + } + egraph } /// Performs the union between two egraphs. @@ -310,8 +332,8 @@ impl> EGraph { product_map: &mut HashMap<(Id, Id), Id>, ) { let res_id = Self::get_product_id(class1, class2, product_map); - for node1 in &self.classes[&class1].nodes { - for node2 in &other.classes[&class2].nodes { + for node1 in &self[class1].nodes { + for node2 in &other[class2].nodes { if node1.matches(node2) { let children1 = node1.children(); let children2 = node2.children(); @@ -333,38 +355,41 @@ impl> EGraph { } } - /// Pick a representative term for a given Id. - /// - /// Calling this function on an uncanonical `Id` returns a representative based on the how it - /// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical), - /// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical)) - pub fn id_to_expr(&self, id: Id) -> RecExpr { - if let Some(explain) = &self.explain { - explain.node_to_recexpr(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } - } - - /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep - pub fn id_to_node(&self, id: Id) -> &L { - if let Some(explain) = &self.explain { - explain.node(id) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id"); - } - } - - /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. + /// Like [`id_to_expr`](EGraphResidual::id_to_expr), but creates a pattern instead of a term. /// When an eclass listed in the given substitutions is found, it creates a variable. /// It also adds this variable and the corresponding Id value to the resulting [`Subst`] - /// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr). + /// Otherwise it behaves like [`id_to_expr`](EGraphResidual::id_to_expr). pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap) -> (Pattern, Subst) { - if let Some(explain) = &self.explain { - explain.node_to_pattern(id, substitutions) - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id"); + let mut res = Default::default(); + let mut subst = Default::default(); + let mut cache = Default::default(); + self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache); + (Pattern::new(res), subst) + } + + fn id_to_pattern_internal( + &self, + res: &mut PatternAst, + node_id: Id, + var_substitutions: &HashMap, + subst: &mut Subst, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; } + let res_id = if let Some(existing) = var_substitutions.get(&node_id) { + let var = format!("?{}", node_id).parse().unwrap(); + subst.insert(var, *existing); + res.add(ENodeOrVar::Var(var)) + } else { + let new_node = self.id_to_node(node_id).clone().map_children(|child| { + self.id_to_pattern_internal(res, child, var_substitutions, subst, cache) + }); + res.add(ENodeOrVar::ENode(new_node)) + }; + cache.insert(node_id, res_id); + res_id } /// Get all the unions ever found in the egraph in terms of enode ids. @@ -390,8 +415,8 @@ impl> EGraph { /// Get the number of congruences between nodes in the egraph. /// Only available when explanations are enabled. pub fn get_num_congr(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_congr::(&self.classes, &self.unionfind) + if let Some(explain) = &mut self.explain { + explain.with_raw_egraph(&self.inner).get_num_congr() } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -399,11 +424,7 @@ impl> EGraph { /// Get the number of nodes in the egraph used for explanations. pub fn get_explanation_num_nodes(&mut self) -> usize { - if let Some(explain) = &self.explain { - explain.get_num_nodes() - } else { - panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") - } + self.number_of_uncanonical_nodes() } /// When explanations are enabled, this function @@ -423,10 +444,10 @@ impl> EGraph { self.explain_id_equivalence(left, right) } - /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraph::id_to_expr)`(left),` - /// [`id_to_expr`](EGraph::id_to_expr)`(right))` but more efficient + /// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraphResidual::id_to_expr)`(left),` + /// [`id_to_expr`](EGraphResidual::id_to_expr)`(right))` but more efficient /// - /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing + /// This function picks representatives using [`id_to_expr`](EGraphResidual::id_to_expr) so choosing /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important /// to control explanations pub fn explain_id_equivalence(&mut self, left: Id, right: Id) -> Explanation { @@ -438,7 +459,9 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain + .with_raw_egraph(&self.inner) + .explain_equivalence(left, right) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -457,11 +480,11 @@ impl> EGraph { self.explain_existance_id(id) } - /// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraph::id_to_expr)`(id))` + /// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraphResidual::id_to_expr)`(id))` /// but more efficient fn explain_existance_id(&mut self, id: Id) -> Explanation { if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_raw_egraph(&self.inner).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -475,7 +498,7 @@ impl> EGraph { ) -> Explanation { let id = self.add_instantiation_noncanonical(pattern, subst); if let Some(explain) = &mut self.explain { - explain.explain_existance(id) + explain.with_raw_egraph(&self.inner).explain_existance(id) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.") } @@ -498,58 +521,20 @@ impl> EGraph { ); } if let Some(explain) = &mut self.explain { - explain.explain_equivalence::(left, right, &mut self.unionfind, &self.classes) + explain + .with_raw_egraph(&self.inner) + .explain_equivalence(left, right) } else { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations."); } } - - /// Canonicalizes an eclass id. - /// - /// This corresponds to the `find` operation on the egraph's - /// underlying unionfind data structure. - /// - /// # Example - /// ``` - /// use egg::{*, SymbolLang as S}; - /// let mut egraph = EGraph::::default(); - /// let x = egraph.add(S::leaf("x")); - /// let y = egraph.add(S::leaf("y")); - /// assert_ne!(egraph.find(x), egraph.find(y)); - /// - /// egraph.union(x, y); - /// egraph.rebuild(); - /// assert_eq!(egraph.find(x), egraph.find(y)); - /// ``` - pub fn find(&self, id: Id) -> Id { - self.unionfind.find(id) - } - - /// This is private, but internals should use this whenever - /// possible because it does path compression. - fn find_mut(&mut self, id: Id) -> Id { - self.unionfind.find_mut(id) - } - - /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. - /// - pub fn dot(&self) -> Dot { - Dot { - egraph: self, - config: vec![], - use_anchors: true, - } - } } /// Given an `Id` using the `egraph[id]` syntax, retrieve the e-class. impl> std::ops::Index for EGraph { type Output = EClass; fn index(&self, id: Id) -> &Self::Output { - let id = self.find(id); - self.classes - .get(&id) - .unwrap_or_else(|| panic!("Invalid id {}", id)) + self.inner.get_class(id) } } @@ -557,10 +542,7 @@ impl> std::ops::Index for EGraph { /// reference to the e-class. impl> std::ops::IndexMut for EGraph { fn index_mut(&mut self, id: Id) -> &mut Self::Output { - let id = self.find_mut(id); - self.classes - .get_mut(&id) - .unwrap_or_else(|| panic!("Invalid id {}", id)) + self.inner.get_class_mut(id).0 } } @@ -586,16 +568,16 @@ impl> EGraph { /// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical /// - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` + /// Calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { let nodes = expr.as_ref(); let mut new_ids = Vec::with_capacity(nodes.len()); let mut new_node_q = Vec::with_capacity(nodes.len()); for node in nodes { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let size_before = self.unionfind.size(); + let size_before = self.inner.number_of_uncanonical_nodes(); let next_id = self.add_uncanonical(new_node); - if self.unionfind.size() > size_before { + if self.inner.number_of_uncanonical_nodes() > size_before { new_node_q.push(true); } else { new_node_q.push(false); @@ -624,7 +606,7 @@ impl> EGraph { /// canonical /// /// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling - /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return an corrispond to the + /// Calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` return an correspond to the /// instantiation of the pattern fn add_instantiation_noncanonical(&mut self, pat: &PatternAst, subst: &Subst) -> Id { let nodes = pat.as_ref(); @@ -639,9 +621,9 @@ impl> EGraph { } ENodeOrVar::ENode(node) => { let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let size_before = self.unionfind.size(); + let size_before = self.inner.number_of_uncanonical_nodes(); let next_id = self.add_uncanonical(new_node); - if self.unionfind.size() > size_before { + if self.inner.number_of_uncanonical_nodes() > size_before { new_node_q.push(true); } else { new_node_q.push(false); @@ -661,67 +643,6 @@ impl> EGraph { *new_ids.last().unwrap() } - /// Lookup the eclass of the given enode. - /// - /// You can pass in either an owned enode or a `&mut` enode, - /// in which case the enode's children will be canonicalized. - /// - /// # Example - /// ``` - /// # use egg::*; - /// let mut egraph: EGraph = Default::default(); - /// let a = egraph.add(SymbolLang::leaf("a")); - /// let b = egraph.add(SymbolLang::leaf("b")); - /// - /// // lookup will find this node if its in the egraph - /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]); - /// assert_eq!(egraph.lookup(node_f_ab.clone()), None); - /// let id = egraph.add(node_f_ab.clone()); - /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id)); - /// - /// // if the query node isn't canonical, and its passed in by &mut instead of owned, - /// // its children will be canonicalized - /// egraph.union(a, b); - /// egraph.rebuild(); - /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id)); - /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a])); - /// ``` - pub fn lookup(&self, enode: B) -> Option - where - B: BorrowMut, - { - self.lookup_internal(enode).map(|id| self.find(id)) - } - - fn lookup_internal(&self, mut enode: B) -> Option - where - B: BorrowMut, - { - let enode = enode.borrow_mut(); - enode.update_children(|id| self.find(id)); - self.memo.get(enode).copied() - } - - /// Lookup the eclass of the given [`RecExpr`]. - /// - /// Equivalent to the last value in [`EGraph::lookup_expr_ids`]. - pub fn lookup_expr(&self, expr: &RecExpr) -> Option { - self.lookup_expr_ids(expr) - .and_then(|ids| ids.last().copied()) - } - - /// Lookup the eclasses of all the nodes in the given [`RecExpr`]. - pub fn lookup_expr_ids(&self, expr: &RecExpr) -> Option> { - let nodes = expr.as_ref(); - let mut new_ids = Vec::with_capacity(nodes.len()); - for node in nodes { - let node = node.clone().map_children(|i| new_ids[usize::from(i)]); - let id = self.lookup(node)?; - new_ids.push(id) - } - Some(new_ids) - } - /// Adds an enode to the [`EGraph`]. /// /// When adding an enode, to the egraph, [`add`] it performs @@ -741,10 +662,10 @@ impl> EGraph { /// Similar to [`add`](EGraph::add) but the `Id` returned may not be canonical /// - /// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will + /// When explanations are enabled calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` will /// correspond to the parameter `enode` /// - /// # Example + /// ## Example /// ``` /// # use egg::*; /// let mut egraph: EGraph = EGraph::default().with_explanations_enabled(); @@ -759,60 +680,63 @@ impl> EGraph { /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); /// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap()); /// ``` - pub fn add_uncanonical(&mut self, mut enode: L) -> Id { - let original = enode.clone(); - if let Some(existing_id) = self.lookup_internal(&mut enode) { - let id = self.find(existing_id); - // when explanations are enabled, we need a new representative for this expr - if let Some(explain) = self.explain.as_mut() { - if let Some(existing_explain) = explain.uncanon_memo.get(&original) { - *existing_explain + /// + /// When explanations are not enabled calling [`id_to_expr`](EGraphResidual::id_to_expr) on this `Id` will + /// produce an expression with equivalent but not necessarily identical children + /// + /// # Example + /// ``` + /// # use egg::*; + /// let mut egraph: EGraph = EGraph::default().with_explanations_disabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.union(a, b); + /// egraph.rebuild(); + /// + /// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + /// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + /// + /// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap()); + /// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap()); + /// ``` + pub fn add_uncanonical(&mut self, enode: L) -> Id { + let mut added = false; + let id = RawEGraph::raw_add( + self, + |x| &mut x.inner, + enode, + |this, existing_id, enode| { + if let Some(explain) = this.explain.as_mut() { + explain.uncanon_memo.get(enode).copied() } else { - let new_id = self.unionfind.make_set(); - explain.add(original, new_id, new_id); - self.unionfind.union(id, new_id); + Some(existing_id) + } + }, + |this, existing_id, new_id| { + if let Some(explain) = this.explain.as_mut() { + explain.add(this.inner.id_to_node(new_id).clone(), new_id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); - new_id } - } else { - existing_id - } - } else { - let id = self.make_new_eclass(enode); + }, + |this, id, _| { + added = true; + let node = this.id_to_node(id).clone(); + let data = N::make(this, &node); + EClassData { + nodes: vec![node], + data, + } + }, + ); + if added { if let Some(explain) = self.explain.as_mut() { - explain.add(original, id, id); + explain.add(self.inner.id_to_node(id).clone(), id, id); } // now that we updated explanations, run the analysis for the new eclass N::modify(self, id); self.clean = false; - id } - } - - /// This function makes a new eclass in the egraph (but doesn't touch explanations) - fn make_new_eclass(&mut self, enode: L) -> Id { - let id = self.unionfind.make_set(); - log::trace!(" ...adding to {}", id); - let class = EClass { - id, - nodes: vec![enode.clone()], - data: N::make(self, &enode), - parents: Default::default(), - }; - - // add this enode to the parent lists of its children - enode.for_each(|child| { - let tup = (enode.clone(), id); - self[child].parents.push(tup); - }); - - // TODO is this needed? - self.pending.push((enode.clone(), id)); - - self.classes.insert(id, class); - assert!(self.memo.insert(enode, id).is_none()); - id } @@ -858,9 +782,9 @@ impl> EGraph { rule_name: impl Into, ) -> (Id, bool) { let id1 = self.add_instantiation_noncanonical(from_pat, subst); - let size_before = self.unionfind.size(); + let size_before = self.number_of_uncanonical_nodes(); let id2 = self.add_instantiation_noncanonical(to_pat, subst); - let rhs_new = self.unionfind.size() > size_before; + let rhs_new = self.number_of_uncanonical_nodes() > size_before; let did_union = self.perform_union( id1, @@ -873,7 +797,7 @@ impl> EGraph { /// Unions two e-classes, using a given reason to justify it. /// - /// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing + /// This function picks representatives using [`id_to_expr`](EGraphResidual::id_to_expr) so choosing /// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important /// to control explanations pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into) -> bool { @@ -914,49 +838,42 @@ impl> EGraph { N::pre_union(self, enode_id1, enode_id2, &rule); self.clean = false; - let mut id1 = self.find_mut(enode_id1); - let mut id2 = self.find_mut(enode_id2); - if id1 == id2 { - if let Some(Justification::Rule(_)) = rule { - if let Some(explain) = &mut self.explain { - explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap()); - } + let mut new_root = None; + let has_undo_log = self.inner.has_undo_log(); + self.inner.raw_union(enode_id1, enode_id2, |info| { + new_root = Some(info.id1); + if has_undo_log && mem::size_of::() > 0 { + self.data_history.push((info.id1, info.data1.data.clone())); + self.data_history.push((info.id2, info.data2.data.clone())); } - return false; - } - // make sure class2 has fewer parents - let class1_parents = self.classes[&id1].parents.len(); - let class2_parents = self.classes[&id2].parents.len(); - if class1_parents < class2_parents { - std::mem::swap(&mut id1, &mut id2); - } - if let Some(explain) = &mut self.explain { - explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); - } - - // make id1 the new root - self.unionfind.union(id1, id2); + let did_merge = self.analysis.merge(&mut info.data1.data, info.data2.data); + if did_merge.0 { + self.analysis_pending + .extend(info.parents1.into_iter().copied()); + } + if did_merge.1 { + self.analysis_pending + .extend(info.parents2.into_iter().copied()); + } - assert_ne!(id1, id2); - let class2 = self.classes.remove(&id2).unwrap(); - let class1 = self.classes.get_mut(&id1).unwrap(); - assert_eq!(id1, class1.id); + concat_vecs(&mut info.data1.nodes, info.data2.nodes); + }); + if let Some(id) = new_root { + if let Some(explain) = &mut self.explain { + explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); + } + N::modify(self, id); - self.pending.extend(class2.parents.iter().cloned()); - let did_merge = self.analysis.merge(&mut class1.data, class2.data); - if did_merge.0 { - self.analysis_pending.extend(class1.parents.iter().cloned()); - } - if did_merge.1 { - self.analysis_pending.extend(class2.parents.iter().cloned()); + true + } else { + if let Some(Justification::Rule(_)) = rule { + if let Some(explain) = &mut self.explain { + explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap()) + } + } + false } - - concat_vecs(&mut class1.nodes, class2.nodes); - concat_vecs(&mut class1.parents, class2.parents); - - N::modify(self, id1); - true } /// Update the analysis data of an e-class. @@ -965,10 +882,13 @@ impl> EGraph { /// so [`Analysis::make`] and [`Analysis::merge`] will get /// called for other parts of the e-graph on rebuild. pub fn set_analysis_data(&mut self, id: Id, new_data: N::Data) { - let id = self.find_mut(id); - let class = self.classes.get_mut(&id).unwrap(); - class.data = new_data; - self.analysis_pending.extend(class.parents.iter().cloned()); + let mut canon = id; + let class = self.inner.get_class_mut(&mut canon).0; + let old_data = mem::replace(&mut class.data, new_data); + self.analysis_pending.extend(class.parents()); + if self.inner.has_undo_log() && mem::size_of::() > 0 { + self.data_history.push((canon, old_data)) + } N::modify(self, id) } @@ -981,7 +901,7 @@ impl> EGraph { /// /// [`Debug`]: std::fmt::Debug pub fn dump(&self) -> impl Debug + '_ { - EGraphDump(self) + self.inner.dump_classes() } } @@ -1020,9 +940,9 @@ impl> EGraph { classes_by_op.values_mut().for_each(|ids| ids.clear()); let mut trimmed = 0; - let uf = &mut self.unionfind; + let (classes, uf) = self.inner.classes_mut(); - for class in self.classes.values_mut() { + for class in classes { let old_len = class.len(); class .nodes @@ -1068,8 +988,8 @@ impl> EGraph { fn check_memo(&self) -> bool { let mut test_memo = HashMap::default(); - for (&id, class) in self.classes.iter() { - assert_eq!(class.id, id); + for class in self.classes() { + let id = class.id; for node in &class.nodes { if let Some(old) = test_memo.insert(node, id) { assert_eq!( @@ -1088,7 +1008,7 @@ impl> EGraph { assert_eq!(e, self.find(e)); assert_eq!( Some(e), - self.memo.get(n).map(|id| self.find(*id)), + self.lookup(n.clone()), "Entry for {:?} at {} in test_memo was incorrect", n, e @@ -1102,34 +1022,35 @@ impl> EGraph { fn process_unions(&mut self) -> usize { let mut n_unions = 0; - while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some((mut node, class)) = self.pending.pop() { - node.update_children(|id| self.find_mut(id)); - if let Some(memo_class) = self.memo.insert(node, class) { - let did_something = self.perform_union( - memo_class, - class, - Some(Justification::Congruence), - false, - ); + while !self.inner.is_clean() || !self.analysis_pending.is_empty() { + RawEGraph::raw_rebuild( + self, + |this| &mut this.inner, + |this, id1, id2| { + let did_something = + this.perform_union(id1, id2, Some(Justification::Congruence), false); n_unions += did_something as usize; - } - } + }, + |_, _, _| {}, + ); - while let Some((node, class_id)) = self.analysis_pending.pop() { - let class_id = self.find_mut(class_id); + while let Some(mut class_id) = self.analysis_pending.pop() { + let node = self.id_to_node(class_id).clone(); let node_data = N::make(self, &node); - let class = self.classes.get_mut(&class_id).unwrap(); - + let has_undo_log = self.inner.has_undo_log(); + let class = self.inner.get_class_mut(&mut class_id).0; + if has_undo_log && mem::size_of::() > 0 { + self.data_history.push((class.id, class.data.clone())); + } let did_merge = self.analysis.merge(&mut class.data, node_data); if did_merge.0 { - self.analysis_pending.extend(class.parents.iter().cloned()); + self.analysis_pending.extend(class.parents()); N::modify(self, class_id) } } } - assert!(self.pending.is_empty()); + assert!(self.inner.is_clean()); assert!(self.analysis_pending.is_empty()); n_unions @@ -1173,7 +1094,7 @@ impl> EGraph { /// assert_eq!(egraph.find(ax), egraph.find(ay)); /// ``` pub fn rebuild(&mut self) -> usize { - let old_hc_size = self.memo.len(); + let old_hc_size = self.total_size(); let old_n_eclasses = self.number_of_classes(); let start = Instant::now(); @@ -1193,7 +1114,7 @@ impl> EGraph { elapsed.subsec_millis(), old_hc_size, old_n_eclasses, - self.memo.len(), + self.total_size(), self.number_of_classes(), n_unions, trimmed_nodes, @@ -1204,27 +1125,184 @@ impl> EGraph { n_unions } - pub(crate) fn check_each_explain(&self, rules: &[&Rewrite]) -> bool { - if let Some(explain) = &self.explain { - explain.check_each_explain(rules) + pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite]) -> bool { + if let Some(explain) = &mut self.explain { + explain + .with_raw_egraph(&self.inner) + .check_each_explain(rules) } else { panic!("Can't check explain when explanations are off"); } } + + /// Remove all nodes from this egraph + pub fn clear(&mut self) { + self.push_log.clear(); + self.inner.clear(); + self.clean = true; + if let Some(explain) = &mut self.explain { + explain.clear() + } + self.analysis_pending.clear(); + self.data_history.clear(); + } } -struct EGraphDump<'a, L: Language, N: Analysis>(&'a EGraph); +impl> EGraph +where + N::Data: Default, +{ + /// Push the current egraph off the stack + /// Requires that the egraph is clean + /// + /// See [`EGraph::pop`] + pub fn push(&mut self) { + assert!( + self.analysis_pending.is_empty() && self.inner.is_clean(), + "`push` can only be called on clean egraphs" + ); + if !self.inner.has_undo_log() { + panic!("Use egraph.with_push_pop_enabled() before running to call push"); + } + N::pre_push(self); + let exp_push_info = self.explain.as_ref().map(Explain::push).unwrap_or_default(); + #[cfg(feature = "push-pop-alt")] + let raw_push_info = self.inner.push1(); + #[cfg(not(feature = "push-pop-alt"))] + let raw_push_info = self.inner.push2(); + self.push_log + .push((raw_push_info, exp_push_info, self.data_history.len())) + } + + /// Pop the current egraph off the stack, replacing + /// it with the previously [`push`](EGraph::push)ed egraph + /// + /// ``` + /// use egg::{EGraph, SymbolLang}; + /// let mut egraph = EGraph::new(()).with_push_pop_enabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.rebuild(); + /// egraph.push(); + /// egraph.union(a, b); + /// assert_eq!(egraph.find(a), egraph.find(b)); + /// egraph.pop(); + /// assert_ne!(egraph.find(a), egraph.find(b)); + /// ``` + pub fn pop(&mut self) { + self.pop_n(1) + } -impl<'a, L: Language, N: Analysis> Debug for EGraphDump<'a, L, N> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); - ids.sort(); - for id in ids { - let mut nodes = self.0[id].nodes.clone(); - nodes.sort(); - writeln!(f, "{} ({:?}): {:?}", id, self.0[id].data, nodes)? + /// Equivalent to calling [`pop`](EGraph::pop) `n` times but possibly more efficient + pub fn pop_n(&mut self, n: usize) { + if !self.inner.has_undo_log() { + panic!("Use egraph.with_push_pop_enabled() before running to call pop"); + } + if n > self.push_log.len() { + self.clear() + } + let mut info = None; + for _ in 0..n { + info = self.push_log.pop() + } + if let Some(info) = info { + self.pop_internal(info); + N::post_pop_n(self, n); + } + } + + #[cfg(not(feature = "push-pop-alt"))] + fn pop_internal(&mut self, (raw_info, exp_info, data_history_len): PushInfo) { + if let Some(explain) = &mut self.explain { + explain.pop( + exp_info, + raw_info.number_of_uncanonical_nodes(), + &self.inner, + ) + } + self.analysis_pending.clear(); + + let mut has_dirty_parents = Vec::new(); + let mut dirty_status = HashMap::default(); + self.inner.raw_pop2( + raw_info, + &mut dirty_status, + |dirty_status, data, id, _| { + dirty_status.insert(id, false); + data.nodes.clear(); + }, + |dirty_status, id, _| { + has_dirty_parents.push(id); + dirty_status.insert(id, false); + EClassData { + nodes: vec![], + data: Default::default(), + } + }, + |_, data, id, ctx| data.nodes.push(ctx.id_to_node(id).clone()), + ); + for id in has_dirty_parents { + for parent in self.inner.get_class_with_cannon(id).parents() { + dirty_status.entry(self.find(parent)).or_insert(true); + } + } + for (id, needs_reset) in dirty_status { + if needs_reset { + let mut nodes = mem::take(&mut self.inner.get_class_mut_with_cannon(id).0.nodes); + nodes.clear(); + self.inner + .undo_ctx() + .equivalent_nodes(id, |eqv| nodes.push(self.id_to_node(eqv).clone())); + self.inner.get_class_mut_with_cannon(id).0.nodes = nodes; + } + let (class, residual) = self.inner.get_class_mut_with_cannon(id); + for node in &mut class.nodes { + node.update_children(|id| residual.find(id)); + } + class.nodes.sort_unstable(); + class.nodes.dedup(); + } + + for (id, data) in self.data_history.drain(data_history_len..).rev() { + if usize::from(id) < self.inner.number_of_uncanonical_nodes() { + self.inner.get_class_mut_with_cannon(id).0.data = data; + } + } + + self.clean = true; + } + + #[cfg(feature = "push-pop-alt")] + fn pop_internal(&mut self, (raw_info, exp_info, data_history_len): PushInfo) { + if let Some(explain) = &mut self.explain { + explain.pop( + exp_info, + raw_info.number_of_uncanonical_nodes(), + &self.inner, + ) + } + self.analysis_pending.clear(); + + self.inner.raw_pop1(raw_info, |_, _, _| EClassData { + nodes: vec![], + data: Default::default(), + }); + + for class in self.classes_mut() { + class.nodes.clear() + } + + for id in self.uncanonical_ids() { + let node = self.id_to_node(id).clone().map_children(|x| self.find(x)); + self[id].nodes.push(node) + } + + for (id, data) in self.data_history.drain(data_history_len..).rev() { + if usize::from(id) < self.inner.number_of_uncanonical_nodes() { + self.inner.get_class_mut_with_cannon(id).0.data = data; + } } - Ok(()) + self.rebuild_classes(); } } @@ -1265,6 +1343,9 @@ mod tests { de(&egraph); let json_rep = serde_json::to_string_pretty(&egraph).unwrap(); + let egraph2: EGraph = serde_json::from_str(&json_rep).unwrap(); + let json_rep2 = serde_json::to_string_pretty(&egraph2).unwrap(); + assert_eq!(json_rep, json_rep2); println!("{}", json_rep); } } diff --git a/src/explain.rs b/src/explain.rs index 187aecfc..76ae5e4e 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1,14 +1,20 @@ +mod semi_persistent; +pub(crate) use semi_persistent::PushInfo; + use crate::Symbol; use crate::{ - util::pretty_print, Analysis, EClass, EGraph, ENodeOrVar, FromOp, HashMap, HashSet, Id, - Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, UnionFind, Var, + util::pretty_print, Analysis, ENodeOrVar, FromOp, HashMap, HashSet, Id, Language, PatternAst, + RecExpr, Rewrite, UnionFind, Var, }; use saturating::Saturating; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; +use std::mem; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; +use crate::raw::RawEGraph; use symbolic_expressions::Sexp; type ProofCost = Saturating; @@ -36,10 +42,21 @@ struct Connection { is_rewrite_forward: bool, } +impl Connection { + #[inline] + fn end(node: Id) -> Self { + Connection { + next: node, + current: node, + justification: Justification::Congruence, + is_rewrite_forward: false, + } + } +} + #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct ExplainNode { - node: L, +struct ExplainNode { // neighbors includes parent connections neighbors: Vec, parent_connection: Connection, @@ -54,8 +71,15 @@ struct ExplainNode { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct Explain { - explainfind: Vec>, + explainfind: Vec, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + #[cfg_attr( + feature = "serde-1", + serde(bound( + serialize = "L: serde::Serialize", + deserialize = "L: serde::Deserialize<'de>", + )) + )] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. pub optimize_explanation_lengths: bool, @@ -67,6 +91,12 @@ pub struct Explain { // That is, less than or equal to the result of `distance_between` #[cfg_attr(feature = "serde-1", serde(skip))] shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, + undo_log: semi_persistent::UndoLog, +} + +pub(crate) struct ExplainWith<'a, L: Language, X> { + explain: &'a mut Explain, + raw: X, } #[derive(Default)] @@ -883,97 +913,6 @@ impl PartialOrd for HeapState { } impl Explain { - pub(crate) fn node(&self, node_id: Id) -> &L { - &self.explainfind[usize::from(node_id)].node - } - fn node_to_explanation( - &self, - node_id: Id, - cache: &mut NodeExplanationCache, - ) -> Rc> { - if let Some(existing) = cache.get(&node_id) { - existing.clone() - } else { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(vec![self.node_to_explanation(child, cache)]); - sofar - }); - let res = Rc::new(TreeTerm::new(node, children)); - cache.insert(node_id, res.clone()); - res - } - } - - pub(crate) fn node_to_recexpr(&self, node_id: Id) -> RecExpr { - let mut res = Default::default(); - let mut cache = Default::default(); - self.node_to_recexpr_internal(&mut res, node_id, &mut cache); - res - } - fn node_to_recexpr_internal( - &self, - res: &mut RecExpr, - node_id: Id, - cache: &mut HashMap, - ) { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_recexpr_internal(res, child, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(new_node); - } - - pub(crate) fn node_to_pattern( - &self, - node_id: Id, - substitutions: &HashMap, - ) -> (Pattern, Subst) { - let mut res = Default::default(); - let mut subst = Default::default(); - let mut cache = Default::default(); - self.node_to_pattern_internal(&mut res, node_id, substitutions, &mut subst, &mut cache); - (Pattern::new(res), subst) - } - - fn node_to_pattern_internal( - &self, - res: &mut PatternAst, - node_id: Id, - var_substitutions: &HashMap, - subst: &mut Subst, - cache: &mut HashMap, - ) { - if let Some(existing) = var_substitutions.get(&node_id) { - let var = format!("?{}", node_id).parse().unwrap(); - res.add(ENodeOrVar::Var(var)); - subst.insert(var, *existing); - } else { - let new_node = self.node(node_id).clone().map_children(|child| { - if let Some(existing) = cache.get(&child) { - *existing - } else { - self.node_to_pattern_internal(res, child, var_substitutions, subst, cache); - Id::from(res.as_ref().len() - 1) - } - }); - res.add(ENodeOrVar::ENode(new_node)); - } - } - - fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { - let node = self.node(node_id).clone(); - let children = node.fold(vec![], |mut sofar, child| { - sofar.push(self.node_to_flat_explanation(child)); - sofar - }); - FlatTerm::new(node, children) - } - fn make_rule_table<'a, N: Analysis>( rules: &[&'a Rewrite], ) -> HashMap> { @@ -983,58 +922,13 @@ impl Explain { } table } - - pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { - let rule_table = Explain::make_rule_table(rules); - for i in 0..self.explainfind.len() { - let explain_node = &self.explainfind[i]; - - // check that explanation reasons never form a cycle - let mut existance = i; - let mut seen_existance: HashSet = Default::default(); - loop { - seen_existance.insert(existance); - let next = usize::from(self.explainfind[existance].existance_node); - if existance == next { - break; - } - existance = next; - if seen_existance.contains(&existance) { - panic!("Cycle in existance!"); - } - } - - if explain_node.parent_connection.next != Id::from(i) { - let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); - let mut next_explanation = - self.node_to_flat_explanation(explain_node.parent_connection.next); - if let Justification::Rule(rule_name) = - &explain_node.parent_connection.justification - { - if let Some(rule) = rule_table.get(rule_name) { - if !explain_node.parent_connection.is_rewrite_forward { - std::mem::swap(&mut current_explanation, &mut next_explanation); - } - if !Explanation::check_rewrite( - ¤t_explanation, - &next_explanation, - rule, - ) { - return false; - } - } - } - } - } - true - } - pub fn new() -> Self { Explain { explainfind: vec![], uncanon_memo: Default::default(), shortest_explanation_memo: Default::default(), optimize_explanation_lengths: true, + undo_log: None, } } @@ -1044,34 +938,44 @@ impl Explain { pub(crate) fn add(&mut self, node: L, set: Id, existance_node: Id) -> Id { assert_eq!(self.explainfind.len(), usize::from(set)); - self.uncanon_memo.insert(node.clone(), set); + self.uncanon_memo.entry(node).or_insert(set); + // If the node already in uncanon memo keep the old version so it's easier to revert the add self.explainfind.push(ExplainNode { - node, neighbors: vec![], - parent_connection: Connection { - justification: Justification::Congruence, - is_rewrite_forward: false, - next: set, - current: set, - }, + parent_connection: Connection::end(set), existance_node, }); set } + /// Reorient connections to make `node` the leader (Used for testing push/pop) + pub(crate) fn test_mk_root(&mut self, node: Id) { + self.set_parent(node, Connection::end(node)) + } + // reverse edges recursively to make this node the leader - fn make_leader(&mut self, node: Id) { - let next = self.explainfind[usize::from(node)].parent_connection.next; - if next != node { - self.make_leader(next); - let node_connection = &self.explainfind[usize::from(node)].parent_connection; + fn set_parent(&mut self, node: Id, parent: Connection) { + let mut prev = node; + let mut curr = mem::replace( + &mut self.explainfind[usize::from(prev)].parent_connection, + parent, + ); + let mut count = 0; + while prev != curr.next { let pconnection = Connection { - justification: node_connection.justification.clone(), - is_rewrite_forward: !node_connection.is_rewrite_forward, - next: node, - current: next, + justification: curr.justification, + is_rewrite_forward: !curr.is_rewrite_forward, + next: prev, + current: curr.next, }; - self.explainfind[usize::from(next)].parent_connection = pconnection; + let next = mem::replace( + &mut self.explainfind[usize::from(curr.next)].parent_connection, + pconnection, + ); + prev = curr.next; + curr = next; + count += 1; + assert!(count < 1000); } } @@ -1109,6 +1013,7 @@ impl Explain { .insert((node1, node2), (Saturating(1), node2)); self.shortest_explanation_memo .insert((node2, node1), (Saturating(1), node1)); + self.undo_log_union(node1); } pub(crate) fn union( @@ -1119,15 +1024,12 @@ impl Explain { new_rhs: bool, ) { if let Justification::Congruence = justification { - assert!(self.node(node1).matches(self.node(node2))); + // assert!(self.node(node1).matches(self.node(node2))); } if new_rhs { self.set_existance_reason(node2, node1) } - self.make_leader(node1); - self.explainfind[usize::from(node1)].parent_connection.next = node2; - if let Justification::Rule(_) = justification { self.shortest_explanation_memo .insert((node1, node2), (Saturating(1), node2)); @@ -1153,9 +1055,11 @@ impl Explain { self.explainfind[usize::from(node2)] .neighbors .push(other_pconnection); - self.explainfind[usize::from(node1)].parent_connection = pconnection; - } + self.set_parent(node1, pconnection); + + self.undo_log_union(node1); + } pub(crate) fn get_union_equalities(&self) -> UnionEqualities { let mut equalities = vec![]; for node in &self.explainfind { @@ -1170,24 +1074,105 @@ impl Explain { equalities } - pub(crate) fn populate_enodes>(&self, mut egraph: EGraph) -> EGraph { - for i in 0..self.explainfind.len() { - let node = &self.explainfind[i]; - egraph.add(node.node.clone()); + pub(crate) fn with_raw_egraph(&mut self, raw: X) -> ExplainWith<'_, L, X> { + ExplainWith { explain: self, raw } + } +} + +impl<'a, L: Language, X> Deref for ExplainWith<'a, L, X> { + type Target = Explain; + + fn deref(&self) -> &Self::Target { + self.explain + } +} + +impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.explain + } +} + +impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph> { + pub(crate) fn node(&self, node_id: Id) -> &L { + self.raw.id_to_node(node_id) + } + fn node_to_explanation( + &self, + node_id: Id, + cache: &mut NodeExplanationCache, + ) -> Rc> { + if let Some(existing) = cache.get(&node_id) { + existing.clone() + } else { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(vec![self.node_to_explanation(child, cache)]); + sofar + }); + let res = Rc::new(TreeTerm::new(node, children)); + cache.insert(node_id, res.clone()); + res } + } - egraph + fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm { + let node = self.node(node_id).clone(); + let children = node.fold(vec![], |mut sofar, child| { + sofar.push(self.node_to_flat_explanation(child)); + sofar + }); + FlatTerm::new(node, children) } - pub(crate) fn explain_equivalence>( - &mut self, - left: Id, - right: Id, - unionfind: &mut UnionFind, - classes: &HashMap>, - ) -> Explanation { + pub fn check_each_explain>(&self, rules: &[&Rewrite]) -> bool { + let rule_table = Explain::make_rule_table(rules); + for i in 0..self.explainfind.len() { + let explain_node = &self.explainfind[i]; + + // check that explanation reasons never form a cycle + let mut existance = i; + let mut seen_existance: HashSet = Default::default(); + loop { + seen_existance.insert(existance); + let next = usize::from(self.explainfind[existance].existance_node); + if existance == next { + break; + } + existance = next; + if seen_existance.contains(&existance) { + panic!("Cycle in existance!"); + } + } + + if explain_node.parent_connection.next != Id::from(i) { + let mut current_explanation = self.node_to_flat_explanation(Id::from(i)); + let mut next_explanation = + self.node_to_flat_explanation(explain_node.parent_connection.next); + if let Justification::Rule(rule_name) = + &explain_node.parent_connection.justification + { + if let Some(rule) = rule_table.get(rule_name) { + if !explain_node.parent_connection.is_rewrite_forward { + std::mem::swap(&mut current_explanation, &mut next_explanation); + } + if !Explanation::check_rewrite( + ¤t_explanation, + &next_explanation, + rule, + ) { + return false; + } + } + } + } + } + true + } + + pub(crate) fn explain_equivalence(&mut self, left: Id, right: Id) -> Explanation { if self.optimize_explanation_lengths { - self.calculate_shortest_explanations::(left, right, classes, unionfind); + self.calculate_shortest_explanations(left, right); } let mut cache = Default::default(); @@ -1328,7 +1313,7 @@ impl Explain { let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone(); let mut index_of_child = 0; let mut found = false; - existance_node.node.for_each(|child| { + self.node(existance).for_each(|child| { if found { return; } @@ -1625,12 +1610,7 @@ impl Explain { distance_memo.parent_distance[usize::from(enode)].1 } - fn find_congruence_neighbors>( - &self, - classes: &HashMap>, - congruence_neighbors: &mut [Vec], - unionfind: &UnionFind, - ) { + fn find_congruence_neighbors(&self, congruence_neighbors: &mut [Vec]) { let mut counter = 0; // add the normal congruence edges first for node in &self.explainfind { @@ -1643,15 +1623,15 @@ impl Explain { } } - 'outer: for eclass in classes.keys() { - let enodes = self.find_all_enodes(*eclass); + 'outer: for eclass in self.raw.classes().map(|x| x.id) { + let enodes = self.find_all_enodes(eclass); // find all congruence nodes let mut cannon_enodes: HashMap> = Default::default(); for enode in &enodes { let cannon = self .node(*enode) .clone() - .map_children(|child| unionfind.find(child)); + .map_children(|child| self.raw.find(child)); if let Some(others) = cannon_enodes.get_mut(&cannon) { for other in others.iter() { congruence_neighbors[usize::from(*enode)].push(*other); @@ -1671,13 +1651,9 @@ impl Explain { } } - pub fn get_num_congr>( - &self, - classes: &HashMap>, - unionfind: &UnionFind, - ) -> usize { + pub fn get_num_congr(&self) -> usize { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; - self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); + self.find_congruence_neighbors(&mut congruence_neighbors); let mut count = 0; for v in congruence_neighbors { count += v.len(); @@ -1686,10 +1662,6 @@ impl Explain { count / 2 } - pub fn get_num_nodes(&self) -> usize { - self.explainfind.len() - } - fn shortest_path_modulo_congruence( &mut self, start: Id, @@ -1888,11 +1860,7 @@ impl Explain { self.explainfind[usize::from(enode)].parent_connection.next } - fn calculate_common_ancestor>( - &self, - classes: &HashMap>, - congruence_neighbors: &[Vec], - ) -> HashMap<(Id, Id), Id> { + fn calculate_common_ancestor(&self, congruence_neighbors: &[Vec]) -> HashMap<(Id, Id), Id> { let mut common_ancestor_queries = HashMap::default(); for (s_int, others) in congruence_neighbors.iter().enumerate() { let start = &Id::from(s_int); @@ -1924,8 +1892,8 @@ impl Explain { unionfind.make_set(); ancestor.push(Id::from(i)); } - for (eclass, _) in classes.iter() { - let enodes = self.find_all_enodes(*eclass); + for eclass in self.raw.classes().map(|x| x.id) { + let enodes = self.find_all_enodes(eclass); let mut children: HashMap> = HashMap::default(); for enode in &enodes { children.insert(*enode, vec![]); @@ -1956,15 +1924,9 @@ impl Explain { common_ancestor } - fn calculate_shortest_explanations>( - &mut self, - start: Id, - end: Id, - classes: &HashMap>, - unionfind: &UnionFind, - ) { + fn calculate_shortest_explanations(&mut self, start: Id, end: Id) { let mut congruence_neighbors = vec![vec![]; self.explainfind.len()]; - self.find_congruence_neighbors::(classes, &mut congruence_neighbors, unionfind); + self.find_congruence_neighbors(&mut congruence_neighbors); let mut parent_distance = vec![(Id::from(0), Saturating(0)); self.explainfind.len()]; for (i, entry) in parent_distance.iter_mut().enumerate() { entry.0 = Id::from(i); @@ -1972,7 +1934,7 @@ impl Explain { let mut distance_memo = DistanceMemo { parent_distance, - common_ancestor: self.calculate_common_ancestor::(classes, &congruence_neighbors), + common_ancestor: self.calculate_common_ancestor(&congruence_neighbors), tree_depth: self.calculate_tree_depths(), }; @@ -2092,7 +2054,7 @@ mod tests { #[test] fn simple_explain_union_trusted() { - use crate::SymbolLang; + use crate::{EGraph, SymbolLang}; crate::init_logger(); let mut egraph = EGraph::new(()).with_explanations_enabled(); diff --git a/src/explain/semi_persistent.rs b/src/explain/semi_persistent.rs new file mode 100644 index 00000000..744e8243 --- /dev/null +++ b/src/explain/semi_persistent.rs @@ -0,0 +1,75 @@ +use crate::explain::{Connection, Explain}; +use crate::raw::EGraphResidual; +use crate::{Id, Language}; + +pub(super) type UndoLog = Option>; + +#[derive(Default, Clone, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub(crate) struct PushInfo(usize); + +impl Explain { + pub(super) fn undo_log_union(&mut self, node: Id) { + if let Some(x) = &mut self.undo_log { + x.push(node) + } + } + pub(crate) fn enable_undo_log(&mut self) { + assert_eq!(self.explainfind.len(), 0); + self.undo_log = Some(Vec::new()); + } + + pub(crate) fn disable_undo_log(&mut self) { + self.undo_log = None + } + + pub(crate) fn push(&self) -> PushInfo { + PushInfo(self.undo_log.as_ref().unwrap().len()) + } + + pub(crate) fn pop( + &mut self, + info: PushInfo, + number_of_uncanon_nodes: usize, + egraph: &EGraphResidual, + ) { + for id in self.undo_log.as_mut().unwrap().drain(info.0..).rev() { + let node1 = &mut self.explainfind[usize::from(id)]; + let id2 = node1.neighbors.pop().unwrap().next; + if node1.parent_connection.next == id2 { + node1.parent_connection = Connection::end(id); + } + let node2 = &mut self.explainfind[usize::from(id2)]; + let id1 = node2.neighbors.pop().unwrap().next; + assert_eq!(id, id1); + if node2.parent_connection.next == id1 { + node2.parent_connection = Connection::end(id2); + } + } + self.explainfind.truncate(number_of_uncanon_nodes); + // We can't easily undo memoize operations, so we just clear them + self.shortest_explanation_memo.clear(); + for (id, node) in egraph + .uncanonical_nodes() + .skip(number_of_uncanon_nodes) + .rev() + { + if *self.uncanon_memo.get(node).unwrap() == id { + self.uncanon_memo.remove(node).unwrap(); + } + } + } + + pub(crate) fn clear_memo(&mut self) { + self.shortest_explanation_memo.clear() + } + + pub(crate) fn clear(&mut self) { + if let Some(v) = &mut self.undo_log { + v.clear() + } + self.explainfind.clear(); + self.uncanon_memo.clear(); + self.shortest_explanation_memo.clear(); + } +} diff --git a/src/language.rs b/src/language.rs index 6414c63a..40072358 100644 --- a/src/language.rs +++ b/src/language.rs @@ -698,7 +698,7 @@ assert_eq!(runner.egraph.find(runner.roots[0]), runner.egraph.find(just_foo)); */ pub trait Analysis: Sized { /// The per-[`EClass`] data for this analysis. - type Data: Debug; + type Data: Debug + Clone; /// Makes a new [`Analysis`] data for a given e-node. /// @@ -761,6 +761,14 @@ pub trait Analysis: Sized { /// `Analysis::merge` when unions are performed. #[allow(unused_variables)] fn modify(egraph: &mut EGraph, id: Id) {} + + /// A hook called at the start of [`EGraph::push`] + #[allow(unused_variables)] + fn pre_push(egraph: &mut EGraph) {} + + /// A hook called at the end of [`EGraph::pop_n`] + #[allow(unused_variables)] + fn post_pop_n(egraph: &mut EGraph, n: usize) {} } impl Analysis for () { diff --git a/src/lib.rs b/src/lib.rs index 5a293a58..7f8853fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,10 +48,12 @@ mod lp_extract; mod machine; mod multipattern; mod pattern; + +/// Lower level egraph API +pub mod raw; mod rewrite; mod run; mod subst; -mod unionfind; mod util; /// A key to identify [`EClass`]es within an @@ -85,11 +87,11 @@ impl std::fmt::Display for Id { } } -pub(crate) use {explain::Explain, unionfind::UnionFind}; +pub(crate) use {explain::Explain, raw::UnionFind}; pub use { dot::Dot, - eclass::EClass, + eclass::{EClass, EClassData}, egraph::EGraph, explain::{ Explanation, FlatExplanation, FlatTerm, Justification, TreeExplanation, TreeTerm, diff --git a/src/raw.rs b/src/raw.rs new file mode 100644 index 00000000..444f2016 --- /dev/null +++ b/src/raw.rs @@ -0,0 +1,17 @@ +mod dhashmap; +mod eclass; +mod egraph; +mod semi_persistent; + +/// One variant of semi_persistence +pub mod semi_persistent1; + +/// Another variant of semi_persistence +pub mod semi_persistent2; +mod unionfind; + +pub use eclass::RawEClass; +pub use egraph::{EGraphResidual, RawEGraph, UnionInfo}; +use semi_persistent::Sealed; +pub use semi_persistent::{AsUnwrap, UndoLogT}; +pub use unionfind::UnionFind; diff --git a/src/raw/dhashmap.rs b/src/raw/dhashmap.rs new file mode 100644 index 00000000..bc2ac8eb --- /dev/null +++ b/src/raw/dhashmap.rs @@ -0,0 +1,199 @@ +use std::fmt::{Debug, Formatter}; +use std::hash::{BuildHasher, Hash}; +use std::iter; +use std::iter::FromIterator; + +use hashbrown::hash_table; + +pub(super) type DHMIdx = u32; + +/// Similar to [`HashMap`](std::collections::HashMap) but with deterministic iteration order +/// +/// Accessing individual elements has similar performance to a [`HashMap`](std::collections::HashMap) +/// (faster than an `IndexMap`), but iteration requires allocation +/// +#[derive(Clone)] +pub(super) struct DHashMap { + data: hash_table::HashTable<(K, V, DHMIdx)>, + hasher: S, +} + +impl Default for DHashMap { + fn default() -> Self { + DHashMap { + data: Default::default(), + hasher: Default::default(), + } + } +} + +pub(super) struct VacantEntry<'a, K, V> { + len: DHMIdx, + entry: hash_table::VacantEntry<'a, (K, V, DHMIdx)>, + k: K, +} + +impl<'a, K, V> VacantEntry<'a, K, V> { + pub(super) fn insert(self, v: V) { + self.entry.insert((self.k, v, self.len)); + } +} + +pub(super) enum Entry<'a, K, V> { + Occupied((K, &'a mut V)), + Vacant(VacantEntry<'a, K, V>), +} + +#[inline] +fn hash_one(hasher: &impl BuildHasher, hash: impl Hash) -> u64 { + use core::hash::Hasher; + let mut hasher = hasher.build_hasher(); + hash.hash(&mut hasher); + hasher.finish() +} + +#[inline] +fn eq(k: &K) -> impl Fn(&(K, V, DHMIdx)) -> bool + '_ { + move |x| &x.0 == k +} + +#[inline] +fn hasher_fn(hasher: &S) -> impl Fn(&(K, V, DHMIdx)) -> u64 + '_ { + move |x| hash_one(hasher, &x.0) +} + +impl DHashMap { + #[inline] + pub(super) fn entry(&mut self, k: K) -> (Entry<'_, K, V>, u64) { + let hash = hash_one(&self.hasher, &k); + let len = self.data.len() as DHMIdx; + let entry = match self.data.entry(hash, eq(&k), hasher_fn(&self.hasher)) { + hash_table::Entry::Occupied(entry) => Entry::Occupied((k, &mut entry.into_mut().1)), + hash_table::Entry::Vacant(entry) => Entry::Vacant(VacantEntry { len, entry, k }), + }; + (entry, hash) + } + + #[inline] + pub(super) fn insert_with_hash(&mut self, hash: u64, k: K, v: V) { + debug_assert!({ + let (v, hash2) = self.get(&k); + v.is_none() && hash == hash2 + }); + let len = self.data.len() as DHMIdx; + self.data + .insert_unique(hash, (k, v, len), hasher_fn(&self.hasher)); + } + + #[inline] + pub(super) fn remove_nth(&mut self, hash: u64, idx: usize) { + debug_assert_eq!(self.data.len() - 1, idx); + let idx = idx as DHMIdx; + match self.data.find_entry(hash, |x| x.2 == idx) { + Ok(x) => x.remove(), + Err(_) => unreachable!(), + }; + } + + #[inline] + pub(super) fn len(&self) -> usize { + self.data.len() + } + + #[inline] + pub(super) fn get(&self, k: &K) -> (Option<&V>, u64) { + let hash = hash_one(&self.hasher, k); + (self.data.find(hash, eq(k)).map(|x| &x.1), hash) + } + + pub(super) fn clear(&mut self) { + self.data.clear() + } +} + +impl<'a, K, V, S> IntoIterator for &'a DHashMap { + type Item = (&'a K, &'a V); + + // TODO replace with TAIT + type IntoIter = iter::Map< + std::vec::IntoIter>, + fn(Option<(&'a K, &'a V)>) -> (&'a K, &'a V), + >; + + #[inline(never)] + fn into_iter(self) -> Self::IntoIter { + let mut data: Vec<_> = iter::repeat(None).take(self.data.len()).collect(); + for (k, v, i) in &self.data { + data[*i as usize] = Some((k, v)) + } + data.into_iter().map(Option::unwrap) + } +} + +impl FromIterator<(K, V)> for DHashMap { + fn from_iter>(iter: T) -> Self { + let mut res = Self::default(); + iter.into_iter().for_each(|(k, v)| { + let hash = hash_one(&res.hasher, &k); + res.insert_with_hash(hash, k, v) + }); + res + } +} + +impl Debug for DHashMap { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_map().entries(self).finish() + } +} + +#[cfg(test)] +mod test { + use crate::raw::dhashmap::DHashMap; + use std::fmt::Debug; + use std::hash::{Hash, Hasher}; + + #[derive(Eq, PartialEq, Debug, Clone)] + struct BadHash(T); + + #[allow(clippy::derive_hash_xor_eq)] // We explicitly want to test a bad implementation + impl Hash for BadHash { + fn hash(&self, _: &mut H) {} + } + + fn test(arr: [(K, V); N]) { + let mut map: DHashMap = DHashMap::default(); + let mut hashes = Vec::new(); + for (k, v) in arr.iter().cloned() { + let (r, hash) = map.get(&k); + assert!(r.is_none()); + hashes.push(hash); + map.insert_with_hash(hash, k, v) + } + assert_eq!(map.len(), N); + for (i, (k, v)) in arr.iter().enumerate().rev() { + let (r, hash) = map.get(k); + assert_eq!(Some(hash), hashes.pop()); + assert_eq!(r, Some(v)); + map.remove_nth(hash, i); + let (r2, hash2) = map.get(k); + assert_eq!(hash2, hash); + assert_eq!(r2, None); + assert_eq!(map.len(), i); + } + } + + #[test] + fn test_base() { + test([('a', "a"), ('b', "b"), ('c', "c")]) + } + + #[test] + fn test_bad_hash() { + test([ + (BadHash('a'), "a"), + (BadHash('b'), "b"), + (BadHash('c'), "c"), + ]) + } +} diff --git a/src/raw/eclass.rs b/src/raw/eclass.rs new file mode 100644 index 00000000..dd6e43be --- /dev/null +++ b/src/raw/eclass.rs @@ -0,0 +1,43 @@ +use crate::Id; +use std::fmt::Debug; +use std::iter::ExactSizeIterator; +use std::ops::{Deref, DerefMut}; + +/// An equivalence class of enodes. +#[non_exhaustive] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct RawEClass { + /// This eclass's id. + pub id: Id, + /// Arbitrary data associated with this eclass. + pub(super) raw_data: D, + /// The original Ids of parent enodes. + pub(super) parents: Vec, +} + +impl RawEClass { + /// Iterates over the non-canonical ids of parent enodes of this eclass. + pub fn parents(&self) -> impl ExactSizeIterator + '_ { + self.parents.iter().copied() + } + + /// Consumes `self` returning the stored data and an iterator similar to [`parents`](RawEClass::parents) + pub fn destruct(self) -> (D, impl ExactSizeIterator) { + (self.raw_data, self.parents.into_iter()) + } +} + +impl Deref for RawEClass { + type Target = D; + + fn deref(&self) -> &D { + &self.raw_data + } +} + +impl DerefMut for RawEClass { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.raw_data + } +} diff --git a/src/raw/egraph.rs b/src/raw/egraph.rs new file mode 100644 index 00000000..9296f006 --- /dev/null +++ b/src/raw/egraph.rs @@ -0,0 +1,796 @@ +use crate::{raw::RawEClass, Dot, HashMap, Id, Language, RecExpr, UnionFind}; +use std::collections::BTreeMap; +use std::convert::Infallible; +use std::ops::{Deref, DerefMut}; +use std::{ + borrow::BorrowMut, + fmt::{self, Debug}, + iter, slice, +}; + +use crate::raw::dhashmap::*; +use crate::raw::UndoLogT; +#[cfg(feature = "serde-1")] +use serde::{Deserialize, Serialize}; + +pub struct Parents<'a>(&'a [Id]); + +impl<'a> IntoIterator for Parents<'a> { + type Item = Id; + type IntoIter = iter::Copied>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter().copied() + } +} + +/// A [`RawEGraph`] without its classes that can be obtained by dereferencing a [`RawEGraph`]. +/// +/// It exists as a separate type so that it can still be used while mutably borrowing a [`RawEClass`] +/// +/// See [`RawEGraph::classes_mut`], [`RawEGraph::get_class_mut`] +#[derive(Clone)] +#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] +pub struct EGraphResidual { + pub(super) unionfind: UnionFind, + /// Stores the original node represented by each non-canonical id + pub(super) nodes: Vec, + /// Stores each enode's `Id`, not the `Id` of the eclass. + /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new + /// unions can cause them to become out of date. + #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] + pub(super) memo: DHashMap, +} + +impl EGraphResidual { + /// Pick a representative term for a given Id. + /// + /// Calling this function on an uncanonical `Id` returns a representative based on how it + /// was obtained + pub fn id_to_expr(&self, id: Id) -> RecExpr { + let mut res = Default::default(); + let mut cache = Default::default(); + self.id_to_expr_internal(&mut res, id, &mut cache); + res + } + + fn id_to_expr_internal( + &self, + res: &mut RecExpr, + node_id: Id, + cache: &mut HashMap, + ) -> Id { + if let Some(existing) = cache.get(&node_id) { + return *existing; + } + let new_node = self + .id_to_node(node_id) + .clone() + .map_children(|child| self.id_to_expr_internal(res, child, cache)); + let res_id = res.add(new_node); + cache.insert(node_id, res_id); + res_id + } + + /// Like [`id_to_expr`](EGraphResidual::id_to_expr) but only goes one layer deep + pub fn id_to_node(&self, id: Id) -> &L { + &self.nodes[usize::from(id)] + } + + /// Canonicalizes an eclass id. + /// + /// This corresponds to the `find` operation on the egraph's + /// underlying unionfind data structure. + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); + /// assert_ne!(egraph.find(x), egraph.find(y)); + /// + /// egraph.union(x, y); + /// egraph.rebuild(); + /// assert_eq!(egraph.find(x), egraph.find(y)); + /// ``` + pub fn find(&self, id: Id) -> Id { + self.unionfind.find(id) + } + + /// Same as [`find`](EGraphResidual::find) but requires mutable access since it does path compression + pub fn find_mut(&mut self, id: Id) -> Id { + self.unionfind.find_mut(id) + } + + /// Returns `true` if the egraph is empty + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// assert!(egraph.is_empty()); + /// egraph.add_uncanonical(S::leaf("foo")); + /// assert!(!egraph.is_empty()); + /// ``` + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + /// Returns the number of uncanonical enodes in the `EGraph`. + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); + /// let fx = egraph.add_uncanonical(S::new("f", vec![x])); + /// let fy = egraph.add_uncanonical(S::new("f", vec![y])); + /// // only one eclass + /// egraph.union(x, y); + /// egraph.rebuild(); + /// + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 4); + /// assert_eq!(egraph.number_of_classes(), 2); + /// ``` + pub fn number_of_uncanonical_nodes(&self) -> usize { + self.nodes.len() + } + + /// Returns an iterator over the uncanonical ids in the egraph and the node + /// that would be obtained by calling [`id_to_node`](EGraphResidual::id_to_node) on each of them + pub fn uncanonical_nodes( + &self, + ) -> impl ExactSizeIterator + DoubleEndedIterator { + self.nodes + .iter() + .enumerate() + .map(|(id, node)| (Id::from(id), node)) + } + + /// Returns an iterator over all the uncanonical ids + pub fn uncanonical_ids(&self) -> impl ExactSizeIterator + 'static { + (0..self.number_of_uncanonical_nodes()) + .into_iter() + .map(Id::from) + } + + /// Returns the number of enodes in the `EGraph`. + /// + /// Actually returns the size of the hashcons index. + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// let mut egraph = RawEGraph::::default(); + /// let x = egraph.add_uncanonical(S::leaf("x")); + /// let y = egraph.add_uncanonical(S::leaf("y")); + /// // only one eclass + /// egraph.union(x, y); + /// egraph.rebuild(); + /// + /// assert_eq!(egraph.total_size(), 2); + /// assert_eq!(egraph.number_of_classes(), 1); + /// ``` + pub fn total_size(&self) -> usize { + self.memo.len() + } + + /// Lookup the eclass of the given enode. + /// + /// You can pass in either an owned enode or a `&mut` enode, + /// in which case the enode's children will be canonicalized. + /// + /// # Example + /// ``` + /// # use egg::{SymbolLang, raw::*}; + /// let mut egraph: RawEGraph = Default::default(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// + /// // lookup will find this node if its in the egraph + /// let mut node_f_ab = SymbolLang::new("f", vec![a, b]); + /// assert_eq!(egraph.lookup(node_f_ab.clone()), None); + /// let id = egraph.add_uncanonical(node_f_ab.clone()); + /// assert_eq!(egraph.lookup(node_f_ab.clone()), Some(id)); + /// + /// // if the query node isn't canonical, and its passed in by &mut instead of owned, + /// // its children will be canonicalized + /// egraph.union(a, b); + /// egraph.rebuild(); + /// assert_eq!(egraph.lookup(&mut node_f_ab), Some(id)); + /// assert_eq!(node_f_ab, SymbolLang::new("f", vec![a, a])); + /// ``` + pub fn lookup(&self, enode: B) -> Option + where + B: BorrowMut, + { + self.lookup_internal(enode).map(|id| self.find(id)) + } + + #[inline] + fn lookup_internal(&self, mut enode: B) -> Option + where + B: BorrowMut, + { + let enode = enode.borrow_mut(); + enode.update_children(|id| self.find(id)); + self.memo.get(enode).0.copied() + } + + /// Lookup the eclass of the given [`RecExpr`]. + /// + /// Equivalent to the last value in [`EGraphResidual::lookup_expr_ids`]. + pub fn lookup_expr(&self, expr: &RecExpr) -> Option { + self.lookup_expr_ids(expr) + .and_then(|ids| ids.last().copied()) + } + + /// Lookup the eclasses of all the nodes in the given [`RecExpr`]. + pub fn lookup_expr_ids(&self, expr: &RecExpr) -> Option> { + let nodes = expr.as_ref(); + let mut new_ids = Vec::with_capacity(nodes.len()); + for node in nodes { + let node = node.clone().map_children(|i| new_ids[usize::from(i)]); + let id = self.lookup(node)?; + new_ids.push(id) + } + Some(new_ids) + } + + /// Generate a mapping from canonical ids to the list of nodes they represent + pub fn generate_class_nodes(&self) -> HashMap> { + let mut classes = HashMap::default(); + let find = |id| self.find(id); + for (id, node) in self.uncanonical_nodes() { + let id = find(id); + let node = node.clone().map_children(find); + match classes.get_mut(&id) { + None => { + classes.insert(id, vec![node]); + } + Some(x) => x.push(node), + } + } + + // define all the nodes, clustered by eclass + for class in classes.values_mut() { + class.sort_unstable(); + class.dedup(); + } + classes + } + + /// Returns a more debug-able representation of the egraph focusing on its uncanonical ids and nodes. + /// + /// [`RawEGraph`]s implement [`Debug`], but it's not pretty. It + /// prints a lot of stuff you probably don't care about. + /// This method returns a wrapper that implements [`Debug`] in a + /// slightly nicer way, just dumping enodes in each eclass. + /// + /// [`Debug`]: std::fmt::Debug + pub fn dump_uncanonical(&self) -> impl Debug + '_ { + EGraphUncanonicalDump(self) + } + + /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. + pub fn dot(&self) -> Dot<'_, L> { + Dot { + egraph: self, + config: vec![], + use_anchors: true, + } + } +} + +// manual debug impl to avoid L: Language bound on EGraph defn +impl Debug for EGraphResidual { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("EGraphResidual") + .field("unionfind", &self.unionfind) + .field("nodes", &self.nodes) + .field("memo", &self.memo) + .finish() + } +} + +/** A data structure to keep track of equalities between expressions. + +Check out the [background tutorial](crate::tutorials::_01_background) +for more information on e-graphs in general. + +# E-graphs in `egg::raw` + +In `egg::raw`, the main types associated with e-graphs are +[`RawEGraph`], [`RawEClass`], [`Language`], and [`Id`]. + +[`RawEGraph`] and [`RawEClass`] are all generic over a +[`Language`], meaning that types actually floating around in the +egraph are all user-defined. +In particular, the e-nodes are elements of your [`Language`]. +[`RawEGraph`]s and [`RawEClass`]es are additionally parameterized by some +abritrary data associated with each e-class. + +Many methods of [`RawEGraph`] deal with [`Id`]s, which represent e-classes. +Because eclasses are frequently merged, many [`Id`]s will refer to the +same e-class. + +[`RawEGraph`] provides a low level API for dealing with egraphs, in particular with handling the data +stored in each [`RawEClass`] so user will likely want to implemented wrappers around +[`raw_add`](RawEGraph::raw_add), [`raw_union`](RawEGraph::raw_union), and [`raw_rebuild`](RawEGraph::raw_rebuild) +to properly handle this data + **/ +#[derive(Clone)] +#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))] +pub struct RawEGraph { + #[cfg_attr(feature = "serde-1", serde(flatten))] + pub(super) residual: EGraphResidual, + /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, + /// not the canonical id of the eclass. + pub(super) pending: Vec, + pub(super) classes: HashMap>, + pub(super) undo_log: U, +} + +impl Default for RawEGraph { + fn default() -> Self { + let residual = EGraphResidual { + unionfind: Default::default(), + nodes: Default::default(), + memo: Default::default(), + }; + RawEGraph { + residual, + pending: Default::default(), + classes: Default::default(), + undo_log: Default::default(), + } + } +} + +impl Deref for RawEGraph { + type Target = EGraphResidual; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.residual + } +} + +impl DerefMut for RawEGraph { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.residual + } +} + +// manual debug impl to avoid L: Language bound on EGraph defn +impl Debug for RawEGraph { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let classes: BTreeMap<_, _> = self + .classes + .iter() + .map(|(x, y)| { + let mut parents = y.parents.clone(); + parents.sort_unstable(); + ( + *x, + RawEClass { + id: y.id, + raw_data: &y.raw_data, + parents, + }, + ) + }) + .collect(); + f.debug_struct("EGraph") + .field("memo", &self.residual.memo) + .field("classes", &classes) + .finish() + } +} + +impl RawEGraph { + /// Returns an iterator over the eclasses in the egraph. + pub fn classes(&self) -> impl ExactSizeIterator> { + self.classes.iter().map(|(id, class)| { + debug_assert_eq!(*id, class.id); + class + }) + } + + /// Returns a mutating iterator over the eclasses in the egraph. + /// Also returns the [`EGraphResidual`] so it can still be used while `self` is borrowed + pub fn classes_mut( + &mut self, + ) -> ( + impl ExactSizeIterator>, + &mut EGraphResidual, + ) { + let iter = self.classes.iter_mut().map(|(id, class)| { + debug_assert_eq!(*id, class.id); + class + }); + (iter, &mut self.residual) + } + + /// Returns the number of eclasses in the egraph. + pub fn number_of_classes(&self) -> usize { + self.classes().len() + } + + /// Returns the eclass corresponding to `id` + pub fn get_class>(&self, mut id: I) -> &RawEClass { + let id = id.borrow_mut(); + *id = self.find(*id); + self.get_class_with_cannon(*id) + } + + /// Like [`get_class`](RawEGraph::get_class) but panics if `id` is not canonical + pub fn get_class_with_cannon(&self, id: Id) -> &RawEClass { + self.classes + .get(&id) + .unwrap_or_else(|| panic!("Invalid id {}", id)) + } + + /// Returns the eclass corresponding to `id` + /// Also returns the [`EGraphResidual`] so it can still be used while `self` is borrowed + pub fn get_class_mut>( + &mut self, + mut id: I, + ) -> (&mut RawEClass, &mut EGraphResidual) { + let id = id.borrow_mut(); + *id = self.find_mut(*id); + self.get_class_mut_with_cannon(*id) + } + + /// Like [`get_class_mut`](RawEGraph::get_class_mut) but panics if `id` is not canonical + pub fn get_class_mut_with_cannon( + &mut self, + id: Id, + ) -> (&mut RawEClass, &mut EGraphResidual) { + ( + self.classes + .get_mut(&id) + .unwrap_or_else(|| panic!("Invalid id {}", id)), + &mut self.residual, + ) + } + + /// Returns whether `self` is congruently closed + /// + /// This will always be true after calling [`raw_rebuild`](RawEGraph::raw_rebuild) + pub fn is_clean(&self) -> bool { + self.pending.is_empty() + } +} + +/// Information about a call to [`RawEGraph::raw_union`] +pub struct UnionInfo { + /// The canonical id of the newly merged class + pub new_id: Id, + /// The number of parents that were in the newly merged class before it was merged + pub parents_cut: usize, + /// The id that used to canonically represent the class that was merged into `new_id` + pub old_id: Id, + /// The data that was in the class reprented by `old_id` + pub old_data: D, +} + +/// Information for [`RawEGraph::raw_union`] callback +#[non_exhaustive] +pub struct MergeInfo<'a, D: 'a> { + /// id that will be the root for the newly merged eclass + pub id1: Id, + /// data associated with `id1` that can be modified to reflect `data2` being merged into it + pub data1: &'a mut D, + /// parents of `id1` before the merge + pub parents1: &'a [Id], + /// id that used to be a root but will now be in `id1` eclass + pub id2: Id, + /// data associated with `id2` + pub data2: D, + /// parents of `id2` before the merge + pub parents2: &'a [Id], + /// true if `id1` was the root of the second id passed to [`RawEGraph::raw_union`] + /// false if `id1` was the root of the first id passed to [`RawEGraph::raw_union`] + pub swapped_ids: bool, +} + +impl> RawEGraph { + /// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T` + /// + /// ## Parameters + /// + /// ### `get_self` + /// Called to extract the [`RawEGraph`] from the wrapper type, and should not perform any mutation. + /// + /// This will likely be a simple field access or just the identity function if there is no wrapper type. + /// + /// ### `handle_equiv` + /// When there already exists a node that is congruently equivalent to `enode` in the egraph + /// this function is called with the uncanonical id of a equivalent node, and a reference to `enode` + /// + /// Returning `Some(id)` will cause `raw_add` to immediately return `id` + /// (in this case `id` should represent an enode that is equivalent to the one being inserted). + /// + /// Returning `None` will cause `raw_add` to create a new id for `enode`, union it to the equivalent node, + /// and then return it. + /// + /// ### `handle_union` + /// Called after `handle_equiv` returns `None` with the uncanonical id of the equivalent node + /// and the new `id` assigned to `enode` + /// + /// Calling [`id_to_node`](EGraphResidual::id_to_node) on the new `id` will return a reference to `enode` + /// + /// ### `mk_data` + /// When there does not already exist a node is congruently equivalent to `enode` in the egraph + /// this function is called with the new `id` assigned to `enode` and a reference to the canonicalized version of + /// `enode` to create to data that will be stored in the [`RawEClass`] associated with it + /// + /// Calling [`id_to_node`](EGraphResidual::id_to_node) on the new `id` will return a reference to `enode` + /// + /// Calling [`get_class`](RawEGraph::get_class) on the new `id` will cause a panic since the [`RawEClass`] is + /// still being built + #[inline] + pub fn raw_add( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + original: L, + handle_equiv: impl FnOnce(&mut T, Id, &L) -> Option, + handle_union: impl FnOnce(&mut T, Id, Id), + mk_data: impl FnOnce(&mut T, Id, &L) -> D, + ) -> Id { + let this = get_self(outer); + let enode = original.clone().map_children(|x| this.find(x)); + let (existing_id, hash) = this.residual.memo.get(&enode); + if let Some(&existing_id) = existing_id { + let canon_id = this.find(existing_id); + // when explanations are enabled, we need a new representative for this expr + if let Some(existing_id) = handle_equiv(outer, existing_id, &original) { + existing_id + } else { + let this = get_self(outer); + let new_id = this.residual.unionfind.make_set(); + this.undo_log.add_node(&original, &[], new_id); + this.undo_log.union(canon_id, new_id, Vec::new()); + debug_assert_eq!(Id::from(this.nodes.len()), new_id); + this.residual.nodes.push(original); + this.residual.unionfind.union(canon_id, new_id); + handle_union(outer, existing_id, new_id); + new_id + } + } else { + let id = this.residual.unionfind.make_set(); + this.undo_log.add_node(&original, enode.children(), id); + debug_assert_eq!(Id::from(this.nodes.len()), id); + this.residual.nodes.push(original); + + log::trace!(" ...adding to {}", id); + let class = RawEClass { + id, + raw_data: mk_data(outer, id, &enode), + parents: Default::default(), + }; + let this = get_self(outer); + + // add this enode to the parent lists of its children + enode.for_each(|child| { + this.get_class_mut(child).0.parents.push(id); + }); + + // TODO is this needed? + this.pending.push(id); + + this.classes.insert(id, class); + this.residual.memo.insert_with_hash(hash, enode, id); + this.undo_log.insert_memo(hash); + + id + } + } + + /// Unions two eclasses given their ids. + /// + /// The given ids need not be canonical. + /// + /// If a union occurs, `merge` is called with the data, id, and parents of the two eclasses being merged + #[inline] + pub fn raw_union( + &mut self, + enode_id1: Id, + enode_id2: Id, + merge: impl FnOnce(MergeInfo<'_, D>), + ) { + let mut id1 = self.find_mut(enode_id1); + let mut id2 = self.find_mut(enode_id2); + if id1 == id2 { + return; + } + // make sure class2 has fewer parents + let class1_parents = self.classes[&id1].parents.len(); + let class2_parents = self.classes[&id2].parents.len(); + let mut swapped = false; + if class1_parents < class2_parents { + swapped = true; + std::mem::swap(&mut id1, &mut id2); + } + + // make id1 the new root + self.residual.unionfind.union(id1, id2); + + assert_ne!(id1, id2); + let class2 = self.classes.remove(&id2).unwrap(); + let class1 = self.classes.get_mut(&id1).unwrap(); + assert_eq!(id1, class1.id); + + let info = MergeInfo { + id1: class1.id, + data1: &mut class1.raw_data, + parents1: &class1.parents, + id2: class2.id, + data2: class2.raw_data, + parents2: &class2.parents, + swapped_ids: swapped, + }; + merge(info); + + self.pending.extend(&class2.parents); + + class1.parents.extend(&class2.parents); + + self.undo_log.union(id1, id2, class2.parents); + } + + /// Rebuild to [`RawEGraph`] to restore congruence closure + /// + /// ## Parameters + /// + /// ### `get_self` + /// Called to extract the [`RawEGraph`] from the wrapper type, and should not perform any mutation. + /// + /// This will likely be a simple field access or just the identity function if there is no wrapper type. + /// + /// ### `perform_union` + /// Called on each pair of ids that needs to be unioned + /// + /// In order to be correct `perform_union` should call [`raw_union`](RawEGraph::raw_union) + /// + /// ### `handle_pending` + /// Called with the uncanonical id of each enode whose canonical children have changed, along with a canonical + /// version of it + #[inline] + pub fn raw_rebuild( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut perform_union: impl FnMut(&mut T, Id, Id), + handle_pending: impl FnMut(&mut T, Id, &L), + ) { + let _: Result<(), Infallible> = RawEGraph::try_raw_rebuild( + outer, + get_self, + |this, id1, id2| Ok(perform_union(this, id1, id2)), + handle_pending, + ); + } + + /// Similar to [`raw_rebuild`] but allows for the union operation to fail and abort the rebuild + #[inline] + pub fn try_raw_rebuild( + outer: &mut T, + get_self: impl Fn(&mut T) -> &mut Self, + mut perform_union: impl FnMut(&mut T, Id, Id) -> Result<(), E>, + mut handle_pending: impl FnMut(&mut T, Id, &L), + ) -> Result<(), E> { + loop { + let this = get_self(outer); + if let Some(class) = this.pending.pop() { + let mut node = this.id_to_node(class).clone(); + node.update_children(|id| this.find_mut(id)); + handle_pending(outer, class, &node); + let this = get_self(outer); + let (entry, hash) = this.residual.memo.entry(node); + match entry { + Entry::Occupied((_, id)) => { + let memo_class = *id; + match perform_union(outer, memo_class, class) { + Ok(()) => {} + Err(e) => { + get_self(outer).pending.push(class); + return Err(e); + } + } + } + Entry::Vacant(vac) => { + this.undo_log.insert_memo(hash); + vac.insert(class); + } + } + } else { + break Ok(()); + } + } + } + + /// Returns a more debug-able representation of the egraph focusing on its classes. + /// + /// [`RawEGraph`]s implement [`Debug`], but it's not pretty. It + /// prints a lot of stuff you probably don't care about. + /// This method returns a wrapper that implements [`Debug`] in a + /// slightly nicer way, just dumping enodes in each eclass. + /// + /// [`Debug`]: std::fmt::Debug + pub fn dump_classes(&self) -> impl Debug + '_ + where + D: Debug, + { + EGraphDump(self) + } + + /// Remove all nodes from this egraph + pub fn clear(&mut self) { + self.residual.nodes.clear(); + self.residual.memo.clear(); + self.residual.unionfind.clear(); + self.pending.clear(); + self.undo_log.clear(); + } +} + +impl> RawEGraph { + /// Simplified version of [`raw_add`](RawEGraph::raw_add) for egraphs without eclass data + pub fn add_uncanonical(&mut self, enode: L) -> Id { + Self::raw_add( + self, + |x| x, + enode, + |_, id, _| Some(id), + |_, _, _| {}, + |_, _, _| (), + ) + } + + /// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data + pub fn union(&mut self, id1: Id, id2: Id) -> bool { + let mut unioned = false; + self.raw_union(id1, id2, |_| { + unioned = true; + }); + unioned + } + + /// Simplified version of [`raw_rebuild`](RawEGraph::raw_rebuild) for egraphs without eclass data + pub fn rebuild(&mut self) { + Self::raw_rebuild( + self, + |x| x, + |this, id1, id2| { + this.union(id1, id2); + }, + |_, _, _| {}, + ); + } +} + +struct EGraphUncanonicalDump<'a, L: Language>(&'a EGraphResidual); + +impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (id, node) in self.0.uncanonical_nodes() { + writeln!(f, "{}: {:?} (root={})", id, node, self.0.find(id))? + } + Ok(()) + } +} + +struct EGraphDump<'a, L: Language, D, U>(&'a RawEGraph); + +impl<'a, L: Language, D: Debug, U> Debug for EGraphDump<'a, L, D, U> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut ids: Vec = self.0.classes().map(|c| c.id).collect(); + ids.sort(); + for id in ids { + writeln!(f, "{} {:?}", id, self.0.get_class(id).raw_data)? + } + Ok(()) + } +} diff --git a/src/raw/semi_persistent.rs b/src/raw/semi_persistent.rs new file mode 100644 index 00000000..87112186 --- /dev/null +++ b/src/raw/semi_persistent.rs @@ -0,0 +1,128 @@ +use crate::raw::RawEGraph; +use crate::{Id, Language}; +use std::fmt::Debug; + +pub trait Sealed {} +impl Sealed for () {} +impl Sealed for Option {} + +/// A sealed trait for types that can be used for `push`/`pop` APIs +/// It is trivially implemented for `()` +pub trait UndoLogT: Default + Debug + Sealed { + #[doc(hidden)] + fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id); + + #[doc(hidden)] + fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec); + + #[doc(hidden)] + fn insert_memo(&mut self, hash: u64); + + #[doc(hidden)] + fn clear(&mut self); + + #[doc(hidden)] + fn is_enabled(&self) -> bool; +} + +impl UndoLogT for () { + #[inline] + fn add_node(&mut self, _: &L, _: &[Id], _: Id) {} + + #[inline] + fn union(&mut self, _: Id, _: Id, _: Vec) {} + + #[inline] + fn insert_memo(&mut self, _: u64) {} + + #[inline] + fn clear(&mut self) {} + + fn is_enabled(&self) -> bool { + false + } +} + +impl> UndoLogT for Option { + #[inline] + fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id) { + if let Some(undo) = self { + undo.add_node(node, canon_children, node_id) + } + } + + #[inline] + fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec) { + if let Some(undo) = self { + undo.union(id1, id2, id2_parents) + } + } + + #[inline] + fn insert_memo(&mut self, hash: u64) { + if let Some(undo) = self { + undo.insert_memo(hash) + } + } + + #[inline] + fn clear(&mut self) { + if let Some(undo) = self { + undo.clear() + } + } + + #[inline] + fn is_enabled(&self) -> bool { + self.as_ref().map(U::is_enabled).unwrap_or(false) + } +} + +/// Trait implemented for `T` and `Option` used to provide bounds for push/pop impls +pub trait AsUnwrap { + #[doc(hidden)] + fn as_unwrap(&self) -> &T; + + #[doc(hidden)] + fn as_mut_unwrap(&mut self) -> &mut T; +} + +impl AsUnwrap for T { + #[inline] + fn as_unwrap(&self) -> &T { + self + } + + #[inline] + fn as_mut_unwrap(&mut self) -> &mut T { + self + } +} +impl AsUnwrap for Option { + #[inline] + fn as_unwrap(&self) -> &T { + self.as_ref().unwrap() + } + + #[inline] + fn as_mut_unwrap(&mut self) -> &mut T { + self.as_mut().unwrap() + } +} + +impl> RawEGraph { + /// Change the [`UndoLogT`] being used + /// + /// If the new [`UndoLogT`] is enabled then the egraph must be empty + pub fn set_undo_log(&mut self, undo_log: U) { + if !self.is_empty() && undo_log.is_enabled() { + panic!("Need to set undo log enabled before adding any expressions to the egraph.") + } + self.undo_log = undo_log + } + + /// Check if the [`UndoLogT`] being used is enabled + pub fn has_undo_log(&self) -> bool { + self.undo_log.is_enabled() + } +} diff --git a/src/raw/semi_persistent1.rs b/src/raw/semi_persistent1.rs new file mode 100644 index 00000000..eddf5526 --- /dev/null +++ b/src/raw/semi_persistent1.rs @@ -0,0 +1,228 @@ +use crate::raw::{AsUnwrap, RawEClass, RawEGraph, Sealed, UndoLogT, UnionFind}; +use crate::{Id, Language}; +use std::fmt::Debug; + +/// Stored information required to restore the egraph to a previous state +/// +/// see [`push1`](RawEGraph::push1) and [`pop1`](RawEGraph::pop1) +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct PushInfo { + node_count: usize, + union_count: usize, + memo_log_count: usize, + pop_parents_count: usize, +} + +impl PushInfo { + /// Returns the result of [`EGraphResidual::number_of_uncanonical_nodes`](super::EGraphResidual::number_of_uncanonical_nodes) + /// from the state where `self` was created + pub fn number_of_uncanonical_nodes(&self) -> usize { + self.node_count + } +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +struct UnionInfo { + old_id: Id, + old_parents: Vec, + added_after: u32, +} + +/// Value for [`UndoLogT`] that enables [`push1`](RawEGraph::push1) and [`raw_pop1`](RawEGraph::raw_pop1) +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct UndoLog { + // Mirror of the union find without path compression + undo_find: UnionFind, + pop_parents: Vec, + union_log: Vec, + memo_log: Vec, +} + +impl Default for UndoLog { + fn default() -> Self { + UndoLog { + undo_find: Default::default(), + pop_parents: Default::default(), + union_log: vec![UnionInfo { + old_id: Id::from(0), + old_parents: vec![], + added_after: 0, + }], + memo_log: Default::default(), + } + } +} + +impl Sealed for UndoLog {} + +impl UndoLogT for UndoLog { + fn add_node(&mut self, _: &L, canon_children: &[Id], node_id: Id) { + let new = self.undo_find.make_set(); + debug_assert_eq!(new, node_id); + self.pop_parents.extend(canon_children); + self.union_log.last_mut().unwrap().added_after += canon_children.len() as u32; + } + + fn union(&mut self, id1: Id, id2: Id, old_parents: Vec) { + self.undo_find.union(id1, id2); + self.union_log.push(UnionInfo { + old_id: id2, + added_after: 0, + old_parents, + }) + } + + fn insert_memo(&mut self, hash: u64) { + self.memo_log.push(hash); + } + + fn clear(&mut self) { + self.union_log.truncate(1); + self.union_log[0].added_after = 0; + self.memo_log.clear(); + self.undo_find.clear(); + } + + #[inline] + fn is_enabled(&self) -> bool { + true + } +} + +impl> RawEGraph { + /// Create a [`PushInfo`] representing the current state of the egraph + /// which can later be passed into [`raw_pop1`](RawEGraph::raw_pop1) + /// + /// Requires [`self.is_clean()`](RawEGraph::is_clean) + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// use egg::raw::semi_persistent1::UndoLog; + /// let mut egraph = RawEGraph::::default(); + /// let a = egraph.add_uncanonical(S::leaf("a")); + /// let fa = egraph.add_uncanonical(S::new("f", vec![a])); + /// let c = egraph.add_uncanonical(S::leaf("c")); + /// egraph.rebuild(); + /// let old = egraph.clone(); + /// let restore_point = egraph.push1(); + /// let b = egraph.add_uncanonical(S::leaf("b")); + /// let _fb = egraph.add_uncanonical(S::new("g", vec![b])); + /// egraph.union(b, a); + /// egraph.union(b, c); + /// egraph.rebuild(); + /// egraph.pop1(restore_point); + /// assert_eq!(format!("{:#?}", egraph.dump_uncanonical()), format!("{:#?}", old.dump_uncanonical())); + /// assert_eq!(format!("{:#?}", egraph), format!("{:#?}", old)); + /// ``` + pub fn push1(&self) -> PushInfo { + assert!(self.is_clean()); + let undo = self.undo_log.as_unwrap(); + PushInfo { + node_count: self.number_of_uncanonical_nodes(), + union_count: undo.union_log.len(), + memo_log_count: undo.memo_log.len(), + pop_parents_count: undo.pop_parents.len(), + } + } + + /// Mostly restores the egraph to the state it was it when it called [`push1`](RawEGraph::push1) + /// to create `info` + /// + /// Invalidates all [`PushInfo`]s that were created after `info` + /// + /// The `raw_data` fields of the [`RawEClass`]s are not properly restored + /// Instead, `split` is called to undo each union with a mutable reference to the merged data, and the two ids + /// that were merged to create the data for the eclass of the second `id` (the eclass of the first `id` will + /// be what's left of the merged data after the call) + pub fn raw_pop1(&mut self, info: PushInfo, split: impl FnMut(&mut D, Id, Id) -> D) { + let PushInfo { + node_count, + union_count, + memo_log_count, + pop_parents_count, + } = info; + self.pending.clear(); + self.pop_memo1(memo_log_count); + self.pop_unions1(union_count, pop_parents_count, split); + self.pop_nodes1(node_count); + } + + /// Return the direct parent from the union find without path compression + pub fn find_direct_parent(&self, id: Id) -> Id { + self.undo_log.as_unwrap().undo_find.parent(id) + } + + fn pop_memo1(&mut self, old_count: usize) { + assert!(self.memo.len() >= old_count); + let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; + let len = memo_log.len(); + for (hash, idx) in memo_log.drain(old_count..).zip(old_count..len).rev() { + self.residual.memo.remove_nth(hash, idx); + } + } + + fn pop_unions1( + &mut self, + old_count: usize, + pop_parents_count: usize, + mut split: impl FnMut(&mut D, Id, Id) -> D, + ) { + let undo = self.undo_log.as_mut_unwrap(); + assert!(self.residual.number_of_uncanonical_nodes() >= old_count); + for info in undo.union_log.drain(old_count..).rev() { + for _ in 0..info.added_after { + let id = undo.pop_parents.pop().unwrap(); + self.classes.get_mut(&id).unwrap().parents.pop(); + } + let old_id = info.old_id; + let new_id = undo.undo_find.parent(old_id); + debug_assert_ne!(new_id, old_id); + debug_assert_eq!(undo.undo_find.find(new_id), new_id); + *undo.undo_find.parent_mut(old_id) = old_id; + let new_class = &mut self.classes.get_mut(&new_id).unwrap(); + let cut = new_class.parents.len() - info.old_parents.len(); + debug_assert_eq!(&new_class.parents[cut..], &info.old_parents); + new_class.parents.truncate(cut); + let old_data = split(&mut new_class.raw_data, new_id, old_id); + self.classes.insert( + old_id, + RawEClass { + id: old_id, + raw_data: old_data, + parents: info.old_parents, + }, + ); + } + let rem = undo.pop_parents.len() - pop_parents_count; + for _ in 0..rem { + let id = undo.pop_parents.pop().unwrap(); + self.classes.get_mut(&id).unwrap().parents.pop(); + } + undo.union_log.last_mut().unwrap().added_after -= rem as u32; + } + + fn pop_nodes1(&mut self, old_count: usize) { + assert!(self.number_of_uncanonical_nodes() >= old_count); + let undo = self.undo_log.as_mut_unwrap(); + undo.undo_find.parents.truncate(old_count); + self.residual + .unionfind + .parents + .clone_from(&undo.undo_find.parents); + for id in (old_count..self.number_of_uncanonical_nodes()).map(Id::from) { + self.classes.remove(&id); + } + self.residual.nodes.truncate(old_count); + } +} + +impl> RawEGraph { + /// Simplified version of [`raw_pop1`](RawEGraph::raw_pop1) for egraphs without eclass data + pub fn pop1(&mut self, info: PushInfo) { + self.raw_pop1(info, |_, _, _| ()) + } +} diff --git a/src/raw/semi_persistent2.rs b/src/raw/semi_persistent2.rs new file mode 100644 index 00000000..31639515 --- /dev/null +++ b/src/raw/semi_persistent2.rs @@ -0,0 +1,343 @@ +use crate::raw::{AsUnwrap, RawEClass, RawEGraph, Sealed, UndoLogT}; +use crate::util::{Entry, HashSet}; +use crate::{Id, Language}; +use std::fmt::Debug; + +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +struct UndoNode { + /// Other ENodes that were unioned with this ENode and chose it as their representative + representative_of: Vec, + /// Non-canonical Id's of direct parents of this non-canonical node + parents: Vec, +} + +fn visit_undo_node(id: Id, undo_find: &[UndoNode], f: &mut impl FnMut(Id, &UndoNode)) { + let node = &undo_find[usize::from(id)]; + f(id, node); + node.representative_of + .iter() + .for_each(|&id| visit_undo_node(id, undo_find, &mut *f)) +} + +/// Stored information required to restore the egraph to a previous state +/// +/// see [`push2`](RawEGraph::push2) and [`pop2`](RawEGraph::pop2) +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct PushInfo { + node_count: usize, + union_count: usize, + memo_log_count: usize, + pop_parents_count: usize, +} + +impl PushInfo { + /// Returns the result of [`EGraphResidual::number_of_uncanonical_nodes`](super::EGraphResidual::number_of_uncanonical_nodes) + /// from the state where `self` was created + pub fn number_of_uncanonical_nodes(&self) -> usize { + self.node_count + } +} + +/// Value for [`UndoLogT`] that enables [`push2`](RawEGraph::push2) and [`raw_pop2`](RawEGraph::raw_pop2) +#[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct UndoLog { + undo_find: Vec, + union_log: Vec, + memo_log: Vec, + pop_parents: Vec, + // Scratch space, should be empty other that when inside `pop` + #[cfg_attr(feature = "serde-1", serde(skip))] + dirty: HashSet, +} + +impl Sealed for UndoLog {} + +impl UndoLogT for UndoLog { + fn add_node(&mut self, node: &L, canon: &[Id], node_id: Id) { + debug_assert_eq!(self.undo_find.len(), usize::from(node_id)); + self.undo_find.push(UndoNode::default()); + if !canon.is_empty() { + // this node's children shouldn't since it was equivalent when it was added + for id in node.children() { + self.undo_find[usize::from(*id)].parents.push(node_id) + } + } + self.pop_parents.extend(canon) + } + + fn union(&mut self, id1: Id, id2: Id, _: Vec) { + self.undo_find[usize::from(id1)].representative_of.push(id2); + self.union_log.push(id1) + } + + fn insert_memo(&mut self, hash: u64) { + self.memo_log.push(hash); + } + + fn clear(&mut self) { + self.union_log.clear(); + self.memo_log.clear(); + self.undo_find.clear(); + } + + fn is_enabled(&self) -> bool { + true + } +} + +impl> RawEGraph { + /// Create a [`PushInfo`] representing the current state of the egraph + /// which can later be passed into [`raw_pop2`](RawEGraph::raw_pop2) + /// + /// Requires [`self.is_clean()`](RawEGraph::is_clean) + /// + /// # Example + /// ``` + /// use egg::{raw::*, SymbolLang as S}; + /// use egg::raw::semi_persistent2::UndoLog; + /// let mut egraph = RawEGraph::::default(); + /// let a = egraph.add_uncanonical(S::leaf("a")); + /// let fa = egraph.add_uncanonical(S::new("f", vec![a])); + /// let c = egraph.add_uncanonical(S::leaf("c")); + /// egraph.rebuild(); + /// assert_eq!(egraph.number_of_classes(), 3); + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 3); + /// assert_eq!(egraph.total_size(), 3); + /// let restore_point = egraph.push2(); + /// let b = egraph.add_uncanonical(S::leaf("b")); + /// let _fb = egraph.add_uncanonical(S::new("g", vec![b])); + /// egraph.union(b, a); + /// egraph.union(b, c); + /// egraph.rebuild(); + /// assert_eq!(egraph.find(a), b); + /// assert_eq!(egraph.number_of_classes(), 3); + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 5); + /// assert_eq!(egraph.total_size(), 6); + /// egraph.pop2(restore_point); + /// assert_ne!(egraph.find(a), b); + /// assert_eq!(egraph.lookup(S::leaf("a")), Some(a)); + /// assert_eq!(egraph.lookup(S::new("f", vec![a])), Some(fa)); + /// assert_eq!(egraph.lookup(S::leaf("b")), None); + /// assert_eq!(egraph.number_of_classes(), 3); + /// assert_eq!(egraph.number_of_uncanonical_nodes(), 3); + /// assert_eq!(egraph.total_size(), 3); + /// ``` + pub fn push2(&self) -> PushInfo { + assert!(self.is_clean()); + let undo = self.undo_log.as_unwrap(); + PushInfo { + node_count: undo.undo_find.len(), + union_count: undo.union_log.len(), + memo_log_count: undo.memo_log.len(), + pop_parents_count: undo.pop_parents.len(), + } + } + + /// Mostly restores the egraph to the state it was it when it called [`push2`](RawEGraph::push2) + /// to create `info` + /// + /// Invalidates all [`PushInfo`]s that were created after `info` + /// + /// The `raw_data` fields of the [`RawEClass`]s are not properly restored + /// Instead all eclasses that have were merged into another eclass are recreated with `mk_data` and + /// `clear` is called eclass that had another eclass merged into them + /// + /// After each call to either `mk_data` or `clear`, `handle_eqv` is called on each id that is in + /// the eclass (that was handled by `mk_data` or `clear` + /// + /// The `state` parameter represents arbitrary state that be accessed in any of the closures + pub fn raw_pop2( + &mut self, + info: PushInfo, + state: &mut T, + clear: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + mk_data: impl FnMut(&mut T, Id, UndoCtx<'_, L>) -> D, + handle_eqv: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + ) { + let PushInfo { + node_count, + union_count, + memo_log_count, + pop_parents_count, + } = info; + self.pending.clear(); + self.pop_memo2(memo_log_count); + self.pop_parents2(pop_parents_count, node_count); + self.pop_unions2(union_count, node_count, state, clear, mk_data, handle_eqv); + self.pop_nodes2(node_count); + } + + fn pop_memo2(&mut self, old_count: usize) { + assert!(self.memo.len() >= old_count); + let memo_log = &mut self.undo_log.as_mut_unwrap().memo_log; + let len = memo_log.len(); + for (hash, idx) in memo_log.drain(old_count..).zip(old_count..len).rev() { + self.residual.memo.remove_nth(hash, idx); + } + } + + fn pop_parents2(&mut self, old_count: usize, node_count: usize) { + // Pop uncanonical parents within undo find + let undo = self.undo_log.as_mut_unwrap(); + for (id, node) in self + .residual + .nodes + .iter() + .enumerate() + .skip(node_count) + .rev() + { + for child in node.children() { + let parents = &mut undo.undo_find[usize::from(*child)].parents; + if parents.last().copied() == Some(Id::from(id)) { + // Otherwise this id's children never had it added to its parents + // since it was already equivalent to another node when it was added + parents.pop(); + } + } + } + // Pop canonical parents from classes in egraph + // Note, if `id` is not canonical then its class must have been merged into another class so it's parents will + // be rebuilt anyway + // If another class was merged into `id` we will be popping an incorrect parent, but again it's parents will + // be rebuilt anyway + for id in undo.pop_parents.drain(old_count..) { + if let Some(x) = self.classes.get_mut(&id) { + x.parents.pop(); + } + } + } + + fn pop_unions2( + &mut self, + old_count: usize, + node_count: usize, + state: &mut T, + mut clear: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + mut mk_data: impl FnMut(&mut T, Id, UndoCtx<'_, L>) -> D, + mut handle_eqv: impl FnMut(&mut T, &mut D, Id, UndoCtx<'_, L>), + ) { + let undo = self.undo_log.as_mut_unwrap(); + assert!(undo.union_log.len() >= old_count); + for id in undo.union_log.drain(old_count..) { + let id2 = undo.undo_find[usize::from(id)] + .representative_of + .pop() + .unwrap(); + for id in [id, id2] { + if usize::from(id) < node_count { + undo.dirty.insert(id); + } + } + } + let ctx = UndoCtx { + nodes: &self.residual.nodes[..node_count], + undo_find: &undo.undo_find[..node_count], + }; + for root in undo.dirty.iter().copied() { + let union_find = &mut self.residual.unionfind; + let class = match self.classes.entry(root) { + Entry::Vacant(vac) => { + let default = RawEClass { + id: root, + raw_data: mk_data(state, root, ctx), + parents: Default::default(), + }; + vac.insert(default) + } + Entry::Occupied(occ) => { + let res = occ.into_mut(); + clear(state, &mut res.raw_data, root, ctx); + res.parents.clear(); + res + } + }; + class.parents.clear(); + let parents = &mut class.parents; + let data = &mut class.raw_data; + visit_undo_node(root, &undo.undo_find, &mut |id, node| { + union_find.parents[usize::from(id)] = root; + parents.extend(&node.parents); + handle_eqv(state, data, id, ctx) + }); + // If we call pop again we need parents added more recently at the end + parents.sort_unstable(); + } + undo.dirty.clear(); + } + + fn pop_nodes2(&mut self, old_count: usize) { + assert!(self.number_of_uncanonical_nodes() >= old_count); + for id in (old_count..self.number_of_uncanonical_nodes()).map(Id::from) { + if self.find(id) == id { + self.classes.remove(&id); + } + } + self.residual.nodes.truncate(old_count); + self.undo_log.as_mut_unwrap().undo_find.truncate(old_count); + self.residual.unionfind.parents.truncate(old_count); + } + + /// Returns the [`UndoCtx`] corresponding to the current egraph + pub fn undo_ctx(&self) -> UndoCtx<'_, L> { + UndoCtx { + nodes: &self.nodes, + undo_find: &self.undo_log.as_unwrap().undo_find, + } + } +} + +/// The egraph is in a partially broken state during a call to [`RawEGraph::raw_pop2`] so the passed in closures +/// are given this struct which represents the aspects of the egraph that are currently usable +pub struct UndoCtx<'a, L> { + nodes: &'a [L], + undo_find: &'a [UndoNode], +} + +impl<'a, L> Copy for UndoCtx<'a, L> {} + +impl<'a, L> Clone for UndoCtx<'a, L> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, L> UndoCtx<'a, L> { + /// Calls `f` on all nodes that are equivalent to `id` + /// + /// Requires `id` to be canonical + pub fn equivalent_nodes(self, id: Id, mut f: impl FnMut(Id)) { + visit_undo_node(id, self.undo_find, &mut |id, _| f(id)) + } + + /// Returns an iterator of the uncanonical ids of nodes that contain the uncanonical id `id` + pub fn direct_parents(self, id: Id) -> impl ExactSizeIterator + 'a { + self.undo_find[usize::from(id)].parents.iter().copied() + } + + /// See [`EGraphResidual::id_to_node`](super::EGraphResidual::id_to_node) + pub fn id_to_node(self, id: Id) -> &'a L { + &self.nodes[usize::from(id)] + } + + /// See [`EGraphResidual::number_of_uncanonical_nodes`](super::EGraphResidual::number_of_uncanonical_nodes) + pub fn number_of_uncanonical_nodes(self) -> usize { + self.nodes.len() + } +} + +impl> RawEGraph { + /// Simplified version of [`raw_pop2`](RawEGraph::raw_pop2) for egraphs without eclass data + pub fn pop2(&mut self, info: PushInfo) { + self.raw_pop2( + info, + &mut (), + |_, _, _, _| {}, + |_, _, _| (), + |_, _, _, _| {}, + ) + } +} diff --git a/src/unionfind.rs b/src/raw/unionfind.rs similarity index 74% rename from src/unionfind.rs rename to src/raw/unionfind.rs index 39e9bc58..32fc8e0c 100644 --- a/src/unionfind.rs +++ b/src/raw/unionfind.rs @@ -3,29 +3,33 @@ use std::fmt::Debug; #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +/// Data structure that stores disjoint sets of `Id`s each with a representative pub struct UnionFind { - parents: Vec, + pub(super) parents: Vec, } impl UnionFind { + /// Creates a singleton set and returns its representative pub fn make_set(&mut self) -> Id { let id = Id::from(self.parents.len()); self.parents.push(id); id } + /// Returns the number of ids in all the sets pub fn size(&self) -> usize { self.parents.len() } - fn parent(&self, query: Id) -> Id { + pub(super) fn parent(&self, query: Id) -> Id { self.parents[usize::from(query)] } - fn parent_mut(&mut self, query: Id) -> &mut Id { + pub(super) fn parent_mut(&mut self, query: Id) -> &mut Id { &mut self.parents[usize::from(query)] } + /// Returns the representative of the set `current` belongs to pub fn find(&self, mut current: Id) -> Id { while current != self.parent(current) { current = self.parent(current) @@ -33,6 +37,7 @@ impl UnionFind { current } + /// Equivalent to [`find`](UnionFind::find) but preforms path-compression to optimize further calls pub fn find_mut(&mut self, mut current: Id) -> Id { while current != self.parent(current) { let grandparent = self.parent(self.parent(current)); @@ -42,11 +47,16 @@ impl UnionFind { current } - /// Given two leader ids, unions the two eclasses making root1 the leader. + /// Given two representative ids, unions the two eclasses making root1 the representative. pub fn union(&mut self, root1: Id, root2: Id) -> Id { *self.parent_mut(root2) = root1; root1 } + + /// Resets the union find + pub fn clear(&mut self) { + self.parents.clear() + } } #[cfg(test)] diff --git a/src/test.rs b/src/test.rs index 10815d66..1a730930 100644 --- a/src/test.rs +++ b/src/test.rs @@ -3,6 +3,8 @@ These are not considered part of the public api. */ +use std::cell::RefCell; +use std::rc::Rc; use std::{fmt::Display, fs::File, io::Write, path::PathBuf}; use saturating::Saturating; @@ -37,11 +39,19 @@ pub fn test_runner( should_check: bool, ) where L: Language + Display + FromOp + 'static, - A: Analysis + Default, + A: Analysis + Default + Clone + 'static, + A::Data: Default + Clone, { let _ = env_logger::builder().is_test(true).try_init(); let mut runner = runner.unwrap_or_default(); + let nodes: Vec<_> = runner + .egraph + .uncanonical_nodes() + .map(|(_, n)| n.clone()) + .collect(); + runner.egraph.clear(); + if let Some(lim) = env_var("EGG_NODE_LIMIT") { runner = runner.with_node_limit(lim) } @@ -57,6 +67,22 @@ pub fn test_runner( runner = runner.with_explanations_enabled(); } + let history = Rc::new(RefCell::new(Vec::new())); + let history2 = history.clone(); + // Test push if feature is on + if cfg!(feature = "test-push-pop") { + runner.egraph = runner.egraph.with_push_pop_enabled(); + runner = runner.with_hook(move |runner| { + runner.egraph.push(); + history2.borrow_mut().push(EGraph::clone(&runner.egraph)); + Ok(()) + }); + } + + for node in nodes { + runner.egraph.add_uncanonical(node); + } + runner = runner.with_expr(&start); // NOTE this is a bit of hack, we rely on the fact that the // initial root is the last expr added by the runner. We can't @@ -118,6 +144,30 @@ pub fn test_runner( if let Some(check_fn) = check_fn { check_fn(runner) + } else if cfg!(feature = "test-push-pop") { + let mut egraph = runner.egraph; + let _ = runner.hooks; + for mut old in history.borrow().iter().cloned().rev() { + egraph.pop(); + assert_eq!( + format!("{:#?}", old.dump_uncanonical()), + format!("{:#?}", egraph.dump_uncanonical()), + ); + assert_eq!(format!("{:#?}", old), format!("{:#?}", egraph)); + assert_eq!( + format!("{:#?}", old.dump()), + format!("{:#?}", egraph.dump()), + ); + if let Some(explain) = &mut egraph.explain { + let old_explain = old.explain.as_mut().unwrap(); + old_explain.clear_memo(); + for class in egraph.inner.classes_mut().0 { + explain.test_mk_root(class.id); + old_explain.test_mk_root(class.id); + } + assert_eq!(format!("{:#?}", old_explain), format!("{:#?}", explain)); + } + } } } } diff --git a/src/util.rs b/src/util.rs index 0e9051ee..90eadece 100644 --- a/src/util.rs +++ b/src/util.rs @@ -53,12 +53,15 @@ pub(crate) use hashmap::*; mod hashmap { pub(crate) type HashMap = super::IndexMap; pub(crate) type HashSet = super::IndexSet; + + pub(crate) type Entry<'a, K, V> = indexmap::map::Entry<'a, K, V>; } #[cfg(not(feature = "deterministic"))] mod hashmap { use super::BuildHasher; pub(crate) type HashMap = hashbrown::HashMap; pub(crate) type HashSet = hashbrown::HashSet; + pub(crate) type Entry<'a, K, V> = hashbrown::hash_map::Entry<'a, K, V, BuildHasher>; } pub(crate) type IndexMap = indexmap::IndexMap; @@ -125,7 +128,7 @@ pub(crate) struct UniqueQueue where T: Eq + std::hash::Hash + Clone, { - set: hashbrown::HashSet, + set: hashbrown::HashSet, queue: std::collections::VecDeque, } @@ -171,4 +174,9 @@ where debug_assert_eq!(r, self.set.is_empty()); r } + + pub fn clear(&mut self) { + self.queue.clear(); + self.set.clear(); + } } diff --git a/tests/lambda.rs b/tests/lambda.rs index 80ea4fbd..0ec3acb8 100644 --- a/tests/lambda.rs +++ b/tests/lambda.rs @@ -33,10 +33,10 @@ impl Lambda { type EGraph = egg::EGraph; -#[derive(Default)] +#[derive(Default, Clone)] struct LambdaAnalysis; -#[derive(Debug)] +#[derive(Debug, Clone, Default)] struct Data { free: HashSet, constant: Option<(Lambda, PatternAst)>, diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..3d9dc5db 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -45,7 +45,7 @@ impl egg::CostFunction for MathCostFn { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct ConstantFold; impl Analysis for ConstantFold { type Data = Option<(Constant, PatternAst)>; @@ -102,6 +102,14 @@ impl Analysis for ConstantFold { egraph[id].assert_unique_leaves(); } } + + fn post_pop_n(egraph: &mut EGraph, _: usize) { + for class in egraph.classes_mut() { + if class.data.is_some() { + class.nodes.retain(|x| x.is_leaf()) + } + } + } } fn is_const_or_distinct_var(v: &str, w: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { @@ -275,7 +283,6 @@ egg::test_fn! { .with_time_limit(std::time::Duration::from_secs(10)) .with_iter_limit(60) .with_node_limit(100_000) - .with_explanations_enabled() // HACK this needs to "see" the end expression .with_expr(&"(* x (- (* 3 x) 14))".parse().unwrap()), "(d x (- (pow x 3) (* 7 (pow x 2))))"