Skip to content

Commit dc98ae7

Browse files
authored
Update and export set intersection utilities (#1040)
* Update Set.intersections to be lazier * Mark definitions INLINABLE for specialization * Add matching Intersection and intersections for IntSet * Add property tests
1 parent 2a109ad commit dc98ae7

File tree

7 files changed

+133
-11
lines changed

7 files changed

+133
-11
lines changed

containers-tests/tests/intset-properties.hs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ import Data.List (nub,sort)
77
import qualified Data.List as List
88
import Data.Maybe (listToMaybe)
99
import Data.Monoid (mempty)
10+
#if MIN_VERSION_base(4,18,0)
11+
import Data.List.NonEmpty (NonEmpty(..))
12+
import qualified Data.List.NonEmpty as NE
13+
import qualified Data.Foldable1 as Foldable1
14+
#endif
1015
import qualified Data.Set as Set
1116
import IntSetValidity (valid)
1217
import Prelude hiding (lookup, null, map, filter, foldr, foldl, foldl')
@@ -82,6 +87,10 @@ main = defaultMain $ testGroup "intset-properties"
8287
, testProperty "prop_bitcount" prop_bitcount
8388
, testProperty "prop_alterF_list" prop_alterF_list
8489
, testProperty "prop_alterF_const" prop_alterF_const
90+
#if MIN_VERSION_base(4,18,0)
91+
, testProperty "intersections" prop_intersections
92+
, testProperty "intersections_lazy" prop_intersections_lazy
93+
#endif
8594
]
8695

8796
----------------------------------------------------------------
@@ -500,3 +509,18 @@ prop_alterF_const
500509
prop_alterF_const f k s =
501510
getConst (alterF (Const . applyFun f) k s )
502511
=== getConst (Set.alterF (Const . applyFun f) k (toSet s))
512+
513+
#if MIN_VERSION_base(4,18,0)
514+
prop_intersections :: (IntSet, [IntSet]) -> Property
515+
prop_intersections (s, ss) =
516+
intersections ss' === Foldable1.foldl1' intersection ss'
517+
where
518+
ss' = s :| ss -- Work around missing Arbitrary NonEmpty instance
519+
520+
prop_intersections_lazy :: [IntSet] -> Property
521+
prop_intersections_lazy ss = intersections ss' === empty
522+
where
523+
ss' = NE.fromList $ ss ++ [empty] ++ undefined
524+
-- ^ result will certainly be empty at this point,
525+
-- so the rest of the list should not be demanded.
526+
#endif

containers-tests/tests/set-properties.hs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ import Control.Monad (liftM, liftM3)
1616
import Data.Functor.Identity
1717
import Data.Foldable (all)
1818
import Control.Applicative (liftA2)
19+
#if MIN_VERSION_base(4,18,0)
20+
import Data.List.NonEmpty (NonEmpty(..))
21+
import qualified Data.List.NonEmpty as NE
22+
import qualified Data.Foldable1 as Foldable1
23+
#endif
1924

2025
#if __GLASGOW_HASKELL__ >= 806
2126
import Utils.NoThunks (whnfHasNoThunks)
@@ -112,6 +117,10 @@ main = defaultMain $ testGroup "set-properties"
112117
#endif
113118
, testProperty "eq" prop_eq
114119
, testProperty "compare" prop_compare
120+
#if MIN_VERSION_base(4,18,0)
121+
, testProperty "intersections" prop_intersections
122+
, testProperty "intersections_lazy" prop_intersections_lazy
123+
#endif
115124
]
116125

117126
-- A type with a peculiar Eq instance designed to make sure keys
@@ -738,3 +747,18 @@ prop_eq s1 s2 = (s1 == s2) === (toList s1 == toList s2)
738747

739748
prop_compare :: Set Int -> Set Int -> Property
740749
prop_compare s1 s2 = compare s1 s2 === compare (toList s1) (toList s2)
750+
751+
#if MIN_VERSION_base(4,18,0)
752+
prop_intersections :: (Set Int, [Set Int]) -> Property
753+
prop_intersections (s, ss) =
754+
intersections ss' === Foldable1.foldl1' intersection ss'
755+
where
756+
ss' = s :| ss -- Work around missing Arbitrary NonEmpty instance
757+
758+
prop_intersections_lazy :: [Set Int] -> Property
759+
prop_intersections_lazy ss = intersections ss' === empty
760+
where
761+
ss' = NE.fromList $ ss ++ [empty] ++ undefined
762+
-- ^ result will certainly be empty at this point,
763+
-- so the rest of the list should not be demanded.
764+
#endif

containers/changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939

4040
* Add `lookupMin` and `lookupMax` for `Data.IntSet`. (Soumik Sarkar)
4141

42+
* Add `Intersection` and `intersections` for `Data.Set` and `Data.IntSet`.
43+
(Reed Mullanix, Soumik Sarkar)
44+
4245
## Unreleased with `@since` annotation for 0.7.1:
4346

4447
### Additions

containers/src/Data/IntSet.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ module Data.IntSet (
109109
, difference
110110
, (\\)
111111
, intersection
112+
#if MIN_VERSION_base(4,18,0)
113+
, intersections
114+
#endif
112115
, symmetricDifference
116+
, Intersection(..)
113117

114118
-- * Filter
115119
, IS.filter

containers/src/Data/IntSet/Internal.hs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,11 @@ module Data.IntSet.Internal (
125125
, unions
126126
, difference
127127
, intersection
128+
#if MIN_VERSION_base(4,18,0)
129+
, intersections
130+
#endif
128131
, symmetricDifference
132+
, Intersection(..)
129133

130134
-- * Filter
131135
, filter
@@ -192,11 +196,15 @@ import Control.DeepSeq (NFData(rnf))
192196
import Data.Bits
193197
import qualified Data.List as List
194198
import Data.Maybe (fromMaybe)
195-
import Data.Semigroup (Semigroup(stimes))
199+
import Data.Semigroup
200+
(Semigroup(stimes), stimesIdempotent, stimesIdempotentMonoid)
196201
#if !(MIN_VERSION_base(4,11,0))
197202
import Data.Semigroup (Semigroup((<>)))
198203
#endif
199-
import Data.Semigroup (stimesIdempotentMonoid)
204+
#if MIN_VERSION_base(4,18,0)
205+
import qualified Data.Foldable1 as Foldable1
206+
import Data.List.NonEmpty (NonEmpty(..))
207+
#endif
200208
import Utils.Containers.Internal.Prelude hiding
201209
(filter, foldr, foldl, foldl', null, map)
202210
import Prelude ()
@@ -659,6 +667,40 @@ intersection (Tip kx1 bm1) t2 = intersectBM t2
659667

660668
intersection Nil _ = Nil
661669

670+
#if MIN_VERSION_base(4,18,0)
671+
-- | The intersection of a series of sets. Intersections are performed
672+
-- left-to-right.
673+
--
674+
-- @since FIXME
675+
intersections :: Foldable1.Foldable1 f => f IntSet -> IntSet
676+
intersections ss = case Foldable1.toNonEmpty ss of
677+
s0 :| ss'
678+
| null s0 -> empty
679+
| otherwise -> List.foldr go id ss' s0
680+
where
681+
go s r acc
682+
| null acc' = empty
683+
| otherwise = r acc'
684+
where
685+
acc' = intersection acc s
686+
{-# INLINABLE intersections #-}
687+
#endif
688+
689+
-- | @IntSet@s form a 'Semigroup' under 'intersection'.
690+
--
691+
-- A @Monoid@ instance is not defined because it would be impractical to
692+
-- construct @mempty@, the @IntSet@ containing all @Int@s.
693+
--
694+
-- @since FIXME
695+
newtype Intersection = Intersection { getIntersection :: IntSet }
696+
deriving (Show, Eq, Ord)
697+
698+
instance Semigroup Intersection where
699+
Intersection s1 <> Intersection s2 = Intersection (intersection s1 s2)
700+
701+
stimes = stimesIdempotent
702+
{-# INLINABLE stimes #-}
703+
662704
{--------------------------------------------------------------------
663705
Symmetric difference
664706
--------------------------------------------------------------------}

containers/src/Data/Set.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,13 @@ module Data.Set (
117117
, difference
118118
, (\\)
119119
, intersection
120+
#if MIN_VERSION_base(4,18,0)
121+
, intersections
122+
#endif
120123
, symmetricDifference
121124
, cartesianProduct
122125
, disjointUnion
126+
, Intersection(..)
123127

124128
-- * Filter
125129
, S.filter

containers/src/Data/Set/Internal.hs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ module Data.Set.Internal (
155155
, unions
156156
, difference
157157
, intersection
158+
#if MIN_VERSION_base(4,18,0)
158159
, intersections
160+
#endif
159161
, symmetricDifference
160162
, cartesianProduct
161163
, disjointUnion
@@ -240,7 +242,6 @@ import Control.Applicative (Const(..))
240242
import qualified Data.List as List
241243
import Data.Bits (shiftL, shiftR)
242244
import Data.Semigroup (Semigroup(stimes))
243-
import Data.List.NonEmpty (NonEmpty(..))
244245
#if !(MIN_VERSION_base(4,11,0))
245246
import Data.Semigroup (Semigroup((<>)))
246247
#endif
@@ -249,6 +250,10 @@ import Data.Functor.Classes
249250
import Data.Functor.Identity (Identity)
250251
import qualified Data.Foldable as Foldable
251252
import Control.DeepSeq (NFData(rnf))
253+
#if MIN_VERSION_base(4,18,0)
254+
import qualified Data.Foldable1 as Foldable1
255+
import Data.List.NonEmpty (NonEmpty(..))
256+
#endif
252257

253258
import Utils.Containers.Internal.StrictPair
254259
import Utils.Containers.Internal.PtrEquality
@@ -894,21 +899,37 @@ intersection t1@(Bin _ x l1 r1) t2
894899
{-# INLINABLE intersection #-}
895900
#endif
896901

897-
-- | The intersection of a series of sets. Intersections are performed left-to-right.
898-
intersections :: Ord a => NonEmpty (Set a) -> Set a
899-
intersections (s0 :| ss) = List.foldr go id ss s0
900-
where
901-
go s r acc
902-
| null acc = empty
903-
| otherwise = r (intersection acc s)
902+
#if MIN_VERSION_base(4,18,0)
903+
-- | The intersection of a series of sets. Intersections are performed
904+
-- left-to-right.
905+
--
906+
-- @since FIXME
907+
intersections :: (Foldable1.Foldable1 f, Ord a) => f (Set a) -> Set a
908+
intersections ss = case Foldable1.toNonEmpty ss of
909+
s0 :| ss'
910+
| null s0 -> empty
911+
| otherwise -> List.foldr go id ss' s0
912+
where
913+
go s r acc
914+
| null acc' = empty
915+
| otherwise = r acc'
916+
where
917+
acc' = intersection acc s
918+
{-# INLINABLE intersections #-}
919+
#endif
904920

905-
-- | Sets form a 'Semigroup' under 'intersection'.
921+
-- | @Set@s form a 'Semigroup' under 'intersection'.
922+
--
923+
-- @since FIXME
906924
newtype Intersection a = Intersection { getIntersection :: Set a }
907925
deriving (Show, Eq, Ord)
908926

909927
instance (Ord a) => Semigroup (Intersection a) where
910928
(Intersection a) <> (Intersection b) = Intersection $ intersection a b
929+
{-# INLINABLE (<>) #-}
930+
911931
stimes = stimesIdempotent
932+
{-# INLINABLE stimes #-}
912933

913934
{--------------------------------------------------------------------
914935
Symmetric difference

0 commit comments

Comments
 (0)