diff --git a/compiler/rustc_transmute/src/layout/dfa.rs b/compiler/rustc_transmute/src/layout/dfa.rs index af568171f911c..bb909c54d2bc3 100644 --- a/compiler/rustc_transmute/src/layout/dfa.rs +++ b/compiler/rustc_transmute/src/layout/dfa.rs @@ -1,19 +1,18 @@ use std::fmt; use std::sync::atomic::{AtomicU32, Ordering}; -use tracing::instrument; - -use super::{Byte, Nfa, Ref, nfa}; +use super::{Byte, Ref, Tree, Uninhabited}; use crate::Map; -#[derive(PartialEq, Clone, Debug)] +#[derive(PartialEq)] +#[cfg_attr(test, derive(Clone))] pub(crate) struct Dfa where R: Ref, { pub(crate) transitions: Map>, pub(crate) start: State, - pub(crate) accepting: State, + pub(crate) accept: State, } #[derive(PartialEq, Clone, Debug)] @@ -34,35 +33,15 @@ where } } -impl Transitions -where - R: Ref, -{ - #[cfg(test)] - fn insert(&mut self, transition: Transition, state: State) { - match transition { - Transition::Byte(b) => { - self.byte_transitions.insert(b, state); - } - Transition::Ref(r) => { - self.ref_transitions.insert(r, state); - } - } - } -} - -/// The states in a `Nfa` represent byte offsets. +/// The states in a [`Dfa`] represent byte offsets. #[derive(Hash, Eq, PartialEq, PartialOrd, Ord, Copy, Clone)] -pub(crate) struct State(u32); +pub(crate) struct State(pub(crate) u32); -#[cfg(test)] -#[derive(Hash, Eq, PartialEq, Clone, Copy)] -pub(crate) enum Transition -where - R: Ref, -{ - Byte(Byte), - Ref(R), +impl State { + pub(crate) fn new() -> Self { + static COUNTER: AtomicU32 = AtomicU32::new(0); + Self(COUNTER.fetch_add(1, Ordering::SeqCst)) + } } impl fmt::Debug for State { @@ -71,19 +50,6 @@ impl fmt::Debug for State { } } -#[cfg(test)] -impl fmt::Debug for Transition -where - R: Ref, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self { - Self::Byte(b) => b.fmt(f), - Self::Ref(r) => r.fmt(f), - } - } -} - impl Dfa where R: Ref, @@ -92,60 +58,167 @@ where pub(crate) fn bool() -> Self { let mut transitions: Map> = Map::default(); let start = State::new(); - let accepting = State::new(); + let accept = State::new(); - transitions.entry(start).or_default().insert(Transition::Byte(Byte::Init(0x00)), accepting); + transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x00), accept); - transitions.entry(start).or_default().insert(Transition::Byte(Byte::Init(0x01)), accepting); + transitions.entry(start).or_default().byte_transitions.insert(Byte::Init(0x01), accept); - Self { transitions, start, accepting } + Self { transitions, start, accept } } - #[instrument(level = "debug")] - pub(crate) fn from_nfa(nfa: Nfa) -> Self { - let Nfa { transitions: nfa_transitions, start: nfa_start, accepting: nfa_accepting } = nfa; + pub(crate) fn unit() -> Self { + let transitions: Map> = Map::default(); + let start = State::new(); + let accept = start; + + Self { transitions, start, accept } + } - let mut dfa_transitions: Map> = Map::default(); - let mut nfa_to_dfa: Map = Map::default(); - let dfa_start = State::new(); - nfa_to_dfa.insert(nfa_start, dfa_start); + pub(crate) fn from_byte(byte: Byte) -> Self { + let mut transitions: Map> = Map::default(); + let start = State::new(); + let accept = State::new(); - let mut queue = vec![(nfa_start, dfa_start)]; + transitions.entry(start).or_default().byte_transitions.insert(byte, accept); - while let Some((nfa_state, dfa_state)) = queue.pop() { - if nfa_state == nfa_accepting { - continue; - } + Self { transitions, start, accept } + } - for (nfa_transition, next_nfa_states) in nfa_transitions[&nfa_state].iter() { - let dfa_transitions = - dfa_transitions.entry(dfa_state).or_insert_with(Default::default); - - let mapped_state = next_nfa_states.iter().find_map(|x| nfa_to_dfa.get(x).copied()); - - let next_dfa_state = match nfa_transition { - &nfa::Transition::Byte(b) => *dfa_transitions - .byte_transitions - .entry(b) - .or_insert_with(|| mapped_state.unwrap_or_else(State::new)), - &nfa::Transition::Ref(r) => *dfa_transitions - .ref_transitions - .entry(r) - .or_insert_with(|| mapped_state.unwrap_or_else(State::new)), - }; - - for &next_nfa_state in next_nfa_states { - nfa_to_dfa.entry(next_nfa_state).or_insert_with(|| { - queue.push((next_nfa_state, next_dfa_state)); - next_dfa_state - }); + pub(crate) fn from_ref(r: R) -> Self { + let mut transitions: Map> = Map::default(); + let start = State::new(); + let accept = State::new(); + + transitions.entry(start).or_default().ref_transitions.insert(r, accept); + + Self { transitions, start, accept } + } + + pub(crate) fn from_tree(tree: Tree) -> Result { + Ok(match tree { + Tree::Byte(b) => Self::from_byte(b), + Tree::Ref(r) => Self::from_ref(r), + Tree::Alt(alts) => { + // Convert and filter the inhabited alternatives. + let mut alts = alts.into_iter().map(Self::from_tree).filter_map(Result::ok); + // If there are no alternatives, return `Uninhabited`. + let dfa = alts.next().ok_or(Uninhabited)?; + // Combine the remaining alternatives with `dfa`. + alts.fold(dfa, |dfa, alt| dfa.union(alt, State::new)) + } + Tree::Seq(elts) => { + let mut dfa = Self::unit(); + for elt in elts.into_iter().map(Self::from_tree) { + dfa = dfa.concat(elt?); } + dfa } + }) + } + + /// Concatenate two `Dfa`s. + pub(crate) fn concat(self, other: Self) -> Self { + if self.start == self.accept { + return other; + } else if other.start == other.accept { + return self; } - let dfa_accepting = nfa_to_dfa[&nfa_accepting]; + let start = self.start; + let accept = other.accept; + + let mut transitions: Map> = self.transitions; - Self { transitions: dfa_transitions, start: dfa_start, accepting: dfa_accepting } + for (source, transition) in other.transitions { + let fix_state = |state| if state == other.start { self.accept } else { state }; + let entry = transitions.entry(fix_state(source)).or_default(); + for (edge, destination) in transition.byte_transitions { + entry.byte_transitions.insert(edge, fix_state(destination)); + } + for (edge, destination) in transition.ref_transitions { + entry.ref_transitions.insert(edge, fix_state(destination)); + } + } + + Self { transitions, start, accept } + } + + /// Compute the union of two `Dfa`s. + pub(crate) fn union(self, other: Self, mut new_state: impl FnMut() -> State) -> Self { + // We implement `union` by lazily initializing a set of states + // corresponding to the product of states in `self` and `other`, and + // then add transitions between these states that correspond to where + // they exist between `self` and `other`. + + let a = self; + let b = other; + + let accept = new_state(); + + let mut mapping: Map<(Option, Option), State> = Map::default(); + + let mut mapped = |(a_state, b_state)| { + if Some(a.accept) == a_state || Some(b.accept) == b_state { + // If either `a_state` or `b_state` are accepting, map to a + // common `accept` state. + accept + } else { + *mapping.entry((a_state, b_state)).or_insert_with(&mut new_state) + } + }; + + let start = mapped((Some(a.start), Some(b.start))); + let mut transitions: Map> = Map::default(); + let mut queue = vec![(Some(a.start), Some(b.start))]; + let empty_transitions = Transitions::default(); + + while let Some((a_src, b_src)) = queue.pop() { + let a_transitions = + a_src.and_then(|a_src| a.transitions.get(&a_src)).unwrap_or(&empty_transitions); + let b_transitions = + b_src.and_then(|b_src| b.transitions.get(&b_src)).unwrap_or(&empty_transitions); + + let byte_transitions = + a_transitions.byte_transitions.keys().chain(b_transitions.byte_transitions.keys()); + + for byte_transition in byte_transitions { + let a_dst = a_transitions.byte_transitions.get(byte_transition).copied(); + let b_dst = b_transitions.byte_transitions.get(byte_transition).copied(); + + assert!(a_dst.is_some() || b_dst.is_some()); + + let src = mapped((a_src, b_src)); + let dst = mapped((a_dst, b_dst)); + + transitions.entry(src).or_default().byte_transitions.insert(*byte_transition, dst); + + if !transitions.contains_key(&dst) { + queue.push((a_dst, b_dst)) + } + } + + let ref_transitions = + a_transitions.ref_transitions.keys().chain(b_transitions.ref_transitions.keys()); + + for ref_transition in ref_transitions { + let a_dst = a_transitions.ref_transitions.get(ref_transition).copied(); + let b_dst = b_transitions.ref_transitions.get(ref_transition).copied(); + + assert!(a_dst.is_some() || b_dst.is_some()); + + let src = mapped((a_src, b_src)); + let dst = mapped((a_dst, b_dst)); + + transitions.entry(src).or_default().ref_transitions.insert(*ref_transition, dst); + + if !transitions.contains_key(&dst) { + queue.push((a_dst, b_dst)) + } + } + } + + Self { transitions, start, accept } } pub(crate) fn bytes_from(&self, start: State) -> Option<&Map> { @@ -159,24 +232,48 @@ where pub(crate) fn refs_from(&self, start: State) -> Option<&Map> { Some(&self.transitions.get(&start)?.ref_transitions) } -} -impl State { - pub(crate) fn new() -> Self { - static COUNTER: AtomicU32 = AtomicU32::new(0); - Self(COUNTER.fetch_add(1, Ordering::SeqCst)) + #[cfg(test)] + pub(crate) fn from_edges>( + start: u32, + accept: u32, + edges: &[(u32, B, u32)], + ) -> Self { + let start = State(start); + let accept = State(accept); + let mut transitions: Map> = Map::default(); + + for &(src, edge, dst) in edges { + let src = State(src); + let dst = State(dst); + let old = transitions.entry(src).or_default().byte_transitions.insert(edge.into(), dst); + assert!(old.is_none()); + } + + Self { start, accept, transitions } } } -#[cfg(test)] -impl From> for Transition +/// Serialize the DFA using the Graphviz DOT format. +impl fmt::Debug for Dfa where R: Ref, { - fn from(nfa_transition: nfa::Transition) -> Self { - match nfa_transition { - nfa::Transition::Byte(byte) => Transition::Byte(byte), - nfa::Transition::Ref(r) => Transition::Ref(r), + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "digraph {{")?; + writeln!(f, " {:?} [shape = doublecircle]", self.start)?; + writeln!(f, " {:?} [shape = doublecircle]", self.accept)?; + + for (src, transitions) in self.transitions.iter() { + for (t, dst) in transitions.byte_transitions.iter() { + writeln!(f, " {src:?} -> {dst:?} [label=\"{t:?}\"]")?; + } + + for (t, dst) in transitions.ref_transitions.iter() { + writeln!(f, " {src:?} -> {dst:?} [label=\"{t:?}\"]")?; + } } + + writeln!(f, "}}") } } diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs index c4c01a8fac31f..c940f7c42a82f 100644 --- a/compiler/rustc_transmute/src/layout/mod.rs +++ b/compiler/rustc_transmute/src/layout/mod.rs @@ -4,9 +4,6 @@ use std::hash::Hash; pub(crate) mod tree; pub(crate) use tree::Tree; -pub(crate) mod nfa; -pub(crate) use nfa::Nfa; - pub(crate) mod dfa; pub(crate) use dfa::Dfa; @@ -29,6 +26,13 @@ impl fmt::Debug for Byte { } } +#[cfg(test)] +impl From for Byte { + fn from(src: u8) -> Self { + Self::Init(src) + } +} + pub(crate) trait Def: Debug + Hash + Eq + PartialEq + Copy + Clone { fn has_safety_invariants(&self) -> bool; } diff --git a/compiler/rustc_transmute/src/layout/nfa.rs b/compiler/rustc_transmute/src/layout/nfa.rs deleted file mode 100644 index 9c21fd94f03ec..0000000000000 --- a/compiler/rustc_transmute/src/layout/nfa.rs +++ /dev/null @@ -1,169 +0,0 @@ -use std::fmt; -use std::sync::atomic::{AtomicU32, Ordering}; - -use super::{Byte, Ref, Tree, Uninhabited}; -use crate::{Map, Set}; - -/// A non-deterministic finite automaton (NFA) that represents the layout of a type. -/// The transmutability of two given types is computed by comparing their `Nfa`s. -#[derive(PartialEq, Debug)] -pub(crate) struct Nfa -where - R: Ref, -{ - pub(crate) transitions: Map, Set>>, - pub(crate) start: State, - pub(crate) accepting: State, -} - -/// The states in a `Nfa` represent byte offsets. -#[derive(Hash, Eq, PartialEq, PartialOrd, Ord, Copy, Clone)] -pub(crate) struct State(u32); - -/// The transitions between states in a `Nfa` reflect bit validity. -#[derive(Hash, Eq, PartialEq, Clone, Copy)] -pub(crate) enum Transition -where - R: Ref, -{ - Byte(Byte), - Ref(R), -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "S_{}", self.0) - } -} - -impl fmt::Debug for Transition -where - R: Ref, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self { - Self::Byte(b) => b.fmt(f), - Self::Ref(r) => r.fmt(f), - } - } -} - -impl Nfa -where - R: Ref, -{ - pub(crate) fn unit() -> Self { - let transitions: Map, Set>> = Map::default(); - let start = State::new(); - let accepting = start; - - Nfa { transitions, start, accepting } - } - - pub(crate) fn from_byte(byte: Byte) -> Self { - let mut transitions: Map, Set>> = Map::default(); - let start = State::new(); - let accepting = State::new(); - - let source = transitions.entry(start).or_default(); - let edge = source.entry(Transition::Byte(byte)).or_default(); - edge.insert(accepting); - - Nfa { transitions, start, accepting } - } - - pub(crate) fn from_ref(r: R) -> Self { - let mut transitions: Map, Set>> = Map::default(); - let start = State::new(); - let accepting = State::new(); - - let source = transitions.entry(start).or_default(); - let edge = source.entry(Transition::Ref(r)).or_default(); - edge.insert(accepting); - - Nfa { transitions, start, accepting } - } - - pub(crate) fn from_tree(tree: Tree) -> Result { - Ok(match tree { - Tree::Byte(b) => Self::from_byte(b), - Tree::Ref(r) => Self::from_ref(r), - Tree::Alt(alts) => { - let mut alts = alts.into_iter().map(Self::from_tree); - let mut nfa = alts.next().ok_or(Uninhabited)??; - for alt in alts { - nfa = nfa.union(alt?); - } - nfa - } - Tree::Seq(elts) => { - let mut nfa = Self::unit(); - for elt in elts.into_iter().map(Self::from_tree) { - nfa = nfa.concat(elt?); - } - nfa - } - }) - } - - /// Concatenate two `Nfa`s. - pub(crate) fn concat(self, other: Self) -> Self { - if self.start == self.accepting { - return other; - } else if other.start == other.accepting { - return self; - } - - let start = self.start; - let accepting = other.accepting; - - let mut transitions: Map, Set>> = self.transitions; - - for (source, transition) in other.transitions { - let fix_state = |state| if state == other.start { self.accepting } else { state }; - let entry = transitions.entry(fix_state(source)).or_default(); - for (edge, destinations) in transition { - let entry = entry.entry(edge).or_default(); - for destination in destinations { - entry.insert(fix_state(destination)); - } - } - } - - Self { transitions, start, accepting } - } - - /// Compute the union of two `Nfa`s. - pub(crate) fn union(self, other: Self) -> Self { - let start = self.start; - let accepting = self.accepting; - - let mut transitions: Map, Set>> = self.transitions.clone(); - - for (&(mut source), transition) in other.transitions.iter() { - // if source is starting state of `other`, replace with starting state of `self` - if source == other.start { - source = self.start; - } - let entry = transitions.entry(source).or_default(); - for (edge, destinations) in transition { - let entry = entry.entry(*edge).or_default(); - for &(mut destination) in destinations { - // if dest is accepting state of `other`, replace with accepting state of `self` - if destination == other.accepting { - destination = self.accepting; - } - entry.insert(destination); - } - } - } - Self { transitions, start, accepting } - } -} - -impl State { - pub(crate) fn new() -> Self { - static COUNTER: AtomicU32 = AtomicU32::new(0); - Self(COUNTER.fetch_add(1, Ordering::SeqCst)) - } -} diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs index 00928137d2976..76fa6ceabe7e7 100644 --- a/compiler/rustc_transmute/src/lib.rs +++ b/compiler/rustc_transmute/src/lib.rs @@ -2,7 +2,7 @@ #![feature(never_type)] // tidy-alphabetical-end -pub(crate) use rustc_data_structures::fx::{FxIndexMap as Map, FxIndexSet as Set}; +pub(crate) use rustc_data_structures::fx::FxIndexMap as Map; pub mod layout; mod maybe_transmutable; diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs index 63fabc9c83d93..db0e1ab8e986a 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs @@ -4,7 +4,7 @@ pub(crate) mod query_context; #[cfg(test)] mod tests; -use crate::layout::{self, Byte, Def, Dfa, Nfa, Ref, Tree, Uninhabited, dfa}; +use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, Uninhabited, dfa}; use crate::maybe_transmutable::query_context::QueryContext; use crate::{Answer, Condition, Map, Reason}; @@ -73,7 +73,7 @@ where /// Answers whether a `Tree` is transmutable into another `Tree`. /// /// This method begins by de-def'ing `src` and `dst`, and prunes private paths from `dst`, - /// then converts `src` and `dst` to `Nfa`s, and computes an answer using those NFAs. + /// then converts `src` and `dst` to `Dfa`s, and computes an answer using those DFAs. #[inline(always)] #[instrument(level = "debug", skip(self), fields(src = ?self.src, dst = ?self.dst))] pub(crate) fn answer(self) -> Answer<::Ref> { @@ -105,22 +105,22 @@ where trace!(?dst, "pruned dst"); - // Convert `src` from a tree-based representation to an NFA-based + // Convert `src` from a tree-based representation to an DFA-based // representation. If the conversion fails because `src` is uninhabited, // conclude that the transmutation is acceptable, because instances of // the `src` type do not exist. - let src = match Nfa::from_tree(src) { + let src = match Dfa::from_tree(src) { Ok(src) => src, Err(Uninhabited) => return Answer::Yes, }; - // Convert `dst` from a tree-based representation to an NFA-based + // Convert `dst` from a tree-based representation to an DFA-based // representation. If the conversion fails because `src` is uninhabited, // conclude that the transmutation is unacceptable. Valid instances of // the `dst` type do not exist, either because it's genuinely // uninhabited, or because there are no branches of the tree that are // free of safety invariants. - let dst = match Nfa::from_tree(dst) { + let dst = match Dfa::from_tree(dst) { Ok(dst) => dst, Err(Uninhabited) => return Answer::No(Reason::DstMayHaveSafetyInvariants), }; @@ -129,23 +129,6 @@ where } } -impl MaybeTransmutableQuery::Ref>, C> -where - C: QueryContext, -{ - /// Answers whether a `Nfa` is transmutable into another `Nfa`. - /// - /// This method converts `src` and `dst` to DFAs, then computes an answer using those DFAs. - #[inline(always)] - #[instrument(level = "debug", skip(self), fields(src = ?self.src, dst = ?self.dst))] - pub(crate) fn answer(self) -> Answer<::Ref> { - let Self { src, dst, assume, context } = self; - let src = Dfa::from_nfa(src); - let dst = Dfa::from_nfa(dst); - MaybeTransmutableQuery { src, dst, assume, context }.answer() - } -} - impl MaybeTransmutableQuery::Ref>, C> where C: QueryContext, @@ -173,7 +156,7 @@ where src_transitions_len = self.src.transitions.len(), dst_transitions_len = self.dst.transitions.len() ); - let answer = if dst_state == self.dst.accepting { + let answer = if dst_state == self.dst.accept { // truncation: `size_of(Src) >= size_of(Dst)` // // Why is truncation OK to do? Because even though the Src is bigger, all we care about @@ -190,7 +173,7 @@ where // that none of the actually-used data can introduce an invalid state for Dst's type, we // are able to safely transmute, even with truncation. Answer::Yes - } else if src_state == self.src.accepting { + } else if src_state == self.src.accept { // extension: `size_of(Src) >= size_of(Dst)` if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) { self.answer_memo(cache, src_state, dst_state_prime) diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs index 69a6b1b77f4b0..cc6a4dce17b63 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs @@ -126,7 +126,7 @@ mod bool { let into_set = |alts: Vec<_>| { #[cfg(feature = "rustc")] - let mut set = crate::Set::default(); + let mut set = rustc_data_structures::fx::FxIndexSet::default(); #[cfg(not(feature = "rustc"))] let mut set = std::collections::HashSet::new(); set.extend(alts); @@ -174,3 +174,32 @@ mod bool { } } } + +mod union { + use super::*; + + #[test] + fn union() { + let [a, b, c, d] = [0, 1, 2, 3]; + let s = Dfa::from_edges(a, d, &[(a, 0, b), (b, 0, d), (a, 1, c), (c, 1, d)]); + + let t = Dfa::from_edges(a, c, &[(a, 1, b), (b, 0, c)]); + + let mut ctr = 0; + let new_state = || { + let state = crate::layout::dfa::State(ctr); + ctr += 1; + state + }; + + let u = s.clone().union(t.clone(), new_state); + + let expected_u = + Dfa::from_edges(b, a, &[(b, 0, c), (b, 1, d), (d, 1, a), (d, 0, a), (c, 0, a)]); + + assert_eq!(u, expected_u); + + assert_eq!(is_transmutable(&s, &u, Assume::default()), Answer::Yes); + assert_eq!(is_transmutable(&t, &u, Assume::default()), Answer::Yes); + } +}