From 21f238b1e11522262b455ed82ab3562855589b96 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 11:04:01 -0400 Subject: [PATCH 01/28] fast intersection --- Data/HashMap/Internal.hs | 179 +++++++++++++++++++++++++++----- Data/HashMap/Internal/Array.hs | 9 +- Data/HashSet/Internal.hs | 2 +- tests/Properties/HashMapLazy.hs | 32 +++++- 4 files changed, 188 insertions(+), 34 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 102a93de..6f2950bf 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -152,6 +152,7 @@ import Data.Functor.Classes (Eq1 (..), Eq2 (..), Ord1 (..), Ord2 (..), import Data.Functor.Identity (Identity (..)) import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) import Data.Hashable (Hashable) +import Debug.Trace (traceId) import Data.Hashable.Lifted (Hashable1, Hashable2) import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid) import GHC.Exts (Int (..), Int#, TYPE, (==#)) @@ -169,6 +170,8 @@ import qualified Data.Hashable.Lifted as H import qualified Data.List as List import qualified GHC.Exts as Exts import qualified Language.Haskell.TH.Syntax as TH +import Numeric (showIntAtBase) +import Data.Char (intToDigit) -- | A set of values. A set cannot contain duplicate values. ------------------------------------------------------------------------ @@ -178,7 +181,7 @@ hash :: H.Hashable a => a -> Hash hash = fromIntegral . H.hash data Leaf k v = L !k v - deriving (Eq) + deriving (Show, Eq) instance (NFData k, NFData v) => NFData (Leaf k v) where rnf (L k v) = rnf k `seq` rnf v @@ -210,6 +213,7 @@ data HashMap k v | Leaf !Hash !(Leaf k v) | Full !(A.Array (HashMap k v)) | Collision !Hash !(A.Array (Leaf k v)) + deriving (Show) type role HashMap nominal representational @@ -337,9 +341,9 @@ instance (Eq k, Hashable k, Read k, Read e) => Read (HashMap k e) where readListPrec = readListPrecDefault -instance (Show k, Show v) => Show (HashMap k v) where - showsPrec d m = showParen (d > 10) $ - showString "fromList " . shows (toList m) +-- instance (Show k, Show v) => Show (HashMap k v) where +-- showsPrec d m = showParen (d > 10) $ +-- showString "fromList " . shows (toList m) instance Traversable (HashMap k) where traverse f = traverseWithKey (const f) @@ -1602,10 +1606,6 @@ unionWithKey f = go 0 ary' = update32With' ary2 i $ \st2 -> go (s+bitsPerSubkey) t1 st2 in Full ary' - leafHashCode (Leaf h _) = h - leafHashCode (Collision h _) = h - leafHashCode _ = error "leafHashCode" - goDifferentHash s h1 h2 t1 t2 | m1 == m2 = BitmapIndexed m1 (A.singleton $! goDifferentHash (s+bitsPerSubkey) h1 h2 t1 t2) | m1 < m2 = BitmapIndexed (m1 .|. m2) (A.pair t1 t2) @@ -1639,7 +1639,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do A.write mary i =<< A.indexM ary2 i2 go (i+1) i1 (i2+1) b' where - m = 1 `unsafeShiftL` (countTrailingZeros b) + m = 1 `unsafeShiftL` countTrailingZeros b testBit x = x .&. m /= 0 b' = b .&. complement m go 0 0 0 bCombined @@ -1770,38 +1770,155 @@ differenceWith f a b = foldlWithKey' go empty a -- | /O(n*log m)/ Intersection of two maps. Return elements of the first -- map for keys existing in the second. -intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v -intersection a b = foldlWithKey' go empty a - where - go m k v = case lookup k b of - Just _ -> unsafeInsert k v m - _ -> m +intersection :: (Eq k, Hashable k, Show v, Show w, Show k) => HashMap k v -> HashMap k w -> HashMap k v +intersection = intersectionWith const {-# INLINABLE intersection #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. -intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 - -> HashMap k v2 -> HashMap k v3 -intersectionWith f a b = foldlWithKey' go empty a - where - go m k v = case lookup k b of - Just w -> unsafeInsert k (f v w) m - _ -> m +intersectionWith :: (Show v1, Show v2, Show v3, Show k, Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWith f = intersectionWithKey $ const f {-# INLINABLE intersectionWith #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. -intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) - -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWithKey f a b = foldlWithKey' go empty a +intersectionWithKey :: (Show v1, Show v2, Show v3, Show k, Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWithKey f = go 0 where - go m k v = case lookup k b of - Just w -> unsafeInsert k (f k v w) m - _ -> m + -- empty vs. anything + go !_ _ Empty = Empty + go _ Empty _ = Empty + -- leaf vs. anything + go s (Leaf h1 (L k1 v1)) t2 = lookupCont (\_ -> Empty) (\v _ -> Leaf h1 $ L k1 $ f k1 v1 v) h1 k1 s t2 + go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> Leaf h2 $ L k2 $ f k2 v v2) h2 k2 s t1 + -- collision vs. collision + go _ (Collision h1 ls1) (Collision h2 ls2) + | h1 == h2 = if A.length ls == 0 then Empty else Collision h1 ls + | otherwise = Empty + where + ls = intersectionUnorderedArrayWithKey (\k v1 v2 -> (# f k v1 v2 #)) ls1 ls2 + -- branch vs. branch + go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = normalize b ary + where + (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 + go s (BitmapIndexed b1 ary1) (Full ary2) = normalize b ary + where + (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2 + go s (Full ary1) (BitmapIndexed b2 ary2) = normalize b ary + where + (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2 + go s (Full ary1) (Full ary2) = normalize b ary + where + (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2 + -- collision vs. branch + go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2) + | b1 .&. m2 == 0 = Empty + | otherwise = go (s + bitsPerSubkey) (A.index ary1 i) t2 + where + m2 = mask h2 s + i = sparseIndex b1 m2 + go s t1@(Collision h1 _ls1) (BitmapIndexed b2 ary2) + | b2 .&. m1 == 0 = Empty + | otherwise = go (s + bitsPerSubkey) t1 (A.index ary2 i) + where + m1 = mask h1 s + i = sparseIndex b2 m1 + go s (Full ary1) t2@(Collision h2 _ls2) = go (s + bitsPerSubkey) (A.index ary1 i) t2 + where + i = index h2 s + go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i) + where + i = index h1 s + + normalize b ary + | A.length ary == 0 = Empty + | otherwise = bitmapIndexedOrFull b ary {-# INLINABLE intersectionWithKey #-} +showBitmap :: Bitmap -> String +showBitmap b = showIntAtBase 2 intToDigit b "" + +intersectionArrayBy :: (Show v1, Show v2, Show k, Show v) => (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v)) +intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True +{-# INLINE intersectionArrayBy #-} + +intersectionArrayByFilter :: (Show v1, Show v2, Show v3) => (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array v3) +intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do + let bCombined = b1 .|. b2 + let + -- !_ = traceId $ "intersecting arrays" + -- !_ = traceId $ "b1: " ++ showBitmap b1 + -- !_ = traceId $ "b2: " ++ showBitmap b2 + -- !_ = traceId $ "bCombined: " ++ showBitmap bCombined + mary <- A.new_ $ popCount $ b1 .&. b2 + -- iterate over nonzero bits of b1 .&. b2 + let go !i !i1 !i2 !b !bFinal + | b == 0 = pure (i, bFinal) + | testBit $ b1 .&. b2 = do + x1 <- A.indexM ary1 i1 + x2 <- A.indexM ary2 i2 + let + -- !_ = traceId $ "i1: " ++ show i1 + -- !_ = traceId $ "i2: " ++ show i2 + -- !_ = traceId $ "x1: " ++ show x1 + -- !_ = traceId $ "x2: " ++ show x2 + -- !_ = traceId $ "writing " ++ show (f x1 x2) + let !x = f x1 x2 + if p x + then do + A.write mary i $! f x1 x2 + go (i + 1) (i1 + 1) (i2 + 1) b' bFinal + else go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m) + | testBit b1 = go i (i1 + 1) i2 b' bFinal + | otherwise = go i i1 (i2 + 1) b' bFinal + where + m = 1 `unsafeShiftL` countTrailingZeros b + testBit x = x .&. m /= 0 + b' = b .&. complement m + (maryLen, bFinal) <- go 0 0 0 bCombined (b1 .&. b2) + A.shrink mary maryLen + ary <- A.unsafeFreeze mary + pure (bFinal, ary) +{-# INLINE intersectionArrayByFilter #-} + +intersectionUnorderedArrayWithKey :: (Show k, Show v1, Show v2, Show v3, Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> A.Array (Leaf k v3) +intersectionUnorderedArrayWithKey f ary1 ary2 = A.run $ do + mary2 <- A.thaw ary2 0 $ A.length ary2 + mary <- A.new_ $ A.length ary1 + A.length ary2 + let go i j + | i >= A.length ary1 || j >= A.lengthM mary2 = pure j + | otherwise = do + L k1 v1 <- A.indexM ary1 i + searchSwap k1 j mary2 >>= \case + Just (L _k2 v2) -> do + -- let !_ = traceId $ "found " ++ show k1 + let !(# v3 #) = f k1 v1 v2 + A.write mary j $ L k1 v3 + go (i + 1) (j + 1) + Nothing -> do + -- let !_ = traceId $ "did not find " ++ show k1 + go (i + 1) j + maryLen <- go 0 0 + A.shrink mary maryLen + pure mary +{-# INLINABLE intersectionUnorderedArrayWithKey #-} + +searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) +searchSwap toFind start = go start toFind start + where + go i0 k i mary + | i >= A.lengthM mary = pure Nothing + | otherwise = do + l@(L k' _v) <- A.read mary i + if k == k' + then do + A.write mary i =<< A.read mary i0 + pure $ Just l + else go i0 k (i + 1) mary + + ------------------------------------------------------------------------ -- * Folds @@ -2282,6 +2399,12 @@ ptrEq :: a -> a -> Bool ptrEq x y = Exts.isTrue# (Exts.reallyUnsafePtrEquality# x y ==# 1#) {-# INLINE ptrEq #-} +leafHashCode :: HashMap k v -> Hash +leafHashCode (Leaf h _) = h +leafHashCode (Collision h _) = h +leafHashCode _ = error "leafHashCode" +{-# INLINE leafHashCode #-} + ------------------------------------------------------------------------ -- IsList instance instance (Eq k, Hashable k) => Exts.IsList (HashMap k v) where diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index 9d74eb03..5506fa8b 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -76,6 +76,7 @@ module Data.HashMap.Internal.Array , toList , fromList , fromList' + , shrink ) where import Control.Applicative (liftA2) @@ -90,7 +91,7 @@ import GHC.Exts (Int (..), SmallArray#, SmallMutableArray#, sizeofSmallMutableArray#, tagToEnum#, thawSmallArray#, unsafeCoerce#, unsafeFreezeSmallArray#, unsafeThawSmallArray#, - writeSmallArray#) + writeSmallArray#, shrinkSmallMutableArray#) import GHC.ST (ST (..)) import Prelude hiding (all, filter, foldMap, foldl, foldr, length, map, read, traverse) @@ -204,6 +205,12 @@ new _n@(I# n#) b = new_ :: Int -> ST s (MArray s a) new_ n = new n undefinedElem +shrink :: MArray s a -> Int -> ST s () +shrink mary (I# n#) = + ST $ \s -> case shrinkSmallMutableArray# (unMArray mary) n# s of + s' -> (# s', () #) +{-# INLINE shrink #-} + singleton :: a -> Array a singleton x = runST (singletonM x) {-# INLINE singleton #-} diff --git a/Data/HashSet/Internal.hs b/Data/HashSet/Internal.hs index 340bb742..15b2254e 100644 --- a/Data/HashSet/Internal.hs +++ b/Data/HashSet/Internal.hs @@ -391,7 +391,7 @@ difference (HashSet a) (HashSet b) = HashSet (H.difference a b) -- -- >>> HashSet.intersection (HashSet.fromList [1,2,3]) (HashSet.fromList [2,3,4]) -- fromList [2,3] -intersection :: (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a +intersection :: (Show a, Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a intersection (HashSet a) (HashSet b) = HashSet (H.intersection a b) {-# INLINABLE intersection #-} diff --git a/tests/Properties/HashMapLazy.hs b/tests/Properties/HashMapLazy.hs index 8b712da3..52cb8862 100644 --- a/tests/Properties/HashMapLazy.hs +++ b/tests/Properties/HashMapLazy.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- because of Arbitrary (HashMap k v) +{-# LANGUAGE BangPatterns #-} -- | Tests for the 'Data.HashMap.Lazy' module. We test functions by -- comparing them to @Map@ from @containers@. @@ -17,6 +18,7 @@ import Control.Applicative (Const (..)) import Control.Monad (guard) import Data.Bifoldable import Data.Function (on) +import Debug.Trace (traceId) import Data.Functor.Identity (Identity (..)) import Data.Hashable (Hashable (hashWithSalt)) import Data.Ord (comparing) @@ -26,6 +28,7 @@ import Test.QuickCheck.Function (Fun, apply) import Test.QuickCheck.Poly (A, B) import Test.Tasty (TestTree, testGroup) import Test.Tasty.QuickCheck (testProperty) +import Test.Tasty.HUnit import qualified Data.Foldable as Foldable import qualified Data.List as List @@ -42,7 +45,7 @@ import qualified Data.Map.Lazy as M -- Key type that generates more hash collisions. newtype Key = K { unK :: Int } - deriving (Arbitrary, Eq, Ord, Read, Show) + deriving (Arbitrary, Eq, Ord, Read, Show, Num) instance Hashable Key where hashWithSalt salt k = hashWithSalt salt (unK k) `mod` 20 @@ -249,7 +252,15 @@ pSubmapDifference m1 m2 = HM.isSubmapOf (HM.difference m1 m2) m1 pNotSubmapDifference :: HashMap Key Int -> HashMap Key Int -> Property pNotSubmapDifference m1 m2 = - not (HM.null (HM.intersection m1 m2)) ==> + not (HM.null (HM.intersection m1 m2)) ==> do + + let + res = HM.intersection m1 m2 + res' = M.intersection (M.fromList $ HM.toList m1) (M.fromList $ HM.toList m2) + -- !_ = traceId $ "res: " ++ show res + -- !_ = traceId $ "res': " ++ show res' + -- !_ = traceId $ "m1: " ++ show m1 + -- !_ = traceId $ "m2: " ++ show m2 not (HM.isSubmapOf m1 (HM.difference m1 m2)) pSubmapDelete :: HashMap Key Int -> Property @@ -318,8 +329,20 @@ pDifferenceWith xs ys = M.differenceWith f (M.fromList xs) `eq_` f x y = if x == 0 then Nothing else Just (x - y) pIntersection :: [(Key, Int)] -> [(Key, Int)] -> Bool -pIntersection xs ys = M.intersection (M.fromList xs) `eq_` - HM.intersection (HM.fromList xs) $ ys +pIntersection xs ys = do + let + res' = M.intersection (M.fromList xs) (M.fromList ys) + res = HM.intersection (HM.fromList xs) (HM.fromList ys) + -- !_ = traceId $ "res': " ++ show res' + -- !_ = traceId $ "res: " ++ show res + -- !_ = traceId $ "xs: " ++ show (HM.fromList xs) + -- !_ = traceId $ "ys: " ++ show (HM.fromList ys) + M.intersection (M.fromList xs) + `eq_` HM.intersection (HM.fromList xs) + $ ys + +intersectionBad :: Assertion +intersectionBad = pIntersection [(-20, 0), (0, 0)] [(0, 0), (20, 0)] @? "should be true" pIntersectionWith :: [(Key, Int)] -> [(Key, Int)] -> Bool pIntersectionWith xs ys = M.intersectionWith (-) (M.fromList xs) `eq_` @@ -531,6 +554,7 @@ tests = [ testProperty "difference" pDifference , testProperty "differenceWith" pDifferenceWith , testProperty "intersection" pIntersection + , testCase "intersectionBad" intersectionBad , testProperty "intersectionWith" pIntersectionWith , testProperty "intersectionWithKey" pIntersectionWithKey ] From 16f1f7fc4a4828fdbdb00620665825fec32dfe4d Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 13:05:04 -0400 Subject: [PATCH 02/28] cleanup --- Data/HashMap/Internal.hs | 60 +++++++++++---------------------- tests/Properties/HashMapLazy.hs | 9 +---- 2 files changed, 20 insertions(+), 49 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 6f2950bf..942a7f64 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -143,17 +143,16 @@ import Control.Applicative (Const (..)) import Control.DeepSeq (NFData (..), NFData1 (..), NFData2 (..)) import Control.Monad.ST (ST, runST) import Data.Bifoldable (Bifoldable (..)) -import Data.Bits (complement, popCount, unsafeShiftL, - unsafeShiftR, (.&.), (.|.), countTrailingZeros) +import Data.Bits (complement, countTrailingZeros, popCount, + unsafeShiftL, unsafeShiftR, (.&.), (.|.)) import Data.Coerce (coerce) import Data.Data (Constr, Data (..), DataType) import Data.Functor.Classes (Eq1 (..), Eq2 (..), Ord1 (..), Ord2 (..), Read1 (..), Show1 (..), Show2 (..)) import Data.Functor.Identity (Identity (..)) -import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) import Data.Hashable (Hashable) -import Debug.Trace (traceId) import Data.Hashable.Lifted (Hashable1, Hashable2) +import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid) import GHC.Exts (Int (..), Int#, TYPE, (==#)) import GHC.Stack (HasCallStack) @@ -164,14 +163,12 @@ import Text.Read hiding (step) import qualified Data.Data as Data import qualified Data.Foldable as Foldable import qualified Data.Functor.Classes as FC -import qualified Data.HashMap.Internal.Array as A import qualified Data.Hashable as H import qualified Data.Hashable.Lifted as H +import qualified Data.HashMap.Internal.Array as A import qualified Data.List as List import qualified GHC.Exts as Exts import qualified Language.Haskell.TH.Syntax as TH -import Numeric (showIntAtBase) -import Data.Char (intToDigit) -- | A set of values. A set cannot contain duplicate values. ------------------------------------------------------------------------ @@ -1770,21 +1767,21 @@ differenceWith f a b = foldlWithKey' go empty a -- | /O(n*log m)/ Intersection of two maps. Return elements of the first -- map for keys existing in the second. -intersection :: (Eq k, Hashable k, Show v, Show w, Show k) => HashMap k v -> HashMap k w -> HashMap k v +intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v intersection = intersectionWith const {-# INLINABLE intersection #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. -intersectionWith :: (Show v1, Show v2, Show v3, Show k, Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 intersectionWith f = intersectionWithKey $ const f {-# INLINABLE intersectionWith #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. -intersectionWithKey :: (Show v1, Show v2, Show v3, Show k, Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 intersectionWithKey f = go 0 where -- empty vs. anything @@ -1800,18 +1797,10 @@ intersectionWithKey f = go 0 where ls = intersectionUnorderedArrayWithKey (\k v1 v2 -> (# f k v1 v2 #)) ls1 ls2 -- branch vs. branch - go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = normalize b ary - where - (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 - go s (BitmapIndexed b1 ary1) (Full ary2) = normalize b ary - where - (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2 - go s (Full ary1) (BitmapIndexed b2 ary2) = normalize b ary - where - (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2 - go s (Full ary1) (Full ary2) = normalize b ary - where - (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2 + go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArray s b1 b2 ary1 ary2 + go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArray s b1 fullNodeMask ary1 ary2 + go s (Full ary1) (BitmapIndexed b2 ary2) = intersectionArray s fullNodeMask b2 ary1 ary2 + go s (Full ary1) (Full ary2) = intersectionArray s fullNodeMask fullNodeMask ary1 ary2 -- collision vs. branch go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2) | b1 .&. m2 == 0 = Empty @@ -1832,26 +1821,22 @@ intersectionWithKey f = go 0 where i = index h1 s + intersectionArray s b1 b2 ary1 ary2 = normalize b ary + where + (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 + normalize b ary | A.length ary == 0 = Empty | otherwise = bitmapIndexedOrFull b ary {-# INLINABLE intersectionWithKey #-} -showBitmap :: Bitmap -> String -showBitmap b = showIntAtBase 2 intToDigit b "" - -intersectionArrayBy :: (Show v1, Show v2, Show k, Show v) => (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v)) +intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v)) intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True {-# INLINE intersectionArrayBy #-} -intersectionArrayByFilter :: (Show v1, Show v2, Show v3) => (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array v3) +intersectionArrayByFilter :: (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array v3) intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do let bCombined = b1 .|. b2 - let - -- !_ = traceId $ "intersecting arrays" - -- !_ = traceId $ "b1: " ++ showBitmap b1 - -- !_ = traceId $ "b2: " ++ showBitmap b2 - -- !_ = traceId $ "bCombined: " ++ showBitmap bCombined mary <- A.new_ $ popCount $ b1 .&. b2 -- iterate over nonzero bits of b1 .&. b2 let go !i !i1 !i2 !b !bFinal @@ -1859,12 +1844,6 @@ intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do | testBit $ b1 .&. b2 = do x1 <- A.indexM ary1 i1 x2 <- A.indexM ary2 i2 - let - -- !_ = traceId $ "i1: " ++ show i1 - -- !_ = traceId $ "i2: " ++ show i2 - -- !_ = traceId $ "x1: " ++ show x1 - -- !_ = traceId $ "x2: " ++ show x2 - -- !_ = traceId $ "writing " ++ show (f x1 x2) let !x = f x1 x2 if p x then do @@ -1883,7 +1862,7 @@ intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do pure (bFinal, ary) {-# INLINE intersectionArrayByFilter #-} -intersectionUnorderedArrayWithKey :: (Show k, Show v1, Show v2, Show v3, Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> A.Array (Leaf k v3) +intersectionUnorderedArrayWithKey :: (Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> A.Array (Leaf k v3) intersectionUnorderedArrayWithKey f ary1 ary2 = A.run $ do mary2 <- A.thaw ary2 0 $ A.length ary2 mary <- A.new_ $ A.length ary1 + A.length ary2 @@ -1893,12 +1872,10 @@ intersectionUnorderedArrayWithKey f ary1 ary2 = A.run $ do L k1 v1 <- A.indexM ary1 i searchSwap k1 j mary2 >>= \case Just (L _k2 v2) -> do - -- let !_ = traceId $ "found " ++ show k1 let !(# v3 #) = f k1 v1 v2 A.write mary j $ L k1 v3 go (i + 1) (j + 1) Nothing -> do - -- let !_ = traceId $ "did not find " ++ show k1 go (i + 1) j maryLen <- go 0 0 A.shrink mary maryLen @@ -1917,6 +1894,7 @@ searchSwap toFind start = go start toFind start A.write mary i =<< A.read mary i0 pure $ Just l else go i0 k (i + 1) mary +{-# INLINE searchSwap #-} ------------------------------------------------------------------------ diff --git a/tests/Properties/HashMapLazy.hs b/tests/Properties/HashMapLazy.hs index 52cb8862..1e32a7e0 100644 --- a/tests/Properties/HashMapLazy.hs +++ b/tests/Properties/HashMapLazy.hs @@ -329,14 +329,7 @@ pDifferenceWith xs ys = M.differenceWith f (M.fromList xs) `eq_` f x y = if x == 0 then Nothing else Just (x - y) pIntersection :: [(Key, Int)] -> [(Key, Int)] -> Bool -pIntersection xs ys = do - let - res' = M.intersection (M.fromList xs) (M.fromList ys) - res = HM.intersection (HM.fromList xs) (HM.fromList ys) - -- !_ = traceId $ "res': " ++ show res' - -- !_ = traceId $ "res: " ++ show res - -- !_ = traceId $ "xs: " ++ show (HM.fromList xs) - -- !_ = traceId $ "ys: " ++ show (HM.fromList ys) +pIntersection xs ys = M.intersection (M.fromList xs) `eq_` HM.intersection (HM.fromList xs) $ ys From bcc13fccee9fead0b45baa5bc2b3823dc92c94dc Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 13:24:25 -0400 Subject: [PATCH 03/28] add show back --- Data/HashMap/Internal.hs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 942a7f64..e5684f87 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -210,7 +210,6 @@ data HashMap k v | Leaf !Hash !(Leaf k v) | Full !(A.Array (HashMap k v)) | Collision !Hash !(A.Array (Leaf k v)) - deriving (Show) type role HashMap nominal representational @@ -338,9 +337,9 @@ instance (Eq k, Hashable k, Read k, Read e) => Read (HashMap k e) where readListPrec = readListPrecDefault --- instance (Show k, Show v) => Show (HashMap k v) where --- showsPrec d m = showParen (d > 10) $ --- showString "fromList " . shows (toList m) +instance (Show k, Show v) => Show (HashMap k v) where + showsPrec d m = showParen (d > 10) $ + showString "fromList " . shows (toList m) instance Traversable (HashMap k) where traverse f = traverseWithKey (const f) From d5262bf3574a72189ccd19f0b261304ce94161e9 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 13:45:52 -0400 Subject: [PATCH 04/28] inline --- Data/HashMap/Internal.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index e5684f87..b0fba79c 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -178,7 +178,7 @@ hash :: H.Hashable a => a -> Hash hash = fromIntegral . H.hash data Leaf k v = L !k v - deriving (Show, Eq) + deriving (Eq) instance (NFData k, NFData v) => NFData (Leaf k v) where rnf (L k v) = rnf k `seq` rnf v @@ -1775,7 +1775,7 @@ intersection = intersectionWith const -- maps. intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 intersectionWith f = intersectionWithKey $ const f -{-# INLINABLE intersectionWith #-} +{-# INLINE intersectionWith #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two @@ -1827,7 +1827,7 @@ intersectionWithKey f = go 0 normalize b ary | A.length ary == 0 = Empty | otherwise = bitmapIndexedOrFull b ary -{-# INLINABLE intersectionWithKey #-} +{-# INLINE intersectionWithKey #-} intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v)) intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True From a16456b107c8eeb20a3a9ee35c7d6033c2c74250 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 14:19:18 -0400 Subject: [PATCH 05/28] debug checks --- Data/HashMap/Internal.hs | 13 +++++++++---- Data/HashMap/Internal/Array.hs | 6 +++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index b0fba79c..f30d55b9 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1820,7 +1820,10 @@ intersectionWithKey f = go 0 where i = index h1 s - intersectionArray s b1 b2 ary1 ary2 = normalize b ary + intersectionArray s b1 b2 ary1 ary2 + -- don't create an array of size zero in intersectionArrayBy + | b1 .&. b2 == 0 = Empty + | otherwise = normalize b ary where (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 @@ -1835,8 +1838,7 @@ intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> intersectionArrayByFilter :: (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array v3) intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do - let bCombined = b1 .|. b2 - mary <- A.new_ $ popCount $ b1 .&. b2 + mary <- A.new_ $ popCount bIntersect -- iterate over nonzero bits of b1 .&. b2 let go !i !i1 !i2 !b !bFinal | b == 0 = pure (i, bFinal) @@ -1855,10 +1857,13 @@ intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do m = 1 `unsafeShiftL` countTrailingZeros b testBit x = x .&. m /= 0 b' = b .&. complement m - (maryLen, bFinal) <- go 0 0 0 bCombined (b1 .&. b2) + (maryLen, bFinal) <- go 0 0 0 bCombined bIntersect A.shrink mary maryLen ary <- A.unsafeFreeze mary pure (bFinal, ary) + where + bCombined = b1 .|. b2 + bIntersect = b1 .&. b2 {-# INLINE intersectionArrayByFilter #-} intersectionUnorderedArrayWithKey :: (Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> A.Array (Leaf k v3) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index 5506fa8b..c63d170b 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -110,12 +110,14 @@ if (_k_) < 0 || (_k_) >= (_len_) then error ("Data.HashMap.Internal.Array." ++ ( # define CHECK_OP(_func_,_op_,_lhs_,_rhs_) \ if not ((_lhs_) _op_ (_rhs_)) then error ("Data.HashMap.Internal.Array." ++ (_func_) ++ ": Check failed: _lhs_ _op_ _rhs_ (" ++ show (_lhs_) ++ " vs. " ++ show (_rhs_) ++ ")") else # define CHECK_GT(_func_,_lhs_,_rhs_) CHECK_OP(_func_,>,_lhs_,_rhs_) +# define CHECK_GE(_func_,_lhs_,_rhs_) CHECK_OP(_func_,>=,_lhs_,_rhs_) # define CHECK_LE(_func_,_lhs_,_rhs_) CHECK_OP(_func_,<=,_lhs_,_rhs_) # define CHECK_EQ(_func_,_lhs_,_rhs_) CHECK_OP(_func_,==,_lhs_,_rhs_) #else # define CHECK_BOUNDS(_func_,_len_,_k_) # define CHECK_OP(_func_,_op_,_lhs_,_rhs_) # define CHECK_GT(_func_,_lhs_,_rhs_) +# define CHECK_GE(_func_,_lhs_,_rhs_) # define CHECK_LE(_func_,_lhs_,_rhs_) # define CHECK_EQ(_func_,_lhs_,_rhs_) #endif @@ -206,7 +208,9 @@ new_ :: Int -> ST s (MArray s a) new_ n = new n undefinedElem shrink :: MArray s a -> Int -> ST s () -shrink mary (I# n#) = +shrink mary _n@(I# n#) = + CHECK_GE("shrink", _n, (0 :: Int)) + CHECK_LE("shrink", _n, (lengthM mary)) ST $ \s -> case shrinkSmallMutableArray# (unMArray mary) n# s of s' -> (# s', () #) {-# INLINE shrink #-} From f72011c999ceb3b37d3192fd0b32d2af19fee47e Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 15:49:14 -0400 Subject: [PATCH 06/28] inline function --- Data/HashMap/Internal.hs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index f30d55b9..f1421030 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -154,7 +154,7 @@ import Data.Hashable (Hashable) import Data.Hashable.Lifted (Hashable1, Hashable2) import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid) -import GHC.Exts (Int (..), Int#, TYPE, (==#)) +import GHC.Exts (Int (..), Int#, TYPE, (==#), inline) import GHC.Stack (HasCallStack) import Prelude hiding (filter, foldl, foldr, lookup, map, null, pred) @@ -1767,15 +1767,15 @@ differenceWith f a b = foldlWithKey' go empty a -- | /O(n*log m)/ Intersection of two maps. Return elements of the first -- map for keys existing in the second. intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v -intersection = intersectionWith const +intersection = inline intersectionWith const {-# INLINABLE intersection #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWith f = intersectionWithKey $ const f -{-# INLINE intersectionWith #-} +intersectionWith f = inline intersectionWithKey $ const f +{-# INLINABLE intersectionWith #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two @@ -1830,7 +1830,7 @@ intersectionWithKey f = go 0 normalize b ary | A.length ary == 0 = Empty | otherwise = bitmapIndexedOrFull b ary -{-# INLINE intersectionWithKey #-} +{-# INLINABLE intersectionWithKey #-} intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v)) intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True From 678a38c1a22aac474d6e68baa53cec9db2881537 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 13:33:22 -0400 Subject: [PATCH 07/28] refactor to use snoc --- Data/HashMap/Internal.hs | 26 +++++--------------------- Data/HashMap/Internal/Array.hs | 10 ++++++++++ 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index f1421030..d0964757 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -819,17 +819,9 @@ insertNewKey !h0 !k0 x0 !m0 = go h0 k0 x0 0 m0 in Full (update32 ary i st') where i = index h s go h k x s t@(Collision hy v) - | h == hy = Collision h (snocNewLeaf (L k x) v) + | h == hy = Collision h (A.snoc v (L k x)) | otherwise = go h k x s $ BitmapIndexed (mask hy s) (A.singleton t) - where - snocNewLeaf :: Leaf k v -> A.Array (Leaf k v) -> A.Array (Leaf k v) - snocNewLeaf leaf ary = A.run $ do - let n = A.length ary - mary <- A.new_ (n + 1) - A.copy ary 0 mary 0 n - A.write mary n leaf - return mary {-# NOINLINE insertNewKey #-} @@ -1008,12 +1000,8 @@ insertModifyingArr :: Eq k => v -> (v -> (# v #)) -> k -> A.Array (Leaf k v) insertModifyingArr x f k0 ary0 = go k0 ary0 0 (A.length ary0) where go !k !ary !i !n - | i >= n = A.run $ do - -- Not found, append to the end. - mary <- A.new_ (n + 1) - A.copy ary 0 mary 0 n - A.write mary n (L k x) - return mary + -- Not found, append to the end. + | i >= n = A.snoc ary $ L k x | otherwise = case A.index ary i of (L kx y) | k == kx -> case f y of (# y' #) -> if ptrEq y y' @@ -2263,12 +2251,8 @@ updateOrSnocWithKey :: Eq k => (k -> v -> v -> (# v #)) -> k -> v -> A.Array (Le updateOrSnocWithKey f k0 v0 ary0 = go k0 v0 ary0 0 (A.length ary0) where go !k v !ary !i !n - | i >= n = A.run $ do - -- Not found, append to the end. - mary <- A.new_ (n + 1) - A.copy ary 0 mary 0 n - A.write mary n (L k v) - return mary + -- Not found, append to the end. + | i >= n = A.snoc ary $ L k v | L kx y <- A.index ary i , k == kx , (# v2 #) <- f k v y diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index c63d170b..f23e5e34 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -34,6 +34,7 @@ module Data.HashMap.Internal.Array , new_ , singleton , singletonM + , snoc , pair -- * Basic interface @@ -223,6 +224,15 @@ singletonM :: a -> ST s (Array a) singletonM x = new 1 x >>= unsafeFreeze {-# INLINE singletonM #-} +snoc :: Array a -> a -> Array a +snoc ary x = run $ do + mary <- new (n + 1) x + copy ary 0 mary 0 n + pure mary + where + n = length ary +{-# INLINE snoc #-} + pair :: a -> a -> Array a pair x y = run $ do ary <- new 2 x From ec242154bf4b5d5b7a776878e3decfb14642c5d4 Mon Sep 17 00:00:00 2001 From: David Feuer Date: Sat, 9 Apr 2022 16:54:38 -0400 Subject: [PATCH 08/28] Try the unboxed result thing This one inlines the unboxed form into everything else, hopefully. --- Data/HashMap/Internal.hs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index d0964757..2c2260b6 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1769,20 +1769,24 @@ intersectionWith f = inline intersectionWithKey $ const f -- the provided function is used to combine the values from the two -- maps. intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWithKey f = go 0 +intersectionWithKey f = intersectionWithKey# (\k v1 v2 -> (# f k v1 v2 #)) +{-# INLINABLE intersectionWithKey #-} + +intersectionWithKey# :: (Eq k, Hashable k) => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWithKey# f = go 0 where -- empty vs. anything go !_ _ Empty = Empty go _ Empty _ = Empty -- leaf vs. anything - go s (Leaf h1 (L k1 v1)) t2 = lookupCont (\_ -> Empty) (\v _ -> Leaf h1 $ L k1 $ f k1 v1 v) h1 k1 s t2 - go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> Leaf h2 $ L k2 $ f k2 v v2) h2 k2 s t1 + go s (Leaf h1 (L k1 v1)) t2 = lookupCont (\_ -> Empty) (\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v') h1 k1 s t2 + go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v') h2 k2 s t1 -- collision vs. collision go _ (Collision h1 ls1) (Collision h2 ls2) | h1 == h2 = if A.length ls == 0 then Empty else Collision h1 ls | otherwise = Empty where - ls = intersectionUnorderedArrayWithKey (\k v1 v2 -> (# f k v1 v2 #)) ls1 ls2 + ls = intersectionUnorderedArrayWithKey f ls1 ls2 -- branch vs. branch go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArray s b1 b2 ary1 ary2 go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArray s b1 fullNodeMask ary1 ary2 @@ -1818,7 +1822,7 @@ intersectionWithKey f = go 0 normalize b ary | A.length ary == 0 = Empty | otherwise = bitmapIndexedOrFull b ary -{-# INLINABLE intersectionWithKey #-} +{-# INLINE intersectionWithKey# #-} intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v)) intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True From 767ae6e8066f6a2a1596e7cfae6a4e13300e78ec Mon Sep 17 00:00:00 2001 From: David Feuer Date: Sat, 9 Apr 2022 17:03:41 -0400 Subject: [PATCH 09/28] Remove redundant internal constraint --- Data/HashMap/Internal.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 2c2260b6..3c36f1a3 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1768,11 +1768,11 @@ intersectionWith f = inline intersectionWithKey $ const f -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. -intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 intersectionWithKey f = intersectionWithKey# (\k v1 v2 -> (# f k v1 v2 #)) {-# INLINABLE intersectionWithKey #-} -intersectionWithKey# :: (Eq k, Hashable k) => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 +intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 intersectionWithKey# f = go 0 where -- empty vs. anything From fd43ba7c920b0136b8c9d0b6b6bd1368570fa6b2 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 18:04:31 -0400 Subject: [PATCH 10/28] shrink compat --- Data/HashMap/Internal.hs | 4 +--- Data/HashMap/Internal/Array.hs | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 3c36f1a3..25c28cfc 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1850,8 +1850,7 @@ intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do testBit x = x .&. m /= 0 b' = b .&. complement m (maryLen, bFinal) <- go 0 0 0 bCombined bIntersect - A.shrink mary maryLen - ary <- A.unsafeFreeze mary + ary <- A.unsafeFreeze =<< A.shrink mary maryLen pure (bFinal, ary) where bCombined = b1 .|. b2 @@ -1875,7 +1874,6 @@ intersectionUnorderedArrayWithKey f ary1 ary2 = A.run $ do go (i + 1) j maryLen <- go 0 0 A.shrink mary maryLen - pure mary {-# INLINABLE intersectionUnorderedArrayWithKey #-} searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index f23e5e34..b9af18a7 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -208,14 +208,23 @@ new _n@(I# n#) b = new_ :: Int -> ST s (MArray s a) new_ n = new n undefinedElem -shrink :: MArray s a -> Int -> ST s () +-- when shrinkSmallMutableArray# is available, the returned array is the same as the array given, as it is shrunk in place +-- otherwise a copy is made +shrink :: MArray s a -> Int -> ST s (MArray s a) +#if MIN_VERSION_GLASGOW_HASKELL(8, 10, 7, 0) shrink mary _n@(I# n#) = CHECK_GE("shrink", _n, (0 :: Int)) CHECK_LE("shrink", _n, (lengthM mary)) ST $ \s -> case shrinkSmallMutableArray# (unMArray mary) n# s of - s' -> (# s', () #) + s' -> (# s', mary #) +#else +shrink mary n = do + mary' <- new_ n + copyM mary 0 mary' 0 n + pure mary' +#endif {-# INLINE shrink #-} - + singleton :: a -> Array a singleton x = runST (singletonM x) {-# INLINE singleton #-} From 3612645fc05d90080524a7510fddc73e91e05f7e Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 18:09:24 -0400 Subject: [PATCH 11/28] fix import --- Data/HashMap/Internal/Array.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index b9af18a7..157f915d 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -92,7 +92,8 @@ import GHC.Exts (Int (..), SmallArray#, SmallMutableArray#, sizeofSmallMutableArray#, tagToEnum#, thawSmallArray#, unsafeCoerce#, unsafeFreezeSmallArray#, unsafeThawSmallArray#, - writeSmallArray#, shrinkSmallMutableArray#) + writeSmallArray#) +import qualified GHC.Exts as Exts import GHC.ST (ST (..)) import Prelude hiding (all, filter, foldMap, foldl, foldr, length, map, read, traverse) @@ -215,7 +216,7 @@ shrink :: MArray s a -> Int -> ST s (MArray s a) shrink mary _n@(I# n#) = CHECK_GE("shrink", _n, (0 :: Int)) CHECK_LE("shrink", _n, (lengthM mary)) - ST $ \s -> case shrinkSmallMutableArray# (unMArray mary) n# s of + ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of s' -> (# s', mary #) #else shrink mary n = do From b484042f511aea54705850fce702f7ca740d3490 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 18:21:01 -0400 Subject: [PATCH 12/28] use clone --- Data/HashMap/Internal/Array.hs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index 157f915d..eea41df1 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -219,10 +219,7 @@ shrink mary _n@(I# n#) = ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of s' -> (# s', mary #) #else -shrink mary n = do - mary' <- new_ n - copyM mary 0 mary' 0 n - pure mary' +shrink mary n = cloneM marr 0 n #endif {-# INLINE shrink #-} From 9e48bc0435bf4fcd2477130b64a1f0dfbc38851e Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 18:23:24 -0400 Subject: [PATCH 13/28] oof --- Data/HashMap/Internal/Array.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index eea41df1..02ce7741 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -219,7 +219,7 @@ shrink mary _n@(I# n#) = ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of s' -> (# s', mary #) #else -shrink mary n = cloneM marr 0 n +shrink mary n = cloneM mary 0 n #endif {-# INLINE shrink #-} From 48119cb9593ed12b4ea838311f4163a0e568df85 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 19:14:14 -0400 Subject: [PATCH 14/28] don't shrink to zero --- Data/HashMap/Internal.hs | 35 +++++++++++++++++----------------- Data/HashMap/Internal/Array.hs | 4 +--- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 25c28cfc..0e768c79 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1783,10 +1783,12 @@ intersectionWithKey# f = go 0 go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v') h2 k2 s t1 -- collision vs. collision go _ (Collision h1 ls1) (Collision h2 ls2) - | h1 == h2 = if A.length ls == 0 then Empty else Collision h1 ls + | h1 == h2 = runST $ do + (len, mary) <- intersectionUnorderedArrayWithKey f ls1 ls2 + if len == 0 + then pure Empty + else Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) | otherwise = Empty - where - ls = intersectionUnorderedArrayWithKey f ls1 ls2 -- branch vs. branch go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArray s b1 b2 ary1 ary2 go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArray s b1 fullNodeMask ary1 ary2 @@ -1815,21 +1817,21 @@ intersectionWithKey# f = go 0 intersectionArray s b1 b2 ary1 ary2 -- don't create an array of size zero in intersectionArrayBy | b1 .&. b2 == 0 = Empty - | otherwise = normalize b ary - where - (b, ary) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 + | otherwise = runST $ do + (b, len, ary) <- intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 + -- don't shrink an array of length 0 + if len == 0 + then pure Empty + else bitmapIndexedOrFull b <$> (A.unsafeFreeze =<< A.shrink ary len) - normalize b ary - | A.length ary == 0 = Empty - | otherwise = bitmapIndexedOrFull b ary {-# INLINE intersectionWithKey# #-} -intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array (HashMap k v)) +intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> ST s (Bitmap, Int, A.MArray s (HashMap k v)) intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True {-# INLINE intersectionArrayBy #-} -intersectionArrayByFilter :: (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> (Bitmap, A.Array v3) -intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do +intersectionArrayByFilter :: (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> ST s (Bitmap, Int, A.MArray s v3) +intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = do mary <- A.new_ $ popCount bIntersect -- iterate over nonzero bits of b1 .&. b2 let go !i !i1 !i2 !b !bFinal @@ -1850,15 +1852,14 @@ intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = runST $ do testBit x = x .&. m /= 0 b' = b .&. complement m (maryLen, bFinal) <- go 0 0 0 bCombined bIntersect - ary <- A.unsafeFreeze =<< A.shrink mary maryLen - pure (bFinal, ary) + pure (bFinal, maryLen, mary) where bCombined = b1 .|. b2 bIntersect = b1 .&. b2 {-# INLINE intersectionArrayByFilter #-} -intersectionUnorderedArrayWithKey :: (Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> A.Array (Leaf k v3) -intersectionUnorderedArrayWithKey f ary1 ary2 = A.run $ do +intersectionUnorderedArrayWithKey :: (Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) +intersectionUnorderedArrayWithKey f ary1 ary2 = do mary2 <- A.thaw ary2 0 $ A.length ary2 mary <- A.new_ $ A.length ary1 + A.length ary2 let go i j @@ -1873,7 +1874,7 @@ intersectionUnorderedArrayWithKey f ary1 ary2 = A.run $ do Nothing -> do go (i + 1) j maryLen <- go 0 0 - A.shrink mary maryLen + pure (maryLen, mary) {-# INLINABLE intersectionUnorderedArrayWithKey #-} searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index 02ce7741..d63c03f8 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -112,14 +112,12 @@ if (_k_) < 0 || (_k_) >= (_len_) then error ("Data.HashMap.Internal.Array." ++ ( # define CHECK_OP(_func_,_op_,_lhs_,_rhs_) \ if not ((_lhs_) _op_ (_rhs_)) then error ("Data.HashMap.Internal.Array." ++ (_func_) ++ ": Check failed: _lhs_ _op_ _rhs_ (" ++ show (_lhs_) ++ " vs. " ++ show (_rhs_) ++ ")") else # define CHECK_GT(_func_,_lhs_,_rhs_) CHECK_OP(_func_,>,_lhs_,_rhs_) -# define CHECK_GE(_func_,_lhs_,_rhs_) CHECK_OP(_func_,>=,_lhs_,_rhs_) # define CHECK_LE(_func_,_lhs_,_rhs_) CHECK_OP(_func_,<=,_lhs_,_rhs_) # define CHECK_EQ(_func_,_lhs_,_rhs_) CHECK_OP(_func_,==,_lhs_,_rhs_) #else # define CHECK_BOUNDS(_func_,_len_,_k_) # define CHECK_OP(_func_,_op_,_lhs_,_rhs_) # define CHECK_GT(_func_,_lhs_,_rhs_) -# define CHECK_GE(_func_,_lhs_,_rhs_) # define CHECK_LE(_func_,_lhs_,_rhs_) # define CHECK_EQ(_func_,_lhs_,_rhs_) #endif @@ -214,7 +212,7 @@ new_ n = new n undefinedElem shrink :: MArray s a -> Int -> ST s (MArray s a) #if MIN_VERSION_GLASGOW_HASKELL(8, 10, 7, 0) shrink mary _n@(I# n#) = - CHECK_GE("shrink", _n, (0 :: Int)) + CHECK_GT("shrink", _n, (0 :: Int)) CHECK_LE("shrink", _n, (lengthM mary)) ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of s' -> (# s', mary #) From d9d295dc959b51e29dc7bccf73c2add64d5bb685 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sat, 9 Apr 2022 19:35:55 -0400 Subject: [PATCH 15/28] Leaf special case --- Data/HashMap/Internal.hs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 0e768c79..7addc8cf 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1785,9 +1785,10 @@ intersectionWithKey# f = go 0 go _ (Collision h1 ls1) (Collision h2 ls2) | h1 == h2 = runST $ do (len, mary) <- intersectionUnorderedArrayWithKey f ls1 ls2 - if len == 0 - then pure Empty - else Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) + case len of + 0 -> pure Empty + 1 -> Leaf h1 <$> A.read mary 0 + _ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) | otherwise = Empty -- branch vs. branch go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArray s b1 b2 ary1 ary2 @@ -1819,10 +1820,10 @@ intersectionWithKey# f = go 0 | b1 .&. b2 == 0 = Empty | otherwise = runST $ do (b, len, ary) <- intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 - -- don't shrink an array of length 0 - if len == 0 - then pure Empty - else bitmapIndexedOrFull b <$> (A.unsafeFreeze =<< A.shrink ary len) + case len of + 0 -> pure Empty + 1 -> A.read ary 0 + _ -> bitmapIndexedOrFull b <$> (A.unsafeFreeze =<< A.shrink ary len) {-# INLINE intersectionWithKey# #-} From 88a9c2c946323bdfcfed5fa35fceb8944e78fcb4 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Sun, 10 Apr 2022 12:26:02 -0400 Subject: [PATCH 16/28] add strict verisons --- Data/HashMap/Internal.hs | 38 ++++++++++++++++++--------------- Data/HashMap/Internal/Strict.hs | 13 +++-------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 7addc8cf..c79b2883 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -78,6 +78,7 @@ module Data.HashMap.Internal , intersection , intersectionWith , intersectionWithKey + , intersectionWithKey# -- * Folds , foldr' @@ -1769,7 +1770,7 @@ intersectionWith f = inline intersectionWithKey $ const f -- the provided function is used to combine the values from the two -- maps. intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWithKey f = intersectionWithKey# (\k v1 v2 -> (# f k v1 v2 #)) +intersectionWithKey f = intersectionWithKey# $ \k v1 v2 -> (# f k v1 v2 #) {-# INLINABLE intersectionWithKey #-} intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 @@ -1794,7 +1795,7 @@ intersectionWithKey# f = go 0 go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArray s b1 b2 ary1 ary2 go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArray s b1 fullNodeMask ary1 ary2 go s (Full ary1) (BitmapIndexed b2 ary2) = intersectionArray s fullNodeMask b2 ary1 ary2 - go s (Full ary1) (Full ary2) = intersectionArray s fullNodeMask fullNodeMask ary1 ary2 + go s (Full ary1) (Full ary2) = intersectionArray s fullNodeMask fullNodeMask ary1 ary2 -- collision vs. branch go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2) | b1 .&. m2 == 0 = Empty @@ -1814,7 +1815,7 @@ intersectionWithKey# f = go 0 go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i) where i = index h1 s - + intersectionArray s b1 b2 ary1 ary2 -- don't create an array of size zero in intersectionArrayBy | b1 .&. b2 == 0 = Empty @@ -1824,15 +1825,19 @@ intersectionWithKey# f = go 0 0 -> pure Empty 1 -> A.read ary 0 _ -> bitmapIndexedOrFull b <$> (A.unsafeFreeze =<< A.shrink ary len) - {-# INLINE intersectionWithKey# #-} -intersectionArrayBy :: (v1 -> v2 -> HashMap k v) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> ST s (Bitmap, Int, A.MArray s (HashMap k v)) -intersectionArrayBy f = intersectionArrayByFilter f $ \case Empty -> False; _ -> True -{-# INLINE intersectionArrayBy #-} - -intersectionArrayByFilter :: (v1 -> v2 -> v3) -> (v3 -> Bool) -> Bitmap -> Bitmap -> A.Array v1 -> A.Array v2 -> ST s (Bitmap, Int, A.MArray s v3) -intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = do +intersectionArrayBy :: + ( HashMap k v1 -> + HashMap k v2 -> + HashMap k v3 + ) -> + Bitmap -> + Bitmap -> + A.Array (HashMap k v1) -> + A.Array (HashMap k v2) -> + ST s (Bitmap, Int, A.MArray s (HashMap k v3)) +intersectionArrayBy f !b1 !b2 !ary1 !ary2 = do mary <- A.new_ $ popCount bIntersect -- iterate over nonzero bits of b1 .&. b2 let go !i !i1 !i2 !b !bFinal @@ -1840,12 +1845,11 @@ intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = do | testBit $ b1 .&. b2 = do x1 <- A.indexM ary1 i1 x2 <- A.indexM ary2 i2 - let !x = f x1 x2 - if p x - then do + case f x1 x2 of + Empty -> go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m) + _ -> do A.write mary i $! f x1 x2 go (i + 1) (i1 + 1) (i2 + 1) b' bFinal - else go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m) | testBit b1 = go i (i1 + 1) i2 b' bFinal | otherwise = go i i1 (i2 + 1) b' bFinal where @@ -1857,9 +1861,9 @@ intersectionArrayByFilter f p !b1 !b2 !ary1 !ary2 = do where bCombined = b1 .|. b2 bIntersect = b1 .&. b2 -{-# INLINE intersectionArrayByFilter #-} +{-# INLINE intersectionArrayBy #-} -intersectionUnorderedArrayWithKey :: (Eq k) => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) +intersectionUnorderedArrayWithKey :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) intersectionUnorderedArrayWithKey f ary1 ary2 = do mary2 <- A.thaw ary2 0 $ A.length ary2 mary <- A.new_ $ A.length ary1 + A.length ary2 @@ -1876,7 +1880,7 @@ intersectionUnorderedArrayWithKey f ary1 ary2 = do go (i + 1) j maryLen <- go 0 0 pure (maryLen, mary) -{-# INLINABLE intersectionUnorderedArrayWithKey #-} +{-# INLINE intersectionUnorderedArrayWithKey #-} searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) searchSwap toFind start = go start toFind start diff --git a/Data/HashMap/Internal/Strict.hs b/Data/HashMap/Internal/Strict.hs index e5bee6e6..a8b74dc8 100644 --- a/Data/HashMap/Internal/Strict.hs +++ b/Data/HashMap/Internal/Strict.hs @@ -138,6 +138,7 @@ import Prelude hiding (lookup, map) import qualified Data.HashMap.Internal as HM import qualified Data.HashMap.Internal.Array as A import qualified Data.List as List +import GHC.Exts (inline) {- Note [Imports from Data.HashMap.Internal] @@ -616,11 +617,7 @@ differenceWith f a b = HM.foldlWithKey' go HM.empty a -- maps. intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWith f a b = HM.foldlWithKey' go HM.empty a - where - go m k v = case HM.lookup k b of - Just w -> let !x = f v w in HM.unsafeInsert k x m - _ -> m +intersectionWith f = inline intersectionWithKey $ const f {-# INLINABLE intersectionWith #-} -- | /O(n+m)/ Intersection of two maps. If a key occurs in both maps @@ -628,11 +625,7 @@ intersectionWith f a b = HM.foldlWithKey' go HM.empty a -- maps. intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWithKey f a b = HM.foldlWithKey' go HM.empty a - where - go m k v = case HM.lookup k b of - Just w -> let !x = f k v w in HM.unsafeInsert k x m - _ -> m +intersectionWithKey f = HM.intersectionWithKey# $ \k v1 v2 -> let !v3 = f k v1 v2 in (# v3 #) {-# INLINABLE intersectionWithKey #-} ------------------------------------------------------------------------ From b3cdbd8591b9f152a2596d416d7be8f83433d62c Mon Sep 17 00:00:00 2001 From: oberblastmeister <61095988+oberblastmeister@users.noreply.github.com> Date: Mon, 11 Apr 2022 19:25:45 -0400 Subject: [PATCH 17/28] Update Data/HashMap/Internal.hs Co-authored-by: Simon Jakobi --- Data/HashMap/Internal.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index c79b2883..81884068 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1863,7 +1863,7 @@ intersectionArrayBy f !b1 !b2 !ary1 !ary2 = do bIntersect = b1 .&. b2 {-# INLINE intersectionArrayBy #-} -intersectionUnorderedArrayWithKey :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) +intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) intersectionUnorderedArrayWithKey f ary1 ary2 = do mary2 <- A.thaw ary2 0 $ A.length ary2 mary <- A.new_ $ A.length ary1 + A.length ary2 From bf9a27f805fd3169c24828605c8f4bf7cbb90488 Mon Sep 17 00:00:00 2001 From: oberblastmeister <61095988+oberblastmeister@users.noreply.github.com> Date: Mon, 11 Apr 2022 19:25:56 -0400 Subject: [PATCH 18/28] Update Data/HashSet/Internal.hs Co-authored-by: Simon Jakobi --- Data/HashSet/Internal.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Data/HashSet/Internal.hs b/Data/HashSet/Internal.hs index 15b2254e..340bb742 100644 --- a/Data/HashSet/Internal.hs +++ b/Data/HashSet/Internal.hs @@ -391,7 +391,7 @@ difference (HashSet a) (HashSet b) = HashSet (H.difference a b) -- -- >>> HashSet.intersection (HashSet.fromList [1,2,3]) (HashSet.fromList [2,3,4]) -- fromList [2,3] -intersection :: (Show a, Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a +intersection :: (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a intersection (HashSet a) (HashSet b) = HashSet (H.intersection a b) {-# INLINABLE intersection #-} From 1c2073957211a5417dd6a30e92d4de2f8976c3dc Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Mon, 11 Apr 2022 19:29:07 -0400 Subject: [PATCH 19/28] naming --- Data/HashMap/Internal.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 81884068..74fffb27 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1785,7 +1785,7 @@ intersectionWithKey# f = go 0 -- collision vs. collision go _ (Collision h1 ls1) (Collision h2 ls2) | h1 == h2 = runST $ do - (len, mary) <- intersectionUnorderedArrayWithKey f ls1 ls2 + (len, mary) <- intersectionCollisions f ls1 ls2 case len of 0 -> pure Empty 1 -> Leaf h1 <$> A.read mary 0 @@ -1864,7 +1864,7 @@ intersectionArrayBy f !b1 !b2 !ary1 !ary2 = do {-# INLINE intersectionArrayBy #-} intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) -intersectionUnorderedArrayWithKey f ary1 ary2 = do +intersectionCollisions f ary1 ary2 = do mary2 <- A.thaw ary2 0 $ A.length ary2 mary <- A.new_ $ A.length ary1 + A.length ary2 let go i j @@ -1880,7 +1880,7 @@ intersectionUnorderedArrayWithKey f ary1 ary2 = do go (i + 1) j maryLen <- go 0 0 pure (maryLen, mary) -{-# INLINE intersectionUnorderedArrayWithKey #-} +{-# INLINE intersectionCollisions #-} searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) searchSwap toFind start = go start toFind start From 92e4b2a4b246e2bd2c65c6e1bec079bf3a5ebdff Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Mon, 11 Apr 2022 19:31:07 -0400 Subject: [PATCH 20/28] Exts.inline --- Data/HashMap/Internal.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 74fffb27..5840a618 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -155,7 +155,7 @@ import Data.Hashable (Hashable) import Data.Hashable.Lifted (Hashable1, Hashable2) import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare) import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid) -import GHC.Exts (Int (..), Int#, TYPE, (==#), inline) +import GHC.Exts (Int (..), Int#, TYPE, (==#)) import GHC.Stack (HasCallStack) import Prelude hiding (filter, foldl, foldr, lookup, map, null, pred) @@ -1756,14 +1756,14 @@ differenceWith f a b = foldlWithKey' go empty a -- | /O(n*log m)/ Intersection of two maps. Return elements of the first -- map for keys existing in the second. intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v -intersection = inline intersectionWith const +intersection = Exts.inline intersectionWith const {-# INLINABLE intersection #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps -- the provided function is used to combine the values from the two -- maps. intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWith f = inline intersectionWithKey $ const f +intersectionWith f = Exts.inline intersectionWithKey $ const f {-# INLINABLE intersectionWith #-} -- | /O(n*log m)/ Intersection of two maps. If a key occurs in both maps From 5a439cc754a863c05ef15be3403802957814b61e Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Tue, 12 Apr 2022 07:30:03 -0400 Subject: [PATCH 21/28] add haddocks for searchSwap --- Data/HashMap/Internal.hs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 5840a618..7261ea91 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1882,6 +1882,15 @@ intersectionCollisions f ary1 ary2 = do pure (maryLen, mary) {-# INLINE intersectionCollisions #-} +-- | Say we have +-- @ +-- 1 2 3 4 +-- @ +-- and we search for @3@. Then we can mutate the array to +-- @ +-- undefined 2 1 4 +-- @ +-- We don't actually need to write undefined, we just have to make sure that the next search starts 1 after the current one. searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v)) searchSwap toFind start = go start toFind start where From 1c118c4d8e98a997edf84c9f54115df43ea4215e Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Tue, 12 Apr 2022 14:42:42 -0400 Subject: [PATCH 22/28] cleanup --- Data/HashMap/Internal.hs | 12 +++++------- benchmarks/Benchmarks.hs | 6 +++++- tests/Properties/HashMapLazy.hs | 15 +-------------- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 7261ea91..d15e9b05 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1591,6 +1591,10 @@ unionWithKey f = go 0 ary' = update32With' ary2 i $ \st2 -> go (s+bitsPerSubkey) t1 st2 in Full ary' + leafHashCode (Leaf h _) = h + leafHashCode (Collision h _) = h + leafHashCode _ = error "leafHashCode" + goDifferentHash s h1 h2 t1 t2 | m1 == m2 = BitmapIndexed m1 (A.singleton $! goDifferentHash (s+bitsPerSubkey) h1 h2 t1 t2) | m1 < m2 = BitmapIndexed (m1 .|. m2) (A.pair t1 t2) @@ -1866,7 +1870,7 @@ intersectionArrayBy f !b1 !b2 !ary1 !ary2 = do intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) intersectionCollisions f ary1 ary2 = do mary2 <- A.thaw ary2 0 $ A.length ary2 - mary <- A.new_ $ A.length ary1 + A.length ary2 + mary <- A.new_ $ min (A.length ary1) (A.length ary2) let go i j | i >= A.length ary1 || j >= A.lengthM mary2 = pure j | otherwise = do @@ -2382,12 +2386,6 @@ ptrEq :: a -> a -> Bool ptrEq x y = Exts.isTrue# (Exts.reallyUnsafePtrEquality# x y ==# 1#) {-# INLINE ptrEq #-} -leafHashCode :: HashMap k v -> Hash -leafHashCode (Leaf h _) = h -leafHashCode (Collision h _) = h -leafHashCode _ = error "leafHashCode" -{-# INLINE leafHashCode #-} - ------------------------------------------------------------------------ -- IsList instance instance (Eq k, Hashable k) => Exts.IsList (HashMap k v) where diff --git a/benchmarks/Benchmarks.hs b/benchmarks/Benchmarks.hs index c0f7f550..ae05c422 100644 --- a/benchmarks/Benchmarks.hs +++ b/benchmarks/Benchmarks.hs @@ -318,13 +318,17 @@ main = do [ bench "Int" $ whnf (HM.union hmi) hmi2 , bench "ByteString" $ whnf (HM.union hmbs) hmbsSubset ] + + , bgroup "intersection" + [ bench "Int" $ whnf (HM.intersection hmi) hmi2 + , bench "ByteString" $ whnf (HM.intersection hmbs) hmbsSubset + ] -- Transformations , bench "map" $ whnf (HM.map (\ v -> v + 1)) hmi -- * Difference and intersection , bench "difference" $ whnf (HM.difference hmi) hmi2 - , bench "intersection" $ whnf (HM.intersection hmi) hmi2 -- Folds , bench "foldl'" $ whnf (HM.foldl' (+) 0) hmi diff --git a/tests/Properties/HashMapLazy.hs b/tests/Properties/HashMapLazy.hs index 1e32a7e0..933f3a9e 100644 --- a/tests/Properties/HashMapLazy.hs +++ b/tests/Properties/HashMapLazy.hs @@ -18,7 +18,6 @@ import Control.Applicative (Const (..)) import Control.Monad (guard) import Data.Bifoldable import Data.Function (on) -import Debug.Trace (traceId) import Data.Functor.Identity (Identity (..)) import Data.Hashable (Hashable (hashWithSalt)) import Data.Ord (comparing) @@ -252,15 +251,7 @@ pSubmapDifference m1 m2 = HM.isSubmapOf (HM.difference m1 m2) m1 pNotSubmapDifference :: HashMap Key Int -> HashMap Key Int -> Property pNotSubmapDifference m1 m2 = - not (HM.null (HM.intersection m1 m2)) ==> do - - let - res = HM.intersection m1 m2 - res' = M.intersection (M.fromList $ HM.toList m1) (M.fromList $ HM.toList m2) - -- !_ = traceId $ "res: " ++ show res - -- !_ = traceId $ "res': " ++ show res' - -- !_ = traceId $ "m1: " ++ show m1 - -- !_ = traceId $ "m2: " ++ show m2 + not (HM.null (HM.intersection m1 m2)) ==> not (HM.isSubmapOf m1 (HM.difference m1 m2)) pSubmapDelete :: HashMap Key Int -> Property @@ -334,9 +325,6 @@ pIntersection xs ys = `eq_` HM.intersection (HM.fromList xs) $ ys -intersectionBad :: Assertion -intersectionBad = pIntersection [(-20, 0), (0, 0)] [(0, 0), (20, 0)] @? "should be true" - pIntersectionWith :: [(Key, Int)] -> [(Key, Int)] -> Bool pIntersectionWith xs ys = M.intersectionWith (-) (M.fromList xs) `eq_` HM.intersectionWith (-) (HM.fromList xs) $ ys @@ -547,7 +535,6 @@ tests = [ testProperty "difference" pDifference , testProperty "differenceWith" pDifferenceWith , testProperty "intersection" pIntersection - , testCase "intersectionBad" intersectionBad , testProperty "intersectionWith" pIntersectionWith , testProperty "intersectionWithKey" pIntersectionWithKey ] From 1256cf3138e6ea63ab28a0ba6be57b44100857f7 Mon Sep 17 00:00:00 2001 From: oberblastmeister <61095988+oberblastmeister@users.noreply.github.com> Date: Tue, 12 Apr 2022 14:45:18 -0400 Subject: [PATCH 23/28] Update Data/HashMap/Internal/Array.hs Co-authored-by: Simon Jakobi --- Data/HashMap/Internal/Array.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index d63c03f8..05fadd9b 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -207,8 +207,8 @@ new _n@(I# n#) b = new_ :: Int -> ST s (MArray s a) new_ n = new n undefinedElem --- when shrinkSmallMutableArray# is available, the returned array is the same as the array given, as it is shrunk in place --- otherwise a copy is made +-- | When 'Exts.shrinkSmallMutableArray#' is available, the returned array is the same as the array given, as it is shrunk in place. +-- Otherwise a copy is made. shrink :: MArray s a -> Int -> ST s (MArray s a) #if MIN_VERSION_GLASGOW_HASKELL(8, 10, 7, 0) shrink mary _n@(I# n#) = From b0210c86e238e28b6ca72b44335fda73c2cb2647 Mon Sep 17 00:00:00 2001 From: oberblastmeister <61095988+oberblastmeister@users.noreply.github.com> Date: Tue, 12 Apr 2022 15:05:25 -0400 Subject: [PATCH 24/28] Update Data/HashMap/Internal/Array.hs Co-authored-by: Simon Jakobi --- Data/HashMap/Internal/Array.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Data/HashMap/Internal/Array.hs b/Data/HashMap/Internal/Array.hs index 05fadd9b..6fb52664 100644 --- a/Data/HashMap/Internal/Array.hs +++ b/Data/HashMap/Internal/Array.hs @@ -210,7 +210,7 @@ new_ n = new n undefinedElem -- | When 'Exts.shrinkSmallMutableArray#' is available, the returned array is the same as the array given, as it is shrunk in place. -- Otherwise a copy is made. shrink :: MArray s a -> Int -> ST s (MArray s a) -#if MIN_VERSION_GLASGOW_HASKELL(8, 10, 7, 0) +#if __GLASGOW_HASKELL__ >= 810 shrink mary _n@(I# n#) = CHECK_GT("shrink", _n, (0 :: Int)) CHECK_LE("shrink", _n, (lengthM mary)) From 69f8f286faa491e341c1081872e80ba195e17441 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Wed, 13 Apr 2022 17:00:49 -0400 Subject: [PATCH 25/28] refactor --- Data/HashMap/Internal.hs | 55 ++++++++++++++------------------- tests/Properties/HashMapLazy.hs | 1 - 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index d15e9b05..1d4e7e1b 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1787,19 +1787,12 @@ intersectionWithKey# f = go 0 go s (Leaf h1 (L k1 v1)) t2 = lookupCont (\_ -> Empty) (\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v') h1 k1 s t2 go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v') h2 k2 s t1 -- collision vs. collision - go _ (Collision h1 ls1) (Collision h2 ls2) - | h1 == h2 = runST $ do - (len, mary) <- intersectionCollisions f ls1 ls2 - case len of - 0 -> pure Empty - 1 -> Leaf h1 <$> A.read mary 0 - _ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) - | otherwise = Empty + go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2 -- branch vs. branch - go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArray s b1 b2 ary1 ary2 - go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArray s b1 fullNodeMask ary1 ary2 - go s (Full ary1) (BitmapIndexed b2 ary2) = intersectionArray s fullNodeMask b2 ary1 ary2 - go s (Full ary1) (Full ary2) = intersectionArray s fullNodeMask fullNodeMask ary1 ary2 + go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 + go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2 + go s (Full ary1) (BitmapIndexed b2 ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2 + go s (Full ary1) (Full ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2 -- collision vs. branch go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2) | b1 .&. m2 == 0 = Empty @@ -1819,16 +1812,6 @@ intersectionWithKey# f = go 0 go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i) where i = index h1 s - - intersectionArray s b1 b2 ary1 ary2 - -- don't create an array of size zero in intersectionArrayBy - | b1 .&. b2 == 0 = Empty - | otherwise = runST $ do - (b, len, ary) <- intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 - case len of - 0 -> pure Empty - 1 -> A.read ary 0 - _ -> bitmapIndexedOrFull b <$> (A.unsafeFreeze =<< A.shrink ary len) {-# INLINE intersectionWithKey# #-} intersectionArrayBy :: @@ -1840,10 +1823,12 @@ intersectionArrayBy :: Bitmap -> A.Array (HashMap k v1) -> A.Array (HashMap k v2) -> - ST s (Bitmap, Int, A.MArray s (HashMap k v3)) -intersectionArrayBy f !b1 !b2 !ary1 !ary2 = do + HashMap k v3 +intersectionArrayBy f !b1 !b2 !ary1 !ary2 + | b1 .&. b2 == 0 = Empty + | otherwise = runST $ do mary <- A.new_ $ popCount bIntersect - -- iterate over nonzero bits of b1 .&. b2 + -- iterate over nonzero bits of b1 .|. b2 let go !i !i1 !i2 !b !bFinal | b == 0 = pure (i, bFinal) | testBit $ b1 .&. b2 = do @@ -1860,15 +1845,19 @@ intersectionArrayBy f !b1 !b2 !ary1 !ary2 = do m = 1 `unsafeShiftL` countTrailingZeros b testBit x = x .&. m /= 0 b' = b .&. complement m - (maryLen, bFinal) <- go 0 0 0 bCombined bIntersect - pure (bFinal, maryLen, mary) + (len, bFinal) <- go 0 0 0 bCombined bIntersect + case len of + 0 -> pure Empty + 1 -> A.read mary 0 + _ -> bitmapIndexedOrFull bFinal <$> (A.unsafeFreeze =<< A.shrink mary len) where bCombined = b1 .|. b2 bIntersect = b1 .&. b2 {-# INLINE intersectionArrayBy #-} -intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> ST s (Int, A.MArray s (Leaf k v3)) -intersectionCollisions f ary1 ary2 = do +intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> Hash -> Hash -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> HashMap k v3 +intersectionCollisions f h1 h2 ary1 ary2 + | h1 == h2 = runST $ do mary2 <- A.thaw ary2 0 $ A.length ary2 mary <- A.new_ $ min (A.length ary1) (A.length ary2) let go i j @@ -1882,8 +1871,12 @@ intersectionCollisions f ary1 ary2 = do go (i + 1) (j + 1) Nothing -> do go (i + 1) j - maryLen <- go 0 0 - pure (maryLen, mary) + len <- go 0 0 + case len of + 0 -> pure Empty + 1 -> Leaf h1 <$> A.read mary 0 + _ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) + | otherwise = Empty {-# INLINE intersectionCollisions #-} -- | Say we have diff --git a/tests/Properties/HashMapLazy.hs b/tests/Properties/HashMapLazy.hs index 933f3a9e..f25fb42d 100644 --- a/tests/Properties/HashMapLazy.hs +++ b/tests/Properties/HashMapLazy.hs @@ -27,7 +27,6 @@ import Test.QuickCheck.Function (Fun, apply) import Test.QuickCheck.Poly (A, B) import Test.Tasty (TestTree, testGroup) import Test.Tasty.QuickCheck (testProperty) -import Test.Tasty.HUnit import qualified Data.Foldable as Foldable import qualified Data.List as List From 06cc511f311d8b2c3610dc7f828a14f11be9fa2e Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Wed, 13 Apr 2022 17:20:46 -0400 Subject: [PATCH 26/28] formatting --- Data/HashMap/Internal.hs | 96 ++++++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 1d4e7e1b..7d9a4ee2 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1784,8 +1784,16 @@ intersectionWithKey# f = go 0 go !_ _ Empty = Empty go _ Empty _ = Empty -- leaf vs. anything - go s (Leaf h1 (L k1 v1)) t2 = lookupCont (\_ -> Empty) (\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v') h1 k1 s t2 - go s t1 (Leaf h2 (L k2 v2)) = lookupCont (\_ -> Empty) (\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v') h2 k2 s t1 + go s (Leaf h1 (L k1 v1)) t2 = + lookupCont + (\_ -> Empty) + (\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v') + h1 k1 s t2 + go s t1 (Leaf h2 (L k2 v2)) = + lookupCont + (\_ -> Empty) + (\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v') + h2 k2 s t1 -- collision vs. collision go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2 -- branch vs. branch @@ -1827,29 +1835,29 @@ intersectionArrayBy :: intersectionArrayBy f !b1 !b2 !ary1 !ary2 | b1 .&. b2 == 0 = Empty | otherwise = runST $ do - mary <- A.new_ $ popCount bIntersect - -- iterate over nonzero bits of b1 .|. b2 - let go !i !i1 !i2 !b !bFinal - | b == 0 = pure (i, bFinal) - | testBit $ b1 .&. b2 = do - x1 <- A.indexM ary1 i1 - x2 <- A.indexM ary2 i2 - case f x1 x2 of - Empty -> go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m) - _ -> do - A.write mary i $! f x1 x2 - go (i + 1) (i1 + 1) (i2 + 1) b' bFinal - | testBit b1 = go i (i1 + 1) i2 b' bFinal - | otherwise = go i i1 (i2 + 1) b' bFinal - where - m = 1 `unsafeShiftL` countTrailingZeros b - testBit x = x .&. m /= 0 - b' = b .&. complement m - (len, bFinal) <- go 0 0 0 bCombined bIntersect - case len of - 0 -> pure Empty - 1 -> A.read mary 0 - _ -> bitmapIndexedOrFull bFinal <$> (A.unsafeFreeze =<< A.shrink mary len) + mary <- A.new_ $ popCount bIntersect + -- iterate over nonzero bits of b1 .|. b2 + let go !i !i1 !i2 !b !bFinal + | b == 0 = pure (i, bFinal) + | testBit $ b1 .&. b2 = do + x1 <- A.indexM ary1 i1 + x2 <- A.indexM ary2 i2 + case f x1 x2 of + Empty -> go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m) + _ -> do + A.write mary i $! f x1 x2 + go (i + 1) (i1 + 1) (i2 + 1) b' bFinal + | testBit b1 = go i (i1 + 1) i2 b' bFinal + | otherwise = go i i1 (i2 + 1) b' bFinal + where + m = 1 `unsafeShiftL` countTrailingZeros b + testBit x = x .&. m /= 0 + b' = b .&. complement m + (len, bFinal) <- go 0 0 0 bCombined bIntersect + case len of + 0 -> pure Empty + 1 -> A.read mary 0 + _ -> bitmapIndexedOrFull bFinal <$> (A.unsafeFreeze =<< A.shrink mary len) where bCombined = b1 .|. b2 bIntersect = b1 .&. b2 @@ -1857,25 +1865,25 @@ intersectionArrayBy f !b1 !b2 !ary1 !ary2 intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> Hash -> Hash -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> HashMap k v3 intersectionCollisions f h1 h2 ary1 ary2 - | h1 == h2 = runST $ do - mary2 <- A.thaw ary2 0 $ A.length ary2 - mary <- A.new_ $ min (A.length ary1) (A.length ary2) - let go i j - | i >= A.length ary1 || j >= A.lengthM mary2 = pure j - | otherwise = do - L k1 v1 <- A.indexM ary1 i - searchSwap k1 j mary2 >>= \case - Just (L _k2 v2) -> do - let !(# v3 #) = f k1 v1 v2 - A.write mary j $ L k1 v3 - go (i + 1) (j + 1) - Nothing -> do - go (i + 1) j - len <- go 0 0 - case len of - 0 -> pure Empty - 1 -> Leaf h1 <$> A.read mary 0 - _ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) + | h1 == h2 = runST $ do + mary2 <- A.thaw ary2 0 $ A.length ary2 + mary <- A.new_ $ min (A.length ary1) (A.length ary2) + let go i j + | i >= A.length ary1 || j >= A.lengthM mary2 = pure j + | otherwise = do + L k1 v1 <- A.indexM ary1 i + searchSwap k1 j mary2 >>= \case + Just (L _k2 v2) -> do + let !(# v3 #) = f k1 v1 v2 + A.write mary j $ L k1 v3 + go (i + 1) (j + 1) + Nothing -> do + go (i + 1) j + len <- go 0 0 + case len of + 0 -> pure Empty + 1 -> Leaf h1 <$> A.read mary 0 + _ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len) | otherwise = Empty {-# INLINE intersectionCollisions #-} From d9a50d70dc64aac7a605e464cea8072970d73cad Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Wed, 13 Apr 2022 17:22:33 -0400 Subject: [PATCH 27/28] breakup lines --- Data/HashMap/Internal.hs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/Data/HashMap/Internal.hs b/Data/HashMap/Internal.hs index 7d9a4ee2..37cc63ae 100644 --- a/Data/HashMap/Internal.hs +++ b/Data/HashMap/Internal.hs @@ -1797,10 +1797,14 @@ intersectionWithKey# f = go 0 -- collision vs. collision go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2 -- branch vs. branch - go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 - go s (BitmapIndexed b1 ary1) (Full ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2 - go s (Full ary1) (BitmapIndexed b2 ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2 - go s (Full ary1) (Full ary2) = intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2 + go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2 + go s (BitmapIndexed b1 ary1) (Full ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2 + go s (Full ary1) (BitmapIndexed b2 ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2 + go s (Full ary1) (Full ary2) = + intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2 -- collision vs. branch go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2) | b1 .&. m2 == 0 = Empty From d24cc1f48cef1c7f6bc041e7e73ee40a1fc7b731 Mon Sep 17 00:00:00 2001 From: Brian Shu Date: Thu, 14 Apr 2022 16:54:32 -0400 Subject: [PATCH 28/28] use Exts.inline --- Data/HashMap/Internal/Strict.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Data/HashMap/Internal/Strict.hs b/Data/HashMap/Internal/Strict.hs index a8b74dc8..798292db 100644 --- a/Data/HashMap/Internal/Strict.hs +++ b/Data/HashMap/Internal/Strict.hs @@ -128,17 +128,17 @@ import Data.Bits ((.&.), (.|.)) import Data.Coerce (coerce) import Data.Functor.Identity (Identity (..)) -- See Note [Imports from Data.HashMap.Internal] +import Data.Hashable (Hashable) import Data.HashMap.Internal (Hash, HashMap (..), Leaf (..), LookupRes (..), bitsPerSubkey, fullNodeMask, hash, index, mask, ptrEq, sparseIndex) -import Data.Hashable (Hashable) import Prelude hiding (lookup, map) -- See Note [Imports from Data.HashMap.Internal] import qualified Data.HashMap.Internal as HM import qualified Data.HashMap.Internal.Array as A import qualified Data.List as List -import GHC.Exts (inline) +import qualified GHC.Exts as Exts {- Note [Imports from Data.HashMap.Internal] @@ -617,7 +617,7 @@ differenceWith f a b = HM.foldlWithKey' go HM.empty a -- maps. intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3 -intersectionWith f = inline intersectionWithKey $ const f +intersectionWith f = Exts.inline intersectionWithKey $ const f {-# INLINABLE intersectionWith #-} -- | /O(n+m)/ Intersection of two maps. If a key occurs in both maps