Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor VPi to use a Closure #866

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vehicle/src/Vehicle/Backend/LossFunction/JSON.hs
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ convertValue expr = do
VUniverse {} -> resolutionError currentPass "Universe"
VLam binder closure -> do
binder' <- convertBinder binder
body' <- convertClosure binder closure
return $ Lam binder' body'
VPi binder body -> do
closure' <- convertClosure binder closure
return $ Lam binder' closure'
VPi binder closure -> do
typ' <- convertValue (typeOf binder)
body' <- addNameToContext binder $ convertValue body
return $ Pi typ' body'
closure' <- convertClosure binder closure
return $ Pi typ' closure'
VBuiltin b spine -> convertBuiltin b $ filterOutNonExplicitArgs spine
VBoundVar v spine -> do
name <- lvToProperName mempty v
Expand Down
6 changes: 3 additions & 3 deletions vehicle/src/Vehicle/Backend/LossFunction/LossCompilation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ convertValue e = do
VBoundVar v <$> traverseArgs convertValue spine
VBuiltin b spine -> do
convertBuiltinToLoss b spine
VPi binder body -> do
VPi binder closure -> do
binder' <- traverse convertValue binder
body' <- addNameToContext binder $ convertValue body
return $ VPi binder' body'
closure' <- traverseClosure convertValue mempty binder closure
return $ VPi binder' closure'
VLam binder closure -> do
binder' <- traverse convertValue binder
closure' <- traverseClosure convertValue mempty binder closure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ module Vehicle.Backend.LossFunction.TensorCompilation
)
where

import Control.Monad (void)
import Control.Monad.Except (MonadError (..))
import Control.Monad.Reader (MonadReader (..))
import Control.Monad.Trans.Reader (ReaderT (..))
import Data.List.NonEmpty (NonEmpty (..))
import Vehicle.Backend.LossFunction.Core (pattern VLam2)
import Vehicle.Compile.Arity (Arity)
import Vehicle.Compile.Context.Bound
import Vehicle.Compile.Context.Free.Class (MonadFreeContext, getFreeEnv)
import Vehicle.Compile.Context.Name (MonadNameContext, addNameToContext, getBinderDepth, getNameContext)
import Vehicle.Compile.Error
Expand Down Expand Up @@ -76,10 +74,11 @@ convertValue = go
VBoundVar v <$> traverseArgs go spine
VBuiltin b spine ->
convertBuiltinToTensors b spine
VPi binder body -> do
VPi binder closure -> do
binder' <- traverse go binder
body' <- addBinderToContext (void binder) $ go body
return $ VPi binder' body'
freeEnv <- getFreeEnv
closure' <- traverseClosure convertValue freeEnv binder closure
return $ VPi binder' closure'
VLam binder closure -> do
binder' <- traverse go binder
freeEnv <- getFreeEnv
Expand Down
23 changes: 18 additions & 5 deletions vehicle/src/Vehicle/Backend/LossFunction/ZeroTensorLifting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ module Vehicle.Backend.LossFunction.ZeroTensorLifting
)
where

import Data.Tuple (swap)
import Vehicle.Compile.Context.Name
import Vehicle.Compile.Error
import Vehicle.Compile.Normalise.NBE (traverseClosure, traverseClosureGeneric)
import Vehicle.Compile.Normalise.NBE (eval, traverseClosure)
import Vehicle.Compile.Normalise.Quote (Quote (..))
import Vehicle.Compile.Prelude
import Vehicle.Compile.Print (prettyFriendly, prettyVerbose)
import Vehicle.Data.Builtin.Loss
Expand Down Expand Up @@ -37,9 +37,22 @@ liftDecl ::
(VType LossTensorBuiltin, Value LossTensorBuiltin) ->
m (VType LossTensorBuiltin, Value LossTensorBuiltin)
liftDecl (t, e) = case (t, e) of
(VPi piBinder piBody, VLam lamBinder closure) -> do
(newClosure, newPiBody) <- traverseClosureGeneric (\lamBody -> liftDecl (piBody, lamBody)) swap mempty lamBinder closure
return (VPi piBinder newPiBody, VLam lamBinder newClosure)
(VPi piBinder (Closure piEnv piBody), VLam lamBinder (Closure lamEnv lamBody)) -> do
ctx <- getBinderContext
let lv = boundCtxLv ctx
let newPiEnv = extendEnvWithBound lv piBinder piEnv
let newLamEnv = extendEnvWithBound lv lamBinder lamEnv

(resultPiBody, resultLamBody) <- addNameToContext lamBinder $ do
normPiBody <- eval mempty newPiEnv piBody
normLamBody <- eval mempty newLamEnv lamBody
liftDecl (normPiBody, normLamBody)

let finalPiBody = quote mempty (lv + 1) resultPiBody
let finalLamBody = quote mempty (lv + 1) resultLamBody

let finalEnv = boundContextToEnv ctx
return (VPi piBinder (Closure finalEnv finalPiBody), VLam lamBinder (Closure finalEnv finalLamBody))
(typ, body) -> do
let (wasZeroDimensional, newType) = liftType typ
liftedBody <- liftExpr wasZeroDimensional body
Expand Down
6 changes: 0 additions & 6 deletions vehicle/src/Vehicle/Compile/Arity.hs
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
module Vehicle.Compile.Arity where

import Vehicle.Data.Code.Expr
import Vehicle.Data.Code.Value
import Vehicle.Prelude (isExplicit)

type Arity = Int

class HasArity a where
arityOf :: a -> Arity

arityFromVType :: Value builtin -> Arity
arityFromVType = \case
VPi _ r -> 1 + arityFromVType r
_ -> 0

-- | This is only safe when the type is known to be in normalised type.
explicitArityFromType :: Type builtin -> Arity
explicitArityFromType = \case
Expand Down
4 changes: 2 additions & 2 deletions vehicle/src/Vehicle/Compile/Descope.hs
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ genericDescopeValue f e = case e of
var <- S.Var p <$> lvToName f p v
args <- traverseArgs (genericDescopeValue f) spine
return $ S.normAppList var args
VPi binder body -> do
VPi binder closure -> do
binder' <- traverse (genericDescopeValue f) binder
body' <- addNameToContext binder $ genericDescopeValue f body
body' <- addNameToContext binder $ descopeClosure f binder closure
return $ S.Pi p binder' body'
VLam binder closure -> do
binder' <- traverse (genericDescopeValue f) binder
Expand Down
6 changes: 4 additions & 2 deletions vehicle/src/Vehicle/Compile/ExpandResources/Network.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Control.Monad.Except (MonadError (..))
import Data.Map qualified as Map
import Vehicle.Compile.Error
import Vehicle.Compile.ExpandResources.Core
import Vehicle.Compile.Normalise.NBE (normaliseClosure)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Print
import Vehicle.Compile.Resource
Expand Down Expand Up @@ -42,11 +43,12 @@ getNetworkType ::
GluedType Builtin ->
m NetworkType
getNetworkType decl networkType = case normalised networkType of
VPi binder result
VPi binder closure
| visibilityOf binder /= Explicit -> typingError
| otherwise -> do
inputDetails <- getTensorType Input (typeOf binder)
outputDetails <- getTensorType Output result
resultType <- normaliseClosure 0 binder closure
outputDetails <- getTensorType Output resultType
let networkDetails = NetworkType inputDetails outputDetails
return networkDetails
_ -> compilerDeveloperError "Should have caught the fact that the network type is not a function during type-checking"
Expand Down
22 changes: 18 additions & 4 deletions vehicle/src/Vehicle/Compile/Normalise/NBE.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module Vehicle.Compile.Normalise.NBE
normaliseInEmptyEnv,
normaliseApp,
normaliseBuiltin,
normaliseClosure,
eval,
evalApp,
traverseClosure,
Expand Down Expand Up @@ -78,10 +79,25 @@ normaliseBuiltin b spine = do
freeEnv <- getFreeEnv
evalBuiltin freeEnv b spine

normaliseClosure ::
(MonadNorm builtin m, MonadFreeContext builtin m) =>
Lv ->
VBinder builtin ->
Closure builtin ->
m (Value builtin)
normaliseClosure lv binder closure = do
freeEnv <- getFreeEnv
evalClosure freeEnv closure (binder, VBoundVar lv [])

-----------------------------------------------------------------------------
-- Evaluation of closures

evalClosure :: (MonadNorm builtin m) => FreeEnv builtin -> Closure builtin -> (VBinder builtin, Value builtin) -> m (Value builtin)
evalClosure ::
(MonadNorm builtin m) =>
FreeEnv builtin ->
Closure builtin ->
(VBinder builtin, Value builtin) ->
m (Value builtin)
evalClosure freeEnv (Closure env body) (binder, arg) = do
let newEnv = extendEnvWithDefined arg binder env
eval freeEnv newEnv body
Expand Down Expand Up @@ -115,9 +131,7 @@ eval freeEnv boundEnv expr = do
return $ VLam binder' (Closure boundEnv body)
Pi _ binder body -> do
binder' <- traverse (eval freeEnv boundEnv) binder
let newBoundEnv = extendEnvWithBound (Lv $ length boundEnv) binder' boundEnv
body' <- eval freeEnv newBoundEnv body
return $ VPi binder' body'
return $ VPi binder' (Closure boundEnv body)
Let _ bound binder body -> do
binder' <- traverse (eval freeEnv boundEnv) binder
boundNormExpr <- eval freeEnv boundEnv bound
Expand Down
4 changes: 2 additions & 2 deletions vehicle/src/Vehicle/Compile/Normalise/Quote.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ instance (ConvertableBuiltin builtin1 builtin2) => Quote (Value builtin1) (Expr
VBuiltin b spine -> do
let fn = convertBuiltin p b
quoteApp level p fn spine
VPi binder body -> do
VPi binder closure -> do
let quotedBinder = quote p level binder
let quotedBody = quote p (level + 1) body
let quotedBody = quoteClosure p level (binder, closure)
Pi p quotedBinder quotedBody
VLam binder closure -> do
let quotedBinder = quote p level binder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Vehicle.Compile.Type.Constraint.LinearityAnnotationRestrictions
where

import Vehicle.Compile.Error
import Vehicle.Compile.Normalise.Quote (Quote (..))
import Vehicle.Compile.Normalise.Quote (Quote (..), quoteClosure)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Type.Core
import Vehicle.Compile.Type.Monad
Expand All @@ -22,9 +22,9 @@ checkLinearityNetworkType (ident, p) networkType = case normalised networkType o
-- \|Decomposes the Pi types in a network type signature, checking that the
-- binders are explicit and their types are equal. Returns a function that
-- prepends the max linearity constraint.
VPi binder result -> do
VPi binder closure -> do
let inputLin = quote mempty 0 (typeOf binder)
let outputLin = quote mempty 0 result
let outputLin = quoteClosure mempty 0 (binder, closure)

-- The linearity of the output of a network is the max of 1) Linear (as outputs
-- are also variables) and 2) the linearity of its input. So prepend this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ where

import Data.Maybe (mapMaybe)
import Vehicle.Compile.Error
import Vehicle.Compile.Normalise.NBE (normaliseClosure)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Print (prettyFriendly)
import Vehicle.Compile.Type.Constraint.Core
Expand Down Expand Up @@ -57,11 +58,12 @@ solve = \case

solveQuantifierLinearity :: Quantifier -> LinearitySolver
solveQuantifierLinearity _ _ [getNMeta -> Just m, _] = blockOn [m]
solveQuantifierLinearity _ info [VPi binder body, res] = Just $ do
solveQuantifierLinearity _ info@(ctx, _) [VPi binder closure, res] = Just $ do
let varName = getBinderName binder
let domainLin = VLinearityExpr (Linear (QuantifiedVariableProvenance (provenanceOf binder) varName))
domEq <- createInstanceUnification info (typeOf binder) domainLin
resEq <- createInstanceUnification info res body
resultType <- normaliseClosure (contextDBLevel ctx) binder closure
resEq <- createInstanceUnification info res resultType
return $ Progress [domEq, resEq]
solveQuantifierLinearity _ _ _ = Nothing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Vehicle.Compile.Type.Constraint.PolarityAnnotationRestrictions
where

import Vehicle.Compile.Error
import Vehicle.Compile.Normalise.Quote (Quote (..))
import Vehicle.Compile.Normalise.Quote (Quote (..), quoteClosure)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Type.Core
import Vehicle.Compile.Type.Monad
Expand All @@ -22,9 +22,9 @@ checkNetworkType (_, p) networkType = case normalised networkType of
-- \|Decomposes the Pi types in a network type signature, checking that the
-- binders are explicit and their types are equal. Returns a function that
-- prepends the max linearity constraint.
VPi binder result -> do
VPi binder closure -> do
let inputPol = quote mempty 0 (typeOf binder)
let outputPol = quote mempty 0 result
let outputPol = quoteClosure mempty 0 (binder, closure)

createFreshUnificationConstraint p mempty CheckingAuxiliary (PolarityExpr p Unquantified) inputPol
createFreshUnificationConstraint p mempty CheckingAuxiliary (PolarityExpr p Unquantified) outputPol
Expand Down
16 changes: 11 additions & 5 deletions vehicle/src/Vehicle/Compile/Type/Constraint/PolaritySolver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ where
import Control.Monad.Except (MonadError (..))
import Data.Maybe (mapMaybe)
import Vehicle.Compile.Error
import Vehicle.Compile.Normalise.NBE (normaliseClosure)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Print (prettyFriendly)
import Vehicle.Compile.Type.Constraint.Core
Expand Down Expand Up @@ -66,12 +67,14 @@ solveNegPolarity info@(ctx, _) [arg1, res] = case arg1 of
solveNegPolarity _ _ = Nothing

solveQuantifierPolarity :: Quantifier -> PolaritySolver
solveQuantifierPolarity q info [lam, res] = case lam of
solveQuantifierPolarity q info@(ctx, _) [lam, res] = case lam of
(getNMeta -> Just m) -> blockOn [m]
(VPi binder resPol) -> Just $ do
binderEq <- createInstanceUnification info (typeOf binder) (VPolarityExpr Unquantified)
let tc = PolarityRelation $ AddPolarity q
(_, addConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (Arg mempty Explicit Relevant <$> [resPol, res]))
let lv = contextDBLevel ctx
resultPolarity <- normaliseClosure lv binder resPol
(_, addConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (explicit <$> [resultPolarity, res]))
return $ Progress [binderEq, addConstraint]
_ -> Nothing
solveQuantifierPolarity _ _c _ = Nothing
Expand Down Expand Up @@ -135,10 +138,13 @@ solveFunctionPolarity functionPosition info@(ctx, _) [arg, res] = case (arg, res
let pol3 = VPolarityExpr $ mapPolarityProvenance addFuncProv pol
resEq <- createInstanceUnification info res pol3
return $ Progress [resEq]
(VPi binder1 body1, VPi binder2 body2) -> Just $ do
(VPi binder1 closure1, VPi binder2 closure2) -> Just $ do
let tc = PolarityRelation $ FunctionPolarity functionPosition
(_, binderConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (Arg mempty Explicit Relevant <$> [typeOf binder1, typeOf binder2]))
(_, bodyConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (Arg mempty Explicit Relevant <$> [body1, body2]))
(_, binderConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (explicit <$> [typeOf binder1, typeOf binder2]))
let lv = contextDBLevel ctx
body1 <- normaliseClosure lv binder1 closure1
body2 <- normaliseClosure lv binder2 closure2
(_, bodyConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (explicit <$> [body1, body2]))
return $ Progress [binderConstraint, bodyConstraint]
_ -> Nothing
solveFunctionPolarity _ _ _ = Nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ where

import Control.Monad.Except (MonadError (..))
import Vehicle.Compile.Error
import Vehicle.Compile.Normalise.NBE (normaliseClosure)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Type.Monad.Class
import Vehicle.Data.Builtin.Standard
Expand Down Expand Up @@ -111,12 +112,13 @@ restrictStandardNetworkType decl networkType = case normalised networkType of
-- \|Decomposes the Pi types in a network type signature, checking that the
-- binders are explicit and their types are equal. Returns a function that
-- prepends the max linearity constraint.
VPi binder result
VPi binder closure
| visibilityOf binder /= Explicit ->
throwError $ NetworkTypeHasNonExplicitArguments decl networkType binder
| otherwise -> do
checkTensorType Input (typeOf binder)
checkTensorType Output result
resultType <- normaliseClosure 0 binder closure
checkTensorType Output resultType
return $ unnormalised networkType
_ -> throwError $ NetworkTypeIsNotAFunction decl networkType
where
Expand Down
Loading
Loading