Skip to content

Make intersections much faster #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
21f238b
fast intersection
oberblastmeister Apr 9, 2022
16f1f7f
cleanup
oberblastmeister Apr 9, 2022
bcc13fc
add show back
oberblastmeister Apr 9, 2022
d5262bf
inline
oberblastmeister Apr 9, 2022
a16456b
debug checks
oberblastmeister Apr 9, 2022
f72011c
inline function
oberblastmeister Apr 9, 2022
678a38c
refactor to use snoc
oberblastmeister Apr 9, 2022
ec24215
Try the unboxed result thing
treeowl Apr 9, 2022
767ae6e
Remove redundant internal constraint
treeowl Apr 9, 2022
72510b4
Merge pull request #3 from treeowl/unboxedness
oberblastmeister Apr 9, 2022
fd43ba7
shrink compat
oberblastmeister Apr 9, 2022
3612645
fix import
oberblastmeister Apr 9, 2022
b484042
use clone
oberblastmeister Apr 9, 2022
9e48bc0
oof
oberblastmeister Apr 9, 2022
48119cb
don't shrink to zero
oberblastmeister Apr 9, 2022
d9d295d
Leaf special case
oberblastmeister Apr 9, 2022
88a9c2c
add strict verisons
oberblastmeister Apr 10, 2022
b3cdbd8
Update Data/HashMap/Internal.hs
oberblastmeister Apr 11, 2022
bf9a27f
Update Data/HashSet/Internal.hs
oberblastmeister Apr 11, 2022
1c20739
naming
oberblastmeister Apr 11, 2022
92e4b2a
Exts.inline
oberblastmeister Apr 11, 2022
5a439cc
add haddocks for searchSwap
oberblastmeister Apr 12, 2022
1c118c4
cleanup
oberblastmeister Apr 12, 2022
1256cf3
Update Data/HashMap/Internal/Array.hs
oberblastmeister Apr 12, 2022
b0210c8
Update Data/HashMap/Internal/Array.hs
oberblastmeister Apr 12, 2022
69f8f28
refactor
oberblastmeister Apr 13, 2022
06cc511
formatting
oberblastmeister Apr 13, 2022
d9a50d7
breakup lines
oberblastmeister Apr 13, 2022
d24cc1f
use Exts.inline
oberblastmeister Apr 14, 2022
64f3f2f
Merge branch 'master' into fast-intersection
oberblastmeister Apr 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 147 additions & 22 deletions Data/HashMap/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ module Data.HashMap.Internal
, intersection
, intersectionWith
, intersectionWithKey
, intersectionWithKey#

-- * Folds
, foldr'
Expand Down Expand Up @@ -150,9 +151,9 @@ import Data.Data (Constr, Data (..), DataType)
import Data.Functor.Classes (Eq1 (..), Eq2 (..), Ord1 (..), Ord2 (..),
Read1 (..), Show1 (..), Show2 (..))
import Data.Functor.Identity (Identity (..))
import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare)
import Data.Hashable (Hashable)
import Data.Hashable.Lifted (Hashable1, Hashable2)
import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, the changed sorting of imports is probably due to haskell/stylish-haskell#385, which was recently released.

import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid)
import GHC.Exts (Int (..), Int#, TYPE, (==#))
import GHC.Stack (HasCallStack)
Expand All @@ -163,9 +164,9 @@ import Text.Read hiding (step)
import qualified Data.Data as Data
import qualified Data.Foldable as Foldable
import qualified Data.Functor.Classes as FC
import qualified Data.HashMap.Internal.Array as A
import qualified Data.Hashable as H
import qualified Data.Hashable.Lifted as H
import qualified Data.HashMap.Internal.Array as A
import qualified Data.List as List
import qualified GHC.Exts as Exts
import qualified Language.Haskell.TH.Syntax as TH
Expand Down Expand Up @@ -1627,7 +1628,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do
A.write mary i =<< A.indexM ary2 i2
go (i+1) i1 (i2+1) b'
where
m = 1 `unsafeShiftL` (countTrailingZeros b)
m = 1 `unsafeShiftL` countTrailingZeros b
testBit x = x .&. m /= 0
b' = b .&. complement m
go 0 0 0 bCombined
Expand Down Expand Up @@ -1759,37 +1760,161 @@ differenceWith f a b = foldlWithKey' go empty a
-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
-- map for keys existing in the second.
intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v
intersection a b = foldlWithKey' go empty a
where
go m k v = case lookup k b of
Just _ -> unsafeInsert k v m
_ -> m
intersection = Exts.inline intersectionWith const
{-# INLINABLE intersection #-}

-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
-- the provided function is used to combine the values from the two
-- maps.
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1
-> HashMap k v2 -> HashMap k v3
intersectionWith f a b = foldlWithKey' go empty a
where
go m k v = case lookup k b of
Just w -> unsafeInsert k (f v w) m
_ -> m
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWith f = Exts.inline intersectionWithKey $ const f
{-# INLINABLE intersectionWith #-}

-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
-- the provided function is used to combine the values from the two
-- maps.
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3)
-> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWithKey f a b = foldlWithKey' go empty a
where
go m k v = case lookup k b of
Just w -> unsafeInsert k (f k v w) m
_ -> m
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWithKey f = intersectionWithKey# $ \k v1 v2 -> (# f k v1 v2 #)
{-# INLINABLE intersectionWithKey #-}

intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWithKey# f = go 0
where
-- empty vs. anything
go !_ _ Empty = Empty
go _ Empty _ = Empty
-- leaf vs. anything
go s (Leaf h1 (L k1 v1)) t2 =
lookupCont
(\_ -> Empty)
(\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v')
h1 k1 s t2
go s t1 (Leaf h2 (L k2 v2)) =
lookupCont
(\_ -> Empty)
(\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v')
h2 k2 s t1
-- collision vs. collision
go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2
-- branch vs. branch
go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2
go s (BitmapIndexed b1 ary1) (Full ary2) =
intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2
go s (Full ary1) (BitmapIndexed b2 ary2) =
intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2
go s (Full ary1) (Full ary2) =
intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2
-- collision vs. branch
go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2)
| b1 .&. m2 == 0 = Empty
| otherwise = go (s + bitsPerSubkey) (A.index ary1 i) t2
where
m2 = mask h2 s
i = sparseIndex b1 m2
go s t1@(Collision h1 _ls1) (BitmapIndexed b2 ary2)
| b2 .&. m1 == 0 = Empty
| otherwise = go (s + bitsPerSubkey) t1 (A.index ary2 i)
where
m1 = mask h1 s
i = sparseIndex b2 m1
go s (Full ary1) t2@(Collision h2 _ls2) = go (s + bitsPerSubkey) (A.index ary1 i) t2
where
i = index h2 s
go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i)
where
i = index h1 s
{-# INLINE intersectionWithKey# #-}

intersectionArrayBy ::
( HashMap k v1 ->
HashMap k v2 ->
HashMap k v3
) ->
Bitmap ->
Bitmap ->
A.Array (HashMap k v1) ->
A.Array (HashMap k v2) ->
HashMap k v3
intersectionArrayBy f !b1 !b2 !ary1 !ary2
| b1 .&. b2 == 0 = Empty
| otherwise = runST $ do
mary <- A.new_ $ popCount bIntersect
-- iterate over nonzero bits of b1 .|. b2
let go !i !i1 !i2 !b !bFinal
| b == 0 = pure (i, bFinal)
| testBit $ b1 .&. b2 = do
x1 <- A.indexM ary1 i1
x2 <- A.indexM ary2 i2
case f x1 x2 of
Empty -> go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m)
_ -> do
A.write mary i $! f x1 x2
go (i + 1) (i1 + 1) (i2 + 1) b' bFinal
| testBit b1 = go i (i1 + 1) i2 b' bFinal
| otherwise = go i i1 (i2 + 1) b' bFinal
where
m = 1 `unsafeShiftL` countTrailingZeros b
testBit x = x .&. m /= 0
b' = b .&. complement m
(len, bFinal) <- go 0 0 0 bCombined bIntersect
case len of
0 -> pure Empty
1 -> A.read mary 0
_ -> bitmapIndexedOrFull bFinal <$> (A.unsafeFreeze =<< A.shrink mary len)
where
bCombined = b1 .|. b2
bIntersect = b1 .&. b2
{-# INLINE intersectionArrayBy #-}

intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> Hash -> Hash -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> HashMap k v3
intersectionCollisions f h1 h2 ary1 ary2
| h1 == h2 = runST $ do
mary2 <- A.thaw ary2 0 $ A.length ary2
mary <- A.new_ $ min (A.length ary1) (A.length ary2)
let go i j
| i >= A.length ary1 || j >= A.lengthM mary2 = pure j
| otherwise = do
L k1 v1 <- A.indexM ary1 i
searchSwap k1 j mary2 >>= \case
Just (L _k2 v2) -> do
let !(# v3 #) = f k1 v1 v2
A.write mary j $ L k1 v3
go (i + 1) (j + 1)
Nothing -> do
go (i + 1) j
len <- go 0 0
case len of
0 -> pure Empty
1 -> Leaf h1 <$> A.read mary 0
_ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len)
| otherwise = Empty
{-# INLINE intersectionCollisions #-}

-- | Say we have
-- @
-- 1 2 3 4
-- @
-- and we search for @3@. Then we can mutate the array to
-- @
-- undefined 2 1 4
-- @
-- We don't actually need to write undefined, we just have to make sure that the next search starts 1 after the current one.
searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v))
searchSwap toFind start = go start toFind start
where
go i0 k i mary
| i >= A.lengthM mary = pure Nothing
| otherwise = do
l@(L k' _v) <- A.read mary i
if k == k'
then do
A.write mary i =<< A.read mary i0
pure $ Just l
else go i0 k (i + 1) mary
{-# INLINE searchSwap #-}


------------------------------------------------------------------------
-- * Folds

Expand Down
16 changes: 16 additions & 0 deletions Data/HashMap/Internal/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ module Data.HashMap.Internal.Array
, toList
, fromList
, fromList'
, shrink
) where

import Control.Applicative (liftA2)
Expand All @@ -92,6 +93,7 @@ import GHC.Exts (Int (..), SmallArray#, SmallMutableArray#,
thawSmallArray#, unsafeCoerce#,
unsafeFreezeSmallArray#, unsafeThawSmallArray#,
writeSmallArray#)
import qualified GHC.Exts as Exts
import GHC.ST (ST (..))
import Prelude hiding (all, filter, foldMap, foldl, foldr, length,
map, read, traverse)
Expand Down Expand Up @@ -205,6 +207,20 @@ new _n@(I# n#) b =
new_ :: Int -> ST s (MArray s a)
new_ n = new n undefinedElem

-- | When 'Exts.shrinkSmallMutableArray#' is available, the returned array is the same as the array given, as it is shrunk in place.
-- Otherwise a copy is made.
shrink :: MArray s a -> Int -> ST s (MArray s a)
#if __GLASGOW_HASKELL__ >= 810
shrink mary _n@(I# n#) =
CHECK_GT("shrink", _n, (0 :: Int))
CHECK_LE("shrink", _n, (lengthM mary))
ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of
s' -> (# s', mary #)
#else
shrink mary n = cloneM mary 0 n
#endif
{-# INLINE shrink #-}

singleton :: a -> Array a
singleton x = runST (singletonM x)
{-# INLINE singleton #-}
Expand Down
15 changes: 4 additions & 11 deletions Data/HashMap/Internal/Strict.hs
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,17 @@ import Data.Bits ((.&.), (.|.))
import Data.Coerce (coerce)
import Data.Functor.Identity (Identity (..))
-- See Note [Imports from Data.HashMap.Internal]
import Data.Hashable (Hashable)
import Data.HashMap.Internal (Hash, HashMap (..), Leaf (..), LookupRes (..),
bitsPerSubkey, fullNodeMask, hash, index, mask,
ptrEq, sparseIndex)
import Data.Hashable (Hashable)
import Prelude hiding (lookup, map)

-- See Note [Imports from Data.HashMap.Internal]
import qualified Data.HashMap.Internal as HM
import qualified Data.HashMap.Internal.Array as A
import qualified Data.List as List
import qualified GHC.Exts as Exts

{-
Note [Imports from Data.HashMap.Internal]
Expand Down Expand Up @@ -616,23 +617,15 @@ differenceWith f a b = HM.foldlWithKey' go HM.empty a
-- maps.
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1
-> HashMap k v2 -> HashMap k v3
intersectionWith f a b = HM.foldlWithKey' go HM.empty a
where
go m k v = case HM.lookup k b of
Just w -> let !x = f v w in HM.unsafeInsert k x m
_ -> m
intersectionWith f = Exts.inline intersectionWithKey $ const f
{-# INLINABLE intersectionWith #-}

-- | \(O(n+m)\) Intersection of two maps. If a key occurs in both maps
-- the provided function is used to combine the values from the two
-- maps.
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3)
-> HashMap k v1 -> HashMap k v2 -> HashMap k v3
intersectionWithKey f a b = HM.foldlWithKey' go HM.empty a
where
go m k v = case HM.lookup k b of
Just w -> let !x = f k v w in HM.unsafeInsert k x m
_ -> m
intersectionWithKey f = HM.intersectionWithKey# $ \k v1 v2 -> let !v3 = f k v1 v2 in (# v3 #)
{-# INLINABLE intersectionWithKey #-}

------------------------------------------------------------------------
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/Benchmarks.hs
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,17 @@ main = do
[ bench "Int" $ whnf (HM.union hmi) hmi2
, bench "ByteString" $ whnf (HM.union hmbs) hmbsSubset
]

, bgroup "intersection"
[ bench "Int" $ whnf (HM.intersection hmi) hmi2
, bench "ByteString" $ whnf (HM.intersection hmbs) hmbsSubset
]

-- Transformations
, bench "map" $ whnf (HM.map (\ v -> v + 1)) hmi

-- * Difference and intersection
, bench "difference" $ whnf (HM.difference hmi) hmi2
, bench "intersection" $ whnf (HM.intersection hmi) hmi2

-- Folds
, bench "foldl'" $ whnf (HM.foldl' (+) 0) hmi
Expand Down
9 changes: 6 additions & 3 deletions tests/Properties/HashMapLazy.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# OPTIONS_GHC -fno-warn-orphans #-} -- because of Arbitrary (HashMap k v)
{-# LANGUAGE BangPatterns #-}

-- | Tests for the 'Data.HashMap.Lazy' module. We test functions by
-- comparing them to @Map@ from @containers@.
Expand Down Expand Up @@ -42,7 +43,7 @@ import qualified Data.Map.Lazy as M

-- Key type that generates more hash collisions.
newtype Key = K { unK :: Int }
deriving (Arbitrary, Eq, Ord, Read, Show)
deriving (Arbitrary, Eq, Ord, Read, Show, Num)

instance Hashable Key where
hashWithSalt salt k = hashWithSalt salt (unK k) `mod` 20
Expand Down Expand Up @@ -318,8 +319,10 @@ pDifferenceWith xs ys = M.differenceWith f (M.fromList xs) `eq_`
f x y = if x == 0 then Nothing else Just (x - y)

pIntersection :: [(Key, Int)] -> [(Key, Int)] -> Bool
pIntersection xs ys = M.intersection (M.fromList xs) `eq_`
HM.intersection (HM.fromList xs) $ ys
pIntersection xs ys =
M.intersection (M.fromList xs)
`eq_` HM.intersection (HM.fromList xs)
$ ys

pIntersectionWith :: [(Key, Int)] -> [(Key, Int)] -> Bool
pIntersectionWith xs ys = M.intersectionWith (-) (M.fromList xs) `eq_`
Expand Down