Skip to content

Commit 0c9f4d6

Browse files
Remove closure type parameter from Value (#856)
1 parent 3f27a4f commit 0c9f4d6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+637
-695
lines changed

vehicle/src/Vehicle/Backend/LossFunction/Core.hs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import GHC.Generics (Generic)
88
import Vehicle.Backend.Prelude (DifferentiableLogicID)
99
import Vehicle.Compile.Prelude
1010
import Vehicle.Data.Builtin.Loss
11-
import Vehicle.Data.Code.Value (Value (..), WHNFBinder, WHNFBoundEnv, WHNFClosure (..), WHNFValue)
11+
import Vehicle.Data.Code.Value (BoundEnv, Closure (..), VBinder, Value (..))
1212
import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..))
1313

1414
--------------------------------------------------------------------------------
@@ -56,7 +56,7 @@ instance Pretty TensorDifferentiableLogicField where
5656
pretty = pretty . show
5757

5858
type DifferentiableLogicImplementation =
59-
Map TensorDifferentiableLogicField (WHNFValue LossTensorBuiltin)
59+
Map TensorDifferentiableLogicField (Value LossTensorBuiltin)
6060

6161
type CompiledDifferentiableLogic = (DifferentiableLogicID, DifferentiableLogicImplementation)
6262

@@ -130,7 +130,7 @@ data RatTensorView expr
130130
| VReduceMaxRatTensor expr
131131
| VSearchRatTensor expr expr expr expr
132132
133-
fromRatTensorView :: (BuiltinHasRatTensor builtin, BuiltinHasDimensionData builtin) => RatTensorView (WHNFValue builtin) -> WHNFValue builtin
133+
fromRatTensorView :: (BuiltinHasRatTensor builtin, BuiltinHasDimensionData builtin) => RatTensorView (Value builtin) -> Value builtin
134134
fromRatTensorView = \case
135135
VRatTensor y -> INullaryRatTensorOp (RatTensor y)
136136
VNegRatTensor x -> IRatTensorOp NegRatTensor (explicit <$> [x])
@@ -149,7 +149,7 @@ fromRatTensorView = \case
149149
VSearchRatTensor reduce lower upper fn -> IRatTensorOp SearchRatTensor (explicit <$> [reduce, lower, upper, fn])
150150
VRatTensorVar v -> VBoundVar v []
151151
152-
toRatTensorView :: (BuiltinHasRatTensor builtin, BuiltinHasDimensionData builtin) => WHNFValue builtin -> RatTensorView (WHNFValue builtin)
152+
toRatTensorView :: (BuiltinHasRatTensor builtin, BuiltinHasDimensionData builtin) => Value builtin -> RatTensorView (Value builtin)
153153
toRatTensorView expr = case getRatTensorOp expr of
154154
Just (RatTensor b, []) -> VRatTensor b
155155
Just (NegRatTensor, [x]) -> VNegRatTensor (argExpr x)
@@ -181,5 +181,5 @@ preservedStdLibOps =
181181
[ StdForeachIndex
182182
]
183183

184-
pattern VLam2 :: WHNFBinder builtin -> WHNFBoundEnv builtin -> Binder builtin -> Expr builtin -> WHNFValue builtin
185-
pattern VLam2 binder1 env binder2 body <- VLam binder1 (WHNFClosure env (Lam _ binder2 body))
184+
pattern VLam2 :: VBinder builtin -> BoundEnv builtin -> Binder builtin -> Expr builtin -> Value builtin
185+
pattern VLam2 binder1 env binder2 body <- VLam binder1 (Closure env (Lam _ binder2 body))

vehicle/src/Vehicle/Backend/LossFunction/Domain.hs

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ type MonadDomain m =
4848
)
4949

5050
data Domain = Domain
51-
{ lowerBound :: WHNFValue TensorBuiltin,
52-
upperBound :: WHNFValue TensorBuiltin
51+
{ lowerBound :: Value TensorBuiltin,
52+
upperBound :: Value TensorBuiltin
5353
}
5454

5555
extractSearchDomain ::
5656
(MonadDomain m) =>
5757
DeclProvenance ->
58-
WHNFBinder TensorBuiltin ->
58+
VBinder TensorBuiltin ->
5959
Lv ->
60-
WHNFValue TensorBuiltin ->
61-
m (Domain, WHNFValue TensorBuiltin)
60+
Value TensorBuiltin ->
61+
m (Domain, Value TensorBuiltin)
6262
extractSearchDomain _propertyProv _binder _lv value = do
6363
{-
6464
_variableInfo <- case typeOf binder of
@@ -95,9 +95,9 @@ extractSearchDomain _propertyProv _binder _lv value = do
9595
--------------------------------------------------------------------------------
9696
-- Constraints
9797
98-
type TensorElementInequality = Inequality UserElementVariable (WHNFValue TensorBuiltin)
98+
type TensorElementInequality = Inequality UserElementVariable (Value TensorBuiltin)
9999
100-
type TensorInequality = Inequality Name (WHNFValue TensorBuiltin)
100+
type TensorInequality = Inequality Name (Value TensorBuiltin)
101101
102102
type VariableConstraint = Either TensorElementInequality TensorInequality
103103
@@ -109,20 +109,20 @@ splitConstraints = partitionEithers
109109
pattern NoConstraints :: VariableConstraints
110110
pattern NoConstraints = []
111111
112-
type ConstrainedValue = (VariableConstraints, WHNFValue TensorBuiltin)
112+
type ConstrainedValue = (VariableConstraints, Value TensorBuiltin)
113113
114-
unconstrained :: WHNFValue TensorBuiltin -> ConstrainedValue
114+
unconstrained :: Value TensorBuiltin -> ConstrainedValue
115115
unconstrained = (NoConstraints,)
116116
117117
updateConstrainedValue ::
118-
WHNFValue TensorBuiltin ->
118+
Value TensorBuiltin ->
119119
ConstrainedValue ->
120120
ConstrainedValue
121121
updateConstrainedValue originalExpr = \case
122122
constr@(_ : _, _) -> constr
123123
([], _) -> ([], originalExpr)
124124
125-
instance IsConstant (WHNFValue TensorBuiltin) where
125+
instance IsConstant (Value TensorBuiltin) where
126126
isZero = \case
127127
-- This is only semi-decidable, probably need to think harder about what
128128
-- to do here.
@@ -189,18 +189,18 @@ extractDomainFromConstraints VariableInfo{..} constraints = do
189189
190190
extractVarBounds ::
191191
(MonadCompile m) =>
192-
([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (WHNFValue TensorBuiltin))]) ->
192+
([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (Value TensorBuiltin))]) ->
193193
UserElementVariable ->
194-
m ([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (WHNFValue TensorBuiltin))])
194+
m ([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (Value TensorBuiltin))])
195195
extractVarBounds (currentConstraints, solutions) var = do
196196
(bounds, newInequalities) <- fourierMotzkinElimination var currentConstraints
197197
return (newInequalities, (var, bounds) : solutions)
198198
199199
convertBoundToExpr ::
200200
Map UserElementVariable Lv ->
201-
(WHNFValue TensorBuiltin -> WHNFValue TensorBuiltin -> WHNFValue TensorBuiltin) ->
202-
NonEmpty (Bound UserElementVariable (WHNFValue TensorBuiltin)) ->
203-
WHNFValue TensorBuiltin
201+
(Value TensorBuiltin -> Value TensorBuiltin -> Value TensorBuiltin) ->
202+
NonEmpty (Bound UserElementVariable (Value TensorBuiltin)) ->
203+
Value TensorBuiltin
204204
convertBoundToExpr varMap op bounds = foldr1 _ (fmap convertBound bounds)
205205
where
206206
-- Ignore strictness for the moment.
@@ -214,7 +214,7 @@ convertBoundToExpr varMap op bounds = foldr1 _ (fmap convertBound bounds)
214214
--------------------------------------------------------------------------------
215215
-- Constraint search
216216
217-
findConstraints :: (MonadSearch m) => WHNFValue TensorBuiltin -> m ConstrainedValue
217+
findConstraints :: (MonadSearch m) => Value TensorBuiltin -> m ConstrainedValue
218218
findConstraints expr = case toBoolTensorView expr of
219219
-------------------------
220220
-- Unuseful base cases --
@@ -246,15 +246,15 @@ findConstraints expr = case toBoolTensorView expr of
246246
247247
handleNot ::
248248
forall m . (MonadSearch m) =>
249-
WHNFValue TensorBuiltin ->
249+
Value TensorBuiltin ->
250250
m ConstrainedValue
251251
handleNot expr = do
252252
loweredExpr <- lowerBoolTensor expr
253253
case loweredExpr of
254254
INot {} -> return $ unconstrained expr
255255
newExpr -> updateConstrainedValue expr <$> findConstraints newExpr
256256
where
257-
lowerBoolTensor :: WHNFValue TensorBuiltin -> m (WHNFValue TensorBuiltin)
257+
lowerBoolTensor :: Value TensorBuiltin -> m (Value TensorBuiltin)
258258
lowerBoolTensor e = fromBoolTensorView <$> case toBoolTensorView e of
259259
----------------
260260
-- Base cases --
@@ -278,15 +278,15 @@ handleNot expr = do
278278
VReduceAndTensor {} -> return $ VNotTensor e
279279
VReduceOrTensor {} -> return $ VNotTensor e
280280
281-
lowerBool :: WHNFValue TensorBuiltin -> m (WHNFValue TensorBuiltin)
281+
lowerBool :: Value TensorBuiltin -> m (Value TensorBuiltin)
282282
lowerBool = \case
283283
INullaryBoolTensorOp (BoolLiteral b) -> return $ INullaryBoolTensorOp (BoolLiteral b)
284284
e -> developerError $ "Unexpected expression of type Bool:" <+> prettyVerbose e
285285
286286
unfoldEquality ::
287-
WHNFValue TensorBuiltin ->
288-
WHNFValue TensorBuiltin ->
289-
WHNFValue TensorBuiltin
287+
Value TensorBuiltin ->
288+
Value TensorBuiltin ->
289+
Value TensorBuiltin
290290
unfoldEquality x y = IAnd (IOrderRat Le x y) (IOrderRat Ge x y)
291291
292292
--------------------------------------------------------------------------------
@@ -299,8 +299,8 @@ unfoldEquality x y = IAnd (IOrderRat Le x y) (IOrderRat Ge x y)
299299
handleRatInequality ::
300300
(MonadSearch m) =>
301301
OrderOp ->
302-
WHNFValue TensorBuiltin ->
303-
WHNFValue TensorBuiltin ->
302+
Value TensorBuiltin ->
303+
Value TensorBuiltin ->
304304
m ConstrainedValue
305305
handleRatInequality op e1 e2 = do
306306
result <- compileRatLinearRelation (mkInequality op) e1 e2
@@ -321,25 +321,25 @@ handleRatInequality op e1 e2 = do
321321
compileRatLinearRelation ::
322322
(MonadLogger m, MonadReader VariableInfo m) =>
323323
(LinearExp -> LinearExp -> relation) ->
324-
WHNFValue TensorBuiltin ->
325-
WHNFValue TensorBuiltin ->
326-
m (Either (WHNFValue TensorBuiltin) relation)
324+
Value TensorBuiltin ->
325+
Value TensorBuiltin ->
326+
m (Either (Value TensorBuiltin) relation)
327327
compileRatLinearRelation mkRelation x y = do
328328
runExceptT $ do
329329
x' <- compileRatLinearExpr x
330330
y' <- compileRatLinearExpr y
331331
return $ mkRelation x' y'
332332
333-
type LinearExp = LinearExpr UserElementVariable (WHNFValue TensorBuiltin)
333+
type LinearExp = LinearExpr UserElementVariable (Value TensorBuiltin)
334334
335335
compileRatLinearExpr ::
336336
forall m.
337-
(MonadLogger m, MonadReader VariableInfo m, MonadError (WHNFValue TensorBuiltin) m) =>
338-
WHNFValue TensorBuiltin ->
337+
(MonadLogger m, MonadReader VariableInfo m, MonadError (Value TensorBuiltin) m) =>
338+
Value TensorBuiltin ->
339339
m LinearExp
340340
compileRatLinearExpr = go
341341
where
342-
go :: WHNFValue TensorBuiltin -> m LinearExp
342+
go :: Value TensorBuiltin -> m LinearExp
343343
go expr = case toRatTensorView expr of
344344
----------------
345345
-- Base cases --

vehicle/src/Vehicle/Backend/LossFunction/JSON.hs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,17 @@ convertDecl = \case
154154
expr' <- convertExpr mempty body
155155
return $ JDecl (nameOf ident) typ' expr'
156156

157-
convertExpr :: (MonadJSON m) => WHNFBoundEnv LossTensorBuiltin -> Expr LossTensorBuiltin -> m JExpr
157+
convertExpr :: (MonadJSON m) => BoundEnv LossTensorBuiltin -> Expr LossTensorBuiltin -> m JExpr
158158
convertExpr env body = convertValue =<< eval mempty env body
159159

160-
convertValue :: (MonadJSON m) => WHNFValue LossTensorBuiltin -> m JExpr
160+
convertValue :: (MonadJSON m) => Value LossTensorBuiltin -> m JExpr
161161
convertValue expr = do
162162
showEntry expr
163163
result <- case expr of
164164
VMeta {} -> resolutionError currentPass "VMeta"
165165
VFreeVar {} -> resolutionError currentPass "VFreeVar"
166166
VUniverse {} -> resolutionError currentPass "Universe"
167-
VLam binder (WHNFClosure env body) -> do
167+
VLam binder (Closure env body) -> do
168168
let name = getBinderName binder
169169
typ' <- convertValue (typeOf binder)
170170
lv <- getBinderDepth
@@ -183,7 +183,7 @@ convertValue expr = do
183183
showExit result
184184
return result
185185

186-
convertBuiltin :: (MonadJSON m) => LossTensorBuiltin -> [WHNFValue LossTensorBuiltin] -> m JExpr
186+
convertBuiltin :: (MonadJSON m) => LossTensorBuiltin -> [Value LossTensorBuiltin] -> m JExpr
187187
convertBuiltin b spine = case b of
188188
LossTensorRat op -> case op of
189189
L.RatTensor t -> convertNullaryOp b (RatTensor $ mapTensor toRat t) spine
@@ -216,42 +216,42 @@ convertBuiltin b spine = case b of
216216
L.DimensionIndexType -> convertIndexType spine
217217
L.TensorType -> convertTensorType spine
218218

219-
convertNullaryOp :: (MonadJSON m) => LossTensorBuiltin -> JExpr -> [WHNFValue LossTensorBuiltin] -> m JExpr
219+
convertNullaryOp :: (MonadJSON m) => LossTensorBuiltin -> JExpr -> [Value LossTensorBuiltin] -> m JExpr
220220
convertNullaryOp b fn = \case
221221
[] -> return fn
222222
spine -> arityError b 0 spine
223223

224-
convertUnaryOp :: (MonadJSON m) => LossTensorBuiltin -> (JExpr -> JExpr) -> [WHNFValue LossTensorBuiltin] -> m JExpr
224+
convertUnaryOp :: (MonadJSON m) => LossTensorBuiltin -> (JExpr -> JExpr) -> [Value LossTensorBuiltin] -> m JExpr
225225
convertUnaryOp b fn = \case
226226
[x] -> fn <$> convertValue x
227227
spine -> arityError b 1 spine
228228

229-
convertBinaryOp :: (MonadJSON m) => LossTensorBuiltin -> (JExpr -> JExpr -> JExpr) -> [WHNFValue LossTensorBuiltin] -> m JExpr
229+
convertBinaryOp :: (MonadJSON m) => LossTensorBuiltin -> (JExpr -> JExpr -> JExpr) -> [Value LossTensorBuiltin] -> m JExpr
230230
convertBinaryOp b fn = \case
231231
[x, y] -> fn <$> convertValue x <*> convertValue y
232232
spine -> arityError b 2 spine
233233

234-
convertNaryOp :: (MonadJSON m) => LossTensorBuiltin -> Int -> ([JExpr] -> JExpr) -> [WHNFValue LossTensorBuiltin] -> m JExpr
234+
convertNaryOp :: (MonadJSON m) => LossTensorBuiltin -> Int -> ([JExpr] -> JExpr) -> [Value LossTensorBuiltin] -> m JExpr
235235
convertNaryOp b n fn spine
236236
| length spine == n = fn <$> traverse convertValue spine
237237
| otherwise = arityError b n spine
238238

239-
convertTensorType :: (MonadJSON m) => [WHNFValue LossTensorBuiltin] -> m JExpr
239+
convertTensorType :: (MonadJSON m) => [Value LossTensorBuiltin] -> m JExpr
240240
convertTensorType = \case
241241
[tElem, _dims] -> TensorType <$> convertValue tElem
242242
spine -> arityError (LossTensorDimType L.TensorType) 2 spine
243243

244-
convertIndexType :: (MonadJSON m) => [WHNFValue LossTensorBuiltin] -> m JExpr
244+
convertIndexType :: (MonadJSON m) => [Value LossTensorBuiltin] -> m JExpr
245245
convertIndexType = \case
246246
[_dim] -> return DimensionIndexType
247247
spine -> arityError (LossTensorDimType L.DimensionIndexType) 1 spine
248248

249-
convertSearch :: (MonadJSON m) => [WHNFValue LossTensorBuiltin] -> m JExpr
249+
convertSearch :: (MonadJSON m) => [Value LossTensorBuiltin] -> m JExpr
250250
convertSearch = \case
251251
[unaryOp, lowerBound, upperBound, fn] -> SearchRatTensor <$> convertValue unaryOp <*> convertValue lowerBound <*> convertValue upperBound <*> convertValue fn
252252
spine -> arityError (show L.SearchRatTensor) 5 spine
253253

254-
arityError :: (MonadCompile m, Pretty fn) => fn -> Arity -> [WHNFValue LossTensorBuiltin] -> m a
254+
arityError :: (MonadCompile m, Pretty fn) => fn -> Arity -> [Value LossTensorBuiltin] -> m a
255255
arityError fun arity explicitArgs =
256256
compilerDeveloperError $
257257
"Number of args is different from expected arity:"
@@ -271,7 +271,7 @@ arityError fun arity explicitArgs =
271271
<+> prettyVerbose explicitArgs
272272
)
273273

274-
showEntry :: (MonadJSON m) => WHNFValue LossTensorBuiltin -> m ()
274+
showEntry :: (MonadJSON m) => Value LossTensorBuiltin -> m ()
275275
showEntry e = do
276276
logDebug MaxDetail $ "json-enter:" <+> prettyVerbose e
277277
incrCallDepth

0 commit comments

Comments
 (0)