Skip to content

Commit 02e09da

Browse files
authored
Merge pull request #5557 from unisonweb/topic/pattern-compilation
Rework data pattern matching to use default cases
2 parents abc9720 + cd81713 commit 02e09da

File tree

3 files changed

+192
-41
lines changed

3 files changed

+192
-41
lines changed

unison-runtime/src/Unison/Runtime/Machine.hs

Lines changed: 117 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -683,14 +683,13 @@ eval env !denv !activeThreads !stk !k r (Match i br) = do
683683
n <- peekOffN stk i
684684
eval env denv activeThreads stk k r $ selectBranch n br
685685
eval env !denv !activeThreads !stk !k r (DMatch mr i br) = do
686-
(t, stk) <- dumpDataNoTag mr stk =<< peekOff stk i
687-
eval env denv activeThreads stk k r $
688-
selectBranch (maskTags t) br
686+
(nx, stk) <- dataBranch mr stk br =<< bpeekOff stk i
687+
eval env denv activeThreads stk k r nx
689688
eval env !denv !activeThreads !stk !k r (NMatch _mr i br) = do
690689
n <- peekOffN stk i
691690
eval env denv activeThreads stk k r $ selectBranch n br
692691
eval env !denv !activeThreads !stk !k r (RMatch i pu br) = do
693-
(t, stk) <- dumpDataNoTag Nothing stk =<< peekOff stk i
692+
(t, stk) <- dumpDataValNoTag stk =<< peekOff stk i
694693
if t == TT.pureEffectTag
695694
then eval env denv activeThreads stk k r pu
696695
else case ANF.unpackTags t of
@@ -1000,46 +999,41 @@ buildData !stk !r !t (VArgV i) = do
1000999
l = fsize stk - i
10011000
{-# INLINE buildData #-}
10021001

1002+
dumpDataValNoTag ::
1003+
Stack ->
1004+
Val ->
1005+
IO (PackedTag, Stack)
1006+
dumpDataValNoTag stk (BoxedVal c) =
1007+
(closureTag c,) <$> dumpDataNoTag Nothing stk c
1008+
dumpDataValNoTag _ v =
1009+
die $ "dumpDataValNoTag: unboxed val: " ++ show v
1010+
{-# inline dumpDataValNoTag #-}
1011+
10031012
-- Dumps a data type closure to the stack without writing its tag.
10041013
-- Instead, the tag is returned for direct case analysis.
10051014
dumpDataNoTag ::
10061015
Maybe Reference ->
10071016
Stack ->
1008-
Val ->
1009-
IO (PackedTag, Stack)
1017+
Closure ->
1018+
IO Stack
10101019
dumpDataNoTag !mr !stk = \case
10111020
-- Normally we want to avoid dumping unboxed values since it's unnecessary, but sometimes we don't know the type of
10121021
-- the incoming value and end up dumping unboxed values, so we just push them back to the stack as-is. e.g. in type-casts/coercions
1013-
val@(UnboxedVal _ t) -> do
1022+
Enum _ _ -> pure stk
1023+
Data1 _ _ x -> do
10141024
stk <- bump stk
1015-
poke stk val
1016-
pure (unboxedPackedTag t, stk)
1017-
BoxedVal clos -> case clos of
1018-
(Enum _ t) -> pure (t, stk)
1019-
(Data1 _ t x) -> do
1020-
stk <- bump stk
1021-
poke stk x
1022-
pure (t, stk)
1023-
(Data2 _ t x y) -> do
1024-
stk <- bumpn stk 2
1025-
pokeOff stk 1 y
1026-
poke stk x
1027-
pure (t, stk)
1028-
(DataG _ t seg) -> do
1029-
stk <- dumpSeg stk seg S
1030-
pure (t, stk)
1031-
clo ->
1032-
die $
1033-
"dumpDataNoTag: bad closure: "
1034-
++ show clo
1035-
++ maybe "" (\r -> "\nexpected type: " ++ show r) mr
1036-
where
1037-
unboxedPackedTag :: UnboxedTypeTag -> PackedTag
1038-
unboxedPackedTag = \case
1039-
CharTag -> TT.charTag
1040-
FloatTag -> TT.floatTag
1041-
IntTag -> TT.intTag
1042-
NatTag -> TT.natTag
1025+
poke stk x
1026+
pure stk
1027+
Data2 _ _ x y -> do
1028+
stk <- bumpn stk 2
1029+
pokeOff stk 1 y
1030+
stk <$ poke stk x
1031+
DataG _ _ seg -> dumpSeg stk seg S
1032+
clo ->
1033+
die $
1034+
"dumpDataNoTag: bad closure: "
1035+
++ show clo
1036+
++ maybe "" (\r -> "\nexpected type: " ++ show r) mr
10431037
{-# INLINE dumpDataNoTag #-}
10441038

10451039
-- Note: although the representation allows it, it is impossible
@@ -1995,6 +1989,94 @@ selectBranch t (TestW df cs) = lookupWithDefault df t cs
19951989
selectBranch _ (TestT {}) = error "impossible"
19961990
{-# INLINE selectBranch #-}
19971991

1992+
-- Combined branch selection and field dumping function for data types.
1993+
-- Fields should only be dumped on _matches_, not default cases, because
1994+
-- default cases potentially cover many constructors which could result
1995+
-- in a variable number of values being put on the stack. Default cases
1996+
-- uniformly expect _no_ values to be added to the stack.
1997+
dataBranch
1998+
:: Maybe Reference -> Stack -> MBranch -> Closure -> IO (MSection, Stack)
1999+
dataBranch mrf stk (Test1 u cu df) = \case
2000+
Enum _ t
2001+
| maskTags t == u -> pure (cu, stk)
2002+
| otherwise -> pure (df, stk)
2003+
Data1 _ t x
2004+
| maskTags t == u -> do
2005+
stk <- bump stk
2006+
(cu, stk) <$ poke stk x
2007+
| otherwise -> pure (df, stk)
2008+
Data2 _ t x y
2009+
| maskTags t == u -> do
2010+
stk <- bumpn stk 2
2011+
pokeOff stk 1 y
2012+
(cu, stk) <$ poke stk x
2013+
| otherwise -> pure (df, stk)
2014+
DataG _ t seg
2015+
| maskTags t == u -> (cu,) <$> dumpSeg stk seg S
2016+
| otherwise -> pure (df, stk)
2017+
clo -> dataBranchClosureError mrf clo
2018+
dataBranch mrf stk (Test2 u cu v cv df) = \case
2019+
Enum _ t
2020+
| maskTags t == u -> pure (cu, stk)
2021+
| maskTags t == v -> pure (cv, stk)
2022+
| otherwise -> pure (df, stk)
2023+
Data1 _ t x
2024+
| maskTags t == u -> do
2025+
stk <- bump stk
2026+
(cu, stk) <$ poke stk x
2027+
| maskTags t == v -> do
2028+
stk <- bump stk
2029+
(cv, stk) <$ poke stk x
2030+
| otherwise -> pure (df, stk)
2031+
Data2 _ t x y
2032+
| maskTags t == u -> do
2033+
stk <- bumpn stk 2
2034+
pokeOff stk 1 y
2035+
(cu, stk) <$ poke stk x
2036+
| maskTags t == v -> do
2037+
stk <- bumpn stk 2
2038+
pokeOff stk 1 y
2039+
(cv, stk) <$ poke stk x
2040+
| otherwise -> pure (df, stk)
2041+
DataG _ t seg
2042+
| maskTags t == u -> (cu,) <$> dumpSeg stk seg S
2043+
| maskTags t == v -> (cv,) <$> dumpSeg stk seg S
2044+
| otherwise -> pure (df, stk)
2045+
clo -> dataBranchClosureError mrf clo
2046+
dataBranch mrf stk (TestW df bs) = \case
2047+
Enum _ t
2048+
| Just ca <- EC.lookup (maskTags t) bs -> pure (ca, stk)
2049+
| otherwise -> pure (df, stk)
2050+
Data1 _ t x
2051+
| Just ca <- EC.lookup (maskTags t) bs -> do
2052+
stk <- bump stk
2053+
(ca, stk) <$ poke stk x
2054+
| otherwise -> pure (df, stk)
2055+
Data2 _ t x y
2056+
| Just ca <- EC.lookup (maskTags t) bs -> do
2057+
stk <- bumpn stk 2
2058+
pokeOff stk 1 y
2059+
(ca, stk) <$ poke stk x
2060+
| otherwise -> pure (df, stk)
2061+
DataG _ t seg
2062+
| Just ca <- EC.lookup (maskTags t) bs ->
2063+
(ca,) <$> dumpSeg stk seg S
2064+
| otherwise -> pure (df, stk)
2065+
clo -> dataBranchClosureError mrf clo
2066+
dataBranch _ _ br = \_ ->
2067+
dataBranchBranchError br
2068+
{-# inline dataBranch #-}
2069+
2070+
dataBranchClosureError :: Maybe Reference -> Closure -> IO a
2071+
dataBranchClosureError mrf clo =
2072+
die $ "dataBranch: bad closure: "
2073+
++ show clo
2074+
++ maybe "" (\ r -> "\nexpected type: " ++ show r) mrf
2075+
2076+
dataBranchBranchError :: MBranch -> IO a
2077+
dataBranchBranchError br =
2078+
die $ "dataBranch: unexpected branch: " ++ show br
2079+
19982080
-- Splits off a portion of the continuation up to a given prompt.
19992081
--
20002082
-- The main procedure walks along the 'code' stack `k`, keeping track of how

unison-runtime/src/Unison/Runtime/Pattern.hs

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ module Unison.Runtime.Pattern
1313
where
1414

1515
import Control.Monad.State (State, evalState, modify, runState, state)
16+
import Data.Containers.ListUtils (nubOrd)
1617
import Data.List (transpose)
1718
import Data.Map.Strict
1819
( fromListWith,
@@ -92,6 +93,11 @@ builtinDataSpec = Map.fromList decls
9293
| (_, x, y) <- builtinEffectDecls
9394
]
9495

96+
findPattern :: Eq v => v -> PatternRow v -> Maybe (Pattern v)
97+
findPattern v (PR ms _ _)
98+
| (_, p : _) <- break ((== v) . loc) ms = Just p
99+
| otherwise = Nothing
100+
95101
-- A pattern compilation matrix is just a list of rows. There is
96102
-- no need for the rows to have uniform length; the variable
97103
-- annotations on the patterns in the rows keep track of what
@@ -125,8 +131,11 @@ refutable (P.Unbound _) = False
125131
refutable (P.Var _) = False
126132
refutable _ = True
127133

128-
rowIrrefutable :: PatternRow v -> Bool
129-
rowIrrefutable (PR ps _ _) = null ps
134+
noMatches :: PatternRow v -> Bool
135+
noMatches (PR ps _ _) = null ps
136+
137+
rowRefutable :: PatternRow v -> Bool
138+
rowRefutable (PR ps g _) = isJust g || not (null ps)
130139

131140
firstRow :: ([P.Pattern v] -> Maybe v) -> Heuristic v
132141
firstRow f (PM (r : _)) = f $ matches r
@@ -481,6 +490,19 @@ splitMatrix v rf cons (PM rs) =
481490
where
482491
mmap = fmap (\(t, fs) -> (t, splitRow v rf t fs =<< rs)) cons
483492

493+
-- Eliminates a variable from a matrix, keeping the rows that are
494+
-- _not_ specific matches on that variable (so, would potentially
495+
-- occur in a default case).
496+
antiSplitMatrix ::
497+
(Var v) =>
498+
v ->
499+
PatternMatrix v ->
500+
PatternMatrix v
501+
antiSplitMatrix v (PM rs) = PM (f =<< rs)
502+
where
503+
-- keep rows that do not have a refutable pattern for v
504+
f r = [ r | isNothing $ findPattern v r ]
505+
484506
-- Monad for pattern preparation. It is a state monad carrying a fresh
485507
-- variable source, the list of variables bound the pattern being
486508
-- prepared, and a variable renaming mapping.
@@ -596,7 +618,7 @@ compile _ _ (PM []) = apps' bu [text () "pattern match failure"]
596618
where
597619
bu = ref () (Builtin "bug")
598620
compile spec ctx m@(PM (r : rs))
599-
| rowIrrefutable r =
621+
| noMatches r =
600622
case guard r of
601623
Nothing -> body r
602624
Just g -> iff mempty g (body r) $ compile spec ctx (PM rs)
@@ -614,8 +636,11 @@ compile spec ctx m@(PM (r : rs))
614636
case lookupData rf spec of
615637
Right cons ->
616638
match () (var () v) $
617-
buildCase spec rf False cons ctx
618-
<$> splitMatrix v (Just rf) (numberCons cons) m
639+
(buildCase spec rf False cons ctx
640+
<$> splitMatrix v (Just rf) ncons m)
641+
++ buildDefaultCase spec False needDefault ctx dm
642+
where
643+
needDefault = length ncons < length cons
619644
Left err -> internalBug err
620645
| PReq rfs <- ty =
621646
match () (var () v) $
@@ -631,7 +656,29 @@ compile spec ctx m@(PM (r : rs))
631656
internalBug "unknown pattern compilation type"
632657
where
633658
v = choose heuristics m
659+
ncons = relevantConstructors m v
634660
ty = Map.findWithDefault Unknown v ctx
661+
dm = antiSplitMatrix v m
662+
663+
-- Calculates the data constructors—with their arities—that should be
664+
-- matched on when splitting a matrix on a given variable. This
665+
-- includes
666+
relevantConstructors :: Ord v => PatternMatrix v -> v -> [(Int, Int)]
667+
relevantConstructors (PM rows) v = search [] rows
668+
where
669+
search acc (row : rows)
670+
| rowRefutable row = case findPattern v row of
671+
Just (P.Constructor _ (ConstructorReference _ t) sps) ->
672+
search ((fromIntegral t, length sps) : acc) rows
673+
Just (P.Boolean _ b) ->
674+
search ((if b then 1 else 0, 0) : acc) rows
675+
Just p ->
676+
internalBug $ "unexpected data pattern: " ++ show p
677+
-- if the pattern is not found, it must have been irrefutable,
678+
-- so contributes no relevant constructor.
679+
_ -> search acc rows
680+
-- irrefutable row, or no rows left
681+
search acc _ = nubOrd $ reverse acc
635682

636683
buildCaseBuiltin ::
637684
(Var v) =>
@@ -677,6 +724,18 @@ buildCase spec r eff cons ctx0 (t, vts, m) =
677724
vs = ((),) . fst <$> vts
678725
ctx = Map.fromList vts <> ctx0
679726

727+
buildDefaultCase ::
728+
(Var v) =>
729+
DataSpec ->
730+
Bool ->
731+
Bool ->
732+
Ctx v ->
733+
PatternMatrix v ->
734+
[MatchCase () (Term v)]
735+
buildDefaultCase spec _eff needed ctx pm
736+
| needed = [MatchCase (Unbound ()) Nothing $ compile spec ctx pm]
737+
| otherwise = []
738+
680739
mkRow ::
681740
(Var v) =>
682741
v ->

unison-runtime/src/Unison/Runtime/Stack.hs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ module Unison.Runtime.Stack
2222
BlackHole,
2323
UnboxedTypeTag
2424
),
25+
closureTag,
2526
UnboxedTypeTag (..),
2627
unboxedTypeTagToInt,
2728
unboxedTypeTagFromInt,
@@ -153,7 +154,7 @@ module Unison.Runtime.Stack
153154
)
154155
where
155156

156-
import Control.Exception (throwIO)
157+
import Control.Exception (throw, throwIO)
157158
import Control.Monad.Primitive
158159
import Data.Char qualified as Char
159160
import Data.IORef (IORef)
@@ -371,6 +372,15 @@ splitData = \case
371372
(DataG r t seg) -> Just (r, t, segToList seg)
372373
_ -> Nothing
373374

375+
closureTag :: Closure -> PackedTag
376+
closureTag (Enum _ t) = t
377+
closureTag (Data1 _ t _) = t
378+
closureTag (Data2 _ t _ _) = t
379+
closureTag (DataG _ t _) = t
380+
closureTag c =
381+
throw $ Panic "closureTag: unexpected closure" (Just $ BoxedVal c)
382+
{-# inline closureTag #-}
383+
374384
-- | Converts a list of integers representing an unboxed segment back into the
375385
-- appropriate segment. Segments are stored backwards in the runtime, so this
376386
-- reverses the list.

0 commit comments

Comments
 (0)