@@ -48,17 +48,17 @@ type MonadDomain m =
48
48
)
49
49
50
50
data Domain = Domain
51
- { lowerBound :: WHNFValue TensorBuiltin ,
52
- upperBound :: WHNFValue TensorBuiltin
51
+ { lowerBound :: Value TensorBuiltin ,
52
+ upperBound :: Value TensorBuiltin
53
53
}
54
54
55
55
extractSearchDomain ::
56
56
(MonadDomain m ) =>
57
57
DeclProvenance ->
58
- WHNFBinder TensorBuiltin ->
58
+ VBinder TensorBuiltin ->
59
59
Lv ->
60
- WHNFValue TensorBuiltin ->
61
- m (Domain , WHNFValue TensorBuiltin )
60
+ Value TensorBuiltin ->
61
+ m (Domain , Value TensorBuiltin )
62
62
extractSearchDomain _propertyProv _binder _lv value = do
63
63
{-
64
64
_variableInfo <- case typeOf binder of
@@ -95,9 +95,9 @@ extractSearchDomain _propertyProv _binder _lv value = do
95
95
--------------------------------------------------------------------------------
96
96
-- Constraints
97
97
98
- type TensorElementInequality = Inequality UserElementVariable (WHNFValue TensorBuiltin)
98
+ type TensorElementInequality = Inequality UserElementVariable (Value TensorBuiltin)
99
99
100
- type TensorInequality = Inequality Name (WHNFValue TensorBuiltin)
100
+ type TensorInequality = Inequality Name (Value TensorBuiltin)
101
101
102
102
type VariableConstraint = Either TensorElementInequality TensorInequality
103
103
@@ -109,20 +109,20 @@ splitConstraints = partitionEithers
109
109
pattern NoConstraints :: VariableConstraints
110
110
pattern NoConstraints = []
111
111
112
- type ConstrainedValue = (VariableConstraints, WHNFValue TensorBuiltin)
112
+ type ConstrainedValue = (VariableConstraints, Value TensorBuiltin)
113
113
114
- unconstrained :: WHNFValue TensorBuiltin -> ConstrainedValue
114
+ unconstrained :: Value TensorBuiltin -> ConstrainedValue
115
115
unconstrained = (NoConstraints,)
116
116
117
117
updateConstrainedValue ::
118
- WHNFValue TensorBuiltin ->
118
+ Value TensorBuiltin ->
119
119
ConstrainedValue ->
120
120
ConstrainedValue
121
121
updateConstrainedValue originalExpr = \case
122
122
constr@(_ : _, _) -> constr
123
123
([], _) -> ([], originalExpr)
124
124
125
- instance IsConstant (WHNFValue TensorBuiltin) where
125
+ instance IsConstant (Value TensorBuiltin) where
126
126
isZero = \case
127
127
-- This is only semi-decidable, probably need to think harder about what
128
128
-- to do here.
@@ -189,18 +189,18 @@ extractDomainFromConstraints VariableInfo{..} constraints = do
189
189
190
190
extractVarBounds ::
191
191
(MonadCompile m) =>
192
- ([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (WHNFValue TensorBuiltin))]) ->
192
+ ([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (Value TensorBuiltin))]) ->
193
193
UserElementVariable ->
194
- m ([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (WHNFValue TensorBuiltin))])
194
+ m ([TensorElementInequality], [(UserElementVariable, Bounds UserElementVariable (Value TensorBuiltin))])
195
195
extractVarBounds (currentConstraints, solutions) var = do
196
196
(bounds, newInequalities) <- fourierMotzkinElimination var currentConstraints
197
197
return (newInequalities, (var, bounds) : solutions)
198
198
199
199
convertBoundToExpr ::
200
200
Map UserElementVariable Lv ->
201
- (WHNFValue TensorBuiltin -> WHNFValue TensorBuiltin -> WHNFValue TensorBuiltin) ->
202
- NonEmpty (Bound UserElementVariable (WHNFValue TensorBuiltin)) ->
203
- WHNFValue TensorBuiltin
201
+ (Value TensorBuiltin -> Value TensorBuiltin -> Value TensorBuiltin) ->
202
+ NonEmpty (Bound UserElementVariable (Value TensorBuiltin)) ->
203
+ Value TensorBuiltin
204
204
convertBoundToExpr varMap op bounds = foldr1 _ (fmap convertBound bounds)
205
205
where
206
206
-- Ignore strictness for the moment.
@@ -214,7 +214,7 @@ convertBoundToExpr varMap op bounds = foldr1 _ (fmap convertBound bounds)
214
214
--------------------------------------------------------------------------------
215
215
-- Constraint search
216
216
217
- findConstraints :: (MonadSearch m) => WHNFValue TensorBuiltin -> m ConstrainedValue
217
+ findConstraints :: (MonadSearch m) => Value TensorBuiltin -> m ConstrainedValue
218
218
findConstraints expr = case toBoolTensorView expr of
219
219
-------------------------
220
220
-- Unuseful base cases --
@@ -246,15 +246,15 @@ findConstraints expr = case toBoolTensorView expr of
246
246
247
247
handleNot ::
248
248
forall m . (MonadSearch m) =>
249
- WHNFValue TensorBuiltin ->
249
+ Value TensorBuiltin ->
250
250
m ConstrainedValue
251
251
handleNot expr = do
252
252
loweredExpr <- lowerBoolTensor expr
253
253
case loweredExpr of
254
254
INot {} -> return $ unconstrained expr
255
255
newExpr -> updateConstrainedValue expr <$> findConstraints newExpr
256
256
where
257
- lowerBoolTensor :: WHNFValue TensorBuiltin -> m (WHNFValue TensorBuiltin)
257
+ lowerBoolTensor :: Value TensorBuiltin -> m (Value TensorBuiltin)
258
258
lowerBoolTensor e = fromBoolTensorView <$> case toBoolTensorView e of
259
259
----------------
260
260
-- Base cases --
@@ -278,15 +278,15 @@ handleNot expr = do
278
278
VReduceAndTensor {} -> return $ VNotTensor e
279
279
VReduceOrTensor {} -> return $ VNotTensor e
280
280
281
- lowerBool :: WHNFValue TensorBuiltin -> m (WHNFValue TensorBuiltin)
281
+ lowerBool :: Value TensorBuiltin -> m (Value TensorBuiltin)
282
282
lowerBool = \case
283
283
INullaryBoolTensorOp (BoolLiteral b) -> return $ INullaryBoolTensorOp (BoolLiteral b)
284
284
e -> developerError $ "Unexpected expression of type Bool:" <+> prettyVerbose e
285
285
286
286
unfoldEquality ::
287
- WHNFValue TensorBuiltin ->
288
- WHNFValue TensorBuiltin ->
289
- WHNFValue TensorBuiltin
287
+ Value TensorBuiltin ->
288
+ Value TensorBuiltin ->
289
+ Value TensorBuiltin
290
290
unfoldEquality x y = IAnd (IOrderRat Le x y) (IOrderRat Ge x y)
291
291
292
292
--------------------------------------------------------------------------------
@@ -299,8 +299,8 @@ unfoldEquality x y = IAnd (IOrderRat Le x y) (IOrderRat Ge x y)
299
299
handleRatInequality ::
300
300
(MonadSearch m) =>
301
301
OrderOp ->
302
- WHNFValue TensorBuiltin ->
303
- WHNFValue TensorBuiltin ->
302
+ Value TensorBuiltin ->
303
+ Value TensorBuiltin ->
304
304
m ConstrainedValue
305
305
handleRatInequality op e1 e2 = do
306
306
result <- compileRatLinearRelation (mkInequality op) e1 e2
@@ -321,25 +321,25 @@ handleRatInequality op e1 e2 = do
321
321
compileRatLinearRelation ::
322
322
(MonadLogger m, MonadReader VariableInfo m) =>
323
323
(LinearExp -> LinearExp -> relation) ->
324
- WHNFValue TensorBuiltin ->
325
- WHNFValue TensorBuiltin ->
326
- m (Either (WHNFValue TensorBuiltin) relation)
324
+ Value TensorBuiltin ->
325
+ Value TensorBuiltin ->
326
+ m (Either (Value TensorBuiltin) relation)
327
327
compileRatLinearRelation mkRelation x y = do
328
328
runExceptT $ do
329
329
x' <- compileRatLinearExpr x
330
330
y' <- compileRatLinearExpr y
331
331
return $ mkRelation x' y'
332
332
333
- type LinearExp = LinearExpr UserElementVariable (WHNFValue TensorBuiltin)
333
+ type LinearExp = LinearExpr UserElementVariable (Value TensorBuiltin)
334
334
335
335
compileRatLinearExpr ::
336
336
forall m.
337
- (MonadLogger m, MonadReader VariableInfo m, MonadError (WHNFValue TensorBuiltin) m) =>
338
- WHNFValue TensorBuiltin ->
337
+ (MonadLogger m, MonadReader VariableInfo m, MonadError (Value TensorBuiltin) m) =>
338
+ Value TensorBuiltin ->
339
339
m LinearExp
340
340
compileRatLinearExpr = go
341
341
where
342
- go :: WHNFValue TensorBuiltin -> m LinearExp
342
+ go :: Value TensorBuiltin -> m LinearExp
343
343
go expr = case toRatTensorView expr of
344
344
----------------
345
345
-- Base cases --
0 commit comments