diff --git a/regex-syntax/src/hir/interval.rs b/regex-syntax/src/hir/interval.rs index d507ee724..f76c940d7 100644 --- a/regex-syntax/src/hir/interval.rs +++ b/regex-syntax/src/hir/interval.rs @@ -1,4 +1,4 @@ -use core::{char, cmp, fmt::Debug, slice}; +use core::{char, cmp, fmt::Debug, mem}; use alloc::vec::Vec; @@ -81,23 +81,26 @@ impl IntervalSet { /// Add a new interval to this set. pub fn push(&mut self, interval: I) { - // TODO: This could be faster. e.g., Push the interval such that - // it preserves canonicalization. - self.ranges.push(interval); - self.canonicalize(); + // Use a binary search to try to find the approximate place this + // interval should go + let point = match self.ranges.binary_search(&interval) { + // We lucked out, this interval already exists in the set. + Ok(_) => return, + Err(point) => point, + }; + + // TODO: A more efficient implementation is possible here, one which + // avoids the unconditional insert and searches only the range covered + // by `interval` when performing the union. + self.ranges.insert(point, interval); + union_sorted(&mut self.ranges); + // We don't know whether the new interval added here is considered // case folded, so we conservatively assume that the entire set is // no longer case folded if it was previously. self.folded = false; } - /// Return an iterator over all intervals in this set. - /// - /// The iterator yields intervals in ascending order. - pub fn iter(&self) -> IntervalSetIter<'_, I> { - IntervalSetIter(self.ranges.iter()) - } - /// Return an immutable slice of intervals in this set. /// /// The sequence returned is in canonical ordering. @@ -134,9 +137,44 @@ impl IntervalSet { if other.ranges.is_empty() || self.ranges == other.ranges { return; } - // This could almost certainly be done more efficiently. - self.ranges.extend(&other.ranges); - self.canonicalize(); + + if self.ranges.is_empty() { + self.clone_from(other); + return; + } + + // If our allocated capacity is sufficient to hold both ourself and + // the new range, we just merge the ranges in-place, then canonicalize. + if self.ranges.capacity() >= self.ranges.len() + other.ranges.len() { + merge_sorted_into(&mut self.ranges, other.ranges.iter().copied()); + union_sorted(&mut self.ranges) + } + // Otherwise, build a new vector by merging the two ranges and unioning + // them as we go + else { + // No way to know what the new size will be, so for now we assume that + // in typical cases, the union of a set of classes won't have many + // overlaps. + let mut ranges = + Vec::with_capacity(self.ranges.len() + other.ranges.len()); + + let final_range = + MergeIter::new(self.ranges.iter(), other.ranges.iter()) + .copied() + .reduce(|range, next_range| { + range.union_right(&next_range).unwrap_or_else(|| { + ranges.push(range); + next_range + }) + }); + + if let Some(final_range) = final_range { + ranges.push(final_range); + } + + self.ranges = ranges; + } + self.folded = self.folded && other.folded; } @@ -294,36 +332,70 @@ impl IntervalSet { /// For all `x` where `x` is any element, if `x` was in this set, then it /// will not be in this set after negation. pub fn negate(&mut self) { - if self.ranges.is_empty() { - let (min, max) = (I::Bound::min_value(), I::Bound::max_value()); + let Some(first_range) = self.ranges.first() else { + let (min, max) = (I::Bound::MIN, I::Bound::MAX); self.ranges.push(I::create(min, max)); // The set containing everything must case folded. self.folded = true; return; - } + }; + + // The basic algorithm: replace each interval `[low..high]` with + // `[margin..low]``, and record `high` as the new `margin`. Do all of + // that without making off-by-one errors, and take care that there may + // not be a new interval at [0..?] or at [?..max]. + + // First, take care of the first range; if it's 0.., it has no leftward + // negation, so it's skipped, and its upper bound is used as the first + // leftward margin. + let (margin, skip_first) = match first_range.lower() == I::Bound::MIN { + false => (I::Bound::MIN, false), + true => match first_range.upper().try_increment() { + Some(bound) => (bound, true), + // The current range covers everything, so its negation is the empty set + None => { + self.ranges.clear(); + self.folded = true; + return; + } + }, + }; - // There should be a way to do this in-place with constant memory, - // but I couldn't figure out a simple way to do it. So just append - // the negation to the end of this range, and then drain it before - // we're done. - let drain_end = self.ranges.len(); + let mut left_margin = Some(margin); + let mut left_margin_ref = left_margin.as_ref().unwrap(); + + // Again, we're replacing each range with its leftward negation. If + // skip_first is true, then the first range HAS no leftward negation, + // and everything else is shifted to the left one slot. + let start = if skip_first { 1 } else { 0 }; + + for index in start..self.ranges.len() { + let dest_index = if skip_first { index - 1 } else { index }; + + let start = *left_margin_ref; + let end = self.ranges[index].lower().decrement(); + left_margin = self.ranges[index].upper().try_increment(); + + self.ranges[dest_index] = I::create(start, end); - // We do checked arithmetic below because of the canonical ordering - // invariant. - if self.ranges[0].lower() > I::Bound::min_value() { - let upper = self.ranges[0].lower().decrement(); - self.ranges.push(I::create(I::Bound::min_value(), upper)); + left_margin_ref = match left_margin.as_ref() { + Some(margin) => margin, + None => break, + } } - for i in 1..drain_end { - let lower = self.ranges[i - 1].upper().increment(); - let upper = self.ranges[i].lower().decrement(); - self.ranges.push(I::create(lower, upper)); + + // If we skipped the first range, then all of the subsequent ranges + // were stored one slot to the left. Pop the last slot. + if skip_first { + self.ranges.pop(); } - if self.ranges[drain_end - 1].upper() < I::Bound::max_value() { - let lower = self.ranges[drain_end - 1].upper().increment(); - self.ranges.push(I::create(lower, I::Bound::max_value())); + + // If there's a final margin, we need to add an extra righward negation, + // covering everything on that side. + if let Some(left_margin) = left_margin { + self.ranges.push(I::create(left_margin, I::Bound::MAX)); } - self.ranges.drain(..drain_end); + // We don't need to update whether this set is folded or not, because // it is conservatively preserved through negation. Namely, if a set // is not folded, then it is possible that its negation is folded, for @@ -344,27 +416,10 @@ impl IntervalSet { if self.is_canonical() { return; } - self.ranges.sort(); - assert!(!self.ranges.is_empty()); - // Is there a way to do this in-place with constant memory? I couldn't - // figure out a way to do it. So just append the canonicalization to - // the end of this range, and then drain it before we're done. - let drain_end = self.ranges.len(); - for oldi in 0..drain_end { - // If we've added at least one new range, then check if we can - // merge this range in the previously added range. - if self.ranges.len() > drain_end { - let (last, rest) = self.ranges.split_last_mut().unwrap(); - if let Some(union) = last.union(&rest[oldi]) { - *last = union; - continue; - } - } - let range = self.ranges[oldi]; - self.ranges.push(range); - } - self.ranges.drain(..drain_end); + self.ranges.sort_unstable(); + assert!(!self.ranges.is_empty()); + union_sorted(&mut self.ranges) } /// Returns true if and only if this class is in a canonical ordering. @@ -381,18 +436,6 @@ impl IntervalSet { } } -/// An iterator over intervals. -#[derive(Debug)] -pub struct IntervalSetIter<'a, I>(slice::Iter<'a, I>); - -impl<'a, I> Iterator for IntervalSetIter<'a, I> { - type Item = &'a I; - - fn next(&mut self) -> Option<&'a I> { - self.0.next() - } -} - pub trait Interval: Clone + Copy + Debug + Default + Eq + PartialEq + PartialOrd + Ord { @@ -420,16 +463,21 @@ pub trait Interval: int } - /// Union the given overlapping range into this range. + /// Union the given overlapping range into this range, assuming that + /// self.begin <= right.begin. Useful for performing a series of unions + /// on a sorted list of intervals. /// - /// If the two ranges aren't contiguous, then this returns `None`. - fn union(&self, other: &Self) -> Option { - if !self.is_contiguous(other) { - return None; + /// If the ranges aren't contiguous, this returns `None`. + /// Returns unspecified garbage if self.begin > right.begin. + fn union_right(&self, right: &Self) -> Option { + if self.upper().as_u32().saturating_add(1) < right.lower().as_u32() { + None + } else { + Some(Self::create( + self.lower(), + cmp::max(self.upper(), right.upper()), + )) } - let lower = cmp::min(self.lower(), other.lower()); - let upper = cmp::max(self.upper(), other.upper()); - Some(Self::create(lower, upper)) } /// Intersect this range with the given range and return the result. @@ -510,55 +558,210 @@ pub trait Interval: pub trait Bound: Copy + Clone + Debug + Eq + PartialEq + PartialOrd + Ord { - fn min_value() -> Self; - fn max_value() -> Self; + const MIN: Self; + const MAX: Self; + fn as_u32(self) -> u32; - fn increment(self) -> Self; - fn decrement(self) -> Self; -} + fn try_increment(self) -> Option; + fn try_decrement(self) -> Option; -impl Bound for u8 { - fn min_value() -> Self { - u8::MIN + fn increment(self) -> Self { + self.try_increment().unwrap() } - fn max_value() -> Self { - u8::MAX + + fn decrement(self) -> Self { + self.try_decrement().unwrap() } +} + +impl Bound for u8 { + const MIN: Self = u8::MIN; + const MAX: Self = u8::MAX; + fn as_u32(self) -> u32 { u32::from(self) } - fn increment(self) -> Self { - self.checked_add(1).unwrap() + + fn try_increment(self) -> Option { + self.checked_add(1) } - fn decrement(self) -> Self { - self.checked_sub(1).unwrap() + + fn try_decrement(self) -> Option { + self.checked_sub(1) } } impl Bound for char { - fn min_value() -> Self { - '\x00' - } - fn max_value() -> Self { - '\u{10FFFF}' - } + const MIN: Self = '\x00'; + const MAX: Self = '\u{10FFFF}'; + fn as_u32(self) -> u32 { u32::from(self) } - fn increment(self) -> Self { + fn try_increment(self) -> Option { match self { - '\u{D7FF}' => '\u{E000}', - c => char::from_u32(u32::from(c).checked_add(1).unwrap()).unwrap(), + '\u{D7FF}' => Some('\u{E000}'), + c => char::from_u32(u32::from(c).checked_add(1)?), } } - fn decrement(self) -> Self { + fn try_decrement(self) -> Option { match self { - '\u{E000}' => '\u{D7FF}', - c => char::from_u32(u32::from(c).checked_sub(1).unwrap()).unwrap(), + '\u{E000}' => Some('\u{D7FF}'), + c => char::from_u32(u32::from(c).checked_sub(1)?), + } + } +} + +/// Iterator that, given a pair of sorted iterators, merges their items into +/// a single sorted sequence +struct MergeIter { + left: I, + right: I, + + state: MergeIterState, +} + +impl MergeIter +where + I: Iterator, + I::Item: Ord, +{ + pub fn new(mut left: I, right: I) -> Self { + let state = match left.next() { + Some(item) => MergeIterState::LeftItem(item), + None => MergeIterState::LeftExhausted, + }; + + Self { left, right, state } + } +} + +enum MergeIterState { + LeftExhausted, + RightExhausted, + LeftItem(T), + RightItem(T), +} + +impl MergeIterState { + // Get the current state, and if it's an item, replace the state with + // the appropriate exhaustion state for the other side. + fn step(&mut self) -> Self { + use MergeIterState::*; + + match *self { + LeftItem(_) => mem::replace(self, RightExhausted), + RightItem(_) => mem::replace(self, LeftExhausted), + LeftExhausted => LeftExhausted, + RightExhausted => RightExhausted, + } + } +} + +impl Iterator for MergeIter +where + I: Iterator, + I::Item: Ord, +{ + type Item = I::Item; + + fn next(&mut self) -> Option { + use MergeIterState::*; + + let (left, right) = match self.state.step() { + LeftExhausted => return self.right.next(), + RightExhausted => return self.left.next(), + + LeftItem(left) => match self.right.next() { + Some(right) => (left, right), + None => return Some(left), + }, + RightItem(right) => match self.left.next() { + Some(left) => (left, right), + None => return Some(right), + }, + }; + + let (item, state) = match left <= right { + true => (left, RightItem(right)), + false => (right, LeftItem(left)), + }; + + self.state = state; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + use MergeIterState::*; + + match self.state { + LeftExhausted => return self.right.size_hint(), + RightExhausted => return self.left.size_hint(), + LeftItem(_) | RightItem(_) => {} + } + + // Fundamentally this is a spicy concatenation, so add the sizes together + let (min1, max1) = self.left.size_hint(); + let (min2, max2) = self.right.size_hint(); + + let min = min1.saturating_add(min2).saturating_add(1); + let max = max1 + .and_then(|max| max.checked_add(max2?)) + .and_then(|max| max.checked_add(1)); + + (min, max) + } +} + +/// Given a pair of sorted lists, merge them into `dest` so that `dest` +/// remains sorted +fn merge_sorted_into( + dest: &mut Vec, + others: impl DoubleEndedIterator + ExactSizeIterator, +) { + let mut dest_len = dest.len(); + let mut insert_idx = dest.len() + others.len(); + + dest.resize_with(dest.len() + others.len(), Default::default); + + others.rev().for_each(|new_item| { + // First, shift all the items that are ``> new_item`` rightward + // in the vec + for item_idx in (0..dest_len).rev() { + if dest[item_idx] > new_item { + dest_len -= 1; + insert_idx -= 1; + dest.swap(item_idx, insert_idx); + } else { + break; + } + } + + // Then insert this item + insert_idx -= 1; + dest[insert_idx] = new_item; + }); +} + +// Given a sorted list of intervals, union them together into a canonical form. +fn union_sorted(ranges: &mut Vec) { + // `merge_idx` is the range into which we're merging contiguous ranges. + let mut merge_idx = 0; + + for i in 1..ranges.len() { + if let Some(union) = ranges[merge_idx].union_right(&ranges[i]) { + ranges[merge_idx] = union; + } else { + merge_idx += 1; + ranges[merge_idx] = ranges[i]; } } + + // At this point, `merge_idx` is the index of the last range that was + // merged into, so we truncate. + ranges.truncate(merge_idx + 1); } // Tests for interval sets are written in src/hir.rs against the public API. diff --git a/regex-syntax/src/hir/mod.rs b/regex-syntax/src/hir/mod.rs index ae3ba318e..cb27ef067 100644 --- a/regex-syntax/src/hir/mod.rs +++ b/regex-syntax/src/hir/mod.rs @@ -17,7 +17,7 @@ equivalent regex pattern string, it is unlikely to look like the original due to its simplified structure. */ -use core::{char, cmp}; +use core::{char, cmp, slice}; use alloc::{ boxed::Box, @@ -29,7 +29,7 @@ use alloc::{ use crate::{ ast::Span, - hir::interval::{Interval, IntervalSet, IntervalSetIter}, + hir::interval::{Interval, IntervalSet}, unicode, }; @@ -1086,7 +1086,7 @@ impl ClassUnicode { /// /// The iterator yields ranges in ascending order. pub fn iter(&self) -> ClassUnicodeIter<'_> { - ClassUnicodeIter(self.set.iter()) + ClassUnicodeIter(self.set.intervals().iter()) } /// Return the underlying ranges as a slice. @@ -1227,7 +1227,7 @@ impl ClassUnicode { /// /// The lifetime `'a` refers to the lifetime of the underlying class. #[derive(Debug)] -pub struct ClassUnicodeIter<'a>(IntervalSetIter<'a, ClassUnicodeRange>); +pub struct ClassUnicodeIter<'a>(slice::Iter<'a, ClassUnicodeRange>); impl<'a> Iterator for ClassUnicodeIter<'a> { type Item = &'a ClassUnicodeRange; @@ -1235,6 +1235,10 @@ impl<'a> Iterator for ClassUnicodeIter<'a> { fn next(&mut self) -> Option<&'a ClassUnicodeRange> { self.0.next() } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } } /// A single range of characters represented by Unicode scalar values. @@ -1385,7 +1389,7 @@ impl ClassBytes { /// /// The iterator yields ranges in ascending order. pub fn iter(&self) -> ClassBytesIter<'_> { - ClassBytesIter(self.set.iter()) + ClassBytesIter(self.set.intervals().iter()) } /// Return the underlying ranges as a slice. @@ -1505,7 +1509,7 @@ impl ClassBytes { /// /// The lifetime `'a` refers to the lifetime of the underlying class. #[derive(Debug)] -pub struct ClassBytesIter<'a>(IntervalSetIter<'a, ClassBytesRange>); +pub struct ClassBytesIter<'a>(slice::Iter<'a, ClassBytesRange>); impl<'a> Iterator for ClassBytesIter<'a> { type Item = &'a ClassBytesRange; @@ -1513,6 +1517,10 @@ impl<'a> Iterator for ClassBytesIter<'a> { fn next(&mut self) -> Option<&'a ClassBytesRange> { self.0.next() } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } } /// A single range of characters represented by arbitrary bytes.