Skip to content

Commit

Permalink
More refactoring in preparation for working domain implementation (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewDaggitt authored Aug 23, 2024
1 parent 04c45c1 commit 9f28d24
Show file tree
Hide file tree
Showing 80 changed files with 789 additions and 572 deletions.
4 changes: 4 additions & 0 deletions vehicle-syntax/src/Vehicle/Syntax/Builtin/BasicOperations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module Vehicle.Syntax.Builtin.BasicOperations
orderOpName,
Strictness (..),
isStrict,
isForward,
flipStrictness,
flipOrder,
chainable,
Expand Down Expand Up @@ -112,6 +113,9 @@ orderOpName = \case
isStrict :: OrderOp -> Bool
isStrict order = order == Lt || order == Gt

isForward :: OrderOp -> Bool
isForward order = order == Lt || order == Le

flipStrictness :: OrderOp -> OrderOp
flipStrictness = \case
Le -> Lt
Expand Down
10 changes: 5 additions & 5 deletions vehicle/src/Vehicle/Backend/Agda/CapitaliseTypeNames.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Vehicle.Data.Expr.Interface
-- convention. This pass identifies all such defined functions and capitalises
-- all references to them. Cannot be done during the main compilation pass as we
-- need to be able to distinguish between free and bound variables.
capitaliseTypeNames :: (HasStandardTypes builtin) => Prog var builtin -> Prog var builtin
capitaliseTypeNames :: (BuiltinHasStandardTypes builtin) => Prog var builtin -> Prog var builtin
capitaliseTypeNames prog = evalState (cap prog) mempty

--------------------------------------------------------------------------------
Expand All @@ -28,10 +28,10 @@ type MonadCapitalise m = MonadState (Set Identifier) m
class CapitaliseTypes a where
cap :: (MonadCapitalise m) => a -> m a

instance (HasStandardTypes builtin) => CapitaliseTypes (Prog var builtin) where
instance (BuiltinHasStandardTypes builtin) => CapitaliseTypes (Prog var builtin) where
cap (Main ds) = Main <$> traverse cap ds

instance (HasStandardTypes builtin) => CapitaliseTypes (Decl var builtin) where
instance (BuiltinHasStandardTypes builtin) => CapitaliseTypes (Decl var builtin) where
cap d = case d of
DefAbstract p ident r t ->
DefAbstract p <$> capitaliseIdentifier ident <*> pure r <*> cap t
Expand Down Expand Up @@ -72,14 +72,14 @@ capitaliseIdentifier ident@(Identifier m s) = do
then capitaliseFirstLetter s
else s

isTypeDef :: (HasStandardTypes builtin) => Expr var builtin -> Bool
isTypeDef :: (BuiltinHasStandardTypes builtin) => Expr var builtin -> Bool
isTypeDef t = case t of
-- We don't capitalise things of type `Bool` because they will be lifted
-- to the type level, only things of type `X -> Bool`.
Pi _ _ result -> go result
_ -> False
where
go :: (HasStandardTypes builtin) => Expr var builtin -> Bool
go :: (BuiltinHasStandardTypes builtin) => Expr var builtin -> Bool
go (IBoolType _) = True
go (Pi _ _ res) = go res
go _ = False
28 changes: 27 additions & 1 deletion vehicle/src/Vehicle/Backend/LossFunction/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import GHC.Generics (Generic)
import Vehicle.Compile.Prelude
import Vehicle.Data.Builtin.Loss.Core
import Vehicle.Data.Builtin.Standard.Core (Builtin)
import Vehicle.Data.Expr.Normalised (BoundEnv, Spine, VBinder, VDecl, Value, WHNFClosure (..))
import Vehicle.Data.Builtin.Tensor (TensorBuiltin)
import Vehicle.Data.Builtin.Tensor qualified as T
import Vehicle.Data.Expr.Value (BoundEnv, NFValue, Spine, VBinder, VDecl, Value (..), WHNFClosure (..))
import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..))

--------------------------------------------------------------------------------
Expand All @@ -17,6 +19,27 @@ import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..))
data LossClosure
= LossClosure (BoundEnv MixedClosure LossBuiltin) (Expr Ix LossBuiltin)

-- | Okay, so closures for loss functions are complicated. How compilation
-- currently works is that we first do standard normalisation on the `Builtin`
-- type, then recurse down to the closures converting `Builtin` to `LossBuiltin`.
--
-- This means that we need the standard closures from the first half.
-- However, when we convert a standard `Builtin` to the equivalent loss
-- expression, we need we apply the translated expression to the previous arguments.
--
-- subst evalApp
-- e.g _<=_ a b --------> (\x y -> x - y) a b ---------> a - b
--
-- During the evalApp phase we may need to form closures over expressions that have
-- already been translated to loss functions. This means that we need the loss
-- closure constructor as well. In theory, if all arguments to the builtin are present
-- (i.e. the boolean operator is fully-applied) then these loss closures are only
-- formed temporarily and by the time `evalApp` has finally finished executing
-- they will no longer exist.
--
-- However, during the intermediary computation of `evalApp` both closures can exist
-- within the same expression. Maybe there's a nicer way of encoding this all in the
-- types that avoids this, but I haven't found it yet.
data MixedClosure
= StandardClos (WHNFClosure Builtin)
| LossClos LossClosure
Expand Down Expand Up @@ -68,3 +91,6 @@ preservedStdLibOps =
Set.fromList
[ StdForeachIndex
]

constRatTensor :: Rational -> NFValue TensorBuiltin
constRatTensor v = VBuiltin (T.ConstRatTensor $ T.convertRat v) [explicit (VBuiltin T.NilList [])]
131 changes: 91 additions & 40 deletions vehicle/src/Vehicle/Backend/LossFunction/Domain.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
{-
module Vehicle.Backend.LossFunction.Domain where
module Vehicle.Backend.LossFunction.Domain
( -- extractSearchDomain,
-- Domain (..),
)
where

{-
import Control.Applicative (Applicative (..))
import Control.Monad (when)
import Control.Monad.Except (MonadError (..), runExceptT, void)
import Control.Monad.Reader (MonadReader (..), ReaderT (..))
import Data.Either (partitionEithers)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Proxy (Proxy (..))
Expand All @@ -11,21 +19,19 @@ import Vehicle.Compile.Boolean.LowerNot (lowerNot)
import Vehicle.Compile.Boolean.Unblock (ReduceVectorVars, UnblockingActions (..))
import Vehicle.Compile.Boolean.Unblock qualified as Unblocking
import Vehicle.Compile.Context.Bound (MonadBoundContext, getNamedBoundCtx)
import Vehicle.Compile.Error (CompileError (..), MonadCompile)
import Vehicle.Compile.Context.Free (MonadFreeContext)
import Vehicle.Compile.Error (CompileError (..), MonadCompile, unexpectedExprError)
import Vehicle.Compile.Normalise.NBE (normaliseInEnv)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Print (prettyVerbose)
import Vehicle.Compile.Rational.LinearExpr
import Vehicle.Compile.Variable (createUserVar)
import Vehicle.Data.Builtin.Standard
import Vehicle.Data.Builtin.Tensor (TensorBuiltin (..))
import Vehicle.Data.Expr.Interface
import Vehicle.Data.Expr.Normalised
import Vehicle.Compile.Context.Free (MonadFreeContext)
import Control.Applicative (Applicative(..))
import Control.Monad.Reader (MonadReader (..), ReaderT (..))
import Vehicle.Data.QuantifiedVariable
import Control.Monad (when)
import Vehicle.Data.Expr.Linear (addExprs, rearrangeExprToSolveFor)
import Data.Either (partitionEithers)
import Vehicle.Data.Expr.Value
import Vehicle.Data.QuantifiedVariable
type MonadDomain m =
( MonadCompile m,
Expand All @@ -34,15 +40,15 @@ type MonadDomain m =
)
type MonadSearch m =
( MonadDomain m
, MonadReader VariableInfo m
( MonadDomain m,
MonadReader VariableInfo m
)
-- | Information for the variable whose domain we are trying to find.
data VariableInfo = VariableInfo
{ variableLv :: Lv
, vectorExpr :: WHNFValue Builtin
, reducedVars :: [(Lv, UserRationalVariable)]
{ variableLv :: Lv,
vectorExpr :: WHNFValue Builtin,
reducedVars :: [(Lv, UserRationalVariable)]
}
extractSearchDomain ::
Expand All @@ -51,7 +57,7 @@ extractSearchDomain ::
MixedLossBinder ->
Lv ->
WHNFClosure Builtin ->
m (Maybe (Domain, WHNFValue Builtin))
m (Domain, WHNFValue Builtin)
extractSearchDomain propertyProv binder lv (WHNFClosure env expr) = do
-- Convert the binder
namedCtx <- getNamedBoundCtx (Proxy @MixedLossValue)
Expand All @@ -69,24 +75,24 @@ extractSearchDomain propertyProv binder lv (WHNFClosure env expr) = do
case maybeConstraints of
Nothing -> throwError $ NoQuantifierDomainFound propertyProv (void binder) Nothing
Just constraints -> do
let maybeDomain = extractDomainFromConstraints constraints (userTensorVarDimensions userVar) reducedUseVars
let maybeDomain = extractDomainFromConstraints constraints reducedUseVars
case maybeDomain of
Left missingCostraints -> throwError $ NoQuantifierDomainFound propertyProv (void binder) (Just missingCostraints)
Right domain -> return $ Just (domain, remainder)
Right domain -> return (domain, remainder)
--------------------------------------------------------------------------------
-- Constraints
data VariableConstraints = VariableConstraints
{ lowerBounds :: Map Lv (WHNFValue Builtin),
upperBounds :: Map Lv (WHNFValue Builtin)
{ lowerBounds :: Map Lv (NFValue TensorBuiltin),
upperBounds :: Map Lv (NFValue TensorBuiltin)
}
instance Semigroup VariableConstraints where
x <> y =
VariableConstraints
{ lowerBounds = Map.unionWith IMax (lowerBounds x) (lowerBounds y),
upperBounds = Map.unionWith IMin (upperBounds x) (upperBounds y)
{ lowerBounds = Map.unionWith (\u v -> VBuiltin MaxRatTensor (explicit <$> [u, v])) (lowerBounds x) (lowerBounds y),
upperBounds = Map.unionWith (\u v -> VBuiltin MinRatTensor (explicit <$> [u, v])) (upperBounds x) (upperBounds y)
}
instance Monoid VariableConstraints where
Expand All @@ -106,16 +112,15 @@ updateConstrainedValue originalExpr = \case
-- Domain
data Domain = Domain
{ lowerBound :: WHNFValue Builtin,
upperBound :: WHNFValue Builtin
{ lowerBound :: NFValue TensorBuiltin,
upperBound :: NFValue TensorBuiltin
}
extractDomainFromConstraints ::
VariableConstraints ->
TensorShape ->
[(Lv, UserRationalVariable)] ->
Either [(UserRationalVariable, UnderConstrainedVariableStatus)] Domain
extractDomainFromConstraints VariableConstraints{..} tensorShape allVariables = do
extractDomainFromConstraints VariableConstraints {..} allVariables = do
let lowerBoundExprs = flip map allVariables $ \(lv, var) ->
case (Map.lookup lv lowerBounds, Map.lookup lv upperBounds) of
(Just x, Just y) -> Right (x, y)
Expand All @@ -127,9 +132,10 @@ extractDomainFromConstraints VariableConstraints{..} tensorShape allVariables =
if not $ null missingVars
then Left missingVars
else do
let n = length allVariables
let (lowerBoundElements, upperBoundElements) = unzip presentVarBounds
let lowerBoundExpr = tensorLikeToExpr id tensorShape lowerBoundElements
let upperBoundExpr = tensorLikeToExpr id tensorShape upperBoundElements
let lowerBoundExpr = VBuiltin (StackRatTensor n) (explicit <$> lowerBoundElements)
let upperBoundExpr = VBuiltin (StackRatTensor n) (explicit <$> upperBoundElements)
Right $ Domain lowerBoundExpr upperBoundExpr
--------------------------------------------------------------------------------
Expand Down Expand Up @@ -194,19 +200,21 @@ tryPurifyAssertion value whenPure = do
Right (Left purified) -> updateConstrainedValue value <$> findConstraints purified
unblockBoolExpr ::
(MonadDomain m) =>
(MonadSearch m) =>
WHNFValue Builtin ->
m ConstrainedValue
unblockBoolExpr value = do
ctx <- getNamedBoundCtx (Proxy @Builtin)
ctx <- getNamedBoundCtx (Proxy @MixedLossValue)
result <- runExceptT (Unblocking.unblockBoolExpr ctx unblockingActions value)
case result of
Left {} -> return (Nothing, value)
Right unblockedValue -> do
constrainedValue <- findConstraints unblockedValue
return $ updateConstrainedValue value constrainedValue
unblockingActions :: (MonadError (WHNFValue Builtin) m) => UnblockingActions m
unblockingActions ::
(MonadError (WHNFValue Builtin) m, MonadReader VariableInfo m) =>
UnblockingActions m
unblockingActions =
UnblockingActions
{ unblockFreeVectorVar = unblockFreeVectorVariable,
Expand All @@ -228,28 +236,71 @@ unblockBoundVectorVariable ::
Lv ->
m (WHNFValue Builtin)
unblockBoundVectorVariable lv = do
VariableInfo{..} <- ask
VariableInfo {..} <- ask
when (lv /= variableLv) $
throwError $ VBoundVar lv []
throwError $
VBoundVar lv []
return vectorExpr
--------------------------------------------------------------------------------
-- Compilation of inequalities
handleRatInequality ::
(MonadDomain m) =>
(MonadSearch m) =>
OrderOp ->
WHNFValue Builtin ->
WHNFValue Builtin ->
m ConstrainedValue
handleRatInequality op e1 e2 = do
result <- compileRatLinearRelation _ e1 e2
let e = ISub SubRat e1 e2
result <- runExceptT (compileRatLinearExpr e)
case result of
Left NonLinearity -> return (Nothing, IOrderRat op e1 e2)
Right (le1, le2) -> do
let le = addExprs 1 (-1) le1 le2
let (_, rearrangedExpr) = rearrangeExprToSolveFor _ le
return _
Right (Linear a b) -> do
case op of
Le -> return (_, ITrueExpr mempty)
_ -> return (Nothing, IOrderRat op e1 e2)
data Bound
= Constant (NFValue TensorBuiltin)
| Linear (NFValue TensorBuiltin) (NFValue TensorBuiltin)
compileRatLinearExpr ::
forall m.
(MonadLogger m, MonadError NonLinearity m) =>
WHNFValue Builtin ->
m Bound
compileRatLinearExpr = go
where
go :: WHNFValue Builtin -> m Bound
go e = case e of
----------------
-- Base cases --
----------------
IRatLiteral _ l -> return $ Constant _
VBoundVar lv [] -> return $ Linear _ _
---------------------
-- Inductive cases --
---------------------
INeg NegRat v -> scaleExpr (-1) <$> go v
IAdd AddRat e1 e2 -> addExprs 1 1 <$> go e1 <*> go e2
ISub SubRat e1 e2 -> addExprs 1 (-1) <$> go e1 <*> go e2
IMul MulRat e1 e2 -> do
e1' <- go e1
e2' <- go e2
case (e1', e2') of
(Constant c1, Constant c2) -> return $ scaleExpr c1 e2'
(_, Just c2) -> return $ scaleExpr c2 e1'
_ -> throwError NonLinearity
IDiv DivRat e1 e2 -> do
e1' <- go e1
e2' <- go e2
case isConstant e2' of
(Just c2) -> return $ scaleExpr (1 / c2) e1'
_ -> throwError NonLinearity
-----------------
-- Error cases --
-----------------
ex -> unexpectedExprError "compile linear rational expression" $ prettyVerbose ex
-}
15 changes: 3 additions & 12 deletions vehicle/src/Vehicle/Backend/LossFunction/LogicCompilation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ module Vehicle.Backend.LossFunction.LogicCompilation
( compileLogic,
convertToLossBuiltins,
normStandardExprToLoss,
normLossExprToLoss,
MonadLogicCtx,
runMonadLogicT,
)
Expand All @@ -29,7 +28,7 @@ import Vehicle.Compile.Prelude
import Vehicle.Compile.Print (prettyFriendly, prettyVerbose)
import Vehicle.Data.Builtin.Loss
import Vehicle.Data.Expr.Interface (pattern INot)
import Vehicle.Data.Expr.Normalised (VBinder, Value (..), WHNFBoundEnv, WHNFClosure (..), WHNFValue, boundContextToEnv, extendEnvWithBound, extendEnvWithDefined)
import Vehicle.Data.Expr.Value (VBinder, Value (..), WHNFBoundEnv, WHNFClosure (..), WHNFValue, boundContextToEnv, extendEnvWithBound, extendEnvWithDefined)
import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..))
import Vehicle.Syntax.Builtin (Builtin)
import Vehicle.Syntax.Builtin qualified as V
Expand Down Expand Up @@ -194,15 +193,7 @@ normStandardExprToLoss ::
m MixedLossValue
normStandardExprToLoss boundEnv expr = do
standardValue <- normaliseInEnv boundEnv expr
result <- convertToLossBuiltins standardValue
return result

normLossExprToLoss ::
(MonadLogic m) =>
MixedBoundEnv ->
Expr Ix LossBuiltin ->
m MixedLossValue
normLossExprToLoss = eval mempty
convertToLossBuiltins standardValue

convertToLossBuiltins ::
forall m.
Expand Down Expand Up @@ -273,7 +264,7 @@ convertBuiltin builtin spine = convert builtin
V.LIndex x -> unchangedBuiltin (Index x)
V.LNat x -> unchangedBuiltin (Nat x)
V.LRat x -> unchangedBuiltin (Rat x)
V.LVec _x -> unchangedBuiltin Vector
V.LVec n -> unchangedBuiltin $ Vector n

convertBuiltinFunction :: V.BuiltinFunction -> m MixedLossValue
convertBuiltinFunction b = case b of
Expand Down
Loading

0 comments on commit 9f28d24

Please sign in to comment.