diff --git a/vehicle-syntax/src/Vehicle/Syntax/Builtin/BasicOperations.hs b/vehicle-syntax/src/Vehicle/Syntax/Builtin/BasicOperations.hs index 78a0fbcc6..6ccd50419 100644 --- a/vehicle-syntax/src/Vehicle/Syntax/Builtin/BasicOperations.hs +++ b/vehicle-syntax/src/Vehicle/Syntax/Builtin/BasicOperations.hs @@ -9,6 +9,7 @@ module Vehicle.Syntax.Builtin.BasicOperations orderOpName, Strictness (..), isStrict, + isForward, flipStrictness, flipOrder, chainable, @@ -112,6 +113,9 @@ orderOpName = \case isStrict :: OrderOp -> Bool isStrict order = order == Lt || order == Gt +isForward :: OrderOp -> Bool +isForward order = order == Lt || order == Le + flipStrictness :: OrderOp -> OrderOp flipStrictness = \case Le -> Lt diff --git a/vehicle/src/Vehicle/Backend/Agda/CapitaliseTypeNames.hs b/vehicle/src/Vehicle/Backend/Agda/CapitaliseTypeNames.hs index 72c0e38b6..237bb9c84 100644 --- a/vehicle/src/Vehicle/Backend/Agda/CapitaliseTypeNames.hs +++ b/vehicle/src/Vehicle/Backend/Agda/CapitaliseTypeNames.hs @@ -17,7 +17,7 @@ import Vehicle.Data.Expr.Interface -- convention. This pass identifies all such defined functions and capitalises -- all references to them. Cannot be done during the main compilation pass as we -- need to be able to distinguish between free and bound variables. -capitaliseTypeNames :: (HasStandardTypes builtin) => Prog var builtin -> Prog var builtin +capitaliseTypeNames :: (BuiltinHasStandardTypes builtin) => Prog var builtin -> Prog var builtin capitaliseTypeNames prog = evalState (cap prog) mempty -------------------------------------------------------------------------------- @@ -28,10 +28,10 @@ type MonadCapitalise m = MonadState (Set Identifier) m class CapitaliseTypes a where cap :: (MonadCapitalise m) => a -> m a -instance (HasStandardTypes builtin) => CapitaliseTypes (Prog var builtin) where +instance (BuiltinHasStandardTypes builtin) => CapitaliseTypes (Prog var builtin) where cap (Main ds) = Main <$> traverse cap ds -instance (HasStandardTypes builtin) => CapitaliseTypes (Decl var builtin) where +instance (BuiltinHasStandardTypes builtin) => CapitaliseTypes (Decl var builtin) where cap d = case d of DefAbstract p ident r t -> DefAbstract p <$> capitaliseIdentifier ident <*> pure r <*> cap t @@ -72,14 +72,14 @@ capitaliseIdentifier ident@(Identifier m s) = do then capitaliseFirstLetter s else s -isTypeDef :: (HasStandardTypes builtin) => Expr var builtin -> Bool +isTypeDef :: (BuiltinHasStandardTypes builtin) => Expr var builtin -> Bool isTypeDef t = case t of -- We don't capitalise things of type `Bool` because they will be lifted -- to the type level, only things of type `X -> Bool`. Pi _ _ result -> go result _ -> False where - go :: (HasStandardTypes builtin) => Expr var builtin -> Bool + go :: (BuiltinHasStandardTypes builtin) => Expr var builtin -> Bool go (IBoolType _) = True go (Pi _ _ res) = go res go _ = False diff --git a/vehicle/src/Vehicle/Backend/LossFunction/Core.hs b/vehicle/src/Vehicle/Backend/LossFunction/Core.hs index f0079a867..87718e007 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/Core.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/Core.hs @@ -8,7 +8,9 @@ import GHC.Generics (Generic) import Vehicle.Compile.Prelude import Vehicle.Data.Builtin.Loss.Core import Vehicle.Data.Builtin.Standard.Core (Builtin) -import Vehicle.Data.Expr.Normalised (BoundEnv, Spine, VBinder, VDecl, Value, WHNFClosure (..)) +import Vehicle.Data.Builtin.Tensor (TensorBuiltin) +import Vehicle.Data.Builtin.Tensor qualified as T +import Vehicle.Data.Expr.Value (BoundEnv, NFValue, Spine, VBinder, VDecl, Value (..), WHNFClosure (..)) import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..)) -------------------------------------------------------------------------------- @@ -17,6 +19,27 @@ import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..)) data LossClosure = LossClosure (BoundEnv MixedClosure LossBuiltin) (Expr Ix LossBuiltin) +-- | Okay, so closures for loss functions are complicated. How compilation +-- currently works is that we first do standard normalisation on the `Builtin` +-- type, then recurse down to the closures converting `Builtin` to `LossBuiltin`. +-- +-- This means that we need the standard closures from the first half. +-- However, when we convert a standard `Builtin` to the equivalent loss +-- expression, we need we apply the translated expression to the previous arguments. +-- +-- subst evalApp +-- e.g _<=_ a b --------> (\x y -> x - y) a b ---------> a - b +-- +-- During the evalApp phase we may need to form closures over expressions that have +-- already been translated to loss functions. This means that we need the loss +-- closure constructor as well. In theory, if all arguments to the builtin are present +-- (i.e. the boolean operator is fully-applied) then these loss closures are only +-- formed temporarily and by the time `evalApp` has finally finished executing +-- they will no longer exist. +-- +-- However, during the intermediary computation of `evalApp` both closures can exist +-- within the same expression. Maybe there's a nicer way of encoding this all in the +-- types that avoids this, but I haven't found it yet. data MixedClosure = StandardClos (WHNFClosure Builtin) | LossClos LossClosure @@ -68,3 +91,6 @@ preservedStdLibOps = Set.fromList [ StdForeachIndex ] + +constRatTensor :: Rational -> NFValue TensorBuiltin +constRatTensor v = VBuiltin (T.ConstRatTensor $ T.convertRat v) [explicit (VBuiltin T.NilList [])] diff --git a/vehicle/src/Vehicle/Backend/LossFunction/Domain.hs b/vehicle/src/Vehicle/Backend/LossFunction/Domain.hs index 828278058..c535f07e2 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/Domain.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/Domain.hs @@ -1,7 +1,15 @@ -{- -module Vehicle.Backend.LossFunction.Domain where +module Vehicle.Backend.LossFunction.Domain + ( -- extractSearchDomain, + -- Domain (..), + ) +where +{- +import Control.Applicative (Applicative (..)) +import Control.Monad (when) import Control.Monad.Except (MonadError (..), runExceptT, void) +import Control.Monad.Reader (MonadReader (..), ReaderT (..)) +import Data.Either (partitionEithers) import Data.Map (Map) import Data.Map qualified as Map import Data.Proxy (Proxy (..)) @@ -11,21 +19,19 @@ import Vehicle.Compile.Boolean.LowerNot (lowerNot) import Vehicle.Compile.Boolean.Unblock (ReduceVectorVars, UnblockingActions (..)) import Vehicle.Compile.Boolean.Unblock qualified as Unblocking import Vehicle.Compile.Context.Bound (MonadBoundContext, getNamedBoundCtx) -import Vehicle.Compile.Error (CompileError (..), MonadCompile) +import Vehicle.Compile.Context.Free (MonadFreeContext) +import Vehicle.Compile.Error (CompileError (..), MonadCompile, unexpectedExprError) import Vehicle.Compile.Normalise.NBE (normaliseInEnv) import Vehicle.Compile.Prelude +import Vehicle.Compile.Print (prettyVerbose) import Vehicle.Compile.Rational.LinearExpr import Vehicle.Compile.Variable (createUserVar) import Vehicle.Data.Builtin.Standard +import Vehicle.Data.Builtin.Tensor (TensorBuiltin (..)) import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised -import Vehicle.Compile.Context.Free (MonadFreeContext) -import Control.Applicative (Applicative(..)) -import Control.Monad.Reader (MonadReader (..), ReaderT (..)) -import Vehicle.Data.QuantifiedVariable -import Control.Monad (when) import Vehicle.Data.Expr.Linear (addExprs, rearrangeExprToSolveFor) -import Data.Either (partitionEithers) +import Vehicle.Data.Expr.Value +import Vehicle.Data.QuantifiedVariable type MonadDomain m = ( MonadCompile m, @@ -34,15 +40,15 @@ type MonadDomain m = ) type MonadSearch m = - ( MonadDomain m - , MonadReader VariableInfo m + ( MonadDomain m, + MonadReader VariableInfo m ) -- | Information for the variable whose domain we are trying to find. data VariableInfo = VariableInfo - { variableLv :: Lv - , vectorExpr :: WHNFValue Builtin - , reducedVars :: [(Lv, UserRationalVariable)] + { variableLv :: Lv, + vectorExpr :: WHNFValue Builtin, + reducedVars :: [(Lv, UserRationalVariable)] } extractSearchDomain :: @@ -51,7 +57,7 @@ extractSearchDomain :: MixedLossBinder -> Lv -> WHNFClosure Builtin -> - m (Maybe (Domain, WHNFValue Builtin)) + m (Domain, WHNFValue Builtin) extractSearchDomain propertyProv binder lv (WHNFClosure env expr) = do -- Convert the binder namedCtx <- getNamedBoundCtx (Proxy @MixedLossValue) @@ -69,24 +75,24 @@ extractSearchDomain propertyProv binder lv (WHNFClosure env expr) = do case maybeConstraints of Nothing -> throwError $ NoQuantifierDomainFound propertyProv (void binder) Nothing Just constraints -> do - let maybeDomain = extractDomainFromConstraints constraints (userTensorVarDimensions userVar) reducedUseVars + let maybeDomain = extractDomainFromConstraints constraints reducedUseVars case maybeDomain of Left missingCostraints -> throwError $ NoQuantifierDomainFound propertyProv (void binder) (Just missingCostraints) - Right domain -> return $ Just (domain, remainder) + Right domain -> return (domain, remainder) -------------------------------------------------------------------------------- -- Constraints data VariableConstraints = VariableConstraints - { lowerBounds :: Map Lv (WHNFValue Builtin), - upperBounds :: Map Lv (WHNFValue Builtin) + { lowerBounds :: Map Lv (NFValue TensorBuiltin), + upperBounds :: Map Lv (NFValue TensorBuiltin) } instance Semigroup VariableConstraints where x <> y = VariableConstraints - { lowerBounds = Map.unionWith IMax (lowerBounds x) (lowerBounds y), - upperBounds = Map.unionWith IMin (upperBounds x) (upperBounds y) + { lowerBounds = Map.unionWith (\u v -> VBuiltin MaxRatTensor (explicit <$> [u, v])) (lowerBounds x) (lowerBounds y), + upperBounds = Map.unionWith (\u v -> VBuiltin MinRatTensor (explicit <$> [u, v])) (upperBounds x) (upperBounds y) } instance Monoid VariableConstraints where @@ -106,16 +112,15 @@ updateConstrainedValue originalExpr = \case -- Domain data Domain = Domain - { lowerBound :: WHNFValue Builtin, - upperBound :: WHNFValue Builtin + { lowerBound :: NFValue TensorBuiltin, + upperBound :: NFValue TensorBuiltin } extractDomainFromConstraints :: VariableConstraints -> - TensorShape -> [(Lv, UserRationalVariable)] -> Either [(UserRationalVariable, UnderConstrainedVariableStatus)] Domain -extractDomainFromConstraints VariableConstraints{..} tensorShape allVariables = do +extractDomainFromConstraints VariableConstraints {..} allVariables = do let lowerBoundExprs = flip map allVariables $ \(lv, var) -> case (Map.lookup lv lowerBounds, Map.lookup lv upperBounds) of (Just x, Just y) -> Right (x, y) @@ -127,9 +132,10 @@ extractDomainFromConstraints VariableConstraints{..} tensorShape allVariables = if not $ null missingVars then Left missingVars else do + let n = length allVariables let (lowerBoundElements, upperBoundElements) = unzip presentVarBounds - let lowerBoundExpr = tensorLikeToExpr id tensorShape lowerBoundElements - let upperBoundExpr = tensorLikeToExpr id tensorShape upperBoundElements + let lowerBoundExpr = VBuiltin (StackRatTensor n) (explicit <$> lowerBoundElements) + let upperBoundExpr = VBuiltin (StackRatTensor n) (explicit <$> upperBoundElements) Right $ Domain lowerBoundExpr upperBoundExpr -------------------------------------------------------------------------------- @@ -194,11 +200,11 @@ tryPurifyAssertion value whenPure = do Right (Left purified) -> updateConstrainedValue value <$> findConstraints purified unblockBoolExpr :: - (MonadDomain m) => + (MonadSearch m) => WHNFValue Builtin -> m ConstrainedValue unblockBoolExpr value = do - ctx <- getNamedBoundCtx (Proxy @Builtin) + ctx <- getNamedBoundCtx (Proxy @MixedLossValue) result <- runExceptT (Unblocking.unblockBoolExpr ctx unblockingActions value) case result of Left {} -> return (Nothing, value) @@ -206,7 +212,9 @@ unblockBoolExpr value = do constrainedValue <- findConstraints unblockedValue return $ updateConstrainedValue value constrainedValue -unblockingActions :: (MonadError (WHNFValue Builtin) m) => UnblockingActions m +unblockingActions :: + (MonadError (WHNFValue Builtin) m, MonadReader VariableInfo m) => + UnblockingActions m unblockingActions = UnblockingActions { unblockFreeVectorVar = unblockFreeVectorVariable, @@ -228,10 +236,11 @@ unblockBoundVectorVariable :: Lv -> m (WHNFValue Builtin) unblockBoundVectorVariable lv = do - VariableInfo{..} <- ask + VariableInfo {..} <- ask when (lv /= variableLv) $ - throwError $ VBoundVar lv [] + throwError $ + VBoundVar lv [] return vectorExpr @@ -239,17 +248,59 @@ unblockBoundVectorVariable lv = do -- Compilation of inequalities handleRatInequality :: - (MonadDomain m) => + (MonadSearch m) => OrderOp -> WHNFValue Builtin -> WHNFValue Builtin -> m ConstrainedValue handleRatInequality op e1 e2 = do - result <- compileRatLinearRelation _ e1 e2 + let e = ISub SubRat e1 e2 + result <- runExceptT (compileRatLinearExpr e) case result of - Left NonLinearity -> return (Nothing, IOrderRat op e1 e2) - Right (le1, le2) -> do - let le = addExprs 1 (-1) le1 le2 - let (_, rearrangedExpr) = rearrangeExprToSolveFor _ le - return _ + Right (Linear a b) -> do + case op of + Le -> return (_, ITrueExpr mempty) + _ -> return (Nothing, IOrderRat op e1 e2) + +data Bound + = Constant (NFValue TensorBuiltin) + | Linear (NFValue TensorBuiltin) (NFValue TensorBuiltin) + +compileRatLinearExpr :: + forall m. + (MonadLogger m, MonadError NonLinearity m) => + WHNFValue Builtin -> + m Bound +compileRatLinearExpr = go + where + go :: WHNFValue Builtin -> m Bound + go e = case e of + ---------------- + -- Base cases -- + ---------------- + IRatLiteral _ l -> return $ Constant _ + VBoundVar lv [] -> return $ Linear _ _ + --------------------- + -- Inductive cases -- + --------------------- + INeg NegRat v -> scaleExpr (-1) <$> go v + IAdd AddRat e1 e2 -> addExprs 1 1 <$> go e1 <*> go e2 + ISub SubRat e1 e2 -> addExprs 1 (-1) <$> go e1 <*> go e2 + IMul MulRat e1 e2 -> do + e1' <- go e1 + e2' <- go e2 + case (e1', e2') of + (Constant c1, Constant c2) -> return $ scaleExpr c1 e2' + (_, Just c2) -> return $ scaleExpr c2 e1' + _ -> throwError NonLinearity + IDiv DivRat e1 e2 -> do + e1' <- go e1 + e2' <- go e2 + case isConstant e2' of + (Just c2) -> return $ scaleExpr (1 / c2) e1' + _ -> throwError NonLinearity + ----------------- + -- Error cases -- + ----------------- + ex -> unexpectedExprError "compile linear rational expression" $ prettyVerbose ex -} diff --git a/vehicle/src/Vehicle/Backend/LossFunction/LogicCompilation.hs b/vehicle/src/Vehicle/Backend/LossFunction/LogicCompilation.hs index 73fdefd96..82bf83462 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/LogicCompilation.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/LogicCompilation.hs @@ -4,7 +4,6 @@ module Vehicle.Backend.LossFunction.LogicCompilation ( compileLogic, convertToLossBuiltins, normStandardExprToLoss, - normLossExprToLoss, MonadLogicCtx, runMonadLogicT, ) @@ -29,7 +28,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyFriendly, prettyVerbose) import Vehicle.Data.Builtin.Loss import Vehicle.Data.Expr.Interface (pattern INot) -import Vehicle.Data.Expr.Normalised (VBinder, Value (..), WHNFBoundEnv, WHNFClosure (..), WHNFValue, boundContextToEnv, extendEnvWithBound, extendEnvWithDefined) +import Vehicle.Data.Expr.Value (VBinder, Value (..), WHNFBoundEnv, WHNFClosure (..), WHNFValue, boundContextToEnv, extendEnvWithBound, extendEnvWithDefined) import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..)) import Vehicle.Syntax.Builtin (Builtin) import Vehicle.Syntax.Builtin qualified as V @@ -194,15 +193,7 @@ normStandardExprToLoss :: m MixedLossValue normStandardExprToLoss boundEnv expr = do standardValue <- normaliseInEnv boundEnv expr - result <- convertToLossBuiltins standardValue - return result - -normLossExprToLoss :: - (MonadLogic m) => - MixedBoundEnv -> - Expr Ix LossBuiltin -> - m MixedLossValue -normLossExprToLoss = eval mempty + convertToLossBuiltins standardValue convertToLossBuiltins :: forall m. @@ -273,7 +264,7 @@ convertBuiltin builtin spine = convert builtin V.LIndex x -> unchangedBuiltin (Index x) V.LNat x -> unchangedBuiltin (Nat x) V.LRat x -> unchangedBuiltin (Rat x) - V.LVec _x -> unchangedBuiltin Vector + V.LVec n -> unchangedBuiltin $ Vector n convertBuiltinFunction :: V.BuiltinFunction -> m MixedLossValue convertBuiltinFunction b = case b of diff --git a/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs b/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs index 2f9299af3..15a73c575 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/TensorCompilation.hs @@ -12,6 +12,7 @@ import Data.Maybe (fromMaybe) import Data.Proxy (Proxy (..)) import Data.Ratio import Vehicle.Backend.LossFunction.Core +-- import Vehicle.Backend.LossFunction.Domain (Domain (..), extractSearchDomain) import Vehicle.Backend.LossFunction.LogicCompilation import Vehicle.Backend.Prelude (DifferentiableLogicID) import Vehicle.Compile.Context.Bound @@ -23,7 +24,7 @@ import Vehicle.Data.Builtin.Loss (LossBuiltin) import Vehicle.Data.Builtin.Loss qualified as L import Vehicle.Data.Builtin.Tensor (TensorBuiltin) import Vehicle.Data.Builtin.Tensor qualified as T -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.Tensor import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (StdForeachIndex), pattern TensorIdent) import Vehicle.Prelude.Warning @@ -119,10 +120,8 @@ convertClosure lv binder closure = case closure of StandardClos (WHNFClosure env standardExpr) -> do let newEnv = extendEnvWithBound lv binder env convertExprToTensorValue newEnv standardExpr - LossClos (LossClosure env lossExpr) -> do - let newEnv = extendEnvWithBound lv binder env - normLossExpr <- switchToMonadLogic $ normLossExprToLoss newEnv lossExpr - convertLossToTensorValue normLossExpr + LossClos (LossClosure _env _lossExpr) -> do + compilerDeveloperError "Impossible" convertBuiltins :: (MonadTensor m) => LossBuiltin -> MixedLossSpine -> m (NFValue TensorBuiltin) convertBuiltins b args = do @@ -144,8 +143,8 @@ convertBuiltins b args = do L.Index i -> VBuiltin (T.Index i) <$> normArgs L.Bool v -> return $ T.VBoolTensor (Tensor [] [v]) L.Nat v -> VBuiltin (T.Nat v) <$> normArgs - L.Rat v -> return $ constRatTensor (T.convertRat v) - L.Vector -> convertVector =<< normArgs + L.Rat v -> return $ constRatTensor v + L.Vector {} -> convertVector =<< normArgs ---------------- -- Operations -- ---------------- @@ -167,17 +166,44 @@ convertBuiltins b args = do L.FoldList -> VBuiltin T.FoldList <$> normArgs L.MapList -> VBuiltin T.MapList <$> normArgs L.ForeachIndex -> convertForeachIndex =<< normArgs - ---------------------- - -- Other operations -- - ---------------------- - L.Search -> do - let op = T.SearchRatTensor - boundCtx <- getNamedBoundCtx (Proxy @MixedLossValue) - let namedCtx = fmap (fromMaybe "") boundCtx - VBuiltin (op namedCtx) <$> normArgs + L.Search -> convertSearch args where unsupportedTypeError op = compilerDeveloperError $ "Conversion of" <+> pretty op <+> "not yet supported" +convertSearch :: (MonadTensor m) => MixedLossSpine -> m (NFValue TensorBuiltin) +convertSearch args = do + boundCtx <- getNamedBoundCtx (Proxy @MixedLossValue) + let namedCtx = fmap (fromMaybe "") boundCtx + VBuiltin (T.SearchRatTensor namedCtx) <$> traverseArgs convertLossToTensorValue args + +{- +case args of + [unionOp, argExpr -> VLam binder (StandardClos closure)] -> do + -- Extract the context + boundCtx <- getNamedBoundCtx (Proxy @MixedLossValue) + let namedCtx = fmap (fromMaybe "") boundCtx + + -- Convert the union operation (for combining search results) and the binder. + tensorUnionOp <- traverse convertLossToTensorValue unionOp + tensorBinder <- traverse convertLossToTensorValue binder + + -- Extract the domain for the search + declProv <- getDeclProvenance + (Domain {..}, newBody) <- extractSearchDomain declProv binder (boundCtxLv boundCtx) closure + let tensorLowerBounds = explicit lowerBound + let tensorUpperBounds = explicit upperBound + + -- Convert the new body of the predicate + lossBody <- switchToMonadLogic $ convertToLossBuiltins newBody + tensorPredicate <- explicit . VLam tensorBinder . NFClosure <$> convertLossToTensorValue lossBody + + -- Extract the domain for the search + let newArgs = [tensorUnionOp, tensorLowerBounds, tensorUpperBounds, tensorPredicate] + + return $ VBuiltin (T.SearchRatTensor namedCtx) newArgs + _ -> unexpectedExprError currentPass (prettyVerbose $ VBuiltin L.Search args) +-} + convertVectorType :: (MonadTensor m) => [NFArg TensorBuiltin] -> m (NFValue TensorBuiltin) convertVectorType = \case [argExpr -> elemType, argExpr -> size] -> do @@ -219,9 +245,6 @@ convertVector args = case args of Just constantTensors -> mk $ stack constantTensors Nothing -> VBuiltin (T.StackRatTensor (length (a : as))) args -constRatTensor :: T.Rat -> NFValue TensorBuiltin -constRatTensor v = VBuiltin (T.ConstRatTensor v) [explicit (VBuiltin T.NilList [])] - extendConstRatTensor :: T.Rat -> NFValue TensorBuiltin -> NFArg TensorBuiltin -> NFValue TensorBuiltin extendConstRatTensor x dim dims = do let newDims = VBuiltin T.ConsList [explicit dim, dims] diff --git a/vehicle/src/Vehicle/Backend/Queries.hs b/vehicle/src/Vehicle/Backend/Queries.hs index 4942728cf..6a1b46322 100644 --- a/vehicle/src/Vehicle/Backend/Queries.hs +++ b/vehicle/src/Vehicle/Backend/Queries.hs @@ -24,7 +24,7 @@ import Vehicle.Compile.Print.Warning () import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Boolean import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Prelude.Warning (CompileWarning (..)) import Vehicle.Verify.Core import Vehicle.Verify.QueryFormat diff --git a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs index 50ab0278b..720154e6e 100644 --- a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs +++ b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination.hs @@ -33,7 +33,7 @@ import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Boolean import Vehicle.Data.Expr.Interface import Vehicle.Data.Expr.Linear (LinearExpr) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.QuantifiedVariable import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (StdEqualsVector, StdNotEqualsVector)) import Vehicle.Verify.Core (NetworkContextInfo (..), QuerySetNegationStatus) diff --git a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs index d3867c996..143336b73 100644 --- a/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs +++ b/vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs @@ -25,7 +25,7 @@ import Vehicle.Compile.Resource (NetworkType (..), dimensions) import Vehicle.Data.Expr.Boolean import Vehicle.Data.Expr.Interface import Vehicle.Data.Expr.Linear -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.Hashing () import Vehicle.Data.QuantifiedVariable import Vehicle.Data.Tensor diff --git a/vehicle/src/Vehicle/Compile/Arity.hs b/vehicle/src/Vehicle/Compile/Arity.hs index d5f91251b..57d5b829f 100644 --- a/vehicle/src/Vehicle/Compile/Arity.hs +++ b/vehicle/src/Vehicle/Compile/Arity.hs @@ -1,6 +1,6 @@ module Vehicle.Compile.Arity where -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.AST type Arity = Int diff --git a/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs b/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs index 359ad8bc0..ac68df7dc 100644 --- a/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs +++ b/vehicle/src/Vehicle/Compile/Boolean/LiftIf.hs @@ -11,7 +11,7 @@ import Vehicle.Compile.Normalise.NBE import Vehicle.Compile.Prelude import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- If lifting diff --git a/vehicle/src/Vehicle/Compile/Boolean/LowerNot.hs b/vehicle/src/Vehicle/Compile/Boolean/LowerNot.hs index 89141ffbc..6f1181265 100644 --- a/vehicle/src/Vehicle/Compile/Boolean/LowerNot.hs +++ b/vehicle/src/Vehicle/Compile/Boolean/LowerNot.hs @@ -9,7 +9,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyVerbose) import Vehicle.Data.Builtin.Standard () import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (StdNotBoolOp2)) import Vehicle.Syntax.Builtin diff --git a/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs b/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs index b53f625ac..c25dfc4c8 100644 --- a/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs +++ b/vehicle/src/Vehicle/Compile/Boolean/Unblock.hs @@ -15,7 +15,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyFriendly, prettyVerbose) import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Libraries.StandardLibrary.Definitions -------------------------------------------------------------------------------- diff --git a/vehicle/src/Vehicle/Compile/Context/Bound/Class.hs b/vehicle/src/Vehicle/Compile/Context/Bound/Class.hs index 9dabd2a53..019592894 100644 --- a/vehicle/src/Vehicle/Compile/Context/Bound/Class.hs +++ b/vehicle/src/Vehicle/Compile/Context/Bound/Class.hs @@ -8,7 +8,7 @@ import Vehicle.Compile.Context.Bound.Core import Vehicle.Compile.Error (MonadCompile, lookupIxInBoundCtx, lookupLvInBoundCtx) import Vehicle.Compile.Normalise.Quote qualified as Quote (unnormalise) import Vehicle.Compile.Prelude -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Context monad class diff --git a/vehicle/src/Vehicle/Compile/Context/Free.hs b/vehicle/src/Vehicle/Compile/Context/Free.hs index 480df609c..08d5c5a5f 100644 --- a/vehicle/src/Vehicle/Compile/Context/Free.hs +++ b/vehicle/src/Vehicle/Compile/Context/Free.hs @@ -13,7 +13,7 @@ import Vehicle.Compile.Context.Free.Instance as X import Vehicle.Compile.Normalise.Builtin (NormalisableBuiltin) import Vehicle.Compile.Normalise.NBE import Vehicle.Compile.Prelude -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Libraries.StandardLibrary.Definitions appHiddenStdlibDef :: diff --git a/vehicle/src/Vehicle/Compile/Context/Free/Class.hs b/vehicle/src/Vehicle/Compile/Context/Free/Class.hs index 7d236ea2b..f4fb64996 100644 --- a/vehicle/src/Vehicle/Compile/Context/Free/Class.hs +++ b/vehicle/src/Vehicle/Compile/Context/Free/Class.hs @@ -1,5 +1,6 @@ module Vehicle.Compile.Context.Free.Class where +import Control.Monad.Except (ExceptT, mapExceptT) import Control.Monad.Identity (IdentityT, mapIdentityT) import Control.Monad.Reader (ReaderT (..), mapReaderT) import Control.Monad.State (StateT (..), mapStateT) @@ -11,7 +12,7 @@ import Vehicle.Compile.Context.Free.Core import Vehicle.Compile.Error (MonadCompile, lookupInFreeCtx) import Vehicle.Compile.Prelude import Vehicle.Compile.Print (PrintableBuiltin) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Libraries.StandardLibrary.Definitions -------------------------------------------------------------------------------- @@ -69,6 +70,12 @@ instance (MonadFreeContext builtin m) => MonadFreeContext builtin (SupplyT s m) getHiddenStdLibDecl p = lift . getHiddenStdLibDecl p hideStdLibDecls p = mapSupplyT . hideStdLibDecls p +instance (MonadFreeContext builtin m) => MonadFreeContext builtin (ExceptT s m) where + addDeclEntryToContext = mapExceptT . addDeclEntryToContext + getFreeCtx = lift . getFreeCtx + getHiddenStdLibDecl p = lift . getHiddenStdLibDecl p + hideStdLibDecls p = mapExceptT . hideStdLibDecls p + -------------------------------------------------------------------------------- -- Operations diff --git a/vehicle/src/Vehicle/Compile/Context/Free/Core.hs b/vehicle/src/Vehicle/Compile/Context/Free/Core.hs index 6d4861684..eb59af307 100644 --- a/vehicle/src/Vehicle/Compile/Context/Free/Core.hs +++ b/vehicle/src/Vehicle/Compile/Context/Free/Core.hs @@ -2,7 +2,7 @@ module Vehicle.Compile.Context.Free.Core where import Data.Map (Map) import Vehicle.Data.DeBruijn (Ix) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.AST -- | Stores information associated with the declarations that are currently in diff --git a/vehicle/src/Vehicle/Compile/Descope.hs b/vehicle/src/Vehicle/Compile/Descope.hs index 89db96765..d54e7e66f 100644 --- a/vehicle/src/Vehicle/Compile/Descope.hs +++ b/vehicle/src/Vehicle/Compile/Descope.hs @@ -16,9 +16,9 @@ import Vehicle.Backend.LossFunction.Core (LossClosure (..), MixedClosure (..)) import Vehicle.Compile.Prelude import Vehicle.Data.Builtin.Interface (ConvertableBuiltin, convertBuiltin) import Vehicle.Data.Builtin.Standard.Core (Builtin) -import Vehicle.Data.Expr.Normalised import Vehicle.Data.Expr.Relevant (RelBinder, RelExpr) import Vehicle.Data.Expr.Relevant qualified as R +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Interface diff --git a/vehicle/src/Vehicle/Compile/Error.hs b/vehicle/src/Vehicle/Compile/Error.hs index 0c07765c1..32c58c62c 100644 --- a/vehicle/src/Vehicle/Compile/Error.hs +++ b/vehicle/src/Vehicle/Compile/Error.hs @@ -12,13 +12,13 @@ import Vehicle.Backend.LossFunction.Core (DifferentiableLogicField) import Vehicle.Backend.Prelude import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Core -import Vehicle.Data.Builtin.Interface (HasStandardData, PrintableBuiltin) +import Vehicle.Data.Builtin.Interface (BuiltinHasStandardData, PrintableBuiltin) import Vehicle.Data.Builtin.Linearity.Core import Vehicle.Data.Builtin.Polarity.Core import Vehicle.Data.Builtin.Standard.Core import Vehicle.Data.Builtin.Tensor import Vehicle.Data.DeBruijn -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.QuantifiedVariable (UnderConstrainedVariableStatus, UserRationalVariable) import Vehicle.Syntax.Parse (ParseError, ParseLocation) import Vehicle.Verify.QueryFormat.Core @@ -46,7 +46,7 @@ data CompileError | -- Type checking errors UnresolvedHole Provenance Name | forall builtin. - (PrintableBuiltin builtin, Show builtin, HasStandardData builtin) => + (PrintableBuiltin builtin, Show builtin, BuiltinHasStandardData builtin) => TypingError (TypingError builtin) | UnsolvedMetas (NonEmpty (MetaID, Provenance)) | RelevantUseOfIrrelevantVariable Provenance Name diff --git a/vehicle/src/Vehicle/Compile/EtaConversion.hs b/vehicle/src/Vehicle/Compile/EtaConversion.hs index 5488fba86..9e81f4015 100644 --- a/vehicle/src/Vehicle/Compile/EtaConversion.hs +++ b/vehicle/src/Vehicle/Compile/EtaConversion.hs @@ -15,11 +15,11 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyVerbose) import Vehicle.Data.Builtin.Interface import Vehicle.Data.DeBruijn (liftDBIndices) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value etaExpandProg :: forall m builtin. - (MonadCompile m, PrintableBuiltin builtin, HasStandardData builtin) => + (MonadCompile m, PrintableBuiltin builtin, BuiltinHasStandardData builtin) => Prog Ix builtin -> m (Prog Ix builtin) etaExpandProg (Main ds) = diff --git a/vehicle/src/Vehicle/Compile/ExpandResources.hs b/vehicle/src/Vehicle/Compile/ExpandResources.hs index 56d9f35f4..43c178c7d 100644 --- a/vehicle/src/Vehicle/Compile/ExpandResources.hs +++ b/vehicle/src/Vehicle/Compile/ExpandResources.hs @@ -25,7 +25,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print.Warning () import Vehicle.Data.Builtin.Standard.Core import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Prelude.Warning (CompileWarning (..)) -- | Calculates the context for external resources, reading them from disk and diff --git a/vehicle/src/Vehicle/Compile/ExpandResources/Core.hs b/vehicle/src/Vehicle/Compile/ExpandResources/Core.hs index e2c6c3dd6..1c01aed5e 100644 --- a/vehicle/src/Vehicle/Compile/ExpandResources/Core.hs +++ b/vehicle/src/Vehicle/Compile/ExpandResources/Core.hs @@ -8,7 +8,7 @@ import Data.Map qualified as Map import Vehicle.Compile.Context.Free (MonadFreeContext) import Vehicle.Compile.Error import Vehicle.Compile.Prelude -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin (Builtin) import Vehicle.Verify.Core diff --git a/vehicle/src/Vehicle/Compile/ExpandResources/Dataset.hs b/vehicle/src/Vehicle/Compile/ExpandResources/Dataset.hs index e95c681fc..85e1209e7 100644 --- a/vehicle/src/Vehicle/Compile/ExpandResources/Dataset.hs +++ b/vehicle/src/Vehicle/Compile/ExpandResources/Dataset.hs @@ -12,7 +12,7 @@ import Vehicle.Compile.ExpandResources.Core import Vehicle.Compile.ExpandResources.Dataset.IDX (readIDX) import Vehicle.Compile.Prelude import Vehicle.Data.Builtin.Standard -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Dataset parsing diff --git a/vehicle/src/Vehicle/Compile/ExpandResources/Dataset/IDX.hs b/vehicle/src/Vehicle/Compile/ExpandResources/Dataset/IDX.hs index 77b8a2f54..6512256a7 100644 --- a/vehicle/src/Vehicle/Compile/ExpandResources/Dataset/IDX.hs +++ b/vehicle/src/Vehicle/Compile/ExpandResources/Dataset/IDX.hs @@ -23,7 +23,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -- | Reads the IDX dataset from the provided file, checking that the user type -- matches the type of the stored data. diff --git a/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs b/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs index 05b886a64..a4c0e23cc 100644 --- a/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs +++ b/vehicle/src/Vehicle/Compile/ExpandResources/Network.hs @@ -12,7 +12,7 @@ import Vehicle.Compile.Print import Vehicle.Compile.Resource import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Verify.Core (NetworkContextInfo (..)) -------------------------------------------------------------------------------- diff --git a/vehicle/src/Vehicle/Compile/ExpandResources/Parameter.hs b/vehicle/src/Vehicle/Compile/ExpandResources/Parameter.hs index 5795f84d4..c023bc385 100644 --- a/vehicle/src/Vehicle/Compile/ExpandResources/Parameter.hs +++ b/vehicle/src/Vehicle/Compile/ExpandResources/Parameter.hs @@ -14,7 +14,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Parameter parsing diff --git a/vehicle/src/Vehicle/Compile/Monomorphisation.hs b/vehicle/src/Vehicle/Compile/Monomorphisation.hs index def032e6d..8431e79a8 100644 --- a/vehicle/src/Vehicle/Compile/Monomorphisation.hs +++ b/vehicle/src/Vehicle/Compile/Monomorphisation.hs @@ -54,7 +54,7 @@ import Vehicle.Libraries.StandardLibrary.Definitions -- (e.g. naturals, rationals and tensors) monomorphise :: forall m builtin. - (MonadCompile m, Hashable builtin, PrintableBuiltin builtin, HasStandardData builtin) => + (MonadCompile m, Hashable builtin, PrintableBuiltin builtin, BuiltinHasStandardData builtin) => (Decl Ix builtin -> Bool) -> Text -> Prog Ix builtin -> diff --git a/vehicle/src/Vehicle/Compile/Normalise/Builtin.hs b/vehicle/src/Vehicle/Compile/Normalise/Builtin.hs index 7b735e780..e210ace5d 100644 --- a/vehicle/src/Vehicle/Compile/Normalise/Builtin.hs +++ b/vehicle/src/Vehicle/Compile/Normalise/Builtin.hs @@ -9,9 +9,9 @@ import Data.Foldable (foldrM) import Vehicle.Compile.Error (MonadCompile) import Vehicle.Compile.Prelude import Vehicle.Compile.Print (PrintableBuiltin) -import Vehicle.Data.Builtin.Interface (HasStandardData (..)) +import Vehicle.Data.Builtin.Interface (BuiltinHasStandardData (..)) import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin -- Okay so the important thing to remember about this module is that we have @@ -65,7 +65,7 @@ class (PrintableBuiltin builtin) => NormalisableBuiltin builtin where m (Value closure builtin) evalTypeClassOp :: - (MonadLogger m, HasStandardData builtin) => + (MonadLogger m, BuiltinHasStandardData builtin, Show builtin) => (BuiltinFunction -> EvalBuiltin (Value closure builtin) m) -> EvalApp (Value closure builtin) m -> Value closure builtin -> @@ -149,7 +149,7 @@ evalIf originalExpr = \case _ -> originalExpr -- TODO define in terms of language. The problem is the polarity checking... -evalImplies :: (HasStandardDataExpr expr, HasBoolLits expr) => EvalSimpleBuiltin expr +evalImplies :: (HasStandardData expr, HasBoolLits expr) => EvalSimpleBuiltin expr evalImplies originalExpr = \case [e1, e2] -> do let ne1 = evalNot (INot (argExpr e1)) [e1] diff --git a/vehicle/src/Vehicle/Compile/Normalise/NBE.hs b/vehicle/src/Vehicle/Compile/Normalise/NBE.hs index 658fcb3e9..0981b352f 100644 --- a/vehicle/src/Vehicle/Compile/Normalise/NBE.hs +++ b/vehicle/src/Vehicle/Compile/Normalise/NBE.hs @@ -22,7 +22,7 @@ import Vehicle.Compile.Normalise.Builtin (NormalisableBuiltin (..), filterOutIrr import Vehicle.Compile.Prelude import Vehicle.Compile.Print import Vehicle.Data.Builtin.Standard.Core (Builtin) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -- import Control.Monad (when) @@ -213,27 +213,42 @@ currentPass :: Doc () currentPass = "normalisation by evaluation" showEntry :: (MonadNorm closure builtin m) => BoundEnv closure builtin -> Expr Ix builtin -> m () -showEntry _boundEnv _expr = do +showEntry _ _ = return () + +showExit :: (MonadNorm closure builtin m) => BoundEnv closure builtin -> Value closure builtin -> m () +showExit _ _ = return () + +{- +showEntry :: (MonadNorm closure builtin m) => BoundEnv closure builtin -> Expr Ix builtin -> m () +showEntry _boundEnv expr = do -- logDebug MidDetail $ "nbe-entry" <+> prettyFriendly (WithContext expr (fmap fst boundEnv)) <+> " { boundEnv=" <+> hang 0 (prettyVerbose boundEnv) <+> "}" - -- logDebug MidDetail $ "nbe-entry" <+> prettyVerbose expr -- <+> " { boundEnv=" <+> prettyVerbose boundEnv <+> "}" - -- incrCallDepth + logDebug MidDetail $ "nbe-entry" <+> prettyVerbose expr -- <+> " { boundEnv=" <+> prettyVerbose boundEnv <+> "}" + incrCallDepth return () showExit :: (MonadNorm closure builtin m) => BoundEnv closure builtin -> Value closure builtin -> m () -showExit _boundEnv _result = do - -- decrCallDepth - -- logDebug MidDetail $ "nbe-exit" <+> prettyVerbose result +showExit _boundEnv result = do + decrCallDepth + logDebug MidDetail $ "nbe-exit" <+> prettyVerbose result -- logDebug MidDetail $ "nbe-exit" <+> prettyFriendly (WithContext result (fmap fst boundEnv)) return () +-} +showApp :: (MonadNorm closure builtin m) => Value closure builtin -> Spine closure builtin -> m () +showApp _ _ = return () + +showAppExit :: (MonadNorm closure builtin m) => Value closure builtin -> m () +showAppExit _ = return () +{- showApp :: (MonadNorm closure builtin m) => Value closure builtin -> Spine closure builtin -> m () -showApp _fun _spine = do - -- logDebug MaxDetail $ "nbe-app:" <+> prettyVerbose fun <+> "@" <+> prettyVerbose spine - -- incrCallDepth +showApp fun spine = do + logDebug MaxDetail $ "nbe-app:" <+> prettyVerbose fun <+> "@" <+> prettyVerbose spine + incrCallDepth return () showAppExit :: (MonadNorm closure builtin m) => Value closure builtin -> m () -showAppExit _result = do - -- decrCallDepth - -- logDebug MaxDetail $ "nbe-app-exit:" <+> prettyVerbose result +showAppExit result = do + decrCallDepth + logDebug MaxDetail $ "nbe-app-exit:" <+> prettyVerbose result return () +-} diff --git a/vehicle/src/Vehicle/Compile/Normalise/Quote.hs b/vehicle/src/Vehicle/Compile/Normalise/Quote.hs index 0b48bc804..74867d1ac 100644 --- a/vehicle/src/Vehicle/Compile/Normalise/Quote.hs +++ b/vehicle/src/Vehicle/Compile/Normalise/Quote.hs @@ -5,7 +5,7 @@ import Vehicle.Compile.Prelude import Vehicle.Data.Builtin.Interface (ConvertableBuiltin (..)) import Vehicle.Data.Builtin.Standard.Core (Builtin) import Vehicle.Data.DeBruijn -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -- | Converts from a normalised representation to an unnormalised representation. -- Do not call except for logging and debug purposes, very expensive with nested diff --git a/vehicle/src/Vehicle/Compile/Print.hs b/vehicle/src/Vehicle/Compile/Print.hs index d457d6e50..797819f99 100644 --- a/vehicle/src/Vehicle/Compile/Print.hs +++ b/vehicle/src/Vehicle/Compile/Print.hs @@ -34,7 +34,7 @@ import Vehicle.Compile.Type.Meta.Map (MetaMap (..)) import Vehicle.Data.Builtin.Interface import Vehicle.Data.Builtin.Standard.Core import Vehicle.Data.Expr.Boolean -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Print -------------------------------------------------------------------------------- diff --git a/vehicle/src/Vehicle/Compile/Print/Error.hs b/vehicle/src/Vehicle/Compile/Print/Error.hs index 7fe86916d..437f23ba8 100644 --- a/vehicle/src/Vehicle/Compile/Print/Error.hs +++ b/vehicle/src/Vehicle/Compile/Print/Error.hs @@ -23,7 +23,7 @@ import Vehicle.Data.Builtin.Standard.Core import Vehicle.Data.DSL import Vehicle.Data.DeBruijn (substDBInto) import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.QuantifiedVariable (prettyUnderConstrainedVariables) import Vehicle.Libraries.StandardLibrary.Definitions (pattern TensorIdent) import Vehicle.Syntax.Parse (ParseError (..)) @@ -1234,7 +1234,7 @@ datasetDimensionsFix feature ident file = <+> "is in the format you were expecting." unsupportedAnnotationTypeDescription :: - (PrintableBuiltin builtin) => + (Eq builtin, PrintableBuiltin builtin) => Doc a -> Identifier -> GluedType builtin -> diff --git a/vehicle/src/Vehicle/Compile/Rational/LinearExpr.hs b/vehicle/src/Vehicle/Compile/Rational/LinearExpr.hs index f6a0a216b..46d2bd429 100644 --- a/vehicle/src/Vehicle/Compile/Rational/LinearExpr.hs +++ b/vehicle/src/Vehicle/Compile/Rational/LinearExpr.hs @@ -16,7 +16,7 @@ import Vehicle.Compile.Print (prettyVerbose) import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface import Vehicle.Data.Expr.Linear (LinearExpr, addExprs, constantExpr, isConstant, scaleExpr, singletonVarExpr) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.QuantifiedVariable import Vehicle.Data.Tensor (RationalTensor, Tensor (..), zeroTensor) import Prelude hiding (Applicative (..)) @@ -28,6 +28,9 @@ type MonadCompileLinearExpr m = data NonLinearity = NonLinearity +-------------------------------------------------------------------------------- +-- Rational expression + compileRatLinearRelation :: (MonadLogger m) => (Lv -> ExceptT NonLinearity m RationalVariable) -> diff --git a/vehicle/src/Vehicle/Compile/Type.hs b/vehicle/src/Vehicle/Compile/Type.hs index 397189927..bd3dcba26 100644 --- a/vehicle/src/Vehicle/Compile/Type.hs +++ b/vehicle/src/Vehicle/Compile/Type.hs @@ -25,7 +25,7 @@ import Vehicle.Compile.Type.Meta.Set qualified as MetaSet import Vehicle.Compile.Type.Monad import Vehicle.Compile.Type.Monad.Class import Vehicle.Data.Builtin.Standard.Core -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value ------------------------------------------------------------------------------- -- Algorithm diff --git a/vehicle/src/Vehicle/Compile/Type/Bidirectional.hs b/vehicle/src/Vehicle/Compile/Type/Bidirectional.hs index 51c80c398..109f215ae 100644 --- a/vehicle/src/Vehicle/Compile/Type/Bidirectional.hs +++ b/vehicle/src/Vehicle/Compile/Type/Bidirectional.hs @@ -19,7 +19,7 @@ import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad import Vehicle.Data.Builtin.Interface (TypableBuiltin (..)) import Vehicle.Data.DeBruijn -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Prelude hiding (pi) -------------------------------------------------------------------------------- diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/Core.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/Core.hs index 424773853..5ccca10b2 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/Core.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/Core.hs @@ -24,7 +24,7 @@ import Vehicle.Compile.Type.Meta (MetaSet) import Vehicle.Compile.Type.Meta.Set qualified as MetaSet import Vehicle.Compile.Type.Monad (MonadTypeChecker, TCM, copyContext, freshMetaIdAndExpr, trackSolvedMetas) import Vehicle.Data.DSL -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -- | Attempts to solve as many constraints as possible. Takes in -- the set of meta-variables solved since the solver was last run and outputs diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/IndexSolver.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/IndexSolver.hs index 306633093..1dd70946f 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/IndexSolver.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/IndexSolver.hs @@ -11,7 +11,7 @@ import Vehicle.Compile.Type.Meta (MetaSet) import Vehicle.Compile.Type.Meta.Set qualified as MetaSet import Vehicle.Compile.Type.Monad import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin solveIndexConstraint :: diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceDefaultSolver.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceDefaultSolver.hs index 665e245ff..0324d98ff 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceDefaultSolver.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceDefaultSolver.hs @@ -16,7 +16,7 @@ import Vehicle.Compile.Type.Meta.Set qualified as MetaSet import Vehicle.Compile.Type.Meta.Variable import Vehicle.Compile.Type.Monad import Vehicle.Data.Builtin.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value class HasInstanceDefaults builtin where getCandidatesFromConstraint :: diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceSolver.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceSolver.hs index 259a51efd..bae16580e 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceSolver.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/InstanceSolver.hs @@ -19,7 +19,7 @@ import Vehicle.Compile.Type.Constraint.Core import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad import Vehicle.Data.DeBruijn (dbLevelToIndex, substDBInto) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Public interface diff --git a/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs b/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs index 674af33c5..b6a28db8b 100644 --- a/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs +++ b/vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs @@ -27,7 +27,7 @@ import Vehicle.Compile.Type.Meta.Set qualified as MetaSet (null, singleton) import Vehicle.Compile.Type.Monad import Vehicle.Compile.Type.Monad.Class import Vehicle.Data.DeBruijn -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Unification solver diff --git a/vehicle/src/Vehicle/Compile/Type/Core.hs b/vehicle/src/Vehicle/Compile/Type/Core.hs index 5dca1b187..bb3a8d82e 100644 --- a/vehicle/src/Vehicle/Compile/Type/Core.hs +++ b/vehicle/src/Vehicle/Compile/Type/Core.hs @@ -9,7 +9,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Meta.Map (MetaMap (..)) import Vehicle.Compile.Type.Meta.Set (MetaSet) import Vehicle.Compile.Type.Meta.Set qualified as MetaSet -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Typing errors diff --git a/vehicle/src/Vehicle/Compile/Type/Force.hs b/vehicle/src/Vehicle/Compile/Type/Force.hs index bd2db3c22..f36eeda46 100644 --- a/vehicle/src/Vehicle/Compile/Type/Force.hs +++ b/vehicle/src/Vehicle/Compile/Type/Force.hs @@ -13,14 +13,14 @@ import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Meta (MetaSet) import Vehicle.Compile.Type.Meta.Map qualified as MetaMap (lookup) import Vehicle.Compile.Type.Meta.Set qualified as MetaSet (singleton, unions) -import Vehicle.Data.Builtin.Interface (HasStandardData (getBuiltinFunction)) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Builtin.Interface (BuiltinHasStandardData (getBuiltinFunction)) +import Vehicle.Data.Expr.Value ----------------------------------------------------------------------------- -- Meta-variable forcing type ForcableBuiltin builtin = - (HasStandardData builtin, NormalisableBuiltin builtin) + (BuiltinHasStandardData builtin, NormalisableBuiltin builtin) -- | Recursively forces the evaluation of any meta-variables at the head -- of the expresson. diff --git a/vehicle/src/Vehicle/Compile/Type/Irrelevance.hs b/vehicle/src/Vehicle/Compile/Type/Irrelevance.hs index f1eabe623..47e5bdf9c 100644 --- a/vehicle/src/Vehicle/Compile/Type/Irrelevance.hs +++ b/vehicle/src/Vehicle/Compile/Type/Irrelevance.hs @@ -16,7 +16,7 @@ import Vehicle.Data.Expr.Interface -- | Removes all irrelevant code from the program/expression. removeIrrelevantCodeFromProg :: - (MonadCompile m, HasStandardData builtin, PrintableBuiltin builtin) => + (MonadCompile m, BuiltinHasStandardData builtin, PrintableBuiltin builtin) => Prog Ix builtin -> m (Prog Ix builtin) removeIrrelevantCodeFromProg x = do @@ -44,7 +44,7 @@ instance (RemoveIrrelevantCode m expr) => RemoveIrrelevantCode m (GenericProg ex instance (RemoveIrrelevantCode m expr) => RemoveIrrelevantCode m (GenericDecl expr) where remove = traverse remove -instance (HasStandardData builtin) => RemoveIrrelevantCode m (Expr Ix builtin) where +instance (BuiltinHasStandardData builtin) => RemoveIrrelevantCode m (Expr Ix builtin) where remove expr = do -- showRemoveEntry expr result <- case expr of diff --git a/vehicle/src/Vehicle/Compile/Type/Meta/Substitution.hs b/vehicle/src/Vehicle/Compile/Type/Meta/Substitution.hs index 2f0b14bc7..f10b2cf08 100644 --- a/vehicle/src/Vehicle/Compile/Type/Meta/Substitution.hs +++ b/vehicle/src/Vehicle/Compile/Type/Meta/Substitution.hs @@ -13,7 +13,7 @@ import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Meta.Map (MetaMap (..)) import Vehicle.Compile.Type.Meta.Map qualified as MetaMap import Vehicle.Compile.Type.Meta.Variable (MetaInfo (..)) -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Substitution operation diff --git a/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs b/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs index 7fcfad3c8..1445089a6 100644 --- a/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs +++ b/vehicle/src/Vehicle/Compile/Type/Meta/Variable.hs @@ -19,7 +19,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Meta.Set (MetaSet) import Vehicle.Compile.Type.Meta.Set qualified as MetaSet -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -- Eventually when metas make into the builtins, this should module -- should also contain the definition of meta-variables themselves. diff --git a/vehicle/src/Vehicle/Compile/Type/Monad.hs b/vehicle/src/Vehicle/Compile/Type/Monad.hs index d2e66f360..647a7fa81 100644 --- a/vehicle/src/Vehicle/Compile/Type/Monad.hs +++ b/vehicle/src/Vehicle/Compile/Type/Monad.hs @@ -50,7 +50,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad.Class import Vehicle.Compile.Type.Monad.Instance -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -- | The type-checking monad. type TCM builtin m = diff --git a/vehicle/src/Vehicle/Compile/Type/Monad/Class.hs b/vehicle/src/Vehicle/Compile/Type/Monad/Class.hs index 43834f331..c6bb735a9 100644 --- a/vehicle/src/Vehicle/Compile/Type/Monad/Class.hs +++ b/vehicle/src/Vehicle/Compile/Type/Monad/Class.hs @@ -27,7 +27,7 @@ import Vehicle.Compile.Type.Meta.Set qualified as MetaSet import Vehicle.Compile.Type.Meta.Substitution as MetaSubstitution (MetaSubstitutable (..)) import Vehicle.Data.Builtin.Interface import Vehicle.Data.Builtin.Standard.Core -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Solved meta-state @@ -132,7 +132,7 @@ instance (MonadTypeChecker builtin m) => MonadTypeChecker builtin (BoundContextT -- Abstract interface for a type system. -- | A class that provides an abstract interface for a set of builtins. -class (HasStandardData builtin, TypableBuiltin builtin) => HasTypeSystem builtin where +class (Eq builtin, BuiltinHasStandardData builtin, TypableBuiltin builtin) => HasTypeSystem builtin where convertFromStandardBuiltins :: (MonadTypeChecker builtin m) => BuiltinUpdate m Ix Builtin builtin diff --git a/vehicle/src/Vehicle/Compile/Type/Monad/Instance.hs b/vehicle/src/Vehicle/Compile/Type/Monad/Instance.hs index dd4b8bc27..c013cd8e2 100644 --- a/vehicle/src/Vehicle/Compile/Type/Monad/Instance.hs +++ b/vehicle/src/Vehicle/Compile/Type/Monad/Instance.hs @@ -24,7 +24,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad.Class -import Vehicle.Data.Builtin.Interface (HasStandardData) +import Vehicle.Data.Builtin.Interface (BuiltinHasStandardData) -------------------------------------------------------------------------------- -- Implementation @@ -74,13 +74,13 @@ mapTypeCheckerT f m = TypeCheckerT (mapFreeContextT (mapReaderT (mapStateT f)) ( -------------------------------------------------------------------------------- -- Instances that TypeCheckerT satisfies -instance (PrintableBuiltin builtin, HasStandardData builtin, MonadCompile m) => MonadFreeContext builtin (TypeCheckerT builtin m) where +instance (PrintableBuiltin builtin, BuiltinHasStandardData builtin, MonadCompile m) => MonadFreeContext builtin (TypeCheckerT builtin m) where addDeclEntryToContext entry = TypeCheckerT . addDeclEntryToContext entry . unTypeCheckerT getFreeCtx = TypeCheckerT . getFreeCtx hideStdLibDecls p f = TypeCheckerT . hideStdLibDecls p f . unTypeCheckerT getHiddenStdLibDecl p = TypeCheckerT . getHiddenStdLibDecl p -instance (PrintableBuiltin builtin, HasStandardData builtin, NormalisableBuiltin builtin, MonadCompile m) => MonadTypeChecker builtin (TypeCheckerT builtin m) where +instance (PrintableBuiltin builtin, BuiltinHasStandardData builtin, NormalisableBuiltin builtin, MonadCompile m) => MonadTypeChecker builtin (TypeCheckerT builtin m) where getMetaState = TypeCheckerT get modifyMetaCtx f = TypeCheckerT $ modify f getFreshName typ = TypeCheckerT $ getFreshNameInternal typ diff --git a/vehicle/src/Vehicle/Compile/Type/Subsystem.hs b/vehicle/src/Vehicle/Compile/Type/Subsystem.hs index 4ffe95d16..6c93d6370 100644 --- a/vehicle/src/Vehicle/Compile/Type/Subsystem.hs +++ b/vehicle/src/Vehicle/Compile/Type/Subsystem.hs @@ -40,7 +40,7 @@ typeCheckWithSubsystem instanceCandidates errorHandler prog = do resolveInstanceArguments :: forall m builtin. - (MonadCompile m, HasStandardData builtin, Show builtin) => + (MonadCompile m, BuiltinHasStandardData builtin, Show builtin) => Prog Ix builtin -> m (Prog Ix builtin) resolveInstanceArguments prog = diff --git a/vehicle/src/Vehicle/Compile/Type/Subsystem/InputOutputInsertion.hs b/vehicle/src/Vehicle/Compile/Type/Subsystem/InputOutputInsertion.hs index fc419771a..5f7d3f33d 100644 --- a/vehicle/src/Vehicle/Compile/Type/Subsystem/InputOutputInsertion.hs +++ b/vehicle/src/Vehicle/Compile/Type/Subsystem/InputOutputInsertion.hs @@ -9,7 +9,7 @@ import Vehicle.Compile.Type.Meta.Map (MetaMap (..)) import Vehicle.Compile.Type.Meta.Map qualified as MetaMap import Vehicle.Compile.Type.Monad (TCM, createFreshInstanceConstraint, freshMetaExpr) import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin ------------------------------------------------------------------------------- diff --git a/vehicle/src/Vehicle/Compile/Variable.hs b/vehicle/src/Vehicle/Compile/Variable.hs index ba7f60021..c86db1c93 100644 --- a/vehicle/src/Vehicle/Compile/Variable.hs +++ b/vehicle/src/Vehicle/Compile/Variable.hs @@ -10,7 +10,7 @@ import Vehicle.Compile.Error import Vehicle.Compile.Prelude import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.QuantifiedVariable import Prelude hiding (Applicative (..)) diff --git a/vehicle/src/Vehicle/Data/Builtin/Interface.hs b/vehicle/src/Vehicle/Data/Builtin/Interface.hs index 4784d2b02..7dde627a7 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Interface.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Interface.hs @@ -19,11 +19,97 @@ import Prelude hiding (pi) -- (e.g. normalisation) once, rather than once for each builtin type. -------------------------------------------------------------------------------- --- HasStandardData +-- Converting builtins + +class ConvertableBuiltin builtin1 builtin2 where + convertBuiltin :: + Provenance -> + builtin1 -> + Expr var builtin2 + +instance ConvertableBuiltin builtin builtin where + convertBuiltin = Builtin + +-------------------------------------------------------------------------------- +-- Printing builtins + +class (Show builtin, ConvertableBuiltin builtin Builtin) => PrintableBuiltin builtin where + -- | Convert expressions with the builtin back to expressions with the standard + -- builtin type. Used for printing. + isCoercion :: builtin -> Bool + +-------------------------------------------------------------------------------- +-- Typable builtin + +class (PrintableBuiltin builtin) => TypableBuiltin builtin where + -- | Construct a type for the builtin + typeBuiltin :: Provenance -> builtin -> Type Ix builtin + +-------------------------------------------------------------------------------- +-- Interface to content of standard builtins +-------------------------------------------------------------------------------- +-- In these classes we need to separate out the types from the literals, as +-- various sets of builtins may have the literals but not the types (e.g. +-- `LinearityBuiltin`) +-------------------------------------------------------------------------------- +-- HasBool + +class BuiltinHasBoolLiterals builtin where + mkBoolBuiltinLit :: Bool -> builtin + getBoolBuiltinLit :: builtin -> Maybe Bool + +-------------------------------------------------------------------------------- +-- HasIndex + +class BuiltinHasIndexLiterals builtin where + mkIndexBuiltinLit :: Int -> builtin + getIndexBuiltinLit :: builtin -> Maybe Int + +-------------------------------------------------------------------------------- +-- HasNat + +class BuiltinHasNatLiterals builtin where + mkNatBuiltinLit :: Int -> builtin + getNatBuiltinLit :: builtin -> Maybe Int + +-------------------------------------------------------------------------------- +-- HasRat + +class BuiltinHasRatLiterals builtin where + mkRatBuiltinLit :: Rational -> builtin + getRatBuiltinLit :: builtin -> Maybe Rational + +class (BuiltinHasRatLiterals builtin) => HasRatTypeBuiltin builtin where + mkRatBuiltinType :: builtin + isRatBuiltinType :: builtin -> Bool + +-------------------------------------------------------------------------------- +-- HasList + +class BuiltinHasListLiterals builtin where + mkBuiltinNil :: builtin + isBuiltinNil :: builtin -> Bool + + mkBuiltinCons :: builtin + isBuiltinCons :: builtin -> Bool + +-------------------------------------------------------------------------------- +-- HasVector + +class BuiltinHasVecLiterals builtin where + mkVecBuiltinLit :: Int -> builtin + getVecBuiltinLit :: builtin -> Maybe Int + +class (BuiltinHasVecLiterals builtin) => HasVecTypeBuiltin builtin where + mkVecBuiltinType :: builtin + isVecBuiltinType :: builtin -> Bool + +-------------------------------------------------------------------------------- +-- BuiltinHasStandardData -- | Indicates that this set of builtins has the standard builtin constructors -- and functions. -class (Show builtin) => HasStandardData builtin where +class BuiltinHasStandardData builtin where mkBuiltinConstructor :: BuiltinConstructor -> builtin getBuiltinConstructor :: builtin -> Maybe BuiltinConstructor @@ -37,95 +123,30 @@ class (Show builtin) => HasStandardData builtin where Just {} -> True Nothing -> False -instance HasStandardData Builtin where - mkBuiltinFunction = BuiltinFunction - getBuiltinFunction = \case - BuiltinFunction c -> Just c - _ -> Nothing - - mkBuiltinConstructor = BuiltinConstructor - getBuiltinConstructor = \case - BuiltinConstructor c -> Just c - _ -> Nothing - - getBuiltinTypeClassOp = \case - TypeClassOp op -> Just op - _ -> Nothing - -------------------------------------------------------------------------------- --- HasStandardTypes +-- BuiltinHasStandardTypes -- | Indicates that this set of builtins has the standard set of types. -class HasStandardTypes builtin where +class BuiltinHasStandardTypes builtin where mkBuiltinType :: BuiltinType -> builtin getBuiltinType :: builtin -> Maybe BuiltinType mkNatInDomainConstraint :: builtin -instance HasStandardTypes Builtin where - mkBuiltinType = BuiltinType - getBuiltinType = \case - BuiltinType c -> Just c - _ -> Nothing - - mkNatInDomainConstraint = NatInDomainConstraint - -------------------------------------------------------------------------------- -- HasStandardBuiltins -- | Indicates that this set of builtins has the standard set of constructors, -- functions and types. -class HasStandardTypeClasses builtin where +class BuiltinHasStandardTypeClasses builtin where mkBuiltinTypeClass :: TypeClass -> builtin -instance HasStandardTypeClasses Builtin where - mkBuiltinTypeClass = TypeClass - -------------------------------------------------------------------------------- -- HasStandardBuiltins -- | Indicates that this set of builtins has the standard set of constructors, -- functions and types. type HasStandardBuiltins builtin = - ( HasStandardTypes builtin, - HasStandardData builtin + ( BuiltinHasStandardTypes builtin, + BuiltinHasStandardData builtin ) - --------------------------------------------------------------------------------- --- Converting builtins - -class ConvertableBuiltin builtin1 builtin2 where - convertBuiltin :: - Provenance -> - builtin1 -> - Expr var builtin2 - -instance ConvertableBuiltin builtin builtin where - convertBuiltin = Builtin - --------------------------------------------------------------------------------- --- Printing builtins - -class (Show builtin, Eq builtin, ConvertableBuiltin builtin Builtin) => PrintableBuiltin builtin where - -- | Convert expressions with the builtin back to expressions with the standard - -- builtin type. Used for printing. - isCoercion :: - builtin -> - Bool - -instance PrintableBuiltin Builtin where - isCoercion = \case - BuiltinFunction FromNat {} -> True - BuiltinFunction FromRat {} -> True - TypeClassOp FromNatTC {} -> True - TypeClassOp FromRatTC {} -> True - TypeClassOp FromVecTC {} -> True - _ -> False - --------------------------------------------------------------------------------- --- Typable builtin - -class (PrintableBuiltin builtin) => TypableBuiltin builtin where - -- | Construct a type for the builtin - typeBuiltin :: - Provenance -> builtin -> Type Ix builtin diff --git a/vehicle/src/Vehicle/Data/Builtin/Linearity.hs b/vehicle/src/Vehicle/Data/Builtin/Linearity.hs index dfa64b91a..6fa2bf54d 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Linearity.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Linearity.hs @@ -16,7 +16,7 @@ import Vehicle.Data.Builtin.Linearity.Core as Core import Vehicle.Data.Builtin.Linearity.Eval () import Vehicle.Data.Builtin.Linearity.LinearitySolver import Vehicle.Data.Builtin.Linearity.Type -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin hiding (Builtin (..)) import Vehicle.Syntax.Builtin qualified as S diff --git a/vehicle/src/Vehicle/Data/Builtin/Linearity/AnnotationRestrictions.hs b/vehicle/src/Vehicle/Data/Builtin/Linearity/AnnotationRestrictions.hs index 40cffc9a6..8551b7956 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Linearity/AnnotationRestrictions.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Linearity/AnnotationRestrictions.hs @@ -10,7 +10,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad import Vehicle.Data.Builtin.Linearity.Core -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value checkNetworkType :: forall m. diff --git a/vehicle/src/Vehicle/Data/Builtin/Linearity/Core.hs b/vehicle/src/Vehicle/Data/Builtin/Linearity/Core.hs index 4d3945046..77186a282 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Linearity/Core.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Linearity/Core.hs @@ -9,8 +9,7 @@ import GHC.Generics (Generic) import Vehicle.Compile.Prelude import Vehicle.Data.Builtin.Interface import Vehicle.Data.DSL -import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin hiding (Builtin (BuiltinConstructor, BuiltinFunction)) import Vehicle.Syntax.Builtin qualified as S @@ -164,7 +163,7 @@ instance ConvertableBuiltin LinearityBuiltin S.Builtin where instance PrintableBuiltin LinearityBuiltin where isCoercion = const False -instance HasStandardData LinearityBuiltin where +instance BuiltinHasStandardData LinearityBuiltin where mkBuiltinFunction = BuiltinFunction getBuiltinFunction = \case BuiltinFunction c -> Just c @@ -177,29 +176,29 @@ instance HasStandardData LinearityBuiltin where getBuiltinTypeClassOp = const Nothing -instance HasBoolLits (Value closure LinearityBuiltin) where - mkBoolLit _p b = VBuiltin (BuiltinConstructor (LBool b)) [] - getBoolLit = \case - VBuiltin (BuiltinConstructor (LBool b)) [] -> Just (mempty, b) +instance BuiltinHasBoolLiterals LinearityBuiltin where + mkBoolBuiltinLit b = BuiltinConstructor (LBool b) + getBoolBuiltinLit = \case + BuiltinConstructor (LBool b) -> Just b _ -> Nothing -instance HasIndexLits (Value closure LinearityBuiltin) where - getIndexLit e = case e of - VBuiltin (BuiltinConstructor (LIndex n)) [] -> Just (mempty, n) +instance BuiltinHasIndexLiterals LinearityBuiltin where + getIndexBuiltinLit e = case e of + BuiltinConstructor (LIndex n) -> Just n _ -> Nothing - mkIndexLit _p x = VBuiltin (BuiltinConstructor (LIndex x)) mempty + mkIndexBuiltinLit x = BuiltinConstructor (LIndex x) -instance HasNatLits (Value closure LinearityBuiltin) where - getNatLit e = case e of - VBuiltin (BuiltinConstructor (LNat b)) [] -> Just (mempty, b) +instance BuiltinHasNatLiterals LinearityBuiltin where + getNatBuiltinLit e = case e of + BuiltinConstructor (LNat b) -> Just b _ -> Nothing - mkNatLit _p x = VBuiltin (BuiltinConstructor (LNat x)) mempty + mkNatBuiltinLit x = BuiltinConstructor (LNat x) -instance HasRatLits (Value closure LinearityBuiltin) where - getRatLit e = case e of - VBuiltin (BuiltinConstructor (LRat b)) [] -> Just (mempty, b) +instance BuiltinHasRatLiterals LinearityBuiltin where + getRatBuiltinLit e = case e of + BuiltinConstructor (LRat b) -> Just b _ -> Nothing - mkRatLit _p x = VBuiltin (BuiltinConstructor (LRat x)) mempty + mkRatBuiltinLit x = BuiltinConstructor (LRat x) ----------------------------------------------------------------------------- -- Patterns diff --git a/vehicle/src/Vehicle/Data/Builtin/Linearity/Eval.hs b/vehicle/src/Vehicle/Data/Builtin/Linearity/Eval.hs index 506d144fd..8b6454d4e 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Linearity/Eval.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Linearity/Eval.hs @@ -7,7 +7,7 @@ import Vehicle.Compile.Prelude (GenericArg (..), MonadLogger, explicit) import Vehicle.Data.Builtin.Linearity.Core (LinearityBuiltin) import Vehicle.Data.Builtin.Linearity.Core qualified as L import Vehicle.Data.Builtin.Standard.Core -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Prelude instance NormalisableBuiltin LinearityBuiltin where diff --git a/vehicle/src/Vehicle/Data/Builtin/Linearity/LinearitySolver.hs b/vehicle/src/Vehicle/Data/Builtin/Linearity/LinearitySolver.hs index 403f7444c..53079b9be 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Linearity/LinearitySolver.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Linearity/LinearitySolver.hs @@ -16,7 +16,7 @@ import Vehicle.Compile.Type.Monad (MonadTypeChecker) import Vehicle.Compile.Type.Monad.Class (addConstraints, solveMeta, substMetas) import Vehicle.Data.Builtin.Linearity.Core import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin solveLinearityConstraint :: diff --git a/vehicle/src/Vehicle/Data/Builtin/Loss/Core.hs b/vehicle/src/Vehicle/Data/Builtin/Loss/Core.hs index 3e38ab12e..0aac727c6 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Loss/Core.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Loss/Core.hs @@ -1,9 +1,7 @@ module Vehicle.Data.Builtin.Loss.Core where import GHC.Generics (Generic) -import Vehicle.Data.Builtin.Interface (ConvertableBuiltin (..), PrintableBuiltin (..)) -import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Builtin.Interface import Vehicle.Data.Expr.Standard (cheatConvertBuiltin) import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (StdForeachIndex)) import Vehicle.Prelude (Pretty (..)) @@ -34,7 +32,7 @@ data LossBuiltin | Rat Rational | NilList | ConsList - | Vector + | Vector Int | ---------------- -- Operations -- ---------------- @@ -75,7 +73,7 @@ instance ConvertableBuiltin LossBuiltin V.Builtin where Rat vs -> builtinConstructor $ V.LRat vs NilList -> builtinConstructor V.Nil ConsList -> builtinConstructor V.Cons - Vector -> builtinConstructor (V.LVec (-1)) + Vector n -> builtinConstructor (V.LVec n) -- Numeric operations Neg dom -> builtinFunction (V.Neg dom) Add dom -> builtinFunction (V.Add dom) @@ -104,37 +102,37 @@ instance ConvertableBuiltin LossBuiltin V.Builtin where instance PrintableBuiltin LossBuiltin where isCoercion = const False -instance HasIndexLits (Value closure LossBuiltin) where - getIndexLit e = case e of - VBuiltin (Index n) [] -> Just (mempty, n) +instance BuiltinHasIndexLiterals LossBuiltin where + getIndexBuiltinLit e = case e of + Index n -> Just n _ -> Nothing - mkIndexLit _p x = VBuiltin (Index x) mempty + mkIndexBuiltinLit = Index -instance HasNatLits (Value closure LossBuiltin) where - getNatLit e = case e of - VBuiltin (Nat b) [] -> Just (mempty, b) +instance BuiltinHasNatLiterals LossBuiltin where + getNatBuiltinLit e = case e of + Nat b -> Just b _ -> Nothing - mkNatLit _p x = VBuiltin (Nat x) mempty + mkNatBuiltinLit = Nat -instance HasRatLits (Value closure LossBuiltin) where - getRatLit e = case e of - VBuiltin (Rat b) [] -> Just (mempty, b) +instance BuiltinHasRatLiterals LossBuiltin where + getRatBuiltinLit e = case e of + Rat b -> Just b _ -> Nothing - mkRatLit _p x = VBuiltin (Rat x) mempty + mkRatBuiltinLit = Rat -instance HasStandardVecLits (Value closure LossBuiltin) where - mkHomoVector t xs = VBuiltin Vector (t : xs) - getHomoVector = \case - VBuiltin Vector (t : xs) -> Just (t, xs) - _ -> Nothing +instance BuiltinHasListLiterals LossBuiltin where + isBuiltinNil e = case e of + NilList -> True + _ -> False + mkBuiltinNil = NilList -instance HasStandardListLits (Value closure LossBuiltin) where - getNil e = case e of - VBuiltin NilList [t] -> Just (mempty, t) - _ -> Nothing - mkNil t = VBuiltin NilList [t] + isBuiltinCons e = case e of + ConsList -> True + _ -> False + mkBuiltinCons = ConsList - getCons e = case e of - VBuiltin ConsList [t, x, xs] -> Just (t, x, xs) +instance BuiltinHasVecLiterals LossBuiltin where + getVecBuiltinLit e = case e of + Vector n -> Just n _ -> Nothing - mkCons t x xs = VBuiltin ConsList [t, x, xs] + mkVecBuiltinLit = Vector diff --git a/vehicle/src/Vehicle/Data/Builtin/Loss/Eval.hs b/vehicle/src/Vehicle/Data/Builtin/Loss/Eval.hs index b06812841..93e1a4f6f 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Loss/Eval.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Loss/Eval.hs @@ -24,7 +24,7 @@ import Vehicle.Compile.Normalise.Builtin ) import Vehicle.Data.Builtin.Loss.Core import Vehicle.Data.Builtin.Standard.Core () -import Vehicle.Data.Expr.Normalised (Value (..)) +import Vehicle.Data.Expr.Value (Value (..)) import Vehicle.Syntax.Builtin qualified as V instance NormalisableBuiltin LossBuiltin where @@ -42,7 +42,7 @@ instance NormalisableBuiltin LossBuiltin where Rat {} -> return unchanged NilList -> return unchanged ConsList -> return unchanged - Vector -> return unchanged + Vector {} -> return unchanged Search {} -> return unchanged -- Numeric operations Neg V.NegRat -> return $ evalNegRat unchanged args diff --git a/vehicle/src/Vehicle/Data/Builtin/Polarity.hs b/vehicle/src/Vehicle/Data/Builtin/Polarity.hs index f84736ef4..2379ddeb1 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Polarity.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Polarity.hs @@ -17,7 +17,7 @@ import Vehicle.Data.Builtin.Polarity.Core as Core hiding (BuiltinFunction) import Vehicle.Data.Builtin.Polarity.Eval () import Vehicle.Data.Builtin.Polarity.PolaritySolver import Vehicle.Data.Builtin.Polarity.Type -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin hiding (Builtin (..)) import Vehicle.Syntax.Builtin qualified as S diff --git a/vehicle/src/Vehicle/Data/Builtin/Polarity/AnnotationRestrictions.hs b/vehicle/src/Vehicle/Data/Builtin/Polarity/AnnotationRestrictions.hs index 521c7fc71..c6c73f0dc 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Polarity/AnnotationRestrictions.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Polarity/AnnotationRestrictions.hs @@ -10,7 +10,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad import Vehicle.Data.Builtin.Polarity.Core -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value checkNetworkType :: forall m. diff --git a/vehicle/src/Vehicle/Data/Builtin/Polarity/Core.hs b/vehicle/src/Vehicle/Data/Builtin/Polarity/Core.hs index d2e02321e..b16c63508 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Polarity/Core.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Polarity/Core.hs @@ -6,11 +6,9 @@ import Data.List.NonEmpty (NonEmpty) import Data.Serialize (Serialize) import GHC.Generics (Generic) import Prettyprinter (Pretty (..), (<+>)) -import Vehicle.Compile.Print (PrintableBuiltin (..)) -import Vehicle.Data.Builtin.Interface (ConvertableBuiltin (..), HasStandardData (..)) +import Vehicle.Data.Builtin.Interface import Vehicle.Data.DSL -import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Prelude (layoutAsText) import Vehicle.Syntax.AST import Vehicle.Syntax.Builtin hiding (Builtin (BuiltinConstructor, BuiltinFunction)) @@ -128,7 +126,7 @@ instance Pretty PolarityBuiltin where Polarity l -> pretty l PolarityRelation c -> pretty c -instance HasStandardData PolarityBuiltin where +instance BuiltinHasStandardData PolarityBuiltin where mkBuiltinFunction = BuiltinFunction getBuiltinFunction = \case BuiltinFunction c -> Just c @@ -141,29 +139,29 @@ instance HasStandardData PolarityBuiltin where getBuiltinTypeClassOp = const Nothing -instance HasBoolLits (Value closure PolarityBuiltin) where - mkBoolLit _p b = VBuiltin (BuiltinConstructor (LBool b)) [] - getBoolLit = \case - VBuiltin (BuiltinConstructor (LBool b)) [] -> Just (mempty, b) +instance BuiltinHasBoolLiterals PolarityBuiltin where + mkBoolBuiltinLit b = BuiltinConstructor (LBool b) + getBoolBuiltinLit = \case + BuiltinConstructor (LBool b) -> Just b _ -> Nothing -instance HasIndexLits (Value closure PolarityBuiltin) where - getIndexLit e = case e of - VBuiltin (BuiltinConstructor (LIndex n)) [] -> Just (mempty, n) +instance BuiltinHasIndexLiterals PolarityBuiltin where + getIndexBuiltinLit e = case e of + BuiltinConstructor (LIndex n) -> Just n _ -> Nothing - mkIndexLit _p x = VBuiltin (BuiltinConstructor (LIndex x)) mempty + mkIndexBuiltinLit x = BuiltinConstructor (LIndex x) -instance HasNatLits (Value closure PolarityBuiltin) where - getNatLit e = case e of - VBuiltin (BuiltinConstructor (LNat b)) [] -> Just (mempty, b) +instance BuiltinHasNatLiterals PolarityBuiltin where + getNatBuiltinLit e = case e of + BuiltinConstructor (LNat b) -> Just b _ -> Nothing - mkNatLit _p x = VBuiltin (BuiltinConstructor (LNat x)) mempty + mkNatBuiltinLit x = BuiltinConstructor (LNat x) -instance HasRatLits (Value closure PolarityBuiltin) where - getRatLit e = case e of - VBuiltin (BuiltinConstructor (LRat b)) [] -> Just (mempty, b) +instance BuiltinHasRatLiterals PolarityBuiltin where + getRatBuiltinLit e = case e of + BuiltinConstructor (LRat b) -> Just b _ -> Nothing - mkRatLit _p x = VBuiltin (BuiltinConstructor (LRat x)) mempty + mkRatBuiltinLit x = BuiltinConstructor (LRat x) instance ConvertableBuiltin PolarityBuiltin S.Builtin where convertBuiltin p = \case diff --git a/vehicle/src/Vehicle/Data/Builtin/Polarity/Eval.hs b/vehicle/src/Vehicle/Data/Builtin/Polarity/Eval.hs index be994c510..ac960f50c 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Polarity/Eval.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Polarity/Eval.hs @@ -7,7 +7,7 @@ import Vehicle.Compile.Prelude (GenericArg (..), MonadLogger, explicit) import Vehicle.Data.Builtin.Polarity.Core (PolarityBuiltin) import Vehicle.Data.Builtin.Polarity.Core qualified as P import Vehicle.Data.Builtin.Standard.Core -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Prelude instance NormalisableBuiltin PolarityBuiltin where diff --git a/vehicle/src/Vehicle/Data/Builtin/Polarity/PolaritySolver.hs b/vehicle/src/Vehicle/Data/Builtin/Polarity/PolaritySolver.hs index deb69fc10..06fe47bca 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Polarity/PolaritySolver.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Polarity/PolaritySolver.hs @@ -13,7 +13,7 @@ import Vehicle.Compile.Type.Core import Vehicle.Compile.Type.Monad import Vehicle.Data.Builtin.Polarity.Core import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Syntax.Builtin solvePolarityConstraint :: diff --git a/vehicle/src/Vehicle/Data/Builtin/Standard.hs b/vehicle/src/Vehicle/Data/Builtin/Standard.hs index 53a2e7dcb..1311de0ae 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Standard.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Standard.hs @@ -20,7 +20,7 @@ import Vehicle.Data.Builtin.Standard.Core as Core import Vehicle.Data.Builtin.Standard.Eval () import Vehicle.Data.Builtin.Standard.InstanceDefaults () import Vehicle.Data.Builtin.Standard.Type -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Prelude hiding (pi) ----------------------------------------------------------------------------- diff --git a/vehicle/src/Vehicle/Data/Builtin/Standard/AnnotationRestrictions.hs b/vehicle/src/Vehicle/Data/Builtin/Standard/AnnotationRestrictions.hs index 09444d496..e304ec7c1 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Standard/AnnotationRestrictions.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Standard/AnnotationRestrictions.hs @@ -12,7 +12,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Type.Monad import Vehicle.Data.Builtin.Standard.Core import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value -------------------------------------------------------------------------------- -- Property diff --git a/vehicle/src/Vehicle/Data/Builtin/Standard/Core.hs b/vehicle/src/Vehicle/Data/Builtin/Standard/Core.hs index c14d49394..50d0b05b3 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Standard/Core.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Standard/Core.hs @@ -5,96 +5,84 @@ module Vehicle.Data.Builtin.Standard.Core ) where -import Data.List.NonEmpty (NonEmpty (..)) -import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised -import Vehicle.Syntax.AST +import Vehicle.Data.Builtin.Interface import Vehicle.Syntax.Builtin as Syntax ----------------------------------------------------------------------------- --- Literal instances for `Value` +-- Classes -instance HasBoolLits (Value closure Builtin) where - mkBoolLit _p b = VBuiltin (BuiltinConstructor (LBool b)) [] - getBoolLit = \case - VBuiltin (BuiltinConstructor (LBool b)) [] -> Just (mempty, b) +instance BuiltinHasBoolLiterals Builtin where + mkBoolBuiltinLit b = BuiltinConstructor (LBool b) + getBoolBuiltinLit = \case + BuiltinConstructor (LBool b) -> Just b _ -> Nothing -instance HasIndexLits (Value closure Builtin) where - getIndexLit e = case e of - VBuiltin (BuiltinConstructor (LIndex n)) [] -> Just (mempty, n) +instance BuiltinHasIndexLiterals Builtin where + getIndexBuiltinLit e = case e of + BuiltinConstructor (LIndex n) -> Just n _ -> Nothing - mkIndexLit _p x = VBuiltin (BuiltinConstructor (LIndex x)) mempty + mkIndexBuiltinLit x = BuiltinConstructor (LIndex x) -instance HasNatLits (Value closure Builtin) where - getNatLit e = case e of - VBuiltin (BuiltinConstructor (LNat b)) [] -> Just (mempty, b) +instance BuiltinHasNatLiterals Builtin where + getNatBuiltinLit e = case e of + BuiltinConstructor (LNat b) -> Just b _ -> Nothing - mkNatLit _p x = VBuiltin (BuiltinConstructor (LNat x)) mempty + mkNatBuiltinLit x = BuiltinConstructor (LNat x) -instance HasRatLits (Value closure Builtin) where - getRatLit e = case e of - VBuiltin (BuiltinConstructor (LRat b)) [] -> Just (mempty, b) +instance BuiltinHasRatLiterals Builtin where + getRatBuiltinLit e = case e of + BuiltinConstructor (LRat b) -> Just b _ -> Nothing - mkRatLit _p x = VBuiltin (BuiltinConstructor (LRat x)) mempty - -instance HasStandardVecLits (Value closure Builtin) where - mkHomoVector t xs = VBuiltin (BuiltinConstructor (LVec (length xs))) (t : xs) - getHomoVector = \case - VBuiltin (BuiltinConstructor (LVec _)) (t : xs) -> Just (t, xs) - _ -> Nothing - -instance HasStandardListLits (Value closure Builtin) where - getNil e = case getConstructor e of - Just (p, Nil, [t]) -> Just (p, t) + mkRatBuiltinLit x = BuiltinConstructor (LRat x) + +instance BuiltinHasListLiterals Builtin where + isBuiltinNil e = case e of + BuiltinConstructor Nil -> True + _ -> False + mkBuiltinNil = BuiltinConstructor Nil + + isBuiltinCons e = case e of + BuiltinConstructor Cons -> True + _ -> False + mkBuiltinCons = BuiltinConstructor Cons + +instance BuiltinHasVecLiterals Builtin where + getVecBuiltinLit e = case e of + BuiltinConstructor (LVec n) -> Just n _ -> Nothing - mkNil t = mkConstructor mempty Nil [t] - - getCons e = case getConstructor e of - Just (_p, Cons, [t, x, xs]) -> Just (t, x, xs) + mkVecBuiltinLit n = BuiltinConstructor (LVec n) + +instance PrintableBuiltin Builtin where + isCoercion = \case + BuiltinFunction FromNat {} -> True + BuiltinFunction FromRat {} -> True + TypeClassOp FromNatTC {} -> True + TypeClassOp FromRatTC {} -> True + TypeClassOp FromVecTC {} -> True + _ -> False + +instance BuiltinHasStandardTypeClasses Builtin where + mkBuiltinTypeClass = TypeClass + +instance BuiltinHasStandardTypes Builtin where + mkBuiltinType = BuiltinType + getBuiltinType = \case + BuiltinType c -> Just c _ -> Nothing - mkCons t x xs = mkConstructor mempty Cons [t, x, xs] ------------------------------------------------------------------------------ --- Literal intstances for `Expr` - -instance HasBoolLits (Expr var Builtin) where - getBoolLit e = case e of - Builtin _ (BuiltinConstructor (LBool b)) -> Just (mempty, b) - _ -> Nothing - mkBoolLit p x = Builtin p (BuiltinConstructor (LBool x)) - -instance HasIndexLits (Expr var Builtin) where - getIndexLit e = case e of - BuiltinExpr _ (BuiltinConstructor (LIndex n)) [] -> Just (mempty, n) - _ -> Nothing - mkIndexLit p x = Builtin p (BuiltinConstructor (LIndex x)) - -instance HasNatLits (Expr var Builtin) where - getNatLit e = case e of - Builtin _ (BuiltinConstructor (LNat b)) -> Just (mempty, b) - _ -> Nothing - mkNatLit p x = Builtin p (BuiltinConstructor (LNat x)) - -instance HasRatLits (Expr var Builtin) where - getRatLit e = case e of - Builtin _ (BuiltinConstructor (LRat b)) -> Just (mempty, b) - _ -> Nothing - mkRatLit p x = Builtin p (BuiltinConstructor (LRat x)) + mkNatInDomainConstraint = NatInDomainConstraint -instance HasStandardVecLits (Expr var Builtin) where - mkHomoVector t xs = BuiltinExpr mempty (BuiltinConstructor (LVec (length xs))) (t :| xs) - getHomoVector = \case - BuiltinExpr _ (BuiltinConstructor (LVec _)) (t :| xs) -> Just (t, xs) +instance BuiltinHasStandardData Builtin where + mkBuiltinFunction = BuiltinFunction + getBuiltinFunction = \case + BuiltinFunction c -> Just c _ -> Nothing -instance HasStandardListLits (Expr var Builtin) where - getNil e = case getConstructor e of - Just (p, Nil, [t]) -> Just (p, t) + mkBuiltinConstructor = BuiltinConstructor + getBuiltinConstructor = \case + BuiltinConstructor c -> Just c _ -> Nothing - mkNil t = mkConstructor mempty Nil [t] - getCons e = case getConstructor e of - Just (_p, Cons, [t, x, xs]) -> Just (t, x, xs) + getBuiltinTypeClassOp = \case + TypeClassOp op -> Just op _ -> Nothing - mkCons t x xs = mkConstructor mempty Cons [t, x, xs] diff --git a/vehicle/src/Vehicle/Data/Builtin/Standard/Eval.hs b/vehicle/src/Vehicle/Data/Builtin/Standard/Eval.hs index 0f44ba335..7d9b995e9 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Standard/Eval.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Standard/Eval.hs @@ -4,8 +4,7 @@ module Vehicle.Data.Builtin.Standard.Eval where import Vehicle.Compile.Normalise.Builtin import Vehicle.Data.Builtin.Standard.Core -import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value instance NormalisableBuiltin Builtin where evalBuiltinApp = evalTypeClassOp evalBuiltinFunction diff --git a/vehicle/src/Vehicle/Data/Builtin/Standard/InstanceDefaults.hs b/vehicle/src/Vehicle/Data/Builtin/Standard/InstanceDefaults.hs index bcbeee40f..80b31ef62 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Standard/InstanceDefaults.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Standard/InstanceDefaults.hs @@ -9,7 +9,7 @@ import Vehicle.Compile.Type.Constraint.InstanceDefaultSolver import Vehicle.Compile.Type.Core import Vehicle.Data.Builtin.Standard.Core import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value instance HasInstanceDefaults Builtin where getCandidatesFromConstraint = getCandidates diff --git a/vehicle/src/Vehicle/Data/Builtin/Standard/Type.hs b/vehicle/src/Vehicle/Data/Builtin/Standard/Type.hs index 6c6909ea0..01a04a490 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Standard/Type.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Standard/Type.hs @@ -39,7 +39,7 @@ typeStandardBuiltin p b = fromDSL p $ case b of -------------------------------------------------------------------------------- -- Type classes -typeOfTypeClass :: (HasStandardTypes builtin) => TypeClass -> DSLExpr builtin +typeOfTypeClass :: (BuiltinHasStandardTypes builtin) => TypeClass -> DSLExpr builtin typeOfTypeClass tc = case tc of HasEq {} -> type0 ~> type0 ~> type0 HasOrd {} -> type0 ~> type0 ~> type0 @@ -56,7 +56,7 @@ typeOfTypeClass tc = case tc of HasRatLits -> type0 ~> type0 HasVecLits {} -> tNat ~> type0 ~> type0 -typeOfTypeClassOp :: (HasStandardBuiltins builtin, HasStandardTypeClasses builtin) => TypeClassOp -> DSLExpr builtin +typeOfTypeClassOp :: (HasStandardBuiltins builtin, BuiltinHasStandardTypeClasses builtin) => TypeClassOp -> DSLExpr builtin typeOfTypeClassOp b = case b of NegTC -> typeOfTCOp1 hasNeg AddTC -> typeOfTCOp2 hasAdd @@ -154,7 +154,7 @@ typeOfBuiltinConstructor = \case LRat {} -> tRat LVec n -> typeOfVectorLiteral n -typeOfIf :: (HasStandardTypes builtin) => DSLExpr builtin +typeOfIf :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin typeOfIf = forAll "A" type0 $ \t -> tBool ~> t ~> t ~> t @@ -172,16 +172,16 @@ typeOfTCOp2 constraint = forAll "C" type0 $ \t3 -> constraint t1 t2 t3 ~~~> t1 ~> t2 ~> t3 -typeOfTCComparisonOp :: (HasStandardTypes builtin) => (DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin +typeOfTCComparisonOp :: (BuiltinHasStandardTypes builtin) => (DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin typeOfTCComparisonOp constraint = forAll "A" type0 $ \t1 -> forAll "B" type0 $ \t2 -> constraint t1 t2 ~~~> typeOfComparisonOp t1 t2 -typeOfComparisonOp :: (HasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +typeOfComparisonOp :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin typeOfComparisonOp t1 t2 = t1 ~> t2 ~> tBool -typeOfIndices :: (HasStandardTypes builtin) => DSLExpr builtin +typeOfIndices :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin typeOfIndices = pi (Just "n") Explicit Relevant tNat $ \n -> tVector (tIndex n) n diff --git a/vehicle/src/Vehicle/Data/Builtin/Tensor.hs b/vehicle/src/Vehicle/Data/Builtin/Tensor.hs index aee4fb816..057466c60 100644 --- a/vehicle/src/Vehicle/Data/Builtin/Tensor.hs +++ b/vehicle/src/Vehicle/Data/Builtin/Tensor.hs @@ -6,8 +6,8 @@ import Vehicle.Compile.Arity (Arity) import Vehicle.Data.Builtin.Interface (ConvertableBuiltin (..), PrintableBuiltin (..)) import Vehicle.Data.Builtin.Standard.Core () import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised import Vehicle.Data.Expr.Standard (cheatConvertBuiltin) +import Vehicle.Data.Expr.Value import Vehicle.Data.Tensor (Tensor (..)) import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (..)) import Vehicle.Prelude.Prettyprinter diff --git a/vehicle/src/Vehicle/Data/Expr/DSL.hs b/vehicle/src/Vehicle/Data/Expr/DSL.hs index bfbf3e03f..48e6c93d6 100644 --- a/vehicle/src/Vehicle/Data/Expr/DSL.hs +++ b/vehicle/src/Vehicle/Data/Expr/DSL.hs @@ -10,79 +10,79 @@ import Prelude hiding (pi) -------------------------------------------------------------------------------- -- Types DSL -builtinType :: (HasStandardTypes builtin) => BuiltinType -> DSLExpr builtin +builtinType :: (BuiltinHasStandardTypes builtin) => BuiltinType -> DSLExpr builtin builtinType = builtin . mkBuiltinType -tUnit :: (HasStandardTypes builtin) => DSLExpr builtin +tUnit :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin tUnit = builtinType Unit -tBool, tNat, tRat :: (HasStandardTypes builtin) => DSLExpr builtin +tBool, tNat, tRat :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin tBool = builtinType Bool tNat = builtinType Nat tRat = builtinType Rat -tVector :: (HasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +tVector :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin tVector tElem dim = builtinType Vector @@ [tElem] .@@ [dim] -tVectorFunctor :: (HasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin +tVectorFunctor :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin tVectorFunctor n = explLam "A" type0 (`tVector` n) -tListRaw :: (HasStandardTypes builtin) => DSLExpr builtin +tListRaw :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin tListRaw = builtinType List -tList :: (HasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin +tList :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin tList tElem = tListRaw @@ [tElem] -tIndex :: (HasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin +tIndex :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin tIndex n = builtinType Index .@@ [n] -forAllNat :: (HasStandardTypes builtin) => (DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin +forAllNat :: (BuiltinHasStandardTypes builtin) => (DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin forAllNat = forAll "n" tNat -forAllIrrelevantNat :: (HasStandardTypes builtin) => Name -> (DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin +forAllIrrelevantNat :: (BuiltinHasStandardTypes builtin) => Name -> (DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin forAllIrrelevantNat name = pi (Just name) (Implicit False) Irrelevant tNat -irrelImplNatLam :: (HasStandardTypes builtin) => Name -> (DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin +irrelImplNatLam :: (BuiltinHasStandardTypes builtin) => Name -> (DSLExpr builtin -> DSLExpr builtin) -> DSLExpr builtin irrelImplNatLam n = lam n (Implicit False) Irrelevant tNat -natInDomainConstraint :: (HasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +natInDomainConstraint :: (BuiltinHasStandardTypes builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin natInDomainConstraint n t = builtin mkNatInDomainConstraint @@ [n, t] -------------------------------------------------------------------------------- -- Constructors DSL -builtinConstructor :: (HasStandardData builtin) => BuiltinConstructor -> DSLExpr builtin +builtinConstructor :: (BuiltinHasStandardData builtin) => BuiltinConstructor -> DSLExpr builtin builtinConstructor = builtin . mkBuiltinConstructor -nil :: (HasStandardData builtin) => DSLExpr builtin -> DSLExpr builtin +nil :: (BuiltinHasStandardData builtin) => DSLExpr builtin -> DSLExpr builtin nil tElem = builtinConstructor Nil @@@ [tElem] -cons :: (HasStandardData builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +cons :: (BuiltinHasStandardData builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin cons tElem x xs = builtinConstructor Cons @@@ [tElem] @@ [x, xs] -natLit :: (HasStandardData builtin) => Int -> DSLExpr builtin +natLit :: (BuiltinHasStandardData builtin) => Int -> DSLExpr builtin natLit n = builtinConstructor (LNat n) -boolLit :: (HasStandardData builtin) => Bool -> DSLExpr builtin +boolLit :: (BuiltinHasStandardData builtin) => Bool -> DSLExpr builtin boolLit n = builtinConstructor (LBool n) -ratLit :: (HasStandardData builtin) => Rational -> DSLExpr builtin +ratLit :: (BuiltinHasStandardData builtin) => Rational -> DSLExpr builtin ratLit r = builtinConstructor (LRat r) -unitLit :: (HasStandardData builtin) => DSLExpr builtin +unitLit :: (BuiltinHasStandardData builtin) => DSLExpr builtin unitLit = builtinConstructor LUnit -------------------------------------------------------------------------------- -- Functions DSL -builtinFunction :: (HasStandardData builtin) => BuiltinFunction -> DSLExpr builtin +builtinFunction :: (BuiltinHasStandardData builtin) => BuiltinFunction -> DSLExpr builtin builtinFunction = builtin . mkBuiltinFunction -addNat :: (HasStandardData builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +addNat :: (BuiltinHasStandardData builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin addNat x y = builtinFunction (Add AddNat) @@ [x, y] ite :: - (HasStandardData builtin) => + (BuiltinHasStandardData builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> @@ -93,53 +93,53 @@ ite t c e1 e2 = builtinFunction If @@@ [t] @@ [c, e1, e2] -------------------------------------------------------------------------------- -- Type classes -builtinTypeClass :: (HasStandardTypeClasses builtin) => TypeClass -> DSLExpr builtin +builtinTypeClass :: (BuiltinHasStandardTypeClasses builtin) => TypeClass -> DSLExpr builtin builtinTypeClass = builtin . mkBuiltinTypeClass -typeClass :: (HasStandardTypeClasses builtin) => TypeClass -> NonEmpty (DSLExpr builtin) -> DSLExpr builtin +typeClass :: (BuiltinHasStandardTypeClasses builtin) => TypeClass -> NonEmpty (DSLExpr builtin) -> DSLExpr builtin typeClass tc args = builtinTypeClass tc @@ args -hasEq :: (HasStandardTypeClasses builtin) => EqualityOp -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasEq :: (BuiltinHasStandardTypeClasses builtin) => EqualityOp -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasEq eq t1 t2 = typeClass (HasEq eq) [t1, t2] -hasOrd :: (HasStandardTypeClasses builtin) => OrderOp -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasOrd :: (BuiltinHasStandardTypeClasses builtin) => OrderOp -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasOrd ord t1 t2 = typeClass (HasOrd ord) [t1, t2] -hasQuantifier :: (HasStandardTypeClasses builtin) => Quantifier -> DSLExpr builtin -> DSLExpr builtin +hasQuantifier :: (BuiltinHasStandardTypeClasses builtin) => Quantifier -> DSLExpr builtin -> DSLExpr builtin hasQuantifier q t = typeClass (HasQuantifier q) [t] -numOp2TypeClass :: (HasStandardTypeClasses builtin) => TypeClass -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +numOp2TypeClass :: (BuiltinHasStandardTypeClasses builtin) => TypeClass -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin numOp2TypeClass tc t1 t2 t3 = typeClass tc [t1, t2, t3] -hasAdd :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasAdd :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasAdd = numOp2TypeClass HasAdd -hasSub :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasSub :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasSub = numOp2TypeClass HasSub -hasMul :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasMul :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasMul = numOp2TypeClass HasMul -hasDiv :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasDiv :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasDiv = numOp2TypeClass HasDiv -hasNeg :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasNeg :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasNeg t1 t2 = typeClass HasNeg [t1, t2] -hasMap :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin +hasMap :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin hasMap tCont = typeClass HasMap [tCont] -hasFold :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin +hasFold :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin hasFold tCont = typeClass HasFold [tCont] -hasQuantifierIn :: (HasStandardTypeClasses builtin) => Quantifier -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasQuantifierIn :: (BuiltinHasStandardTypeClasses builtin) => Quantifier -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasQuantifierIn q tCont tElem tRes = typeClass (HasQuantifierIn q) [tCont, tElem, tRes] -hasNatLits :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin +hasNatLits :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin hasNatLits t = typeClass HasNatLits [t] -hasRatLits :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin +hasRatLits :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin hasRatLits t = typeClass HasRatLits [t] -hasVecLits :: (HasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin +hasVecLits :: (BuiltinHasStandardTypeClasses builtin) => DSLExpr builtin -> DSLExpr builtin -> DSLExpr builtin hasVecLits n d = typeClass HasVecLits [n, d] diff --git a/vehicle/src/Vehicle/Data/Expr/Interface.hs b/vehicle/src/Vehicle/Data/Expr/Interface.hs index c663b4444..eb0dfa12c 100644 --- a/vehicle/src/Vehicle/Data/Expr/Interface.hs +++ b/vehicle/src/Vehicle/Data/Expr/Interface.hs @@ -1,7 +1,5 @@ module Vehicle.Data.Expr.Interface where -import Vehicle.Data.Builtin.Interface -import Vehicle.Data.Expr.Normalised import Vehicle.Data.Tensor import Vehicle.Libraries.StandardLibrary.Definitions import Vehicle.Prelude (TensorShape) @@ -82,7 +80,7 @@ pattern IRatLiteral p n <- (getRatLit -> Just (p, n)) class HasStandardListLits expr where getNil :: expr -> Maybe (Provenance, GenericArg expr) mkNil :: GenericArg expr -> expr - getCons :: expr -> Maybe (GenericArg expr, GenericArg expr, GenericArg expr) + getCons :: expr -> Maybe (Provenance, GenericArg expr, GenericArg expr, GenericArg expr) mkCons :: GenericArg expr -> GenericArg expr -> GenericArg expr -> expr pattern INil :: (HasStandardListLits expr) => GenericArg expr -> expr @@ -91,7 +89,7 @@ pattern INil t <- (getNil -> Just (_, t)) INil t = mkNil t pattern ICons :: (HasStandardListLits expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> expr -pattern ICons t x xs <- (getCons -> Just (t, x, xs)) +pattern ICons t x xs <- (getCons -> Just (_, t, x, xs)) where ICons t x xs = mkCons t x xs @@ -110,11 +108,11 @@ pattern IVecLiteral t xs <- (getHomoVector -> Just (t, xs)) IVecLiteral t xs = mkHomoVector t xs -------------------------------------------------------------------------------- --- HasStandardData +-- BuiltinHasStandardData -- | Indicates that this set of builtins has the standard builtin constructors -- and functions. -class HasStandardDataExpr expr where +class HasStandardData expr where mkConstructor :: Provenance -> BuiltinConstructor -> [GenericArg expr] -> expr getConstructor :: expr -> Maybe (Provenance, BuiltinConstructor, [GenericArg expr]) @@ -126,110 +124,42 @@ class HasStandardDataExpr expr where getTypeClassOp :: expr -> Maybe (Provenance, TypeClassOp, [GenericArg expr]) -instance (HasStandardData builtin) => HasStandardDataExpr (Expr var builtin) where - mkFunction p b = normAppList (Builtin p (mkBuiltinFunction b)) - getFunction e = case getBuiltinApp e of - Just (p, b, args) -> case getBuiltinFunction b of - Just f -> Just (p, f, args) - Nothing -> Nothing - _ -> Nothing - - mkConstructor p b = normAppList (Builtin p (mkBuiltinConstructor b)) - getConstructor e = case getBuiltinApp e of - Just (p, b, args) -> case getBuiltinConstructor b of - Just f -> Just (p, f, args) - Nothing -> Nothing - _ -> Nothing - - mkFreeVar p ident = normAppList (FreeVar p ident) - getFreeVar e = case getFreeVarApp e of - Just (p, ident, args) -> Just (p, ident, args) - _ -> Nothing - - getTypeClassOp e = case getBuiltinApp e of - Just (p, b, args) -> case getBuiltinTypeClassOp b of - Just f -> Just (p, f, args) - Nothing -> Nothing - _ -> Nothing - -instance (HasStandardData builtin) => HasStandardDataExpr (Value closure builtin) where - mkFunction _p b = VBuiltin (mkBuiltinFunction b) - getFunction e = case e of - VBuiltin b args -> case getBuiltinFunction b of - Just t -> Just (mempty, t, args) - Nothing -> Nothing - _ -> Nothing - - mkConstructor _p b = VBuiltin (mkBuiltinConstructor b) - getConstructor e = case e of - VBuiltin b args -> case getBuiltinConstructor b of - Just t -> Just (mempty, t, args) - Nothing -> Nothing - _ -> Nothing - - mkFreeVar _p = VFreeVar - getFreeVar = \case - VFreeVar ident args -> Just (mempty, ident, args) - _ -> Nothing - - getTypeClassOp e = case e of - VBuiltin b args -> case getBuiltinTypeClassOp b of - Just op -> Just (mempty, op, args) - Nothing -> Nothing - _ -> Nothing - -------------------------------------------------------------------------------- --- HasStandardTypes +-- BuiltinHasStandardTypes -- | Indicates that this set of builtins has the standard set of types. -class HasStandardTypesExpr expr where +class HasStandardTypes expr where mkType :: Provenance -> BuiltinType -> [GenericArg expr] -> expr getType :: expr -> Maybe (Provenance, BuiltinType, [GenericArg expr]) -instance (HasStandardTypes builtin) => HasStandardTypesExpr (Expr var builtin) where - mkType p b = normAppList (Builtin p (mkBuiltinType b)) - getType e = case getBuiltinApp e of - Just (p, b, args) -> case getBuiltinType b of - Just t -> Just (p, t, args) - Nothing -> Nothing - _ -> Nothing - -instance (HasStandardTypes builtin) => HasStandardTypesExpr (Value closure builtin) where - mkType _p b = VBuiltin (mkBuiltinType b) - getType e = case e of - VBuiltin b args -> case getBuiltinType b of - Just t -> Just (mempty, t, args) - Nothing -> Nothing - _ -> Nothing - -------------------------------------------------------------------------------- -- Constructors -pattern INullaryTypeExpr :: (HasStandardTypesExpr expr) => Provenance -> BuiltinType -> expr +pattern INullaryTypeExpr :: (HasStandardTypes expr) => Provenance -> BuiltinType -> expr pattern INullaryTypeExpr p b <- (getType -> Just (p, b, [])) where INullaryTypeExpr p b = mkType p b [] -pattern IUnitType :: (HasStandardTypesExpr expr) => Provenance -> expr +pattern IUnitType :: (HasStandardTypes expr) => Provenance -> expr pattern IUnitType p = INullaryTypeExpr p Unit -pattern IBoolType :: (HasStandardTypesExpr expr) => Provenance -> expr +pattern IBoolType :: (HasStandardTypes expr) => Provenance -> expr pattern IBoolType p = INullaryTypeExpr p Bool -pattern IIndexType :: (HasStandardTypesExpr expr) => Provenance -> expr -> expr +pattern IIndexType :: (HasStandardTypes expr) => Provenance -> expr -> expr pattern IIndexType p size <- (getType -> Just (p, Index, [IrrelevantExplicitArg _ size])) -pattern INatType :: (HasStandardTypesExpr expr) => Provenance -> expr +pattern INatType :: (HasStandardTypes expr) => Provenance -> expr pattern INatType p = INullaryTypeExpr p Nat -pattern IRatType :: (HasStandardTypesExpr expr) => Provenance -> expr +pattern IRatType :: (HasStandardTypes expr) => Provenance -> expr pattern IRatType p = INullaryTypeExpr p Rat -pattern IListType :: (HasStandardTypesExpr expr) => Provenance -> expr -> expr +pattern IListType :: (HasStandardTypes expr) => Provenance -> expr -> expr pattern IListType p tElem <- (getType -> Just (p, List, [RelevantExplicitArg _ tElem])) pattern IVectorType :: - (HasStandardTypesExpr expr) => + (HasStandardTypes expr) => Provenance -> expr -> expr -> @@ -239,19 +169,19 @@ pattern IVectorType p tElem tDim <- where IVectorType p tElem tDim = mkType p Vector [Arg p Explicit Relevant tElem, Arg p Explicit Irrelevant tDim] -pattern IRawListType :: (HasStandardTypesExpr expr) => Provenance -> expr +pattern IRawListType :: (HasStandardTypes expr) => Provenance -> expr pattern IRawListType p = INullaryTypeExpr p List -------------------------------------------------------------------------------- -- Constructors -- Can't use `[]` in a bidrectional pattern synonym until GHC 9.4.3?? -pattern INullaryConstructor :: (HasStandardDataExpr expr) => Provenance -> BuiltinConstructor -> expr +pattern INullaryConstructor :: (HasStandardData expr) => Provenance -> BuiltinConstructor -> expr pattern INullaryConstructor p t <- (getConstructor -> Just (p, t, [])) where INullaryConstructor p t = mkConstructor p t [] -pattern IUnitLiteral :: (HasStandardDataExpr expr) => Provenance -> expr +pattern IUnitLiteral :: (HasStandardData expr) => Provenance -> expr pattern IUnitLiteral p = INullaryConstructor p LUnit mkListExpr :: (HasStandardListLits expr) => expr -> [expr] -> expr @@ -263,7 +193,7 @@ mkListExpr typ = foldr cons nil nil = INil tArg cons y ys = ICons tArg (mkExpl y) (mkExpl ys) -mkVecExpr :: (HasStandardDataExpr expr) => [expr] -> expr +mkVecExpr :: (HasStandardData expr) => [expr] -> expr mkVecExpr xs = mkConstructor mempty @@ -271,7 +201,7 @@ mkVecExpr xs = (Arg mempty (Implicit True) Relevant (IUnitLiteral mempty) : (Arg mempty Explicit Relevant <$> xs)) mkTensorLayer :: - (HasStandardVecLits expr, HasStandardListLits expr, HasStandardTypesExpr expr, HasNatLits expr) => + (HasStandardVecLits expr, HasStandardListLits expr, HasStandardTypes expr, HasNatLits expr) => TensorShape -> [expr] -> expr @@ -282,7 +212,7 @@ mkTensorLayer dims xs = do mkHomoVector elementType elements tensorLikeToExpr :: - (HasStandardVecLits expr, HasStandardListLits expr, HasStandardTypesExpr expr, HasNatLits expr) => + (HasStandardVecLits expr, HasStandardListLits expr, HasStandardTypes expr, HasNatLits expr) => (a -> expr) -> TensorShape -> [a] -> @@ -290,7 +220,7 @@ tensorLikeToExpr :: tensorLikeToExpr mkElem = foldMapTensorLike mkElem mkTensorLayer tensorToExpr :: - (HasStandardVecLits expr, HasStandardListLits expr, HasStandardTypesExpr expr, HasNatLits expr) => + (HasStandardVecLits expr, HasStandardListLits expr, HasStandardTypes expr, HasNatLits expr) => (a -> expr) -> Tensor a -> expr @@ -299,88 +229,88 @@ tensorToExpr mkElem = foldMapTensor mkElem mkTensorLayer -------------------------------------------------------------------------------- -- Functions -pattern BuiltinFunc :: (HasStandardDataExpr expr) => BuiltinFunction -> [GenericArg expr] -> expr +pattern BuiltinFunc :: (HasStandardData expr) => BuiltinFunction -> [GenericArg expr] -> expr pattern BuiltinFunc f args <- (getFunction -> Just (_, f, args)) where BuiltinFunc f args = mkFunction mempty f args -pattern IOp1 :: (HasStandardDataExpr expr) => BuiltinFunction -> expr -> expr +pattern IOp1 :: (HasStandardData expr) => BuiltinFunction -> expr -> expr pattern IOp1 op x <- BuiltinFunc op [RelevantExplicitArg _ x] where IOp1 op x = BuiltinFunc op [Arg mempty Explicit Relevant x] -pattern IOp2 :: (HasStandardDataExpr expr) => BuiltinFunction -> expr -> expr -> expr +pattern IOp2 :: (HasStandardData expr) => BuiltinFunction -> expr -> expr -> expr pattern IOp2 op x y <- BuiltinFunc op [RelevantExplicitArg _ x, RelevantExplicitArg _ y] where IOp2 op x y = BuiltinFunc op [Arg mempty Explicit Relevant x, Arg mempty Explicit Relevant y] -pattern IAnd :: (HasStandardDataExpr expr) => expr -> expr -> expr +pattern IAnd :: (HasStandardData expr) => expr -> expr -> expr pattern IAnd x y = IOp2 And x y -pattern IOr :: (HasStandardDataExpr expr) => expr -> expr -> expr +pattern IOr :: (HasStandardData expr) => expr -> expr -> expr pattern IOr x y = IOp2 Or x y -pattern INot :: (HasStandardDataExpr expr) => expr -> expr +pattern INot :: (HasStandardData expr) => expr -> expr pattern INot x = IOp1 Not x -pattern IIf :: (HasStandardDataExpr expr) => expr -> expr -> expr -> expr -> expr +pattern IIf :: (HasStandardData expr) => expr -> expr -> expr -> expr -> expr pattern IIf t c x y <- BuiltinFunc If [RelevantImplicitArg _ t, RelevantExplicitArg _ c, RelevantExplicitArg _ x, RelevantExplicitArg _ y] where IIf t c x y = BuiltinFunc If [Arg mempty (Implicit True) Relevant t, Arg mempty Explicit Relevant c, Arg mempty Explicit Relevant x, Arg mempty Explicit Relevant y] -pattern IOrderOp :: (HasStandardDataExpr expr) => OrderDomain -> OrderOp -> expr -> expr -> [GenericArg expr] -> expr +pattern IOrderOp :: (HasStandardData expr) => OrderDomain -> OrderOp -> expr -> expr -> [GenericArg expr] -> expr pattern IOrderOp dom op x y args <- BuiltinFunc (Order dom op) (reverse -> (argExpr -> y) : (argExpr -> x) : args) -pattern IOrder :: (HasStandardDataExpr expr) => OrderDomain -> OrderOp -> expr -> expr -> expr +pattern IOrder :: (HasStandardData expr) => OrderDomain -> OrderOp -> expr -> expr -> expr pattern IOrder dom op x y <- IOrderOp dom op x y _ -pattern IOrderRat :: (HasStandardDataExpr expr) => OrderOp -> expr -> expr -> expr +pattern IOrderRat :: (HasStandardData expr) => OrderOp -> expr -> expr -> expr pattern IOrderRat op x y = IOp2 (Order OrderRat op) x y -pattern IEqualOp :: (HasStandardDataExpr expr) => EqualityDomain -> EqualityOp -> expr -> expr -> [GenericArg expr] -> expr +pattern IEqualOp :: (HasStandardData expr) => EqualityDomain -> EqualityOp -> expr -> expr -> [GenericArg expr] -> expr pattern IEqualOp dom op x y args <- BuiltinFunc (Equals dom op) (reverse -> (argExpr -> y) : (argExpr -> x) : args) -pattern IEqual :: (HasStandardDataExpr expr) => EqualityDomain -> expr -> expr -> expr +pattern IEqual :: (HasStandardData expr) => EqualityDomain -> expr -> expr -> expr pattern IEqual dom x y <- IEqualOp dom Eq x y _ -pattern IEqualRatOp :: (HasStandardDataExpr expr) => EqualityOp -> expr -> expr -> expr +pattern IEqualRatOp :: (HasStandardData expr) => EqualityOp -> expr -> expr -> expr pattern IEqualRatOp op x y = IOp2 (Equals EqRat op) x y -pattern IEqualRat :: (HasStandardDataExpr expr) => expr -> expr -> expr +pattern IEqualRat :: (HasStandardData expr) => expr -> expr -> expr pattern IEqualRat x y = IEqualRatOp Eq x y -pattern INotEqual :: (HasStandardDataExpr expr) => EqualityDomain -> expr -> expr -> expr +pattern INotEqual :: (HasStandardData expr) => EqualityDomain -> expr -> expr -> expr pattern INotEqual dom x y <- IEqualOp dom Neq x y _ -pattern INeg :: (HasStandardDataExpr expr) => NegDomain -> expr -> expr +pattern INeg :: (HasStandardData expr) => NegDomain -> expr -> expr pattern INeg dom x = IOp1 (Neg dom) x -pattern IAdd :: (HasStandardDataExpr expr) => AddDomain -> expr -> expr -> expr +pattern IAdd :: (HasStandardData expr) => AddDomain -> expr -> expr -> expr pattern IAdd dom x y = IOp2 (Add dom) x y -pattern ISub :: (HasStandardDataExpr expr) => SubDomain -> expr -> expr -> expr +pattern ISub :: (HasStandardData expr) => SubDomain -> expr -> expr -> expr pattern ISub dom x y = IOp2 (Sub dom) x y -pattern IMul :: (HasStandardDataExpr expr) => MulDomain -> expr -> expr -> expr +pattern IMul :: (HasStandardData expr) => MulDomain -> expr -> expr -> expr pattern IMul dom x y = IOp2 (Mul dom) x y -pattern IDiv :: (HasStandardDataExpr expr) => DivDomain -> expr -> expr -> expr +pattern IDiv :: (HasStandardData expr) => DivDomain -> expr -> expr -> expr pattern IDiv dom x y = IOp2 (Div dom) x y -pattern IMax :: (HasStandardDataExpr expr) => expr -> expr -> expr +pattern IMax :: (HasStandardData expr) => expr -> expr -> expr pattern IMax x y = IOp2 MaxRat x y -pattern IMin :: (HasStandardDataExpr expr) => expr -> expr -> expr +pattern IMin :: (HasStandardData expr) => expr -> expr -> expr pattern IMin x y = IOp2 MinRat x y pattern VIndices :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => expr -> expr pattern VIndices n <- BuiltinFunc Indices [argExpr -> n] pattern IAt :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> expr -> @@ -389,7 +319,7 @@ pattern IAt :: pattern IAt t n xs i <- BuiltinFunc At [t, n, argExpr -> xs, argExpr -> i] pattern IFoldVector :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> @@ -402,7 +332,7 @@ pattern IFoldVector n a b f e xs <- BuiltinFunc FoldVector [n, a, b, argExpr -> IFoldVector n a b f e xs = BuiltinFunc FoldVector [n, a, b, Arg mempty Explicit Relevant f, Arg mempty Explicit Relevant e, Arg mempty Explicit Relevant xs] pattern IMapVector :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> @@ -412,7 +342,7 @@ pattern IMapVector :: pattern IMapVector n a b f xs <- BuiltinFunc MapVector [n, a, b, argExpr -> f, argExpr -> xs] pattern IZipWithVector :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> @@ -424,7 +354,7 @@ pattern IZipWithVector :: pattern IZipWithVector a b c n f xs ys <- BuiltinFunc ZipWithVector [a, b, c, n, argExpr -> f, argExpr -> xs, argExpr -> ys] pattern IInfiniteQuantifier :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => Quantifier -> [GenericArg expr] -> expr -> @@ -436,14 +366,14 @@ pattern IInfiniteQuantifier q args lam <- BuiltinFunc (Quantifier q) (reverse (Arg mempty Explicit Relevant lam : args)) pattern IForall :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => [GenericArg expr] -> expr -> expr pattern IForall args lam = IInfiniteQuantifier Forall args lam pattern IExists :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => [GenericArg expr] -> expr -> expr @@ -452,18 +382,18 @@ pattern IExists args lam = IInfiniteQuantifier Exists args lam -------------------------------------------------------------------------------- -- Iector operation patterns -pattern IFreeVar :: (HasStandardDataExpr expr) => Identifier -> [GenericArg expr] -> expr +pattern IFreeVar :: (HasStandardData expr) => Identifier -> [GenericArg expr] -> expr pattern IFreeVar fn spine <- (getFreeVar -> Just (_, fn, spine)) where IFreeVar fn spine = mkFreeVar mempty fn spine -pattern IStandardLib :: (HasStandardDataExpr expr) => StdLibFunction -> [GenericArg expr] -> expr +pattern IStandardLib :: (HasStandardData expr) => StdLibFunction -> [GenericArg expr] -> expr pattern IStandardLib fn spine <- IFreeVar (findStdLibFunction -> Just fn) spine where IStandardLib fn spine = IFreeVar (identifierOf fn) spine pattern IVecEqSpine :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> @@ -476,7 +406,7 @@ pattern IVecEqSpine t1 t2 dim sol x y <- [t1, t2, dim, sol, argExpr -> x, argExp IVecEqSpine t1 t2 dim sol x y = [t1, t2, dim, sol, Arg mempty Explicit Relevant x, Arg mempty Explicit Relevant y] pattern IVecOp2Spine :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> @@ -488,26 +418,26 @@ pattern IVecOp2Spine :: pattern IVecOp2Spine t1 t2 t3 dim sol x y <- [t1, t2, t3, dim, sol, argExpr -> x, argExpr -> y] pattern IVecEqArgs :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => expr -> expr -> [GenericArg expr] pattern IVecEqArgs x y <- IVecEqSpine _ _ _ _ x y -pattern IVectorEqualFull :: (HasStandardDataExpr expr) => [GenericArg expr] -> expr +pattern IVectorEqualFull :: (HasStandardData expr) => [GenericArg expr] -> expr pattern IVectorEqualFull spine = IStandardLib StdEqualsVector spine -pattern IVectorNotEqualFull :: (HasStandardDataExpr expr) => [GenericArg expr] -> expr +pattern IVectorNotEqualFull :: (HasStandardData expr) => [GenericArg expr] -> expr pattern IVectorNotEqualFull spine = IStandardLib StdNotEqualsVector spine -pattern IVectorEqual :: (HasStandardDataExpr expr) => expr -> expr -> expr +pattern IVectorEqual :: (HasStandardData expr) => expr -> expr -> expr pattern IVectorEqual x y <- IVectorEqualFull (IVecEqArgs x y) -pattern IVectorNotEqual :: (HasStandardDataExpr expr) => expr -> expr -> expr +pattern IVectorNotEqual :: (HasStandardData expr) => expr -> expr -> expr pattern IVectorNotEqual x y <- IVectorNotEqualFull (IVecEqArgs x y) pattern IVectorAdd :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> @@ -519,7 +449,7 @@ pattern IVectorAdd :: pattern IVectorAdd a b c n f x y <- IStandardLib StdAddVector [a, b, c, n, f, argExpr -> x, argExpr -> y] pattern IVectorSub :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> GenericArg expr -> GenericArg expr -> @@ -531,7 +461,7 @@ pattern IVectorSub :: pattern IVectorSub a b c n f x y <- IStandardLib StdSubVector [a, b, c, n, f, argExpr -> x, argExpr -> y] pattern IForeachIndex :: - (HasStandardDataExpr expr) => + (HasStandardData expr) => GenericArg expr -> expr -> expr -> @@ -539,12 +469,3 @@ pattern IForeachIndex :: pattern IForeachIndex t n fn <- IStandardLib StdForeachIndex [t, argExpr -> n, argExpr -> fn] where IForeachIndex t n fn = IStandardLib StdForeachIndex [t, Arg mempty Explicit Relevant n, Arg mempty Explicit Relevant fn] - --------------------------------------------------------------------------------- --- WHNFValue Function patterns - --- TODO this should really be removed. -pattern VBuiltinFunction :: (HasStandardData builtin) => BuiltinFunction -> Spine closure builtin -> Value closure builtin -pattern VBuiltinFunction f args <- VBuiltin (getBuiltinFunction -> Just f) args - where - VBuiltinFunction f args = VBuiltin (mkBuiltinFunction f) args diff --git a/vehicle/src/Vehicle/Data/Expr/Standard.hs b/vehicle/src/Vehicle/Data/Expr/Standard.hs index 1a08ec21e..f4dc40974 100644 --- a/vehicle/src/Vehicle/Data/Expr/Standard.hs +++ b/vehicle/src/Vehicle/Data/Expr/Standard.hs @@ -1,15 +1,19 @@ +{-# OPTIONS_GHC -Wno-orphans #-} + module Vehicle.Data.Expr.Standard where import Control.Monad.Identity (Identity (..)) import Control.Monad.Writer (MonadWriter (..), execWriter) +import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NonEmpty import Data.Set (Set) import Data.Set qualified as Set import Prettyprinter (Doc) import Vehicle.Data.Builtin.Interface import Vehicle.Data.DeBruijn (Ix (..), substDBInto) +import Vehicle.Data.Expr.Interface import Vehicle.Prelude (layoutAsText) -import Vehicle.Syntax.AST as X +import Vehicle.Syntax.AST (Arg, Binder, Expr (..), Identifier, Provenance, argExpr, getBuiltinApp, getFreeVarApp, normAppList, stdlibIdentifier, pattern BuiltinExpr) ----------------------------------------------------------------------------- -- Traversing builtins @@ -136,3 +140,81 @@ convertExprBuiltins :: Expr var builtin2 convertExprBuiltins = mapBuiltins $ \p b args -> normAppList (convertBuiltin p b) args + +----------------------------------------------------------------------------- +-- Instances + +instance (BuiltinHasStandardTypes builtin) => HasStandardTypes (Expr var builtin) where + mkType p b = normAppList (Builtin p (mkBuiltinType b)) + getType e = case getBuiltinApp e of + Just (p, b, args) -> case getBuiltinType b of + Just t -> Just (p, t, args) + Nothing -> Nothing + _ -> Nothing + +instance (BuiltinHasStandardData builtin) => HasStandardData (Expr var builtin) where + mkFunction p b = normAppList (Builtin p (mkBuiltinFunction b)) + getFunction e = case getBuiltinApp e of + Just (p, b, args) -> case getBuiltinFunction b of + Just f -> Just (p, f, args) + Nothing -> Nothing + _ -> Nothing + + mkConstructor p b = normAppList (Builtin p (mkBuiltinConstructor b)) + getConstructor e = case getBuiltinApp e of + Just (p, b, args) -> case getBuiltinConstructor b of + Just f -> Just (p, f, args) + Nothing -> Nothing + _ -> Nothing + + mkFreeVar p ident = normAppList (FreeVar p ident) + getFreeVar e = case getFreeVarApp e of + Just (p, ident, args) -> Just (p, ident, args) + _ -> Nothing + + getTypeClassOp e = case getBuiltinApp e of + Just (p, b, args) -> case getBuiltinTypeClassOp b of + Just f -> Just (p, f, args) + Nothing -> Nothing + _ -> Nothing + +instance (BuiltinHasBoolLiterals builtin) => HasBoolLits (Expr var builtin) where + getBoolLit e = case e of + Builtin _ (getBoolBuiltinLit -> Just b) -> Just (mempty, b) + _ -> Nothing + mkBoolLit p x = Builtin p (mkBoolBuiltinLit x) + +instance (BuiltinHasIndexLiterals builtin) => HasIndexLits (Expr var builtin) where + getIndexLit e = case e of + Builtin _ (getIndexBuiltinLit -> Just i) -> Just (mempty, i) + _ -> Nothing + mkIndexLit p i = Builtin p (mkIndexBuiltinLit i) + +instance (BuiltinHasNatLiterals builtin) => HasNatLits (Expr var builtin) where + getNatLit e = case e of + Builtin _ (getNatBuiltinLit -> Just n) -> Just (mempty, n) + _ -> Nothing + mkNatLit p n = Builtin p (mkNatBuiltinLit n) + +instance (BuiltinHasRatLiterals builtin) => HasRatLits (Expr var builtin) where + getRatLit e = case e of + Builtin _ (getRatBuiltinLit -> Just r) -> Just (mempty, r) + _ -> Nothing + mkRatLit p r = Builtin p (mkRatBuiltinLit r) + +instance (BuiltinHasVecLiterals builtin) => HasStandardVecLits (Expr var builtin) where + getHomoVector = \case + BuiltinExpr _ (getVecBuiltinLit -> Just {}) (t :| xs) -> Just (t, xs) + _ -> Nothing + mkHomoVector t xs = BuiltinExpr mempty (mkVecBuiltinLit (length xs)) (t :| xs) + +instance (BuiltinHasListLiterals builtin) => HasStandardListLits (Expr var builtin) where + getNil = \case + BuiltinExpr p (isBuiltinNil -> True) [t] -> Just (p, t) + _ -> Nothing + mkNil t = BuiltinExpr mempty mkBuiltinNil [t] + + getCons = \case + BuiltinExpr p (isBuiltinCons -> True) [t, x, xs] -> Just (p, t, x, xs) + _ -> Nothing + mkCons t x xs = BuiltinExpr mempty mkBuiltinCons [t, x, xs] diff --git a/vehicle/src/Vehicle/Data/Expr/Normalised.hs b/vehicle/src/Vehicle/Data/Expr/Value.hs similarity index 63% rename from vehicle/src/Vehicle/Data/Expr/Normalised.hs rename to vehicle/src/Vehicle/Data/Expr/Value.hs index 906b6ed8a..0d5cb9ebd 100644 --- a/vehicle/src/Vehicle/Data/Expr/Normalised.hs +++ b/vehicle/src/Vehicle/Data/Expr/Value.hs @@ -1,11 +1,14 @@ -module Vehicle.Data.Expr.Normalised where +module Vehicle.Data.Expr.Value where import Control.Monad (void) import Data.Map (Map) import Data.Maybe (fromMaybe) import GHC.Generics import Vehicle.Compile.Context.Bound.Core +import Vehicle.Data.Builtin.Interface +import Vehicle.Data.Builtin.Standard.Core (BuiltinFunction) import Vehicle.Data.DeBruijn +import Vehicle.Data.Expr.Interface import Vehicle.Syntax.AST ----------------------------------------------------------------------------- @@ -190,3 +193,90 @@ traverseUnnormalised :: GluedExpr builtin -> m (GluedExpr builtin) traverseUnnormalised f (Glued u n) = Glued <$> f u <*> pure n + +----------------------------------------------------------------------------- +-- Instances + +instance (BuiltinHasStandardTypes builtin) => HasStandardTypes (Value closure builtin) where + mkType _p b = VBuiltin (mkBuiltinType b) + getType e = case e of + VBuiltin b args -> case getBuiltinType b of + Just t -> Just (mempty, t, args) + Nothing -> Nothing + _ -> Nothing + +instance (BuiltinHasStandardData builtin) => HasStandardData (Value closure builtin) where + mkFunction _p b = VBuiltin (mkBuiltinFunction b) + getFunction e = case e of + VBuiltin b args -> case getBuiltinFunction b of + Just t -> Just (mempty, t, args) + Nothing -> Nothing + _ -> Nothing + + mkConstructor _p b = VBuiltin (mkBuiltinConstructor b) + getConstructor e = case e of + VBuiltin b args -> case getBuiltinConstructor b of + Just t -> Just (mempty, t, args) + Nothing -> Nothing + _ -> Nothing + + mkFreeVar _p = VFreeVar + getFreeVar = \case + VFreeVar ident args -> Just (mempty, ident, args) + _ -> Nothing + + getTypeClassOp e = case e of + VBuiltin b args -> case getBuiltinTypeClassOp b of + Just op -> Just (mempty, op, args) + Nothing -> Nothing + _ -> Nothing + +instance (BuiltinHasBoolLiterals builtin) => HasBoolLits (Value closure builtin) where + getBoolLit = \case + VBuiltin (getBoolBuiltinLit -> Just b) [] -> Just (mempty, b) + _ -> Nothing + mkBoolLit _p b = VBuiltin (mkBoolBuiltinLit b) [] + +instance (BuiltinHasIndexLiterals builtin) => HasIndexLits (Value closure builtin) where + getIndexLit e = case e of + VBuiltin (getIndexBuiltinLit -> Just n) [] -> Just (mempty, n) + _ -> Nothing + mkIndexLit _p x = VBuiltin (mkIndexBuiltinLit x) mempty + +instance (BuiltinHasNatLiterals builtin) => HasNatLits (Value closure builtin) where + getNatLit e = case e of + VBuiltin (getNatBuiltinLit -> Just b) [] -> Just (mempty, b) + _ -> Nothing + mkNatLit _p x = VBuiltin (mkNatBuiltinLit x) mempty + +instance (BuiltinHasRatLiterals builtin) => HasRatLits (Value closure builtin) where + getRatLit e = case e of + VBuiltin (getRatBuiltinLit -> Just b) [] -> Just (mempty, b) + _ -> Nothing + mkRatLit _p x = VBuiltin (mkRatBuiltinLit x) mempty + +instance (BuiltinHasVecLiterals builtin) => HasStandardVecLits (Value closure builtin) where + getHomoVector = \case + VBuiltin (getVecBuiltinLit -> Just {}) (t : xs) -> Just (t, xs) + _ -> Nothing + mkHomoVector t xs = VBuiltin (mkVecBuiltinLit (length xs)) (t : xs) + +instance (BuiltinHasListLiterals builtin) => HasStandardListLits (Value closure builtin) where + getNil = \case + VBuiltin (isBuiltinNil -> True) [t] -> Just (mempty, t) + _ -> Nothing + mkNil t = VBuiltin mkBuiltinNil [t] + + getCons = \case + VBuiltin (isBuiltinCons -> True) [t, x, xs] -> Just (mempty, t, x, xs) + _ -> Nothing + mkCons t x xs = VBuiltin mkBuiltinCons [t, x, xs] + +-------------------------------------------------------------------------------- +-- WHNFValue Function patterns + +-- TODO this should really be removed. +pattern VBuiltinFunction :: (BuiltinHasStandardData builtin) => BuiltinFunction -> Spine closure builtin -> Value closure builtin +pattern VBuiltinFunction f args <- VBuiltin (getBuiltinFunction -> Just f) args + where + VBuiltinFunction f args = VBuiltin (mkBuiltinFunction f) args diff --git a/vehicle/src/Vehicle/Data/Hashing.hs b/vehicle/src/Vehicle/Data/Hashing.hs index a81a683ac..ab4e2eb70 100644 --- a/vehicle/src/Vehicle/Data/Hashing.hs +++ b/vehicle/src/Vehicle/Data/Hashing.hs @@ -5,11 +5,11 @@ module Vehicle.Data.Hashing () where import Data.Hashable (Hashable (..)) -- import GHC.Generics (Generic) --- import Vehicle.Data.Expr.Normalised +-- import Vehicle.Data.Expr.Value import GHC.Generics (Generic) import Vehicle.Data.DeBruijn -import Vehicle.Data.Expr.Normalised (Value, WHNFClosure) +import Vehicle.Data.Expr.Value (Value, WHNFClosure) import Vehicle.Syntax.AST -- We used to have full blown alpha-equivalence based on co-deBruijn indices diff --git a/vehicle/src/Vehicle/Data/QuantifiedVariable.hs b/vehicle/src/Vehicle/Data/QuantifiedVariable.hs index ee0fd0534..deb664d3f 100644 --- a/vehicle/src/Vehicle/Data/QuantifiedVariable.hs +++ b/vehicle/src/Vehicle/Data/QuantifiedVariable.hs @@ -11,7 +11,7 @@ import Numeric (showFFloat) import Prettyprinter (brackets) import Vehicle.Data.DeBruijn import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.Tensor (RationalTensor) import Vehicle.Prelude import Vehicle.Syntax.AST diff --git a/vehicle/src/Vehicle/Prelude/Warning.hs b/vehicle/src/Vehicle/Prelude/Warning.hs index 6cc9e0107..326ff6263 100644 --- a/vehicle/src/Vehicle/Prelude/Warning.hs +++ b/vehicle/src/Vehicle/Prelude/Warning.hs @@ -14,7 +14,7 @@ import Data.Set qualified as Set (singleton) import Vehicle.Compile.Context.Bound.Core import Vehicle.Data.Builtin.Loss.Core (LossBuiltin) import Vehicle.Data.Builtin.Tensor -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Data.QuantifiedVariable import Vehicle.Libraries.StandardLibrary.Definitions import Vehicle.Resource (ExternalResource) diff --git a/vehicle/tests/unit/Vehicle/Test/Unit/Compile/Normalisation.hs b/vehicle/tests/unit/Vehicle/Test/Unit/Compile/Normalisation.hs index 56dc1ab23..9b053aad4 100644 --- a/vehicle/tests/unit/Vehicle/Test/Unit/Compile/Normalisation.hs +++ b/vehicle/tests/unit/Vehicle/Test/Unit/Compile/Normalisation.hs @@ -13,7 +13,7 @@ import Vehicle.Compile.Prelude import Vehicle.Compile.Print (prettyVerbose) import Vehicle.Data.Builtin.Standard import Vehicle.Data.Expr.Interface -import Vehicle.Data.Expr.Normalised +import Vehicle.Data.Expr.Value import Vehicle.Test.Unit.Common (unitTestCase) normalisationTests :: TestTree diff --git a/vehicle/vehicle.cabal b/vehicle/vehicle.cabal index 171e8a581..5439c2670 100644 --- a/vehicle/vehicle.cabal +++ b/vehicle/vehicle.cabal @@ -154,8 +154,8 @@ library Vehicle.Data.Expr.DSL Vehicle.Data.Expr.Interface Vehicle.Data.Expr.Linear - Vehicle.Data.Expr.Normalised Vehicle.Data.Expr.Standard + Vehicle.Data.Expr.Value Vehicle.Data.Hashing Vehicle.Data.Tensor Vehicle.Export @@ -180,6 +180,7 @@ library Vehicle.Backend.Agda.Interact Vehicle.Backend.LossFunction Vehicle.Backend.LossFunction.Core + Vehicle.Backend.LossFunction.Domain Vehicle.Backend.LossFunction.Logics Vehicle.Backend.Queries Vehicle.Backend.Queries.ConstraintSearch