From c08bc94e9292c16c94018cc6ad03573f8d7c1122 Mon Sep 17 00:00:00 2001 From: MatthewDaggitt Date: Thu, 28 Nov 2024 16:31:47 +0800 Subject: [PATCH] Refactor VPi to actually use a closure --- .../src/Vehicle/Backend/LossFunction/JSON.hs | 10 +++---- .../Backend/LossFunction/LossCompilation.hs | 6 ++-- .../Backend/LossFunction/TensorCompilation.hs | 9 +++--- .../Backend/LossFunction/ZeroTensorLifting.hs | 23 ++++++++++---- vehicle/src/Vehicle/Compile/Arity.hs | 6 ---- vehicle/src/Vehicle/Compile/Descope.hs | 4 +-- .../Compile/ExpandResources/Network.hs | 6 ++-- vehicle/src/Vehicle/Compile/Normalise/NBE.hs | 22 +++++++++++--- .../src/Vehicle/Compile/Normalise/Quote.hs | 4 +-- .../LinearityAnnotationRestrictions.hs | 6 ++-- .../Type/Constraint/LinearitySolver.hs | 6 ++-- .../PolarityAnnotationRestrictions.hs | 6 ++-- .../Compile/Type/Constraint/PolaritySolver.hs | 16 ++++++---- .../StandardAnnotationRestrictions.hs | 6 ++-- .../Type/Constraint/UnificationSolver.hs | 30 +++++-------------- .../src/Vehicle/Compile/Type/Meta/Variable.hs | 7 +++-- vehicle/src/Vehicle/Data/Code/Value.hs | 2 +- 17 files changed, 95 insertions(+), 74 deletions(-) diff --git a/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs b/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs index 644e786f7..314e7c1e0 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs @@ -174,12 +174,12 @@ convertValue expr = do VUniverse {} -> resolutionError currentPass "Universe" VLam binder closure -> do binder' <- convertBinder binder - body' <- convertClosure binder closure - return $ Lam binder' body' - VPi binder body -> do + closure' <- convertClosure binder closure + return $ Lam binder' closure' + VPi binder closure -> do typ' <- convertValue (typeOf binder) - body' <- addNameToContext binder $ convertValue body - return $ Pi typ' body' + closure' <- convertClosure binder closure + return $ Pi typ' closure' VBuiltin b spine -> convertBuiltin b $ filterOutNonExplicitArgs spine VBoundVar v spine -> do name <- lvToProperName mempty v diff --git a/vehicle/src/Vehicle/Backend/LossFunction/LossCompilation.hs b/vehicle/src/Vehicle/Backend/LossFunction/LossCompilation.hs index a0344984a..90dbf0dca 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/LossCompilation.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/LossCompilation.hs @@ -99,10 +99,10 @@ convertValue e = do VBoundVar v <$> traverseArgs convertValue spine VBuiltin b spine -> do convertBuiltinToLoss b spine - VPi binder body -> do + VPi binder closure -> do binder' <- traverse convertValue binder - body' <- addNameToContext binder $ convertValue body - return $ VPi binder' body' + closure' <- traverseClosure convertValue mempty binder closure + return $ VPi binder' closure' VLam binder closure -> do binder' <- traverse convertValue binder closure' <- traverseClosure convertValue mempty binder closure diff --git a/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs b/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs index b799e678c..e89242347 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs @@ -7,14 +7,12 @@ module Vehicle.Backend.LossFunction.TensorCompilation ) where -import Control.Monad (void) import Control.Monad.Except (MonadError (..)) import Control.Monad.Reader (MonadReader (..)) import Control.Monad.Trans.Reader (ReaderT (..)) import Data.List.NonEmpty (NonEmpty (..)) import Vehicle.Backend.LossFunction.Core (pattern VLam2) import Vehicle.Compile.Arity (Arity) -import Vehicle.Compile.Context.Bound import Vehicle.Compile.Context.Free.Class (MonadFreeContext, getFreeEnv) import Vehicle.Compile.Context.Name (MonadNameContext, addNameToContext, getBinderDepth, getNameContext) import Vehicle.Compile.Error @@ -76,10 +74,11 @@ convertValue = go VBoundVar v <$> traverseArgs go spine VBuiltin b spine -> convertBuiltinToTensors b spine - VPi binder body -> do + VPi binder closure -> do binder' <- traverse go binder - body' <- addBinderToContext (void binder) $ go body - return $ VPi binder' body' + freeEnv <- getFreeEnv + closure' <- traverseClosure convertValue freeEnv binder closure + return $ VPi binder' closure' VLam binder closure -> do binder' <- traverse go binder freeEnv <- getFreeEnv diff --git a/vehicle/src/Vehicle/Backend/LossFunction/ZeroTensorLifting.hs b/vehicle/src/Vehicle/Backend/LossFunction/ZeroTensorLifting.hs index bdd9540dd..181aec91d 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/ZeroTensorLifting.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/ZeroTensorLifting.hs @@ -3,10 +3,10 @@ module Vehicle.Backend.LossFunction.ZeroTensorLifting ) where -import Data.Tuple (swap) import Vehicle.Compile.Context.Name import Vehicle.Compile.Error -import Vehicle.Compile.Normalise.NBE (traverseClosure, traverseClosureGeneric) +import Vehicle.Compile.Normalise.NBE (eval, traverseClosure) +import Vehicle.Compile.Normalise.Quote (Quote (..)) import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyFriendly, prettyVerbose) import Vehicle.Data.Builtin.Loss @@ -37,9 +37,22 @@ liftDecl :: (VType LossTensorBuiltin, Value LossTensorBuiltin) -> m (VType LossTensorBuiltin, Value LossTensorBuiltin) liftDecl (t, e) = case (t, e) of - (VPi piBinder piBody, VLam lamBinder closure) -> do - (newClosure, newPiBody) <- traverseClosureGeneric (\lamBody -> liftDecl (piBody, lamBody)) swap mempty lamBinder closure - return (VPi piBinder newPiBody, VLam lamBinder newClosure) + (VPi piBinder (Closure piEnv piBody), VLam lamBinder (Closure lamEnv lamBody)) -> do + ctx <- getBinderContext + let lv = boundCtxLv ctx + let newPiEnv = extendEnvWithBound lv piBinder piEnv + let newLamEnv = extendEnvWithBound lv lamBinder lamEnv + + (resultPiBody, resultLamBody) <- addNameToContext lamBinder $ do + normPiBody <- eval mempty newPiEnv piBody + normLamBody <- eval mempty newLamEnv lamBody + liftDecl (normPiBody, normLamBody) + + let finalPiBody = quote mempty (lv + 1) resultPiBody + let finalLamBody = quote mempty (lv + 1) resultLamBody + + let finalEnv = boundContextToEnv ctx + return (VPi piBinder (Closure finalEnv finalPiBody), VLam lamBinder (Closure finalEnv finalLamBody)) (typ, body) -> do let (wasZeroDimensional, newType) = liftType typ liftedBody <- liftExpr wasZeroDimensional body diff --git a/vehicle/src/Vehicle/Compile/Arity.hs b/vehicle/src/Vehicle/Compile/Arity.hs index a3b49d45c..0a6eb1491 100644 --- a/vehicle/src/Vehicle/Compile/Arity.hs +++ b/vehicle/src/Vehicle/Compile/Arity.hs @@ -1,7 +1,6 @@ module Vehicle.Compile.Arity where import Vehicle.Data.Code.Expr -import Vehicle.Data.Code.Value import Vehicle.Prelude (isExplicit) type Arity = Int @@ -9,11 +8,6 @@ type Arity = Int class HasArity a where arityOf :: a -> Arity -arityFromVType :: Value builtin -> Arity -arityFromVType = \case - VPi _ r -> 1 + arityFromVType r - _ -> 0 - -- | This is only safe when the type is known to be in normalised type. explicitArityFromType :: Type builtin -> Arity explicitArityFromType = \case diff --git a/vehicle/src/Vehicle/Compile/Descope.hs b/vehicle/src/Vehicle/Compile/Descope.hs index 0537c1022..d79d90489 100644 --- a/vehicle/src/Vehicle/Compile/Descope.hs +++ b/vehicle/src/Vehicle/Compile/Descope.hs @@ -120,9 +120,9 @@ genericDescopeValue f e = case e of var <- S.Var p <$> lvToName f p v args <- traverseArgs (genericDescopeValue f) spine return $ S.normAppList var args - VPi binder body -> do + VPi binder closure -> do binder' <- traverse (genericDescopeValue f) binder - body' <- addNameToContext binder $ genericDescopeValue f body + body' <- addNameToContext binder $ descopeClosure f binder closure return $ S.Pi p binder' body' VLam binder closure -> do binder' <- traverse (genericDescopeValue f) binder diff --git a/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs b/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs index f6903422e..77bb12c73 100644 --- a/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs +++ b/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs @@ -7,6 +7,7 @@ import Control.Monad.Except (MonadError (..)) import Data.Map qualified as Map import Vehicle.Compile.Error import Vehicle.Compile.ExpandResources.Core +import Vehicle.Compile.Normalise.NBE (normaliseClosure) import Vehicle.Compile.Prelude import Vehicle.Compile.Print import Vehicle.Compile.Resource @@ -42,11 +43,12 @@ getNetworkType :: GluedType Builtin -> m NetworkType getNetworkType decl networkType = case normalised networkType of - VPi binder result + VPi binder closure | visibilityOf binder /= Explicit -> typingError | otherwise -> do inputDetails <- getTensorType Input (typeOf binder) - outputDetails <- getTensorType Output result + resultType <- normaliseClosure 0 binder closure + outputDetails <- getTensorType Output resultType let networkDetails = NetworkType inputDetails outputDetails return networkDetails _ -> compilerDeveloperError "Should have caught the fact that the network type is not a function during type-checking" diff --git a/vehicle/src/Vehicle/Compile/Normalise/NBE.hs b/vehicle/src/Vehicle/Compile/Normalise/NBE.hs index a45d584cc..aa3c1c9fc 100644 --- a/vehicle/src/Vehicle/Compile/Normalise/NBE.hs +++ b/vehicle/src/Vehicle/Compile/Normalise/NBE.hs @@ -6,6 +6,7 @@ module Vehicle.Compile.Normalise.NBE normaliseInEmptyEnv, normaliseApp, normaliseBuiltin, + normaliseClosure, eval, evalApp, traverseClosure, @@ -78,10 +79,25 @@ normaliseBuiltin b spine = do freeEnv <- getFreeEnv evalBuiltin freeEnv b spine +normaliseClosure :: + (MonadNorm builtin m, MonadFreeContext builtin m) => + Lv -> + VBinder builtin -> + Closure builtin -> + m (Value builtin) +normaliseClosure lv binder closure = do + freeEnv <- getFreeEnv + evalClosure freeEnv closure (binder, VBoundVar lv []) + ----------------------------------------------------------------------------- -- Evaluation of closures -evalClosure :: (MonadNorm builtin m) => FreeEnv builtin -> Closure builtin -> (VBinder builtin, Value builtin) -> m (Value builtin) +evalClosure :: + (MonadNorm builtin m) => + FreeEnv builtin -> + Closure builtin -> + (VBinder builtin, Value builtin) -> + m (Value builtin) evalClosure freeEnv (Closure env body) (binder, arg) = do let newEnv = extendEnvWithDefined arg binder env eval freeEnv newEnv body @@ -115,9 +131,7 @@ eval freeEnv boundEnv expr = do return $ VLam binder' (Closure boundEnv body) Pi _ binder body -> do binder' <- traverse (eval freeEnv boundEnv) binder - let newBoundEnv = extendEnvWithBound (Lv $ length boundEnv) binder' boundEnv - body' <- eval freeEnv newBoundEnv body - return $ VPi binder' body' + return $ VPi binder' (Closure boundEnv body) Let _ bound binder body -> do binder' <- traverse (eval freeEnv boundEnv) binder boundNormExpr <- eval freeEnv boundEnv bound diff --git a/vehicle/src/Vehicle/Compile/Normalise/Quote.hs b/vehicle/src/Vehicle/Compile/Normalise/Quote.hs index 8904894d2..0d606c2c1 100644 --- a/vehicle/src/Vehicle/Compile/Normalise/Quote.hs +++ b/vehicle/src/Vehicle/Compile/Normalise/Quote.hs @@ -50,9 +50,9 @@ instance (ConvertableBuiltin builtin1 builtin2) => Quote (Value builtin1) (Expr VBuiltin b spine -> do let fn = convertBuiltin p b quoteApp level p fn spine - VPi binder body -> do + VPi binder closure -> do let quotedBinder = quote p level binder - let quotedBody = quote p (level + 1) body + let quotedBody = quoteClosure p level (binder, closure) Pi p quotedBinder quotedBody VLam binder closure -> do let quotedBinder = quote p level binder diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/LinearityAnnotationRestrictions.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/LinearityAnnotationRestrictions.hs index 2d702ea31..f285c9531 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/LinearityAnnotationRestrictions.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/LinearityAnnotationRestrictions.hs @@ -5,7 +5,7 @@ module Vehicle.Compile.Type.Constraint.LinearityAnnotationRestrictions where import Vehicle.Compile.Error -import Vehicle.Compile.Normalise.Quote (Quote (..)) +import Vehicle.Compile.Normalise.Quote (Quote (..), quoteClosure) import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad @@ -22,9 +22,9 @@ checkLinearityNetworkType (ident, p) networkType = case normalised networkType o -- \|Decomposes the Pi types in a network type signature, checking that the -- binders are explicit and their types are equal. Returns a function that -- prepends the max linearity constraint. - VPi binder result -> do + VPi binder closure -> do let inputLin = quote mempty 0 (typeOf binder) - let outputLin = quote mempty 0 result + let outputLin = quoteClosure mempty 0 (binder, closure) -- The linearity of the output of a network is the max of 1) Linear (as outputs -- are also variables) and 2) the linearity of its input. So prepend this diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/LinearitySolver.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/LinearitySolver.hs index 6f4f7cf25..7a1c9ec39 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/LinearitySolver.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/LinearitySolver.hs @@ -5,6 +5,7 @@ where import Data.Maybe (mapMaybe) import Vehicle.Compile.Error +import Vehicle.Compile.Normalise.NBE (normaliseClosure) import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyFriendly) import Vehicle.Compile.Type.Constraint.Core @@ -57,11 +58,12 @@ solve = \case solveQuantifierLinearity :: Quantifier -> LinearitySolver solveQuantifierLinearity _ _ [getNMeta -> Just m, _] = blockOn [m] -solveQuantifierLinearity _ info [VPi binder body, res] = Just $ do +solveQuantifierLinearity _ info@(ctx, _) [VPi binder closure, res] = Just $ do let varName = getBinderName binder let domainLin = VLinearityExpr (Linear (QuantifiedVariableProvenance (provenanceOf binder) varName)) domEq <- createInstanceUnification info (typeOf binder) domainLin - resEq <- createInstanceUnification info res body + resultType <- normaliseClosure (contextDBLevel ctx) binder closure + resEq <- createInstanceUnification info res resultType return $ Progress [domEq, resEq] solveQuantifierLinearity _ _ _ = Nothing diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/PolarityAnnotationRestrictions.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/PolarityAnnotationRestrictions.hs index 39041fa6f..3caf6a13b 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/PolarityAnnotationRestrictions.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/PolarityAnnotationRestrictions.hs @@ -5,7 +5,7 @@ module Vehicle.Compile.Type.Constraint.PolarityAnnotationRestrictions where import Vehicle.Compile.Error -import Vehicle.Compile.Normalise.Quote (Quote (..)) +import Vehicle.Compile.Normalise.Quote (Quote (..), quoteClosure) import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad @@ -22,9 +22,9 @@ checkNetworkType (_, p) networkType = case normalised networkType of -- \|Decomposes the Pi types in a network type signature, checking that the -- binders are explicit and their types are equal. Returns a function that -- prepends the max linearity constraint. - VPi binder result -> do + VPi binder closure -> do let inputPol = quote mempty 0 (typeOf binder) - let outputPol = quote mempty 0 result + let outputPol = quoteClosure mempty 0 (binder, closure) createFreshUnificationConstraint p mempty CheckingAuxiliary (PolarityExpr p Unquantified) inputPol createFreshUnificationConstraint p mempty CheckingAuxiliary (PolarityExpr p Unquantified) outputPol diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/PolaritySolver.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/PolaritySolver.hs index 52971fe42..e34a2701a 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/PolaritySolver.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/PolaritySolver.hs @@ -6,6 +6,7 @@ where import Control.Monad.Except (MonadError (..)) import Data.Maybe (mapMaybe) import Vehicle.Compile.Error +import Vehicle.Compile.Normalise.NBE (normaliseClosure) import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyFriendly) import Vehicle.Compile.Type.Constraint.Core @@ -66,12 +67,14 @@ solveNegPolarity info@(ctx, _) [arg1, res] = case arg1 of solveNegPolarity _ _ = Nothing solveQuantifierPolarity :: Quantifier -> PolaritySolver -solveQuantifierPolarity q info [lam, res] = case lam of +solveQuantifierPolarity q info@(ctx, _) [lam, res] = case lam of (getNMeta -> Just m) -> blockOn [m] (VPi binder resPol) -> Just $ do binderEq <- createInstanceUnification info (typeOf binder) (VPolarityExpr Unquantified) let tc = PolarityRelation $ AddPolarity q - (_, addConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (Arg mempty Explicit Relevant <$> [resPol, res])) + let lv = contextDBLevel ctx + resultPolarity <- normaliseClosure lv binder resPol + (_, addConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (explicit <$> [resultPolarity, res])) return $ Progress [binderEq, addConstraint] _ -> Nothing solveQuantifierPolarity _ _c _ = Nothing @@ -135,10 +138,13 @@ solveFunctionPolarity functionPosition info@(ctx, _) [arg, res] = case (arg, res let pol3 = VPolarityExpr $ mapPolarityProvenance addFuncProv pol resEq <- createInstanceUnification info res pol3 return $ Progress [resEq] - (VPi binder1 body1, VPi binder2 body2) -> Just $ do + (VPi binder1 closure1, VPi binder2 closure2) -> Just $ do let tc = PolarityRelation $ FunctionPolarity functionPosition - (_, binderConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (Arg mempty Explicit Relevant <$> [typeOf binder1, typeOf binder2])) - (_, bodyConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (Arg mempty Explicit Relevant <$> [body1, body2])) + (_, binderConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (explicit <$> [typeOf binder1, typeOf binder2])) + let lv = contextDBLevel ctx + body1 <- normaliseClosure lv binder1 closure1 + body2 <- normaliseClosure lv binder2 closure2 + (_, bodyConstraint) <- createSubInstance info Irrelevant (VBuiltin tc (explicit <$> [body1, body2])) return $ Progress [binderConstraint, bodyConstraint] _ -> Nothing solveFunctionPolarity _ _ _ = Nothing diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/StandardAnnotationRestrictions.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/StandardAnnotationRestrictions.hs index e2f4b5761..971d3d760 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/StandardAnnotationRestrictions.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/StandardAnnotationRestrictions.hs @@ -8,6 +8,7 @@ where import Control.Monad.Except (MonadError (..)) import Vehicle.Compile.Error +import Vehicle.Compile.Normalise.NBE (normaliseClosure) import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Monad.Class import Vehicle.Data.Builtin.Standard @@ -111,12 +112,13 @@ restrictStandardNetworkType decl networkType = case normalised networkType of -- \|Decomposes the Pi types in a network type signature, checking that the -- binders are explicit and their types are equal. Returns a function that -- prepends the max linearity constraint. - VPi binder result + VPi binder closure | visibilityOf binder /= Explicit -> throwError $ NetworkTypeHasNonExplicitArguments decl networkType binder | otherwise -> do checkTensorType Input (typeOf binder) - checkTensorType Output result + resultType <- normaliseClosure 0 binder closure + checkTensorType Output resultType return $ unnormalised networkType _ -> throwError $ NetworkTypeIsNotAFunction decl networkType where diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs index 59e00a815..03e93f11d 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs @@ -156,10 +156,10 @@ unification info@(constraint, _) = \case VBuiltin b1 spine1 :~: VBuiltin b2 spine2 | b1 == b2 -> solveSpine info spine1 spine2 | isConstructor b1 && isConstructor b2 -> return $ HardFailure [constraint] - VPi binder1 body1 :~: VPi binder2 body2 - | visibilityMatches binder1 binder2 -> solvePi info (binder1, body1) (binder2, body2) - VLam binder1 body1 :~: VLam binder2 body2 -> - solveLam info (binder1, body1) (binder2, body2) + VPi binder1 closure1 :~: VPi binder2 closure2 + | visibilityMatches binder1 binder2 -> solveClosure info (binder1, closure1) (binder2, closure2) + VLam binder1 closure1 :~: VLam binder2 closure2 -> + solveClosure info (binder1, closure1) (binder2, closure2) --------------------- -- Flex-flex cases -- --------------------- @@ -205,13 +205,13 @@ solveSpine info@(constraint, _) args1 args2 | length args1 /= length args2 = return $ HardFailure [constraint] | otherwise = mconcat <$> traverse (solveArg info) (zip args1 args2) -solveLam :: +solveClosure :: (MonadUnify builtin m) => ConstraintInfo builtin -> (VBinder builtin, Closure builtin) -> (VBinder builtin, Closure builtin) -> m (UnificationResult builtin) -solveLam info@(WithContext constraint ctx, blockingMeta) (binder1, Closure env1 body1) (binder2, Closure env2 body2) = do +solveClosure info@(WithContext constraint ctx, blockingMeta) (binder1, Closure env1 body1) (binder2, Closure env2 body2) = do -- Unify binder constraints binderConstraint <- subUnify info (typeOf binder1, typeOf binder2) @@ -232,20 +232,6 @@ solveLam info@(WithContext constraint ctx, blockingMeta) (binder1, Closure env1 -- Return the result return $ binderConstraint <> bodyConstraint -solvePi :: - (MonadUnify builtin m) => - ConstraintInfo builtin -> - (VBinder builtin, Value builtin) -> - (VBinder builtin, Value builtin) -> - m (UnificationResult builtin) -solvePi info (binder1, body1) (binder2, body2) = do - -- !!TODO!! Block until binders are solved - -- One possible implementation, blocked metas = set of sets where outer is conjunction and inner is disjunction - -- BOB: this effectively blocks until the binders are solved, because we usually just try to eagerly solve problems - binderConstraint <- subUnify info (typeOf binder1, typeOf binder2) - bodyConstraint <- subUnify info (body1, body2) - return $ binderConstraint <> bodyConstraint - solveFlexFlex :: (MonadUnify builtin m) => ConstraintInfo builtin -> @@ -334,10 +320,10 @@ pruneMetaDependencies ctx (solvingMetaID, solvingMetaSpine) attemptedSolution = VBuiltin b spine -> VBuiltin b <$> traverse (traverse go) spine VBoundVar v spine -> VBoundVar v <$> traverse (traverse go) spine VFreeVar v spine -> VFreeVar v <$> traverse (traverse go) spine - VPi binder result -> VPi <$> traverse go binder <*> go result -- Definitely going to have come back and fix this one later. -- Can't inspect the metas in the environment, as not every variable - -- in the environment will be used? + -- in the environment will be used? But maybe we can? + VPi {} -> return expr -- VPi <$> traverse go binder <*> go result VLam {} -> return expr createMetaWithRestrictedDependencies :: diff --git a/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs b/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs index 12c04118f..2d683413b 100644 --- a/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs +++ b/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs @@ -125,8 +125,11 @@ instance HasMetas (Value builtin) where VBuiltin _ spine -> findMetas spine VFreeVar _ spine -> findMetas spine VBoundVar _ spine -> findMetas spine - VPi binder result -> do findMetas binder; findMetas result - VLam {} -> compilerDeveloperError "Finding metas in lambda not yet supported." + VPi binder closure -> do findMetas binder; findMetas closure + VLam binder closure -> do findMetas binder; findMetas closure + +instance HasMetas (Closure builtin) where + findMetas (Closure env expr) = do findMetas (fmap snd env); findMetas expr instance (HasMetas expr) => HasMetas (GenericArg expr) where findMetas = mapM_ findMetas diff --git a/vehicle/src/Vehicle/Data/Code/Value.hs b/vehicle/src/Vehicle/Data/Code/Value.hs index f88452619..43afbd384 100644 --- a/vehicle/src/Vehicle/Data/Code/Value.hs +++ b/vehicle/src/Vehicle/Data/Code/Value.hs @@ -32,7 +32,7 @@ data Value builtin | VBoundVar !Lv !(Spine builtin) | VBuiltin !builtin !(Spine builtin) | VLam !(VBinder builtin) !(Closure builtin) - | VPi !(VBinder builtin) !(Value builtin) + | VPi !(VBinder builtin) !(Closure builtin) deriving (Eq, Show, Generic) type VType builtin = Value builtin