@@ -956,15 +956,20 @@ impl<'a, 'b> Pattern<'a> for &'b str {
956
956
957
957
match self . len ( ) . cmp ( & haystack. len ( ) ) {
958
958
Ordering :: Less => {
959
+ if self . len ( ) == 1 {
960
+ return haystack. as_bytes ( ) . contains ( & self . as_bytes ( ) [ 0 ] ) ;
961
+ }
962
+
959
963
#[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ]
960
- if self . as_bytes ( ) . len ( ) <= 8 {
961
- return simd_contains ( self , haystack) ;
964
+ if self . len ( ) <= 32 {
965
+ if let Some ( result) = simd_contains ( self , haystack) {
966
+ return result;
967
+ }
962
968
}
963
969
964
970
self . into_searcher ( haystack) . next_match ( ) . is_some ( )
965
971
}
966
- Ordering :: Equal => self == haystack,
967
- Ordering :: Greater => false ,
972
+ _ => self == haystack,
968
973
}
969
974
}
970
975
@@ -1707,82 +1712,207 @@ impl TwoWayStrategy for RejectAndMatch {
1707
1712
}
1708
1713
}
1709
1714
1715
+ /// SIMD search for short needles based on
1716
+ /// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
1717
+ ///
1718
+ /// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
1719
+ /// does) by probing the first and last byte of the needle for the whole vector width
1720
+ /// and only doing full needle comparisons when the vectorized probe indicated potential matches.
1721
+ ///
1722
+ /// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
1723
+ /// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
1724
+ /// should be evaluated.
1725
+ ///
1726
+ /// For haystacks smaller than vector-size + needle length it falls back to
1727
+ /// a naive O(n*m) search so this implementation should not be called on larger needles.
1728
+ ///
1729
+ /// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
1710
1730
#[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ]
1711
1731
#[ inline]
1712
- fn simd_contains ( needle : & str , haystack : & str ) -> bool {
1732
+ fn simd_contains ( needle : & str , haystack : & str ) -> Option < bool > {
1713
1733
let needle = needle. as_bytes ( ) ;
1714
1734
let haystack = haystack. as_bytes ( ) ;
1715
1735
1716
- if needle. len ( ) == 1 {
1717
- return haystack. contains ( & needle[ 0 ] ) ;
1718
- }
1719
-
1720
- const CHUNK : usize = 16 ;
1736
+ debug_assert ! ( needle. len( ) > 1 ) ;
1737
+
1738
+ use crate :: ops:: BitAnd ;
1739
+ use crate :: simd:: mask8x16 as Mask ;
1740
+ use crate :: simd:: u8x16 as Block ;
1741
+ use crate :: simd:: { SimdPartialEq , ToBitMask } ;
1742
+
1743
+ let first_probe = needle[ 0 ] ;
1744
+
1745
+ // the offset used for the 2nd vector
1746
+ let second_probe_offset = if needle. len ( ) == 2 {
1747
+ // never bail out on len=2 needles because the probes will fully cover them and have
1748
+ // no degenerate cases.
1749
+ 1
1750
+ } else {
1751
+ // try a few bytes in case first and last byte of the needle are the same
1752
+ let Some ( second_probe_offset) = ( needle. len ( ) . saturating_sub ( 4 ) ..needle. len ( ) ) . rfind ( |& idx| needle[ idx] != first_probe) else {
1753
+ // fall back to other search methods if we can't find any different bytes
1754
+ // since we could otherwise hit some degenerate cases
1755
+ return None ;
1756
+ } ;
1757
+ second_probe_offset
1758
+ } ;
1721
1759
1722
- // do a naive search if if the haystack is too small to fit
1723
- if haystack. len ( ) < CHUNK + needle . len ( ) - 1 {
1724
- return haystack. windows ( needle. len ( ) ) . any ( |c| c == needle) ;
1760
+ // do a naive search if the haystack is too small to fit
1761
+ if haystack. len ( ) < Block :: LANES + second_probe_offset {
1762
+ return Some ( haystack. windows ( needle. len ( ) ) . any ( |c| c == needle) ) ;
1725
1763
}
1726
1764
1727
- use crate :: arch:: x86_64:: {
1728
- __m128i, _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
1729
- } ;
1730
-
1731
- // SAFETY: no preconditions other than sse2 being available
1732
- let first: __m128i = unsafe { _mm_set1_epi8 ( needle[ 0 ] as i8 ) } ;
1733
- // SAFETY: no preconditions other than sse2 being available
1734
- let last: __m128i = unsafe { _mm_set1_epi8 ( * needle. last ( ) . unwrap ( ) as i8 ) } ;
1765
+ let first_probe: Block = Block :: splat ( first_probe) ;
1766
+ let second_probe: Block = Block :: splat ( needle[ second_probe_offset] ) ;
1767
+ // first byte are already checked by the outer loop. to verify a match only the
1768
+ // remainder has to be compared.
1769
+ let trimmed_needle = & needle[ 1 ..] ;
1735
1770
1771
+ // this #[cold] is load-bearing, benchmark before removing it...
1736
1772
let check_mask = #[ cold]
1737
- |idx, mut mask: u32| -> bool {
1773
+ |idx, mask: u16, skip: bool| -> bool {
1774
+ if skip {
1775
+ return false ;
1776
+ }
1777
+
1778
+ // and so is this. optimizations are weird.
1779
+ let mut mask = mask;
1780
+
1738
1781
while mask != 0 {
1739
1782
let trailing = mask. trailing_zeros ( ) ;
1740
1783
let offset = idx + trailing as usize + 1 ;
1741
- let sub = & haystack[ offset..] [ ..needle. len ( ) - 2 ] ;
1742
- let trimmed_needle = & needle[ 1 ..needle. len ( ) - 1 ] ;
1743
-
1744
- if sub == trimmed_needle {
1745
- return true ;
1784
+ // SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
1785
+ // and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
1786
+ unsafe {
1787
+ let sub = haystack. get_unchecked ( offset..) . get_unchecked ( ..trimmed_needle. len ( ) ) ;
1788
+ if small_slice_eq ( sub, trimmed_needle) {
1789
+ return true ;
1790
+ }
1746
1791
}
1747
1792
mask &= !( 1 << trailing) ;
1748
1793
}
1749
1794
return false ;
1750
1795
} ;
1751
1796
1752
- let test_chunk = |i | -> bool {
1753
- // SAFETY: this requires at least CHUNK bytes being readable at offset i
1797
+ let test_chunk = |idx | -> u16 {
1798
+ // SAFETY: this requires at least LANES bytes being readable at idx
1754
1799
// that is ensured by the loop ranges (see comments below)
1755
- let a: __m128i = unsafe { _mm_loadu_si128 ( haystack. as_ptr ( ) . add ( i) as * const _ ) } ;
1756
- let b: __m128i =
1757
- // SAFETY: this requires CHUNK + needle.len() - 1 bytes being readable at offset i
1758
- unsafe { _mm_loadu_si128 ( haystack. as_ptr ( ) . add ( i + needle. len ( ) - 1 ) as * const _ ) } ;
1759
-
1760
- // SAFETY: no preconditions other than sse2 being available
1761
- let eq_first: __m128i = unsafe { _mm_cmpeq_epi8 ( first, a) } ;
1762
- // SAFETY: no preconditions other than sse2 being available
1763
- let eq_last: __m128i = unsafe { _mm_cmpeq_epi8 ( last, b) } ;
1764
-
1765
- // SAFETY: no preconditions other than sse2 being available
1766
- let mask: u32 = unsafe { _mm_movemask_epi8 ( _mm_and_si128 ( eq_first, eq_last) ) } as u32 ;
1800
+ let a: Block = unsafe { haystack. as_ptr ( ) . add ( idx) . cast :: < Block > ( ) . read_unaligned ( ) } ;
1801
+ // SAFETY: this requires LANES + block_offset bytes being readable at idx
1802
+ let b: Block = unsafe {
1803
+ haystack. as_ptr ( ) . add ( idx) . add ( second_probe_offset) . cast :: < Block > ( ) . read_unaligned ( )
1804
+ } ;
1805
+ let eq_first: Mask = a. simd_eq ( first_probe) ;
1806
+ let eq_last: Mask = b. simd_eq ( second_probe) ;
1807
+ let both = eq_first. bitand ( eq_last) ;
1808
+ let mask = both. to_bitmask ( ) ;
1767
1809
1768
- if mask != 0 {
1769
- return check_mask ( i, mask) ;
1770
- }
1771
- return false ;
1810
+ return mask;
1772
1811
} ;
1773
1812
1774
1813
let mut i = 0 ;
1775
1814
let mut result = false ;
1776
- while !result && i + CHUNK + needle. len ( ) <= haystack. len ( ) {
1777
- result |= test_chunk ( i) ;
1778
- i += CHUNK ;
1815
+ // The loop condition must ensure that there's enough headroom to read LANE bytes,
1816
+ // and not only at the current index but also at the index shifted by block_offset
1817
+ const UNROLL : usize = 4 ;
1818
+ while i + second_probe_offset + UNROLL * Block :: LANES < haystack. len ( ) && !result {
1819
+ let mut masks = [ 0u16 ; UNROLL ] ;
1820
+ for j in 0 ..UNROLL {
1821
+ masks[ j] = test_chunk ( i + j * Block :: LANES ) ;
1822
+ }
1823
+ for j in 0 ..UNROLL {
1824
+ let mask = masks[ j] ;
1825
+ if mask != 0 {
1826
+ result |= check_mask ( i + j * Block :: LANES , mask, result) ;
1827
+ }
1828
+ }
1829
+ i += UNROLL * Block :: LANES ;
1830
+ }
1831
+ while i + second_probe_offset + Block :: LANES < haystack. len ( ) && !result {
1832
+ let mask = test_chunk ( i) ;
1833
+ if mask != 0 {
1834
+ result |= check_mask ( i, mask, result) ;
1835
+ }
1836
+ i += Block :: LANES ;
1779
1837
}
1780
1838
1781
- // process the tail that didn't fit into CHUNK -sized steps
1782
- // this simply repeats the same procedure but as right-aligned chunk instead
1839
+ // Process the tail that didn't fit into LANES -sized steps.
1840
+ // This simply repeats the same procedure but as right-aligned chunk instead
1783
1841
// of a left-aligned one. The last byte must be exactly flush with the string end so
1784
1842
// we don't miss a single byte or read out of bounds.
1785
- result |= test_chunk ( haystack. len ( ) + 1 - needle. len ( ) - CHUNK ) ;
1843
+ let i = haystack. len ( ) - second_probe_offset - Block :: LANES ;
1844
+ let mask = test_chunk ( i) ;
1845
+ if mask != 0 {
1846
+ result |= check_mask ( i, mask, result) ;
1847
+ }
1848
+
1849
+ Some ( result)
1850
+ }
1851
+
1852
+ /// Compares short slices for equality.
1853
+ ///
1854
+ /// It avoids a call to libc's memcmp which is faster on long slices
1855
+ /// due to SIMD optimizations but it incurs a function call overhead.
1856
+ ///
1857
+ /// # Safety
1858
+ ///
1859
+ /// Both slices must have the same length.
1860
+ #[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ] // only called on x86
1861
+ #[ inline]
1862
+ unsafe fn small_slice_eq ( x : & [ u8 ] , y : & [ u8 ] ) -> bool {
1863
+ // This function is adapted from
1864
+ // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
1786
1865
1787
- return result;
1866
+ // If we don't have enough bytes to do 4-byte at a time loads, then
1867
+ // fall back to the naive slow version.
1868
+ //
1869
+ // Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
1870
+ // of a loop. Benchmark it.
1871
+ if x. len ( ) < 4 {
1872
+ for ( & b1, & b2) in x. iter ( ) . zip ( y) {
1873
+ if b1 != b2 {
1874
+ return false ;
1875
+ }
1876
+ }
1877
+ return true ;
1878
+ }
1879
+ // When we have 4 or more bytes to compare, then proceed in chunks of 4 at
1880
+ // a time using unaligned loads.
1881
+ //
1882
+ // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
1883
+ // that this particular version of memcmp is likely to be called with tiny
1884
+ // needles. That means that if we do 8 byte loads, then a higher proportion
1885
+ // of memcmp calls will use the slower variant above. With that said, this
1886
+ // is a hypothesis and is only loosely supported by benchmarks. There's
1887
+ // likely some improvement that could be made here. The main thing here
1888
+ // though is to optimize for latency, not throughput.
1889
+
1890
+ // SAFETY: Via the conditional above, we know that both `px` and `py`
1891
+ // have the same length, so `px < pxend` implies that `py < pyend`.
1892
+ // Thus, derefencing both `px` and `py` in the loop below is safe.
1893
+ //
1894
+ // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
1895
+ // end of of `px` and `py`. Thus, the final dereference outside of the
1896
+ // loop is guaranteed to be valid. (The final comparison will overlap with
1897
+ // the last comparison done in the loop for lengths that aren't multiples
1898
+ // of four.)
1899
+ //
1900
+ // Finally, we needn't worry about alignment here, since we do unaligned
1901
+ // loads.
1902
+ unsafe {
1903
+ let ( mut px, mut py) = ( x. as_ptr ( ) , y. as_ptr ( ) ) ;
1904
+ let ( pxend, pyend) = ( px. add ( x. len ( ) - 4 ) , py. add ( y. len ( ) - 4 ) ) ;
1905
+ while px < pxend {
1906
+ let vx = ( px as * const u32 ) . read_unaligned ( ) ;
1907
+ let vy = ( py as * const u32 ) . read_unaligned ( ) ;
1908
+ if vx != vy {
1909
+ return false ;
1910
+ }
1911
+ px = px. add ( 4 ) ;
1912
+ py = py. add ( 4 ) ;
1913
+ }
1914
+ let vx = ( pxend as * const u32 ) . read_unaligned ( ) ;
1915
+ let vy = ( pyend as * const u32 ) . read_unaligned ( ) ;
1916
+ vx == vy
1917
+ }
1788
1918
}
0 commit comments