Skip to content

Commit 1187f49

Browse files
authored
Merge pull request #5627 from unisonweb/topic/optimized-replacements
Add a mechanism for swapping builtin implementations for unison code
2 parents 63ae3dc + feaaec8 commit 1187f49

29 files changed

+1132
-663
lines changed

parser-typechecker/src/Unison/Builtin/Decls.hs

+25-1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ seqViewEmpty, seqViewElem :: ConstructorId
118118
noneId = Maybe.fromJust $ constructorId optionalRef "Optional.None"
119119
someId = Maybe.fromJust $ constructorId optionalRef "Optional.Some"
120120

121+
mapTip, mapBin :: ConstructorId
122+
mapTip = Maybe.fromJust $ constructorId mapRef "Map.Tip"
123+
mapBin = Maybe.fromJust $ constructorId mapRef "Map.Bin"
124+
121125
isPropagatedConstructorId = Maybe.fromJust $ constructorId isPropagatedRef "IsPropagated.IsPropagated"
122126

123127
isTestConstructorId = Maybe.fromJust $ constructorId isTestRef "IsTest.IsTest"
@@ -247,6 +251,9 @@ unRewriteSignature _ = Nothing
247251
rewritesRef :: Reference
248252
rewritesRef = lookupDeclRef "Rewrites"
249253

254+
mapRef :: Reference
255+
mapRef = lookupDeclRef "Map"
256+
250257
pattern Rewrites' :: [Term2 vt at ap v a] -> Term2 vt at ap v a
251258
pattern Rewrites' ts <- (unRewrites -> Just ts)
252259

@@ -301,7 +308,8 @@ builtinDataDecls = rs1 ++ rs
301308
(v "RewriteTerm", rewriteTerm),
302309
(v "RewriteSignature", rewriteType),
303310
(v "RewriteCase", rewriteCase),
304-
(v "Rewrites", rewrites)
311+
(v "Rewrites", rewrites),
312+
(v "Map", map)
305313
] of
306314
Right a -> a
307315
Left e -> error $ "builtinDataDecls: " <> show e
@@ -310,6 +318,7 @@ builtinDataDecls = rs1 ++ rs
310318
_ -> error "builtinDataDecls: Expected a single linkRef"
311319
v = Var.named
312320
var name = Type.var () (v name)
321+
infixr 7 `arr`
313322
arr = Type.arrow'
314323
-- see note on `hashDecls` above for why ctor must be called `Unit.Unit`.
315324
unit = DataDeclaration Structural () [] [((), v "Unit.Unit", var "Unit")]
@@ -604,6 +613,21 @@ builtinDataDecls = rs1 ++ rs
604613
[ ((), v "Link.Term", Type.termLink () `arr` var "Link"),
605614
((), v "Link.Type", Type.typeLink () `arr` var "Link")
606615
]
616+
map =
617+
DataDeclaration
618+
(Unique "s9drbo3urtmpecjn6ivkj5mn0vr11gfn")
619+
()
620+
[v "k", v "v"]
621+
let forke = Type.foralls () [v "k", v "v"]
622+
k = var "k"
623+
e = var "v"
624+
mapke = Type.apps' (var "Map") [k, e] in
625+
[ ( (),
626+
v "Map.Bin",
627+
forke $ Type.nat () `arr` k `arr` e `arr` mapke `arr` mapke `arr` mapke
628+
),
629+
((), v "Map.Tip", forke mapke)
630+
]
607631

608632
builtinEffectDecls :: [(Symbol, Reference.Id, DD.EffectDeclaration Symbol ())]
609633
builtinEffectDecls =

unison-core/src/Unison/Type.hs

+3
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ anyRef = Reference.Builtin "Any"
362362
timeSpecRef :: TypeReference
363363
timeSpecRef = Reference.Builtin "TimeSpec"
364364

365+
hmapRef :: TypeReference
366+
hmapRef = Reference.Builtin "Map"
367+
365368
any :: (Ord v) => a -> Type v a
366369
any a = ref a anyRef
367370

unison-runtime/src/Unison/Runtime/ANF.hs

+32
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ module Unison.Runtime.ANF
7979
groupTermLinks,
8080
buildInlineMap,
8181
inline,
82+
replaceConstructors,
83+
replaceFunctions,
8284
foldGroup,
8385
foldGroupLinks,
8486
overGroup,
@@ -685,6 +687,36 @@ inline inls (Rec bs entry) = Rec (fmap go0 <$> bs) (go0 entry)
685687
Just $ TApp f (pre ++ post)
686688
| otherwise = Nothing
687689

690+
replaceConstructors ::
691+
(Var v) =>
692+
Map Reference (Map CTag ForeignFunc) ->
693+
SuperGroup v ->
694+
SuperGroup v
695+
replaceConstructors reps (Rec bs entry) =
696+
Rec (fmap go0 <$> bs) (go0 entry)
697+
where
698+
go0 (Lambda ccs body) = Lambda ccs $ ABTN.visitPure f body
699+
700+
f (TApp (FCon r c) as) = do
701+
cs <- Map.lookup r reps
702+
ff <- Map.lookup c cs
703+
pure $ TApp (FPrim (Right ff)) as
704+
f _ = Nothing
705+
706+
replaceFunctions ::
707+
(Var v) =>
708+
Map Reference Reference ->
709+
SuperGroup v ->
710+
SuperGroup v
711+
replaceFunctions reps (Rec bs entry) =
712+
Rec (fmap go0 <$> bs) (go0 entry)
713+
where
714+
go0 (Lambda ccs body) = Lambda ccs $ ABTN.visitPure f body
715+
716+
f (TApp (FComb r) as) =
717+
Map.lookup r reps <&> \r -> TApp (FComb r) as
718+
f _ = Nothing
719+
688720
addDefaultCases :: (Var v) => (Monoid a) => Text -> Term v a -> Term v a
689721
addDefaultCases = ABT.visitPure . defaultCaseVisitor
690722

unison-runtime/src/Unison/Runtime/Builtin.hs

+9
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,15 @@ declareForeigns = do
12761276
declareForeign Untracked 2 Char_Class_is
12771277
declareForeign Untracked 1 Text_patterns_char
12781278

1279+
-- replacements
1280+
declareForeign Untracked 3 Map_insert
1281+
declareForeign Untracked 2 Map_lookup
1282+
declareForeign Untracked 1 Map_fromList
1283+
declareForeign Untracked 2 Map_eq
1284+
declareForeign Untracked 2 List_range
1285+
declareForeign Untracked 1 List_sort
1286+
1287+
12791288
foreignDeclResults :: (Map ForeignFunc (Sandbox, SuperNormal Symbol))
12801289
foreignDeclResults =
12811290
execState declareForeigns mempty

unison-runtime/src/Unison/Runtime/Decompile.hs

+6-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import Unison.Runtime.Stack
4040
Val (..),
4141
pattern DataC,
4242
pattern PApV,
43+
inflateMap,
4344
)
4445
import Unison.Syntax.NamePrinter (prettyReference)
4546
import Unison.Term
@@ -65,6 +66,7 @@ import Unison.Term qualified as Term
6566
import Unison.Type
6667
( anyRef,
6768
booleanRef,
69+
hmapRef,
6870
iarrayRef,
6971
ibytearrayRef,
7072
listRef,
@@ -172,14 +174,14 @@ decompile backref topTerms = \case
172174
(PApV (CIx rf rt k) _ vs)
173175
| rf == Builtin "jumpCont" ->
174176
err Cont $ bug "<Continuation>"
175-
| Builtin nm <- rf ->
176-
apps' (builtin () nm) <$> traverse (decompile backref topTerms) vs
177177
| Just t <- topTerms rt k ->
178178
Term.etaReduceEtaVars . substitute t
179179
<$> traverse (decompile backref topTerms) vs
180180
| k > 0,
181181
Just _ <- topTerms rt 0 ->
182182
err (UnkLocal rf k) $ bug "<Unknown>"
183+
| Builtin nm <- rf ->
184+
apps' (builtin () nm) <$> traverse (decompile backref topTerms) vs
183185
| otherwise -> err (UnkComb rf) $ ref () rf
184186
(PAp (CIx rf _ _) _ _) ->
185187
err (BadPAp rf) $ bug "<Unknown>"
@@ -228,6 +230,8 @@ decompileForeign backref topTerms f
228230
(decompileBytes . By.fromWord8s $ byteArrayToList a)
229231
| Just s <- unwrapSeq f =
230232
list' () <$> traverse (decompile backref topTerms) s
233+
| Just m <- maybeUnwrapForeign hmapRef f =
234+
decompile backref topTerms . BoxedVal $ inflateMap m
231235
decompileForeign _ _ (Wrap r _) =
232236
err (BadForeign r) $ bug text
233237
where

unison-runtime/src/Unison/Runtime/Foreign.hs

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import Control.Concurrent.STM (TVar)
2424
import Crypto.Hash qualified as Hash
2525
import Data.Atomics qualified as Atomic
2626
import Data.IORef (IORef)
27+
import Data.Map.Strict (Map)
2728
import Data.Tagged (Tagged (..))
2829
import Data.X509 qualified as X509
2930
import Network.Socket (Socket)
@@ -358,6 +359,15 @@ instance BuiltinForeign CharPattern where
358359
foreignName = Tagged "CharPattern"
359360
foreignRef = Tagged Ty.charClassRef
360361

362+
-- Note: this doesn't do any recursive conversion of keys/values,
363+
-- so any use of it needs to be an exact match for the map that was
364+
-- originally placed in the box. The current intention is for use
365+
-- at `Map Val Val`, but `Val` is in a module further along the
366+
-- dependency graph.
367+
instance BuiltinForeign (Map k v) where
368+
foreignName = Tagged "Map"
369+
foreignRef = Tagged Ty.hmapRef
370+
361371
wrapBuiltin :: forall f. (BuiltinForeign f) => f -> Foreign
362372
wrapBuiltin x = Wrap r x
363373
where

unison-runtime/src/Unison/Runtime/Foreign/Function.hs

+75-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ module Unison.Runtime.Foreign.Function
88
, foreignCall
99
, readsAtError
1010
, foreignConventionError
11+
, pseudoConstructors
12+
, functionReplacements
13+
, functionUnreplacements
1114
) where
1215

1316
import Control.Concurrent (ThreadId)
@@ -35,6 +38,8 @@ import Data.ByteString.Lazy qualified as L
3538
import Data.Default (def)
3639
import Data.Digest.Murmur64 (asWord64, hash64)
3740
import Data.IP (IP)
41+
import Data.Map.Strict qualified as Map
42+
import Data.Map.Strict.Internal qualified as Map
3843
import Data.PEM (PEM, pemContent, pemParseLBS)
3944
import Data.Sequence qualified as Sq
4045
import Data.Tagged (Tagged (..))
@@ -148,7 +153,8 @@ import Unison.Runtime.Crypto.Rsa qualified as Rsa
148153
import Unison.Runtime.Exception
149154
import Unison.Runtime.Foreign hiding (Failure)
150155
import Unison.Runtime.Foreign qualified as F
151-
import Unison.Runtime.Foreign.Function.Type (ForeignFunc (..))
156+
import Unison.Runtime.Foreign.Function.Type
157+
(ForeignFunc (..), foreignFuncBuiltinName)
152158
import Unison.Runtime.MCode
153159
import Unison.Runtime.Stack
154160
import Unison.Runtime.TypeTags qualified as TT
@@ -858,6 +864,24 @@ foreignCallHelper = \case
858864
Char_Class_is -> mkForeign $ \(cl, c) -> evaluate $ TPat.charPatternPred cl c
859865
Text_patterns_char -> mkForeign $ \c ->
860866
let v = TPat.cpattern (TPat.Char c) in pure v
867+
Map_tip -> mkForeign $ \() -> pure Map.empty
868+
Map_bin -> mkForeign $ \(sz :: Word64, k :: Val, v :: Val, l, r) ->
869+
pure (Map.Bin (fromIntegral sz) k v l r)
870+
Map_insert -> mkForeign $ \(k :: Val, v :: Val, m :: Map Val Val) ->
871+
evaluate $ Map.insert k v m
872+
Map_lookup -> mkForeign $ \(k :: Val, v :: Map Val Val) ->
873+
evaluate $ Map.lookup k v
874+
Map_fromList -> mkForeign $ \(l :: [(Val, Val)]) ->
875+
evaluate $ Map.fromList l
876+
Map_eq -> mkForeign $ \(l :: Map Val Val, r :: Map Val Val) ->
877+
pure $ l == r
878+
List_range -> mkForeign $ \(m :: Word64, n :: Word64) ->
879+
let sz | m < n = fromIntegral $ n - m
880+
| otherwise = 0
881+
mk i = NatVal $ m + fromIntegral i
882+
force s = foldl (\u x -> x `seq` u) s s
883+
in evaluate . force $ Sq.fromFunction sz mk
884+
List_sort -> mkForeign $ \(l :: Seq Val) -> pure $ Sq.unstableSort l
861885
where
862886
chop = reverse . dropWhile isPathSeparator . reverse
863887

@@ -1928,3 +1952,53 @@ instance {-# overlappable #-} (BuiltinForeign b) => ForeignConvention b where
19281952
encodeVal = encodeBuiltin
19291953
readAtIndex = readBuiltinAt
19301954
writeBack = writeBuiltin
1955+
1956+
pseudoConstructors :: Map Reference (Map TT.CTag ForeignFunc)
1957+
pseudoConstructors =
1958+
Map.singleton Ty.mapRef $
1959+
Map.fromList
1960+
[ (fromIntegral Ty.mapTip, Map_tip)
1961+
, (fromIntegral Ty.mapBin, Map_bin)
1962+
]
1963+
1964+
functionReplacementList :: [(Data.Text.Text, ForeignFunc)]
1965+
functionReplacementList =
1966+
[ ( "03hqp8knrcgdc733mitcunjlug4cpi9headkggu8h9d87nfgneo6e"
1967+
, Map_insert
1968+
)
1969+
, ( "03g44bb2bp3g5eld8eh07g6e8iq7oiqiplapeb6jerbs7ee3icq9s"
1970+
, Map_lookup
1971+
)
1972+
, ( "005mc1fq7ojq72c238qlm2rspjgqo2furjodf28icruv316odu6du"
1973+
, Map_fromList
1974+
)
1975+
, ( "03c559iihi2vj0qps6cln48nv31ajup2srhas4pd05b9k46ds8jvk"
1976+
, Map_eq
1977+
)
1978+
, ( "01f446li3b0j5gcnj7fa99jfqir43shs0jqu779oo0npb7v8d3v22"
1979+
, List_range
1980+
)
1981+
, ( "00jh7o3l67okqqalho1sqgl4ei9n2sdhrpqobgkf7j390v4e938km"
1982+
, List_sort
1983+
)
1984+
]
1985+
1986+
functionReplacements :: Map Reference Reference
1987+
functionReplacements =
1988+
Map.fromList $ fmap process functionReplacementList
1989+
1990+
functionUnreplacements :: Map Reference Reference
1991+
functionUnreplacements =
1992+
Map.fromList . fmap (swap . process) $ functionReplacementList
1993+
where
1994+
swap (x, y) = (y, x)
1995+
1996+
-- Note: using index 0 right now. Generalize if ever replacing
1997+
-- part of a mutually recursive group.
1998+
process :: (Data.Text.Text, ForeignFunc) -> (Reference, Reference)
1999+
process (str, ff) = case derivedBase32Hex str 0 of
2000+
Nothing -> error $ "Could not create reference for " ++ sname
2001+
Just r -> (r, Builtin name)
2002+
where
2003+
name = foreignFuncBuiltinName ff
2004+
sname = Data.Text.unpack name

unison-runtime/src/Unison/Runtime/Foreign/Function/Type.hs

+16
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,14 @@ data ForeignFunc
254254
| Char_Class_letter
255255
| Char_Class_is
256256
| Text_patterns_char
257+
| Map_tip
258+
| Map_bin
259+
| Map_insert
260+
| Map_lookup
261+
| Map_fromList
262+
| Map_eq
263+
| List_range
264+
| List_sort
257265
deriving (Show, Eq, Ord, Enum, Bounded)
258266

259267
foreignFuncBuiltinName :: ForeignFunc -> Text
@@ -504,3 +512,11 @@ foreignFuncBuiltinName = \case
504512
Char_Class_letter -> "Char.Class.letter"
505513
Char_Class_is -> "Char.Class.is"
506514
Text_patterns_char -> "Text.patterns.char"
515+
Map_tip -> "Map.Tip"
516+
Map_bin -> "Map.Bin"
517+
Map_insert -> "Map.insert"
518+
Map_lookup -> "Map.lookup"
519+
Map_fromList -> "Map.fromList"
520+
Map_eq -> "Map.=="
521+
List_range -> "List.range"
522+
List_sort -> "List.sort"

unison-runtime/src/Unison/Runtime/Interface.hs

+3
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ import Unison.Runtime.ANF.Serialize as ANF
106106
import Unison.Runtime.Builtin
107107
import Unison.Runtime.Decompile
108108
import Unison.Runtime.Exception
109+
import Unison.Runtime.Foreign.Function (functionUnreplacements)
109110
import Unison.Runtime.MCode
110111
( Args (..),
111112
CombIx (..),
@@ -872,6 +873,8 @@ backReferenceTm ::
872873
Maybe (Term Symbol)
873874
backReferenceTm ws frs irs dcm c i = do
874875
r <- EC.lookup c ws
876+
-- backmap from function replacements
877+
r <- pure $ Map.findWithDefault r r functionUnreplacements
875878
-- backmap intermediate ref to floated ref
876879
r <- Map.lookup r (backmap irs)
877880
-- backmap floated ref to original ref

0 commit comments

Comments
 (0)