diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 25fa754..ec08494 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -403,8 +403,7 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "egglog" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2e053dcef2cd92b0dee0acfad08d83a98c1c1e55b4c55b8eefa2c0e3f4d055" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "chrono", "clap", @@ -435,8 +434,7 @@ dependencies = [ [[package]] name = "egglog-add-primitive" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eea7ffb4d889c382918cfa139256f2a4d3f46de6e08d9c071e33543393bd3663" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "quote", "syn 2.0.118", @@ -445,8 +443,7 @@ dependencies = [ [[package]] name = "egglog-ast" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c67f2aa84cdabd48b266cff0ad4bc6d4cb4386ca51d8c8d93e089f182bc162b" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "ordered-float", ] @@ -454,8 +451,7 @@ dependencies = [ [[package]] name = "egglog-bridge" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaf9e5c5f59f5edb456367dd1e7e30fe5ecfd5b434954ec7ad2d90b3169d9a2f" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "anyhow", "dyn-clone", @@ -469,7 +465,6 @@ dependencies = [ "num-rational", "once_cell", "ordered-float", - "petgraph", "rayon", "smallvec", "thiserror", @@ -479,8 +474,7 @@ dependencies = [ [[package]] name = "egglog-concurrency" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5cf71f21545d7b14aceb21e51fdc16b1077e5063df03d57253fc13f098ea2cd" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "arc-swap", "egglog-numeric-id", @@ -491,8 +485,7 @@ dependencies = [ [[package]] name = "egglog-core-relations" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7d477debbbe77c0fb24df98c37560cce6b8dcd8d05918a009926204cbf13c73" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "anyhow", "bumpalo", @@ -510,7 +503,6 @@ dependencies = [ "log", "num", "once_cell", - "petgraph", "rand 0.9.4", "rayon", "rustc-hash", @@ -522,8 +514,7 @@ dependencies = [ [[package]] name = "egglog-numeric-id" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b207a4ee3241bc8c84bb314d8c6802fa4974c4742a9b1af5e7646ec6ddb5ae5" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "rayon", ] @@ -531,8 +522,7 @@ dependencies = [ [[package]] name = "egglog-reports" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a58e0caea9c3d0fe0486245784f579645576f3b4a8be97fb5828655b281208" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "clap", "hashbrown 0.16.1", @@ -546,8 +536,7 @@ dependencies = [ [[package]] name = "egglog-union-find" version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "175f8e99129dcdac7f96b12cc4703c4145bb74a1a53accc46073a19ace10b9f9" +source = "git+https://github.com/egraphs-good/egglog.git?rev=b27cd225ba3d646ff4327d076c7def665eadb232#b27cd225ba3d646ff4327d076c7def665eadb232" dependencies = [ "crossbeam", "egglog-concurrency", @@ -631,12 +620,6 @@ version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" -[[package]] -name = "foldhash" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" - [[package]] name = "foldhash" version = "0.2.0" @@ -717,15 +700,6 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -[[package]] -name = "hashbrown" -version = "0.15.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" -dependencies = [ - "foldhash 0.1.5", -] - [[package]] name = "hashbrown" version = "0.16.1" @@ -734,7 +708,7 @@ checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" dependencies = [ "allocator-api2", "equivalent", - "foldhash 0.2.0", + "foldhash", "serde", "serde_core", ] @@ -1099,18 +1073,6 @@ dependencies = [ "sha2", ] -[[package]] -name = "petgraph" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" -dependencies = [ - "fixedbitset", - "hashbrown 0.15.5", - "indexmap 2.14.0", - "serde", -] - [[package]] name = "pin-project-lite" version = "0.2.17" diff --git a/rust/lira_trs/Cargo.toml b/rust/lira_trs/Cargo.toml index 83493f6..2b791ab 100644 --- a/rust/lira_trs/Cargo.toml +++ b/rust/lira_trs/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" [dependencies] lira = { path = "../lira" } -egglog = "2.0.0" +egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "b27cd225ba3d646ff4327d076c7def665eadb232" } du_utils_slist = { git = "https://github.com/belovdv/du_utils.git", rev = "7813f1156c67ceb806682e18a61ce805982388a7" } log = { workspace = true } anyhow = { workspace = true } diff --git a/rust/lira_trs/src/impl_dfg3egraph.rs b/rust/lira_trs/src/impl_dfg3egraph.rs index 4e48d65..fff1eb8 100644 --- a/rust/lira_trs/src/impl_dfg3egraph.rs +++ b/rust/lira_trs/src/impl_dfg3egraph.rs @@ -2,6 +2,8 @@ use std::rc::Rc; use ahash::AHashMap; use du_utils_slist::{SList, slist}; +use egglog::*; +use egglog::{extract::DefaultCost, prelude::*}; use crate::{ dfg::{self, Statement}, @@ -26,7 +28,7 @@ impl EGraph { for i in 0..4 { r.add_constructor(format!("out{i}"), &vec!["i64"; i], "Outputs"); } - for i in 0..6 { + for i in 0..10 { r.add_constructor(format!("in{i}"), &vec!["Selector"; i], "Inputs"); } @@ -38,57 +40,6 @@ impl EGraph { } pub fn add_dfg(&mut self, state: &dfg::State, name: &str) { - struct Ser<'a> { - eg: &'a mut EGraph, - cache: AHashMap<*const Statement, String>, - } - impl Ser<'_> { - fn state(&mut self, state: &dfg::State) -> SList { - match state { - dfg::State::Initial => slist![(state_initial)], - dfg::State::After(statement) => { - let bound = self.stmt(statement); - slist![(state_after (universe_stmt {Quoted(bound)}))] - } - } - } - fn sel(&mut self, sel: &dfg::Selector) -> SList { - let name = self.stmt(&sel.stmt); - slist![(sel {sel.output} (universe_stmt {Quoted(name)}))] - } - fn stmt(&mut self, stmt: &Rc) -> String { - if let Some(name) = self.cache.get(&Rc::as_ptr(stmt)) { - return name.to_string(); - } - let shape = match &stmt.shape.lanes_mult { - Some(mult) => slist![(shape_dynamic {stmt.shape.lanes_base} {mult})], - None => slist![(shape {stmt.shape.lanes_base})], - }; - let outputs = slist![ - ({ format!("out{}", stmt.outputs.len()) } - [stmt.outputs.iter().map(|o| slist![{ o }])]) - ]; - let inputs: Vec<_> = stmt.inputs.iter().map(|sel| self.sel(sel)).collect(); - let inputs = slist![({ format!("in{}", inputs.len()) }[inputs])]; - let implicit = match &stmt.implicit { - Some(imp) => slist![(implicit {self.state(imp)})], - None => slist![(pure)], - }; - - let name = self.eg.gen_temp_name("add_dfg_stmt"); - self.eg.execute(slist![ - (set (universe_stmt {Quoted(&name)}) - (stmt - [[shape, outputs]] - {Quoted(&stmt.kind)} {Quoted(&stmt.spec)} - [[inputs, implicit]] - ) - ) - ]); - self.cache.insert(Rc::as_ptr(stmt), name.clone()); - name - } - } let mut ser = Ser { eg: self, cache: Default::default(), @@ -97,151 +48,277 @@ impl EGraph { self.execute(slist![(set(universe_state {Quoted(name)})[[finish]])]); } + // Note: it's very inefficient to call this function many times. pub fn get_dfg(&mut self, name: &str) -> dfg::State { - use egglog::prelude::*; - use egglog::*; + self.extract(|_| 1).get_dfg(name) + } + // Note: it's very inefficient to call this function many times. + pub fn get_dfg_with(&mut self, name: &str, f: impl Fn(&str) -> u64 + 'static) -> dfg::State { + self.extract(f).get_dfg(name) + } + + pub fn extract(&mut self, f: impl Fn(&str) -> u64 + 'static) -> Extracted { + let extractor = egglog::extract::Extractor::compute_costs_from_rootsorts( + None, + self._inner(), + CostModel(f), + ); + + let rel = "universe_state"; + let func = self._inner().get_function(rel).unwrap(); + assert_eq!(func.schema().input.len(), 1); + let sort_input = func.schema().input[0].clone(); + let sort_output = func.schema().output.clone(); + + let results = query( + self._inner_mut(), + &[("k", sort_input.clone()), ("v", sort_output.clone())], + facts![(= (universe_state k) v)], + ) + .unwrap(); + + let mut td = egglog::TermDag::default(); + let mut index = AHashMap::new(); + for row in results.iter() { + let [key, value] = row.try_into().unwrap(); + let (_, tid) = extractor + .extract_best_with_sort(self._inner(), &mut td, value, sort_output.clone()) + .unwrap(); + let key = self + ._inner() + .value_to_base::(key) + .to_string(); + index.insert(key, tid); + } + + Extracted { td, index } + } +} + +pub struct Extracted { + td: TermDag, + index: AHashMap, +} - let expr = exprs::call("universe_state", vec![exprs::string(name)]); - let (sort, value) = self._inner_mut().eval_expr(&expr).unwrap(); - let (td, t, _) = self._inner().extract_value(&sort, value).unwrap(); +impl Extracted { + pub fn get_dfg(&self, name: &str) -> dfg::State { + let &term_id = self.index.get(name).unwrap_or_else(|| panic!("no {name}")); let mut des = Des { - td: &td, + td: &self.td, stmt_cache: AHashMap::new(), }; + des.state(term_id) + } +} - struct Des<'a> { - td: &'a TermDag, - stmt_cache: AHashMap>, +struct CostModel u64>(F); +impl u64> egglog::extract::CostModel for CostModel { + fn fold(&self, _: &str, children_cost: &[DefaultCost], head_cost: DefaultCost) -> DefaultCost { + use egglog::extract::Cost as _; + children_cost.iter().fold(head_cost, |a, b| a.combine(b)) + } + fn enode_cost( + &self, + _: &prelude::EGraph, + func: &Function, + _: &egglog::FunctionRow, + ) -> DefaultCost { + (self.0)(func.name()) + } + fn base_value_cost( + &self, + egraph: &prelude::EGraph, + sort: &ArcSort, + value: Value, + ) -> DefaultCost { + use egglog::extract::Cost as _; + if sort.name() == "String" { + let s: egglog::sort::S = egraph.value_to_base(value); + return (self.0)(s.as_str()); } + DefaultCost::unit() + } +} - impl<'a> Des<'a> { - fn state(&mut self, t: TermId) -> dfg::State { - match self.td.get(t) { - Term::App(sym, children) if sym.as_str() == "state_initial" => { - dfg::State::Initial - } - Term::App(sym, children) if sym.as_str() == "state_after" => { - let stmt = self.stmt(children[0]); - dfg::State::After(stmt) - } - other => panic!("expected state, got {:?}", other), - } +struct Ser<'a> { + eg: &'a mut EGraph, + cache: AHashMap<*const Statement, String>, +} + +impl Ser<'_> { + fn state(&mut self, state: &dfg::State) -> SList { + match state { + dfg::State::Initial => slist![(state_initial)], + dfg::State::After(statement) => { + let bound = self.stmt(statement); + slist![(state_after (universe_stmt {Quoted(bound)}))] } + } + } + fn sel(&mut self, sel: &dfg::Selector) -> SList { + let name = self.stmt(&sel.stmt); + slist![(sel {sel.output} (universe_stmt {Quoted(name)}))] + } + fn stmt(&mut self, stmt: &Rc) -> String { + if let Some(name) = self.cache.get(&Rc::as_ptr(stmt)) { + return name.to_string(); + } + let shape = match &stmt.shape.lanes_mult { + Some(mult) => slist![(shape_dynamic {stmt.shape.lanes_base} {Quoted(mult)})], + None => slist![(shape {stmt.shape.lanes_base})], + }; + let outputs = slist![ + ({ format!("out{}", stmt.outputs.len()) }[stmt.outputs.iter().map(|o| slist![{ o }])]) + ]; + let inputs: Vec<_> = stmt.inputs.iter().map(|sel| self.sel(sel)).collect(); + let inputs = slist![({ format!("in{}", inputs.len()) }[inputs])]; + let implicit = match &stmt.implicit { + Some(imp) => slist![(implicit {self.state(imp)})], + None => slist![(pure)], + }; - fn stmt(&mut self, t: TermId) -> Rc { - if let Some(stmt) = self.stmt_cache.get(&t) { - return stmt.clone(); - } + let name = self.eg.gen_temp_name("add_dfg_stmt"); + self.eg.execute(slist![ + (set (universe_stmt {Quoted(&name)}) + (stmt + [[shape, outputs]] + {Quoted(&stmt.kind)} {Quoted(&stmt.spec)} + [[inputs, implicit]] + ) + ) + ]); + self.cache.insert(Rc::as_ptr(stmt), name.clone()); + name + } +} - match self.td.get(t) { - Term::App(sym, children) if sym.as_str() == "stmt" => { - let shape = self.shape(children[0]); - let outputs = self.outputs(children[1]); - let kind = self.lit_string(children[2]); - let spec = self.lit_string(children[3]); - let inputs = self.inputs(children[4]); - let implicit = self.implicit(children[5]); - - let statement = Rc::new(dfg::Statement { - shape, - outputs, - kind, - spec, - inputs, - implicit, - }); - self.stmt_cache.insert(t, statement.clone()); - statement - } - other => panic!("expected stmt, got {:?}", other), - } +struct Des<'a> { + td: &'a TermDag, + stmt_cache: AHashMap>, +} + +impl<'a> Des<'a> { + fn state(&mut self, t: TermId) -> dfg::State { + match self.td.get(t) { + Term::App(sym, children) if sym.as_str() == "state_initial" => dfg::State::Initial, + Term::App(sym, children) if sym.as_str() == "state_after" => { + let stmt = self.stmt(children[0]); + dfg::State::After(stmt) } + other => panic!("expected state, got {:?}", other), + } + } - fn shape(&mut self, t: TermId) -> lira::Shape { - match self.td.get(t) { - Term::App(sym, children) if sym.as_str() == "shape" => { - let lanes_base = self.lit_usize(children[0]); - lira::Shape { - lanes_base, - lanes_mult: None, - } - } - Term::App(sym, children) if sym.as_str() == "shape_dynamic" => { - let lanes_base = self.lit_usize(children[0]); - let lanes_mult = Some(self.lit_string(children[1])); - lira::Shape { - lanes_base, - lanes_mult, - } - } - other => panic!("expected shape, got {:?}", other), - } + fn stmt(&mut self, t: TermId) -> Rc { + if let Some(stmt) = self.stmt_cache.get(&t) { + return stmt.clone(); + } + + match self.td.get(t) { + Term::App(sym, children) if sym.as_str() == "stmt" => { + let shape = self.shape(children[0]); + let outputs = self.outputs(children[1]); + let kind = self.lit_string(children[2]); + let spec = self.lit_string(children[3]); + let inputs = self.inputs(children[4]); + let implicit = self.implicit(children[5]); + + let statement = Rc::new(dfg::Statement { + shape, + outputs, + kind, + spec, + inputs, + implicit, + }); + self.stmt_cache.insert(t, statement.clone()); + statement } + other => panic!("expected stmt, got {:?}", other), + } + } - fn outputs(&mut self, t: TermId) -> Vec { - match self.td.get(t) { - Term::App(sym, children) if sym.as_str().starts_with("out") => { - let mut out = Vec::new(); - for &child in children { - out.push(self.lit_usize(child)); - } - out - } - other => panic!("expected outputs, got {:?}", other), + fn shape(&mut self, t: TermId) -> lira::Shape { + match self.td.get(t) { + Term::App(sym, children) if sym.as_str() == "shape" => { + let lanes_base = self.lit_usize(children[0]); + lira::Shape { + lanes_base, + lanes_mult: None, } } - - fn inputs(&mut self, t: TermId) -> Vec { - match self.td.get(t) { - Term::App(sym, children) if sym.as_str().starts_with("in") => { - let mut inputs = Vec::new(); - for &child in children { - inputs.push(self.selector(child)); - } - inputs - } - other => panic!("expected inputs, got {:?}", other), + Term::App(sym, children) if sym.as_str() == "shape_dynamic" => { + let lanes_base = self.lit_usize(children[0]); + let lanes_mult = Some(self.lit_string(children[1])); + lira::Shape { + lanes_base, + lanes_mult, } } + other => panic!("expected shape, got {:?}", other), + } + } - fn implicit(&mut self, t: TermId) -> Option { - match self.td.get(t) { - Term::App(sym, children) if sym.as_str() == "pure" => None, - Term::App(sym, children) if sym.as_str() == "implicit" => { - let state = self.state(children[0]); - Some(state) - } - other => panic!("expected implicit, got {:?}", other), + fn outputs(&mut self, t: TermId) -> Vec { + match self.td.get(t) { + Term::App(sym, children) if sym.as_str().starts_with("out") => { + let mut out = Vec::new(); + for &child in children { + out.push(self.lit_usize(child)); } + out } + other => panic!("expected outputs, got {:?}", other), + } + } - fn selector(&mut self, t: TermId) -> dfg::Selector { - match self.td.get(t) { - Term::App(sym, children) if sym.as_str() == "sel" => { - let output = self.lit_usize(children[0]); - let stmt = self.stmt(children[1]); - dfg::Selector { output, stmt } - } - other => panic!("expected sel, got {:?}", other), + fn inputs(&mut self, t: TermId) -> Vec { + match self.td.get(t) { + Term::App(sym, children) if sym.as_str().starts_with("in") => { + let mut inputs = Vec::new(); + for &child in children { + inputs.push(self.selector(child)); } + inputs } + other => panic!("expected inputs, got {:?}", other), + } + } - fn lit_usize(&self, t: TermId) -> usize { - match self.td.get(t) { - Term::Lit(ast::Literal::Int(n)) => *n as usize, - other => panic!("expected i64 literal, got {:?}", other), - } + fn implicit(&mut self, t: TermId) -> Option { + match self.td.get(t) { + Term::App(sym, children) if sym.as_str() == "pure" => None, + Term::App(sym, children) if sym.as_str() == "implicit" => { + let state = self.state(children[0]); + Some(state) } + other => panic!("expected implicit, got {:?}", other), + } + } - fn lit_string(&self, t: TermId) -> String { - match self.td.get(t) { - Term::Lit(ast::Literal::String(s)) => s.to_string(), - other => panic!("expected string literal, got {:?}", other), - } + fn selector(&mut self, t: TermId) -> dfg::Selector { + match self.td.get(t) { + Term::App(sym, children) if sym.as_str() == "sel" => { + let output = self.lit_usize(children[0]); + let stmt = self.stmt(children[1]); + dfg::Selector { output, stmt } } + other => panic!("expected sel, got {:?}", other), + } + } + + fn lit_usize(&self, t: TermId) -> usize { + match self.td.get(t) { + Term::Lit(ast::Literal::Int(n)) => *n as usize, + other => panic!("expected i64 literal, got {:?}", other), } + } - des.state(t) + fn lit_string(&self, t: TermId) -> String { + match self.td.get(t) { + Term::Lit(ast::Literal::String(s)) => s.to_string(), + other => panic!("expected string literal, got {:?}", other), + } } } diff --git a/rust/lira_trs/tests/rewrite.rs b/rust/lira_trs/tests/opt.rs similarity index 100% rename from rust/lira_trs/tests/rewrite.rs rename to rust/lira_trs/tests/opt.rs diff --git a/rust/lira_trs/tests/replace.egg b/rust/lira_trs/tests/replace.egg new file mode 100644 index 0000000..a59cc97 --- /dev/null +++ b/rust/lira_trs/tests/replace.egg @@ -0,0 +1,4 @@ +(ruleset replace) + +(rewrite (stmt h o k "a" i m) (stmt h o k "b" i m) :ruleset replace) +(rewrite (stmt h o k "b" i m) (stmt h o k "c" i m) :ruleset replace) diff --git a/rust/lira_trs/tests/replace.rs b/rust/lira_trs/tests/replace.rs new file mode 100644 index 0000000..88549e2 --- /dev/null +++ b/rust/lira_trs/tests/replace.rs @@ -0,0 +1,32 @@ +use lira_trs::{dfg, egraph::EGraph}; + +#[test] +fn replace() { + let text = "1 = get a;\n"; + + let ir = lira::StatementSeq::parse(text).unwrap(); + let dfg = dfg::lira2dfg(&ir, |_| false); + + let test = |best: &'static str| { + let mut eg = EGraph::new_dfg(); + eg.add_dfg(&dfg, "test"); + eg._inner_mut() + .parse_and_run_program(None, include_str!("./replace.egg")) + .unwrap(); + eg.run_ruleset_saturate("replace"); + let dfg = eg.get_dfg_with("test", move |s| { + if ["a", "b", "c"].contains(&s) && s != best { + 2 + } else { + 1 + } + }); + let ir = dfg::dfg2lira(&dfg); + let text2 = ir.to_string(); + assert_eq!(text2, format!("1 = get {best};\n")); + }; + + test("a"); + test("b"); + test("c"); +}