Skip to content

Commit a2b2010

Browse files
committed
- convert from core::arch to core::simd
- bump simd compare to 32bytes - import small slice compare code from memmem crate - try a few different probe bytes to avoid degenerate cases - but special-case 2-byte needles
1 parent c37e8fa commit a2b2010

File tree

1 file changed

+182
-52
lines changed

1 file changed

+182
-52
lines changed

library/core/src/str/pattern.rs

+182-52
Original file line numberDiff line numberDiff line change
@@ -956,15 +956,20 @@ impl<'a, 'b> Pattern<'a> for &'b str {
956956

957957
match self.len().cmp(&haystack.len()) {
958958
Ordering::Less => {
959+
if self.len() == 1 {
960+
return haystack.as_bytes().contains(&self.as_bytes()[0]);
961+
}
962+
959963
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
960-
if self.as_bytes().len() <= 8 {
961-
return simd_contains(self, haystack);
964+
if self.len() <= 32 {
965+
if let Some(result) = simd_contains(self, haystack) {
966+
return result;
967+
}
962968
}
963969

964970
self.into_searcher(haystack).next_match().is_some()
965971
}
966-
Ordering::Equal => self == haystack,
967-
Ordering::Greater => false,
972+
_ => self == haystack,
968973
}
969974
}
970975

@@ -1707,82 +1712,207 @@ impl TwoWayStrategy for RejectAndMatch {
17071712
}
17081713
}
17091714

1715+
/// SIMD search for short needles based on
1716+
/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
1717+
///
1718+
/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
1719+
/// does) by probing the first and last byte of the needle for the whole vector width
1720+
/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
1721+
///
1722+
/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
1723+
/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
1724+
/// should be evaluated.
1725+
///
1726+
/// For haystacks smaller than vector-size + needle length it falls back to
1727+
/// a naive O(n*m) search so this implementation should not be called on larger needles.
1728+
///
1729+
/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
17101730
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
17111731
#[inline]
1712-
fn simd_contains(needle: &str, haystack: &str) -> bool {
1732+
fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
17131733
let needle = needle.as_bytes();
17141734
let haystack = haystack.as_bytes();
17151735

1716-
if needle.len() == 1 {
1717-
return haystack.contains(&needle[0]);
1718-
}
1719-
1720-
const CHUNK: usize = 16;
1736+
debug_assert!(needle.len() > 1);
1737+
1738+
use crate::ops::BitAnd;
1739+
use crate::simd::mask8x16 as Mask;
1740+
use crate::simd::u8x16 as Block;
1741+
use crate::simd::{SimdPartialEq, ToBitMask};
1742+
1743+
let first_probe = needle[0];
1744+
1745+
// the offset used for the 2nd vector
1746+
let second_probe_offset = if needle.len() == 2 {
1747+
// never bail out on len=2 needles because the probes will fully cover them and have
1748+
// no degenerate cases.
1749+
1
1750+
} else {
1751+
// try a few bytes in case first and last byte of the needle are the same
1752+
let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
1753+
// fall back to other search methods if we can't find any different bytes
1754+
// since we could otherwise hit some degenerate cases
1755+
return None;
1756+
};
1757+
second_probe_offset
1758+
};
17211759

1722-
// do a naive search if if the haystack is too small to fit
1723-
if haystack.len() < CHUNK + needle.len() - 1 {
1724-
return haystack.windows(needle.len()).any(|c| c == needle);
1760+
// do a naive search if the haystack is too small to fit
1761+
if haystack.len() < Block::LANES + second_probe_offset {
1762+
return Some(haystack.windows(needle.len()).any(|c| c == needle));
17251763
}
17261764

1727-
use crate::arch::x86_64::{
1728-
__m128i, _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
1729-
};
1730-
1731-
// SAFETY: no preconditions other than sse2 being available
1732-
let first: __m128i = unsafe { _mm_set1_epi8(needle[0] as i8) };
1733-
// SAFETY: no preconditions other than sse2 being available
1734-
let last: __m128i = unsafe { _mm_set1_epi8(*needle.last().unwrap() as i8) };
1765+
let first_probe: Block = Block::splat(first_probe);
1766+
let second_probe: Block = Block::splat(needle[second_probe_offset]);
1767+
// first byte are already checked by the outer loop. to verify a match only the
1768+
// remainder has to be compared.
1769+
let trimmed_needle = &needle[1..];
17351770

1771+
// this #[cold] is load-bearing, benchmark before removing it...
17361772
let check_mask = #[cold]
1737-
|idx, mut mask: u32| -> bool {
1773+
|idx, mask: u16, skip: bool| -> bool {
1774+
if skip {
1775+
return false;
1776+
}
1777+
1778+
// and so is this. optimizations are weird.
1779+
let mut mask = mask;
1780+
17381781
while mask != 0 {
17391782
let trailing = mask.trailing_zeros();
17401783
let offset = idx + trailing as usize + 1;
1741-
let sub = &haystack[offset..][..needle.len() - 2];
1742-
let trimmed_needle = &needle[1..needle.len() - 1];
1743-
1744-
if sub == trimmed_needle {
1745-
return true;
1784+
// SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
1785+
// and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
1786+
unsafe {
1787+
let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
1788+
if small_slice_eq(sub, trimmed_needle) {
1789+
return true;
1790+
}
17461791
}
17471792
mask &= !(1 << trailing);
17481793
}
17491794
return false;
17501795
};
17511796

1752-
let test_chunk = |i| -> bool {
1753-
// SAFETY: this requires at least CHUNK bytes being readable at offset i
1797+
let test_chunk = |idx| -> u16 {
1798+
// SAFETY: this requires at least LANES bytes being readable at idx
17541799
// that is ensured by the loop ranges (see comments below)
1755-
let a: __m128i = unsafe { _mm_loadu_si128(haystack.as_ptr().add(i) as *const _) };
1756-
let b: __m128i =
1757-
// SAFETY: this requires CHUNK + needle.len() - 1 bytes being readable at offset i
1758-
unsafe { _mm_loadu_si128(haystack.as_ptr().add(i + needle.len() - 1) as *const _) };
1759-
1760-
// SAFETY: no preconditions other than sse2 being available
1761-
let eq_first: __m128i = unsafe { _mm_cmpeq_epi8(first, a) };
1762-
// SAFETY: no preconditions other than sse2 being available
1763-
let eq_last: __m128i = unsafe { _mm_cmpeq_epi8(last, b) };
1764-
1765-
// SAFETY: no preconditions other than sse2 being available
1766-
let mask: u32 = unsafe { _mm_movemask_epi8(_mm_and_si128(eq_first, eq_last)) } as u32;
1800+
let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
1801+
// SAFETY: this requires LANES + block_offset bytes being readable at idx
1802+
let b: Block = unsafe {
1803+
haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
1804+
};
1805+
let eq_first: Mask = a.simd_eq(first_probe);
1806+
let eq_last: Mask = b.simd_eq(second_probe);
1807+
let both = eq_first.bitand(eq_last);
1808+
let mask = both.to_bitmask();
17671809

1768-
if mask != 0 {
1769-
return check_mask(i, mask);
1770-
}
1771-
return false;
1810+
return mask;
17721811
};
17731812

17741813
let mut i = 0;
17751814
let mut result = false;
1776-
while !result && i + CHUNK + needle.len() <= haystack.len() {
1777-
result |= test_chunk(i);
1778-
i += CHUNK;
1815+
// The loop condition must ensure that there's enough headroom to read LANE bytes,
1816+
// and not only at the current index but also at the index shifted by block_offset
1817+
const UNROLL: usize = 4;
1818+
while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
1819+
let mut masks = [0u16; UNROLL];
1820+
for j in 0..UNROLL {
1821+
masks[j] = test_chunk(i + j * Block::LANES);
1822+
}
1823+
for j in 0..UNROLL {
1824+
let mask = masks[j];
1825+
if mask != 0 {
1826+
result |= check_mask(i + j * Block::LANES, mask, result);
1827+
}
1828+
}
1829+
i += UNROLL * Block::LANES;
1830+
}
1831+
while i + second_probe_offset + Block::LANES < haystack.len() && !result {
1832+
let mask = test_chunk(i);
1833+
if mask != 0 {
1834+
result |= check_mask(i, mask, result);
1835+
}
1836+
i += Block::LANES;
17791837
}
17801838

1781-
// process the tail that didn't fit into CHUNK-sized steps
1782-
// this simply repeats the same procedure but as right-aligned chunk instead
1839+
// Process the tail that didn't fit into LANES-sized steps.
1840+
// This simply repeats the same procedure but as right-aligned chunk instead
17831841
// of a left-aligned one. The last byte must be exactly flush with the string end so
17841842
// we don't miss a single byte or read out of bounds.
1785-
result |= test_chunk(haystack.len() + 1 - needle.len() - CHUNK);
1843+
let i = haystack.len() - second_probe_offset - Block::LANES;
1844+
let mask = test_chunk(i);
1845+
if mask != 0 {
1846+
result |= check_mask(i, mask, result);
1847+
}
1848+
1849+
Some(result)
1850+
}
1851+
1852+
/// Compares short slices for equality.
1853+
///
1854+
/// It avoids a call to libc's memcmp which is faster on long slices
1855+
/// due to SIMD optimizations but it incurs a function call overhead.
1856+
///
1857+
/// # Safety
1858+
///
1859+
/// Both slices must have the same length.
1860+
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
1861+
#[inline]
1862+
unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
1863+
// This function is adapted from
1864+
// https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
17861865

1787-
return result;
1866+
// If we don't have enough bytes to do 4-byte at a time loads, then
1867+
// fall back to the naive slow version.
1868+
//
1869+
// Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
1870+
// of a loop. Benchmark it.
1871+
if x.len() < 4 {
1872+
for (&b1, &b2) in x.iter().zip(y) {
1873+
if b1 != b2 {
1874+
return false;
1875+
}
1876+
}
1877+
return true;
1878+
}
1879+
// When we have 4 or more bytes to compare, then proceed in chunks of 4 at
1880+
// a time using unaligned loads.
1881+
//
1882+
// Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
1883+
// that this particular version of memcmp is likely to be called with tiny
1884+
// needles. That means that if we do 8 byte loads, then a higher proportion
1885+
// of memcmp calls will use the slower variant above. With that said, this
1886+
// is a hypothesis and is only loosely supported by benchmarks. There's
1887+
// likely some improvement that could be made here. The main thing here
1888+
// though is to optimize for latency, not throughput.
1889+
1890+
// SAFETY: Via the conditional above, we know that both `px` and `py`
1891+
// have the same length, so `px < pxend` implies that `py < pyend`.
1892+
// Thus, derefencing both `px` and `py` in the loop below is safe.
1893+
//
1894+
// Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
1895+
// end of of `px` and `py`. Thus, the final dereference outside of the
1896+
// loop is guaranteed to be valid. (The final comparison will overlap with
1897+
// the last comparison done in the loop for lengths that aren't multiples
1898+
// of four.)
1899+
//
1900+
// Finally, we needn't worry about alignment here, since we do unaligned
1901+
// loads.
1902+
unsafe {
1903+
let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
1904+
let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
1905+
while px < pxend {
1906+
let vx = (px as *const u32).read_unaligned();
1907+
let vy = (py as *const u32).read_unaligned();
1908+
if vx != vy {
1909+
return false;
1910+
}
1911+
px = px.add(4);
1912+
py = py.add(4);
1913+
}
1914+
let vx = (pxend as *const u32).read_unaligned();
1915+
let vy = (pyend as *const u32).read_unaligned();
1916+
vx == vy
1917+
}
17881918
}

0 commit comments

Comments
 (0)