diff --git a/vehicle-python/pyproject.toml b/vehicle-python/pyproject.toml index 80de92829..82a6fe7cf 100644 --- a/vehicle-python/pyproject.toml +++ b/vehicle-python/pyproject.toml @@ -15,7 +15,7 @@ readme = 'README.md' license = { file = 'LICENSE' } dynamic = ["version"] requires-python = ">=3.9,<3.14" -dependencies = ["typing_extensions >=4.6,<5"] +dependencies = ["typing_extensions >=4.6,<5", "jaxtyping>=0.2, <0.3"] [project.optional-dependencies] test = ["pytest >=7.1,<9", "packaging >=23", "pygments >=2.14, <3"] diff --git a/vehicle-python/src/vehicle_lang/ast/__init__.py b/vehicle-python/src/vehicle_lang/ast/__init__.py index 71f5e44a9..b9063b0c1 100644 --- a/vehicle-python/src/vehicle_lang/ast/__init__.py +++ b/vehicle-python/src/vehicle_lang/ast/__init__.py @@ -386,21 +386,6 @@ def __init__(self) -> None: raise TypeError("Cannot instantiate abstract class Expression") -@dataclass(frozen=True) -class App(Expression): - provenance: Provenance = field(repr=False) - function: Expression - arguments: Sequence[Expression] - - -@dataclass(frozen=True) -class PartialApp(Expression): - provenance: Provenance = field(repr=False) - arity: int - function: Expression - arguments: Sequence[Expression] - - @dataclass(frozen=True) class Binder(AST): provenance: Provenance = field(repr=False) @@ -409,49 +394,44 @@ class Binder(AST): @dataclass(frozen=True) -class BoundVar(Expression): - provenance: Provenance = field(repr=False) - name: Name - - -@dataclass(frozen=True) -class Builtin(Expression): +class Pi(Expression): provenance: Provenance = field(repr=False) - builtin: Union[BuiltinConstant, BuiltinFunction, BuiltinLiteral, BuiltinType] + binder: Binder + body: Expression @dataclass(frozen=True) -class FreeVar(Expression): +class Lam(Expression): provenance: Provenance = field(repr=False) - name: Name + binder: Binder + body: Expression @dataclass(frozen=True) -class Lam(Expression): +class App(Expression): provenance: Provenance = field(repr=False) - binders: Sequence[Binder] - body: Expression + function: Expression + arguments: Sequence[Expression] @dataclass(frozen=True) -class Let(Expression): +class PartialApp(Expression): provenance: Provenance = field(repr=False) - bound: Expression - binder: Binder - body: Expression + arity: int + function: Expression + arguments: Sequence[Expression] @dataclass(frozen=True) -class Pi(Expression): +class Var(Expression): provenance: Provenance = field(repr=False) - binder: Binder - body: Expression + name: Name @dataclass(frozen=True) -class Universe(Expression): +class Builtin(Expression): provenance: Provenance = field(repr=False) - level: UniverseLevel + builtin: Union[BuiltinConstant, BuiltinFunction, BuiltinLiteral, BuiltinType] ################################################################################ diff --git a/vehicle-python/src/vehicle_lang/compile/abc/translation.py b/vehicle-python/src/vehicle_lang/compile/abc/translation.py index 019f9f2af..4d9f67ca2 100644 --- a/vehicle-python/src/vehicle_lang/compile/abc/translation.py +++ b/vehicle-python/src/vehicle_lang/compile/abc/translation.py @@ -66,54 +66,30 @@ def translate_expression( ) -> vcl_var.Expression: if isinstance(expression, vcl_ast.App): return self.translate_App(expression) - if isinstance(expression, vcl_ast.BoundVar): - return self.translate_BoundVar(expression) + if isinstance(expression, vcl_ast.Var): + return self.translate_Var(expression) if isinstance(expression, vcl_ast.Builtin): return self.translate_Builtin(expression) - if isinstance(expression, vcl_ast.FreeVar): - return self.translate_FreeVar(expression) if isinstance(expression, vcl_ast.Lam): return self.translate_Lam(expression) - if isinstance(expression, vcl_ast.Let): - return self.translate_Let(expression) if isinstance(expression, vcl_ast.PartialApp): return self.translate_PartialApp(expression) if isinstance(expression, vcl_ast.Pi): return self.translate_Pi(expression) - if isinstance(expression, vcl_ast.Universe): - return self.translate_Universe(expression) raise NotImplementedError(type(expression).__name__) @abstractmethod def translate_App(self, expression: vcl_ast.App) -> vcl_var.Expression: ... @abstractmethod - def translate_BoundVar( - self, expression: vcl_ast.BoundVar - ) -> vcl_var.Expression: ... + def translate_Var(self, expression: vcl_ast.Var) -> vcl_var.Expression: ... @abstractmethod def translate_Builtin(self, expression: vcl_ast.Builtin) -> vcl_var.Expression: ... - @abstractmethod - def translate_FreeVar(self, expression: vcl_ast.FreeVar) -> vcl_var.Expression: ... - @abstractmethod def translate_Lam(self, expression: vcl_ast.Lam) -> vcl_var.Expression: ... - def translate_Let(self, expression: vcl_ast.Let) -> vcl_var.Expression: - return self.translate_expression( - vcl_ast.App( - provenance=expression.provenance, - function=vcl_ast.Lam( - provenance=expression.provenance, - binders=(expression.binder,), - body=expression.body, - ), - arguments=[expression.bound], - ) - ) - @abstractmethod def translate_PartialApp( self, expression: vcl_ast.PartialApp @@ -121,8 +97,3 @@ def translate_PartialApp( @abstractmethod def translate_Pi(self, expression: vcl_ast.Pi) -> vcl_var.Expression: ... - - @abstractmethod - def translate_Universe( - self, expression: vcl_ast.Universe - ) -> vcl_var.Expression: ... diff --git a/vehicle-python/src/vehicle_lang/compile/python/__init__.py b/vehicle-python/src/vehicle_lang/compile/python/__init__.py index 67ab38197..b3b10cc99 100644 --- a/vehicle-python/src/vehicle_lang/compile/python/__init__.py +++ b/vehicle-python/src/vehicle_lang/compile/python/__init__.py @@ -94,18 +94,19 @@ def translate_declarations( self.ignored_types.append(name) def translate_DefFunction(self, declaration: vcl.DefFunction) -> py.stmt: - if isinstance(declaration.body, vcl.Lam): + body = declaration.body + binders = [] + while isinstance(body, vcl.Lam): + binders.append(self.translate_binder(body.binder)) + body = body.body + + if binders: return py.FunctionDef( name=declaration.name, - args=py_binder( - *( - self.translate_binder(binder) - for binder in declaration.body.binders - ) - ), + args=py_binder(*binders), body=[ py.Return( - value=self.translate_expression(declaration.body.body), + value=self.translate_expression(body), **asdict(declaration.provenance), ) ], @@ -184,11 +185,9 @@ def translate_App(self, expression: vcl.App) -> py.expr: meetOrJoin, loss = expression.arguments if not isinstance(loss, vcl.Lam): raise VehicleOptimiseTypeError(expression) - if len(loss.binders) != 1: - raise VehicleOptimiseTypeError(expression) # NOTE: We extract the name of the bound variable from the lambda, # which should be the _second_ argument. - name = loss.binders[0].name + name = loss.binder.name return py_app( py_builtin( builtin=expression.function.builtin.__class__.__name__, @@ -227,7 +226,7 @@ def translate_App(self, expression: vcl.App) -> py.expr: provenance=expression.provenance, ) - def translate_BoundVar(self, expression: vcl.BoundVar) -> py.expr: + def translate_Var(self, expression: vcl.Var) -> py.expr: return py_name(expression.name, provenance=expression.provenance) def translate_Builtin(self, expression: vcl.Builtin) -> py.expr: @@ -292,16 +291,9 @@ def translate_Builtin(self, expression: vcl.Builtin) -> py.expr: provenance=expression.provenance, ) - def translate_FreeVar(self, expression: vcl.FreeVar) -> py.expr: - # NOTE: We ignore any declaration where translation touches a type. - if expression.name in self.ignored_types: - raise EraseType() - else: - return py_name(expression.name, provenance=expression.provenance) - def translate_Lam(self, expression: vcl.Lam) -> py.expr: return py.Lambda( - args=py_binder(*(map(self.translate_binder, expression.binders))), + args=py_binder(self.translate_binder(expression.binder)), body=self.translate_expression(expression.body), **asdict(expression.provenance), ) @@ -316,9 +308,6 @@ def translate_PartialApp(self, expression: vcl.PartialApp) -> py.expr: provenance=expression.provenance, ) - def translate_Universe(self, _expression: vcl.Universe) -> py.expr: - raise EraseType() - def py_name(name: vcl.Name, *, provenance: vcl.Provenance) -> py.Name: """Make a name.""" diff --git a/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs b/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs index 50d4ac8e5..644e786f7 100644 --- a/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs +++ b/vehicle/src/Vehicle/Backend/LossFunction/JSON.hs @@ -16,7 +16,8 @@ import Vehicle.Compile.Arity import Vehicle.Compile.Context.Name import Vehicle.Compile.Error import Vehicle.Compile.Normalise.NBE (eval) -import Vehicle.Compile.Prelude (Decl, Doc, Expr (..), GenericArg (..), Ix (..), ModulePath (..), Name, Position, Prog, Provenance (..), Range (..), filterOutNonExplicitArgs, getBinderName, mkExplicitBinder, normAppList) +import Vehicle.Compile.Prelude (Doc, HasProvenance (..), Ix (..), ModulePath (..), Name, Position, Provenance (..), Range (..), filterOutNonExplicitArgs, getBinderName, mkExplicitBinder, normAppList) +import Vehicle.Compile.Prelude qualified as S (Binder, Decl, Expr (..), GenericDecl (..), GenericProg (..), Prog) import Vehicle.Compile.Print import Vehicle.Compile.Type.Irrelevance (removeIrrelevantCodeFromProg) import Vehicle.Data.Builtin.Loss (DimensionDataBuiltin, DimensionTypeBuiltin, LossTensorBuiltin (..), RatTensorBuiltin) @@ -25,45 +26,49 @@ import Vehicle.Data.Builtin.Tensor () import Vehicle.Data.Code.Interface import Vehicle.Data.Code.Value import Vehicle.Data.Tensor (Tensor, mapTensor) -import Vehicle.Prelude (Annotation (..), GenericDecl (..), GenericProg (..), HasName (..), HasType (..), Identifier (..), Position (..), explicit, indent, jsonOptions, line, squotes) +import Vehicle.Prelude (Annotation (..), GenericArg (..), HasName (..), HasType (..), Identifier (..), Position (..), explicit, indent, jsonOptions, line, squotes) import Vehicle.Prelude.Logging.Class import Vehicle.Syntax.Prelude (developerError) -------------------------------------------------------------------------------- -- Public method -convertToJSONProg :: (MonadCompile m) => Prog LossTensorBuiltin -> m JProg +convertToJSONProg :: (MonadCompile m) => S.Prog LossTensorBuiltin -> m JProg convertToJSONProg prog = logCompilerPass MinDetail currentPass $ do relevantProg <- removeIrrelevantCodeFromProg prog runFreshNameContextT $ convertProg relevantProg -convertFromJSONProg :: JProg -> Prog LossTensorBuiltin +convertFromJSONProg :: JProg -> S.Prog LossTensorBuiltin convertFromJSONProg = fromJProg -------------------------------------------------------------------------------- -- The AST exported to JSON newtype JProg - = JProg [JDecl] + = Main [JDecl] deriving (Generic) data JDecl - = JDecl Name JExpr JExpr + = DefFunction Provenance Name JExpr JExpr deriving (Generic) +data JBinder + = Binder Provenance Name JExpr + deriving (Show, Generic) + data JExpr = -- Types - RatType + Pi JExpr JExpr + | Lam JBinder JExpr + | Var Name [JExpr] + | RatType | TensorType JExpr | DimensionType | DimensionsType | DimensionIndexType - | Fun JExpr JExpr | -- Rational tensors - Lambda Name JExpr JExpr - | Var Name [JExpr] - | RatTensor (Tensor Rat) + RatTensor (Tensor Rat) | RatLiteral Rat | NegRatTensor JExpr | AddRatTensor JExpr JExpr @@ -122,6 +127,9 @@ instance ToJSON JDecl where instance ToJSON JExpr where toJSON = genericToJSON jsonOptions +instance ToJSON JBinder where + toJSON = genericToJSON jsonOptions + instance ToJSON Position where toJSON = genericToJSON jsonOptions @@ -143,18 +151,18 @@ type MonadJSON m = MonadNameContext m ) -convertProg :: (MonadJSON m) => Prog LossTensorBuiltin -> m JProg -convertProg (Main decls) = JProg <$> traverse convertDecl decls +convertProg :: (MonadJSON m) => S.Prog LossTensorBuiltin -> m JProg +convertProg (S.Main decls) = Main <$> traverse convertDecl decls -convertDecl :: (MonadJSON m) => Decl LossTensorBuiltin -> m JDecl +convertDecl :: (MonadJSON m) => S.Decl LossTensorBuiltin -> m JDecl convertDecl = \case - DefAbstract {} -> compilerDeveloperError "Found abstract definition when converting to JSON" - DefFunction _ ident _ typ body -> do + S.DefAbstract {} -> compilerDeveloperError "Found abstract definition when converting to JSON" + S.DefFunction p ident _ typ body -> do typ' <- convertExpr mempty typ expr' <- convertExpr mempty body - return $ JDecl (nameOf ident) typ' expr' + return $ DefFunction p (nameOf ident) typ' expr' -convertExpr :: (MonadJSON m) => BoundEnv LossTensorBuiltin -> Expr LossTensorBuiltin -> m JExpr +convertExpr :: (MonadJSON m) => BoundEnv LossTensorBuiltin -> S.Expr LossTensorBuiltin -> m JExpr convertExpr env body = convertValue =<< eval mempty env body convertValue :: (MonadJSON m) => Value LossTensorBuiltin -> m JExpr @@ -164,17 +172,14 @@ convertValue expr = do VMeta {} -> resolutionError currentPass "VMeta" VFreeVar {} -> resolutionError currentPass "VFreeVar" VUniverse {} -> resolutionError currentPass "Universe" - VLam binder (Closure env body) -> do - let name = getBinderName binder - typ' <- convertValue (typeOf binder) - lv <- getBinderDepth - let newEnv = extendEnvWithBound lv binder env - body' <- addNameToContext binder $ convertExpr newEnv body - return $ Lambda name typ' body' + VLam binder closure -> do + binder' <- convertBinder binder + body' <- convertClosure binder closure + return $ Lam binder' body' VPi binder body -> do typ' <- convertValue (typeOf binder) body' <- addNameToContext binder $ convertValue body - return $ Fun typ' body' + return $ Pi typ' body' VBuiltin b spine -> convertBuiltin b $ filterOutNonExplicitArgs spine VBoundVar v spine -> do name <- lvToProperName mempty v @@ -183,6 +188,19 @@ convertValue expr = do showExit result return result +convertBinder :: (MonadJSON m) => VBinder LossTensorBuiltin -> m JBinder +convertBinder binder = do + let p = provenanceOf binder + let name = getBinderName binder + typ' <- convertValue (typeOf binder) + return $ Binder p name typ' + +convertClosure :: (MonadJSON m) => VBinder LossTensorBuiltin -> Closure LossTensorBuiltin -> m JExpr +convertClosure binder (Closure env body) = do + lv <- getBinderDepth + let newEnv = extendEnvWithBound lv binder env + addNameToContext binder $ convertExpr newEnv body + convertBuiltin :: (MonadJSON m) => LossTensorBuiltin -> [Value LossTensorBuiltin] -> m JExpr convertBuiltin b spine = case b of LossTensorRat op -> case op of @@ -284,35 +302,34 @@ showExit _e = do -------------------------------------------------------------------------------- -- Conversion back (for printing purposes) -fromJProg :: JProg -> Prog LossTensorBuiltin +fromJProg :: JProg -> S.Prog LossTensorBuiltin fromJProg = \case - JProg decls -> Main (fmap fromJDecl decls) + Main decls -> S.Main (fmap fromJDecl decls) -fromJDecl :: JDecl -> Decl LossTensorBuiltin +fromJDecl :: JDecl -> S.Decl LossTensorBuiltin fromJDecl = \case - JDecl name typ body -> + DefFunction p name typ body -> runFreshNameContext $ do typ' <- fromJExpr typ body' <- fromJExpr body let ident = Identifier (ModulePath []) name - return $ DefFunction mempty ident [AnnProperty] typ' body' + return $ S.DefFunction p ident [AnnProperty] typ' body' -fromJExpr :: (MonadNameContext m) => JExpr -> m (Expr LossTensorBuiltin) +fromJExpr :: (MonadNameContext m) => JExpr -> m (S.Expr LossTensorBuiltin) fromJExpr = \case - Lambda name typ body -> do - typ' <- fromJExpr typ - let binder' = mkExplicitBinder typ' (Just name) + Lam binder body -> do + binder' <- fromJBinder binder body' <- addNameToContext binder' (fromJExpr body) - return $ Lam mempty binder' body' - Fun input output -> do + return $ S.Lam mempty binder' body' + Pi input output -> do input' <- fromJExpr input let binder' = mkExplicitBinder input' Nothing - Pi mempty binder' <$> fromJExpr output + S.Pi mempty binder' <$> fromJExpr output Var name spine -> do nameCtx <- getNameContext let ix = maybe (developerError ("ill-scoped JExpr, no variable" <+> squotes (pretty name))) Ix (elemIndex (Just name) nameCtx) spine' <- traverse fromJExpr spine - return $ normAppList (BoundVar mempty ix) (fmap explicit spine') + return $ normAppList (S.BoundVar mempty ix) (fmap explicit spine') RatType -> fromRatOp L.RatType [] TensorType t -> fromDimType L.TensorType [t] DimensionType -> fromDimType L.DimensionType [] @@ -341,11 +358,16 @@ fromJExpr = \case ConstTensor c ds -> fromDimData L.ConstTensor [c, ds] StackTensor xs -> fromDimData (L.StackTensor (length xs)) xs -fromRatOp :: (MonadNameContext m) => RatTensorBuiltin -> [JExpr] -> m (Expr LossTensorBuiltin) +fromJBinder :: (MonadNameContext m) => JBinder -> m (S.Binder LossTensorBuiltin) +fromJBinder (Binder _ name typ) = do + typ' <- fromJExpr typ + return $ mkExplicitBinder typ' (Just name) + +fromRatOp :: (MonadNameContext m) => RatTensorBuiltin -> [JExpr] -> m (S.Expr LossTensorBuiltin) fromRatOp op xs = IRatTensorOp op . fmap explicit <$> traverse fromJExpr xs -fromDimType :: (MonadNameContext m) => DimensionTypeBuiltin -> [JExpr] -> m (Expr LossTensorBuiltin) +fromDimType :: (MonadNameContext m) => DimensionTypeBuiltin -> [JExpr] -> m (S.Expr LossTensorBuiltin) fromDimType op xs = IDimensionTypeOp op . fmap explicit <$> traverse fromJExpr xs -fromDimData :: (MonadNameContext m) => DimensionDataBuiltin -> [JExpr] -> m (Expr LossTensorBuiltin) +fromDimData :: (MonadNameContext m) => DimensionDataBuiltin -> [JExpr] -> m (S.Expr LossTensorBuiltin) fromDimData op xs = IDimensionDataOp op . fmap explicit <$> traverse fromJExpr xs