diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index c5912a7a..49808824 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -78,6 +78,7 @@ module Data.HashMap.Internal , intersection , intersectionWith , intersectionWithKey + , intersectionWithKey# -- * Folds , foldr' @@ -150,9 +151,9 @@ import Data.Data (Constr, Data (..), DataType) import Data.Functor.Classes (Eq1 (..), Eq2 (..), Ord1 (..), Ord2 (..), Read1 (..), Show1 (..), Show2 (..)) import Data.Functor.Identity (Identity (..)) -import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) import Data.Hashable (Hashable) import Data.Hashable.Lifted (Hashable1, Hashable2) +import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid) import GHC.Exts (Int (..), Int#, TYPE, (==#)) import GHC.Stack (HasCallStack) @@ -163,9 +164,9 @@ import Text.Read hiding (step) import qualified Data.Data as Data import qualified Data.Foldable as Foldable import qualified Data.Functor.Classes as FC -import qualified Data.HashMap.Internal.Array as A import qualified Data.Hashable as H import qualified Data.Hashable.Lifted as H +import qualified Data.HashMap.Internal.Array as A import qualified Data.List as List import qualified GHC.Exts as Exts import qualified Language.Haskell.TH.Syntax as TH @@ -1627,7 +1628,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do A.write mary i =<< A.indexM ary2 i2 go (i+1) i1 (i2+1) b' where - m = 1 `unsafeShiftL` (countTrailingZeros b) + m = 1 `unsafeShiftL` countTrailingZeros b testBit x = x .&. m /= 0 b' = b .&. complement m go 0 0 0 bCombined @@ -1759,37 +1760,161 @@ differenceWith f a b = foldlWithKey' go empty a -- | \(O(n \log m)\) Intersection of two maps. Return elements of the first -- map for keys existing in the second. intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v -intersection a b = foldlWithKey' go empty a - where - go m k v = case lookup k b of - Just _ -> unsafeInsert k v m - _ -> m +intersection = Exts.inline intersectionWith const {-# INLINABLE intersection #-} -- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. -intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 - -> HashMap k v2 -> HashMap k v3 -intersectionWith f a b = foldlWithKey' go empty a - where - go m k v = case lookup k b of - Just w -> unsafeInsert k (f v w) m - _ -> m +intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWith f = Exts.inline intersectionWithKey $ const f {-# INLINABLE intersectionWith #-} -- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. -intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) - -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWithKey f a b = foldlWithKey' go empty a - where - go m k v = case lookup k b of - Just w -> unsafeInsert k (f k v w) m - _ -> m +intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWithKey f = intersectionWithKey# $ \k v1 v2 -> (# f k v1 v2 #) {-# INLINABLE intersectionWithKey #-} +intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWithKey# f = go 0 + where + -- empty vs. anything + go !_ _ Empty = Empty + go _ Empty _ = Empty + -- leaf vs. anything + go s (Leaf h1 (L k1 v1)) t2 = + lookupCont + (\_ -> Empty) + (\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v') + h1 k1 s t2 + go s t1 (Leaf h2 (L k2 v2)) = + lookupCont + (\_ -> Empty) + (\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v') + h2 k2 s t1 + -- collision vs. collision + go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2 + -- branch vs. branch + go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 + go s (BitmapIndexed b1 ary1) (Full ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2 + go s (Full ary1) (BitmapIndexed b2 ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2 + go s (Full ary1) (Full ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2 + -- collision vs. branch + go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2) + | b1 .&. m2 == 0 = Empty + | otherwise = go (s + bitsPerSubkey) (A.index ary1 i) t2 + where + m2 = mask h2 s + i = sparseIndex b1 m2 + go s t1@(Collision h1 _ls1) (BitmapIndexed b2 ary2) + | b2 .&. m1 == 0 = Empty + | otherwise = go (s + bitsPerSubkey) t1 (A.index ary2 i) + where + m1 = mask h1 s + i = sparseIndex b2 m1 + go s (Full ary1) t2@(Collision h2 _ls2) = go (s + bitsPerSubkey) (A.index ary1 i) t2 + where + i = index h2 s + go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i) + where + i = index h1 s +{-# INLINE intersectionWithKey# #-} + +intersectionArrayBy :: + ( HashMap k v1 -> + HashMap k v2 -> + HashMap k v3 + ) -> + Bitmap -> + Bitmap -> + A.Array (HashMap k v1) -> + A.Array (HashMap k v2) -> + HashMap k v3 +intersectionArrayBy f !b1 !b2 !ary1 !ary2 + | b1 .&. b2 == 0 = Empty + | otherwise = runST $ do + mary <- A.new_ $ popCount bIntersect + -- iterate over nonzero bits of b1 .|. b2 + let go !i !i1 !i2 !b !bFinal + | b == 0 = pure (i, bFinal) + | testBit $ b1 .&. b2 = do + x1 <- A.indexM ary1 i1 + x2 <- A.indexM ary2 i2 + case f x1 x2 of + Empty -> go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m) + _ -> do + A.write mary i $! f x1 x2 + go (i + 1) (i1 + 1) (i2 + 1) b' bFinal + | testBit b1 = go i (i1 + 1) i2 b' bFinal + | otherwise = go i i1 (i2 + 1) b' bFinal + where + m = 1 `unsafeShiftL` countTrailingZeros b + testBit x = x .&. m /= 0 + b' = b .&. complement m + (len, bFinal) <- go 0 0 0 bCombined bIntersect + case len of + 0 -> pure Empty + 1 -> A.read mary 0 + _ -> bitmapIndexedOrFull bFinal <$> (A.unsafeFreeze =<< A.shrink mary len) + where + bCombined = b1 .|. b2 + bIntersect = b1 .&. b2 +{-# INLINE intersectionArrayBy #-} + +intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> Hash -> Hash -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> HashMap k v3 +intersectionCollisions f h1 h2 ary1 ary2 + | h1 == h2 = runST $ do + mary2 <- A.thaw ary2 0 $ A.length ary2 + mary <- A.new_ $ min (A.length ary1) (A.length ary2) + let go i j + | i >= A.length ary1 || j >= A.lengthM mary2 = pure j + | otherwise = do + L k1 v1 <- A.indexM ary1 i + searchSwap k1 j mary2 >>= \case + Just (L _k2 v2) -> do + let !(# v3 #) = f k1 v1 v2 + A.write mary j $ L k1 v3 + go (i + 1) (j + 1) + Nothing -> do + go (i + 1) j + len <- go 0 0 + case len of + 0 -> pure Empty + 1 -> Leaf h1 <$> A.read mary 0 + _ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) + | otherwise = Empty +{-# INLINE intersectionCollisions #-} + +-- | Say we have +-- @ +-- 1 2 3 4 +-- @ +-- and we search for @3@. Then we can mutate the array to +-- @ +-- undefined 2 1 4 +-- @ +-- We don't actually need to write undefined, we just have to make sure that the next search starts 1 after the current one. +searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) +searchSwap toFind start = go start toFind start + where + go i0 k i mary + | i >= A.lengthM mary = pure Nothing + | otherwise = do + l@(L k' _v) <- A.read mary i + if k == k' + then do + A.write mary i =<< A.read mary i0 + pure $ Just l + else go i0 k (i + 1) mary +{-# INLINE searchSwap #-} + + ------------------------------------------------------------------------ -- * Folds diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index 83a927a6..f7696a26 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -77,6 +77,7 @@ module Data.HashMap.Internal.Array , toList , fromList , fromList' + , shrink ) where import Control.Applicative (liftA2) @@ -92,6 +93,7 @@ import GHC.Exts (Int (..), SmallArray#, SmallMutableArray#, thawSmallArray#, unsafeCoerce#, unsafeFreezeSmallArray#, unsafeThawSmallArray#, writeSmallArray#) +import qualified GHC.Exts as Exts import GHC.ST (ST (..)) import Prelude hiding (all, filter, foldMap, foldl, foldr, length, map, read, traverse) @@ -205,6 +207,20 @@ new _n@(I# n#) b = new_ :: Int -> ST s (MArray s a) new_ n = new n undefinedElem +-- | When 'Exts.shrinkSmallMutableArray#' is available, the returned array is the same as the array given, as it is shrunk in place. +-- Otherwise a copy is made. +shrink :: MArray s a -> Int -> ST s (MArray s a) +#if __GLASGOW_HASKELL__ >= 810 +shrink mary _n@(I# n#) = + CHECK_GT("shrink", _n, (0 :: Int)) + CHECK_LE("shrink", _n, (lengthM mary)) + ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of + s' -> (# s', mary #) +#else +shrink mary n = cloneM mary 0 n +#endif +{-# INLINE shrink #-} + singleton :: a -> Array a singleton x = runST (singletonM x) {-# INLINE singleton #-} diff --git a/Data/HashMap/Internal/Strict.hs b/Data/HashMap/Internal/Strict.hs index b25266a4..2d8fb374 100644 --- a/Data/HashMap/Internal/Strict.hs +++ b/Data/HashMap/Internal/Strict.hs @@ -128,16 +128,17 @@ import Data.Bits ((.&.), (.|.)) import Data.Coerce (coerce) import Data.Functor.Identity (Identity (..)) -- See Note [Imports from Data.HashMap.Internal] +import Data.Hashable (Hashable) import Data.HashMap.Internal (Hash, HashMap (..), Leaf (..), LookupRes (..), bitsPerSubkey, fullNodeMask, hash, index, mask, ptrEq, sparseIndex) -import Data.Hashable (Hashable) import Prelude hiding (lookup, map) -- See Note [Imports from Data.HashMap.Internal] import qualified Data.HashMap.Internal as HM import qualified Data.HashMap.Internal.Array as A import qualified Data.List as List +import qualified GHC.Exts as Exts {- Note [Imports from Data.HashMap.Internal] @@ -616,11 +617,7 @@ differenceWith f a b = HM.foldlWithKey' go HM.empty a -- maps. intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWith f a b = HM.foldlWithKey' go HM.empty a - where - go m k v = case HM.lookup k b of - Just w -> let !x = f v w in HM.unsafeInsert k x m - _ -> m +intersectionWith f = Exts.inline intersectionWithKey $ const f {-# INLINABLE intersectionWith #-} -- | \(O(n+m)\) Intersection of two maps. If a key occurs in both maps @@ -628,11 +625,7 @@ intersectionWith f a b = HM.foldlWithKey' go HM.empty a -- maps. intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWithKey f a b = HM.foldlWithKey' go HM.empty a - where - go m k v = case HM.lookup k b of - Just w -> let !x = f k v w in HM.unsafeInsert k x m - _ -> m +intersectionWithKey f = HM.intersectionWithKey# $ \k v1 v2 -> let !v3 = f k v1 v2 in (# v3 #) {-# INLINABLE intersectionWithKey #-} ------------------------------------------------------------------------ diff --git a/benchmarks/Benchmarks.hs b/benchmarks/Benchmarks.hs index c0f7f550..ae05c422 100644 --- a/benchmarks/Benchmarks.hs +++ b/benchmarks/Benchmarks.hs @@ -318,13 +318,17 @@ main = do [ bench "Int" $ whnf (HM.union hmi) hmi2 , bench "ByteString" $ whnf (HM.union hmbs) hmbsSubset ] + + , bgroup "intersection" + [ bench "Int" $ whnf (HM.intersection hmi) hmi2 + , bench "ByteString" $ whnf (HM.intersection hmbs) hmbsSubset + ] -- Transformations , bench "map" $ whnf (HM.map (\ v -> v + 1)) hmi -- * Difference and intersection , bench "difference" $ whnf (HM.difference hmi) hmi2 - , bench "intersection" $ whnf (HM.intersection hmi) hmi2 -- Folds , bench "foldl'" $ whnf (HM.foldl' (+) 0) hmi diff --git a/tests/Properties/HashMapLazy.hs b/tests/Properties/HashMapLazy.hs index 8b712da3..f25fb42d 100644 --- a/tests/Properties/HashMapLazy.hs +++ b/tests/Properties/HashMapLazy.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- because of Arbitrary (HashMap k v) +{-# LANGUAGE BangPatterns #-} -- | Tests for the 'Data.HashMap.Lazy' module. We test functions by -- comparing them to @Map@ from @containers@. @@ -42,7 +43,7 @@ import qualified Data.Map.Lazy as M -- Key type that generates more hash collisions. newtype Key = K { unK :: Int } - deriving (Arbitrary, Eq, Ord, Read, Show) + deriving (Arbitrary, Eq, Ord, Read, Show, Num) instance Hashable Key where hashWithSalt salt k = hashWithSalt salt (unK k) `mod` 20 @@ -318,8 +319,10 @@ pDifferenceWith xs ys = M.differenceWith f (M.fromList xs) `eq_` f x y = if x == 0 then Nothing else Just (x - y) pIntersection :: [(Key, Int)] -> [(Key, Int)] -> Bool -pIntersection xs ys = M.intersection (M.fromList xs) `eq_` - HM.intersection (HM.fromList xs) $ ys +pIntersection xs ys = + M.intersection (M.fromList xs) + `eq_` HM.intersection (HM.fromList xs) + $ ys pIntersectionWith :: [(Key, Int)] -> [(Key, Int)] -> Bool pIntersectionWith xs ys = M.intersectionWith (-) (M.fromList xs) `eq_`