@@ -7,8 +7,12 @@ import Data.Set qualified as Set
7
7
import GHC.Generics (Generic )
8
8
import Vehicle.Backend.Prelude (DifferentiableLogicID )
9
9
import Vehicle.Compile.Prelude
10
+ import Vehicle.Data.Builtin.Interface
10
11
import Vehicle.Data.Builtin.Loss
12
+ import Vehicle.Data.Builtin.Standard
13
+ import Vehicle.Data.Code.Interface
11
14
import Vehicle.Data.Code.Value (BoundEnv , Closure (.. ), VBinder , Value (.. ))
15
+ import Vehicle.Data.Tensor
12
16
import Vehicle.Libraries.StandardLibrary.Definitions (StdLibFunction (.. ))
13
17
14
18
--------------------------------------------------------------------------------
@@ -62,49 +66,49 @@ type CompiledDifferentiableLogic = (DifferentiableLogicID, DifferentiableLogicIm
62
66
63
67
--------------------------------------------------------------------------------
64
68
-- Views on Boolean Tensors
65
- {-
69
+
66
70
-- | A view on all possible expressions that can have type `Tensor Bool`.
67
71
data BoolTensorView expr
68
72
= VConstBoolTensor expr expr
69
73
| VBoolTensor (Tensor Bool )
70
- | VStackBoolTensor Int [expr]
71
- | VAndTensor expr expr
72
- | VOrTensor expr expr
73
- | VNotTensor expr
74
- | VOrderRatTensor OrderOp expr expr
75
- | VEqualsRatTensor EqualityOp expr expr
76
- | VQuantifyRatTensor Quantifier expr
77
- | VReduceAndTensor expr
78
- | VReduceOrTensor expr
74
+ | VStackBoolTensor Int ( GenericArg expr ) [expr ]
75
+ | VAndTensor ( GenericArg expr ) expr expr
76
+ | VOrTensor ( GenericArg expr ) expr expr
77
+ | VNotTensor ( GenericArg expr ) expr
78
+ | VOrderRatTensor OrderOp ( GenericArg expr ) expr expr
79
+ | VEqualsRatTensor EqualityOp ( GenericArg expr ) expr expr
80
+ | VQuantifyRatTensor Quantifier ( GenericArg expr ) expr
81
+ | VReduceAndTensor ( GenericArg expr ) expr
82
+ | VReduceOrTensor ( GenericArg expr ) expr
79
83
80
84
fromBoolTensorView :: (HasBoolTensors expr , HasDimensionData expr ) => BoolTensorView expr -> expr
81
85
fromBoolTensorView = \ case
82
86
VBoolTensor y -> INullaryBoolTensorOp (BoolTensor y)
83
- VAndTensor x y -> IBoolTensorOp AndBoolTensor ( explicit <$> [ x, y])
84
- VOrTensor x y -> IBoolTensorOp OrBoolTensor ( explicit <$> [ x, y])
85
- VNotTensor x -> IBoolTensorOp NotBoolTensor ( explicit <$> [x])
86
- VOrderRatTensor op x y -> IBoolTensorOp (OrderRatTensor op) ( explicit <$> [ x, y])
87
- VEqualsRatTensor op x y -> IBoolTensorOp (EqualsRatTensor op) ( explicit <$> [ x, y])
88
- VQuantifyRatTensor op x -> IBoolTensorOp (QuantifyRatTensor op) ( explicit <$> [x])
89
- VReduceAndTensor x -> IBoolTensorOp ReduceAndTensor ( explicit <$> [x])
90
- VReduceOrTensor x -> IBoolTensorOp ReduceOrTensor ( explicit <$> [x])
91
- VConstBoolTensor x y -> IBoolConstTensor x y
92
- VStackBoolTensor n xs -> IDimensionDataOp (StackTensor n) (explicit <$> xs)
87
+ VAndTensor dims x y -> IBoolTensorOp AndBoolTensor [dims, explicit x, explicit y]
88
+ VOrTensor dims x y -> IBoolTensorOp OrBoolTensor [dims, explicit x, explicit y]
89
+ VNotTensor dims x -> IBoolTensorOp NotBoolTensor [dims, explicit x]
90
+ VOrderRatTensor op dims x y -> IBoolTensorOp (OrderRatTensor op) [dims, explicit x, explicit y]
91
+ VEqualsRatTensor op dims x y -> IBoolTensorOp (EqualsRatTensor op) [dims, explicit x, explicit y]
92
+ VQuantifyRatTensor op dims x -> IBoolTensorOp (QuantifyRatTensor op) [dims, explicit x]
93
+ VReduceAndTensor dims x -> IBoolTensorOp ReduceAndTensor [dims, explicit x]
94
+ VReduceOrTensor dims x -> IBoolTensorOp ReduceOrTensor [dims, explicit x]
95
+ VConstBoolTensor x dims -> IDimensionDataOp ConstTensor [implicit IBoolElementType , explicit x, explicit dims]
96
+ VStackBoolTensor n elemDims xs -> IDimensionDataOp (StackTensor n) (implicit IBoolElementType : elemDims : ( explicit <$> xs) )
93
97
94
98
toBoolTensorView :: (HasDimensionData expr , HasBoolTensors expr ) => expr -> BoolTensorView expr
95
99
toBoolTensorView expr = case getBoolTensorOp expr of
96
100
Just (BoolTensor b, [] ) -> VBoolTensor b
97
- Just (AndBoolTensor, [x, y]) -> VAndTensor (argExpr x) (argExpr y)
98
- Just (OrBoolTensor, [x, y]) -> VOrTensor (argExpr x) (argExpr y)
99
- Just (NotBoolTensor, [x]) -> VNotTensor (argExpr x)
100
- Just (OrderRatTensor op, [x, y]) -> VOrderRatTensor op (argExpr x) (argExpr y)
101
- Just (EqualsRatTensor op, [x, y]) -> VEqualsRatTensor op (argExpr x) (argExpr y)
102
- Just (QuantifyRatTensor op, [x]) -> VQuantifyRatTensor op (argExpr x)
103
- Just (ReduceAndTensor, [x]) -> VReduceAndTensor (argExpr x)
104
- Just (ReduceOrTensor, [x]) -> VReduceOrTensor (argExpr x)
101
+ Just (AndBoolTensor , [dims, x, y]) -> VAndTensor dims (argExpr x) (argExpr y)
102
+ Just (OrBoolTensor , [dims, x, y]) -> VOrTensor dims (argExpr x) (argExpr y)
103
+ Just (NotBoolTensor , [dims, x]) -> VNotTensor dims (argExpr x)
104
+ Just (OrderRatTensor op, [dims, x, y]) -> VOrderRatTensor op dims (argExpr x) (argExpr y)
105
+ Just (EqualsRatTensor op, [dims, x, y]) -> VEqualsRatTensor op dims (argExpr x) (argExpr y)
106
+ Just (QuantifyRatTensor op, [dims, x]) -> VQuantifyRatTensor op dims (argExpr x)
107
+ Just (ReduceAndTensor , [dims, x]) -> VReduceAndTensor dims (argExpr x)
108
+ Just (ReduceOrTensor , [dims, x]) -> VReduceOrTensor dims (argExpr x)
105
109
Nothing -> case getDimensionDataOp expr of
106
- Just (ConstTensor, [x, y ]) -> VConstBoolTensor (argExpr x) (argExpr y )
107
- Just (StackTensor n, args) -> VStackBoolTensor n (fmap argExpr args)
110
+ Just (ConstTensor , [argExpr -> IBoolElementType , x, dims ]) -> VConstBoolTensor (argExpr x) (argExpr dims )
111
+ Just (StackTensor n, (argExpr -> IBoolElementType ) : elemDims : args) -> VStackBoolTensor n elemDims (fmap argExpr args)
108
112
_ -> developerError " ill-typed BoolTensor expression"
109
113
_ -> developerError " ill-typed BoolTensor expression"
110
114
@@ -116,60 +120,60 @@ data RatTensorView expr
116
120
= VConstRatTensor expr expr
117
121
| VRatTensor (Tensor Rational )
118
122
| VRatTensorVar Lv
119
- | VStackRatTensor Int [expr]
120
- | VNegRatTensor expr
121
- | VAddRatTensor expr expr
122
- | VSubRatTensor expr expr
123
- | VMulRatTensor expr expr
124
- | VDivRatTensor expr expr
125
- | VMinRatTensor expr expr
126
- | VMaxRatTensor expr expr
127
- | VReduceAddRatTensor expr
128
- | VReduceMulRatTensor expr
129
- | VReduceMinRatTensor expr
130
- | VReduceMaxRatTensor expr
131
- | VSearchRatTensor expr expr expr expr
123
+ | VStackRatTensor Int ( GenericArg expr ) [expr ]
124
+ | VNegRatTensor ( GenericArg expr ) expr
125
+ | VAddRatTensor ( GenericArg expr ) expr expr
126
+ | VSubRatTensor ( GenericArg expr ) expr expr
127
+ | VMulRatTensor ( GenericArg expr ) expr expr
128
+ | VDivRatTensor ( GenericArg expr ) expr expr
129
+ | VMinRatTensor ( GenericArg expr ) expr expr
130
+ | VMaxRatTensor ( GenericArg expr ) expr expr
131
+ | VReduceAddRatTensor ( GenericArg expr ) expr
132
+ | VReduceMulRatTensor ( GenericArg expr ) expr
133
+ | VReduceMinRatTensor ( GenericArg expr ) expr
134
+ | VReduceMaxRatTensor ( GenericArg expr ) expr
135
+ | VSearchRatTensor ( GenericArg expr ) expr expr expr expr
132
136
133
137
fromRatTensorView :: (BuiltinHasRatTensor builtin , BuiltinHasDimensionData builtin ) => RatTensorView (Value builtin ) -> Value builtin
134
138
fromRatTensorView = \ case
135
139
VRatTensor y -> INullaryRatTensorOp (RatTensor y)
136
- VNegRatTensor x -> IRatTensorOp NegRatTensor ( explicit <$> [x])
137
- VAddRatTensor x y -> IRatTensorOp AddRatTensor ( explicit <$> [ x, y])
138
- VSubRatTensor x y -> IRatTensorOp SubRatTensor ( explicit <$> [ x, y])
139
- VMulRatTensor x y -> IRatTensorOp MulRatTensor ( explicit <$> [ x, y])
140
- VDivRatTensor x y -> IRatTensorOp DivRatTensor ( explicit <$> [ x, y])
141
- VMinRatTensor x y -> IRatTensorOp MinRatTensor ( explicit <$> [ x, y])
142
- VMaxRatTensor x y -> IRatTensorOp MaxRatTensor ( explicit <$> [ x, y])
143
- VReduceAddRatTensor x -> IRatTensorOp ReduceAddRatTensor ( explicit <$> [x])
144
- VReduceMulRatTensor x -> IRatTensorOp ReduceMulRatTensor ( explicit <$> [x])
145
- VReduceMinRatTensor x -> IRatTensorOp ReduceMinRatTensor ( explicit <$> [x])
146
- VReduceMaxRatTensor x -> IRatTensorOp ReduceMaxRatTensor ( explicit <$> [x])
147
- VConstRatTensor x y -> IRatConstTensor [ x, y ]
148
- VStackRatTensor n xs -> IDimensionDataOp (StackTensor n) (explicit <$> xs)
149
- VSearchRatTensor reduce lower upper fn -> IRatTensorOp SearchRatTensor (explicit <$> [reduce, lower, upper, fn])
140
+ VNegRatTensor dims x -> IRatTensorOp NegRatTensor [dims, explicit x]
141
+ VAddRatTensor dims x y -> IRatTensorOp AddRatTensor [dims, explicit x, explicit y]
142
+ VSubRatTensor dims x y -> IRatTensorOp SubRatTensor [dims, explicit x, explicit y]
143
+ VMulRatTensor dims x y -> IRatTensorOp MulRatTensor [dims, explicit x, explicit y]
144
+ VDivRatTensor dims x y -> IRatTensorOp DivRatTensor [dims, explicit x, explicit y]
145
+ VMinRatTensor dims x y -> IRatTensorOp MinRatTensor [dims, explicit x, explicit y]
146
+ VMaxRatTensor dims x y -> IRatTensorOp MaxRatTensor [dims, explicit x, explicit y]
147
+ VReduceAddRatTensor dims x -> IRatTensorOp ReduceAddRatTensor [dims, explicit x]
148
+ VReduceMulRatTensor dims x -> IRatTensorOp ReduceMulRatTensor [dims, explicit x]
149
+ VReduceMinRatTensor dims x -> IRatTensorOp ReduceMinRatTensor [dims, explicit x]
150
+ VReduceMaxRatTensor dims x -> IRatTensorOp ReduceMaxRatTensor [dims, explicit x]
151
+ VConstRatTensor x dims -> IDimensionDataOp ConstTensor [implicit IRatElementType , explicit x, explicit dims ]
152
+ VStackRatTensor n elemDims xs -> IDimensionDataOp (StackTensor n) (implicit IRatElementType : elemDims : ( explicit <$> xs) )
153
+ VSearchRatTensor dims reduce lower upper fn -> IRatTensorOp SearchRatTensor (dims : ( explicit <$> [reduce, lower, upper, fn]) )
150
154
VRatTensorVar v -> VBoundVar v []
151
155
152
156
toRatTensorView :: (BuiltinHasRatTensor builtin , BuiltinHasDimensionData builtin ) => Value builtin -> RatTensorView (Value builtin )
153
157
toRatTensorView expr = case getRatTensorOp expr of
154
158
Just (RatTensor b, [] ) -> VRatTensor b
155
- Just (NegRatTensor, [x]) -> VNegRatTensor (argExpr x)
156
- Just (AddRatTensor, [x, y]) -> VAddRatTensor (argExpr x) (argExpr y)
157
- Just (SubRatTensor, [x, y]) -> VSubRatTensor (argExpr x) (argExpr y)
158
- Just (MulRatTensor, [x, y]) -> VMulRatTensor (argExpr x) (argExpr y)
159
- Just (DivRatTensor, [x, y]) -> VDivRatTensor (argExpr x) (argExpr y)
160
- Just (MinRatTensor, [x, y]) -> VMinRatTensor (argExpr x) (argExpr y)
161
- Just (MaxRatTensor, [x, y]) -> VMaxRatTensor (argExpr x) (argExpr y)
162
- Just (ReduceAddRatTensor, [x]) -> VReduceAddRatTensor (argExpr x)
163
- Just (ReduceMulRatTensor, [x]) -> VReduceMulRatTensor (argExpr x)
164
- Just (ReduceMinRatTensor, [x]) -> VReduceMinRatTensor (argExpr x)
165
- Just (ReduceMaxRatTensor, [x]) -> VReduceMaxRatTensor (argExpr x)
166
- Just (SearchRatTensor, [reduce, lower, upper, fn]) -> VSearchRatTensor (argExpr reduce) (argExpr lower) (argExpr upper) (argExpr fn)
159
+ Just (NegRatTensor , [dims, x]) -> VNegRatTensor dims (argExpr x)
160
+ Just (AddRatTensor , [dims, x, y]) -> VAddRatTensor dims (argExpr x) (argExpr y)
161
+ Just (SubRatTensor , [dims, x, y]) -> VSubRatTensor dims (argExpr x) (argExpr y)
162
+ Just (MulRatTensor , [dims, x, y]) -> VMulRatTensor dims (argExpr x) (argExpr y)
163
+ Just (DivRatTensor , [dims, x, y]) -> VDivRatTensor dims (argExpr x) (argExpr y)
164
+ Just (MinRatTensor , [dims, x, y]) -> VMinRatTensor dims (argExpr x) (argExpr y)
165
+ Just (MaxRatTensor , [dims, x, y]) -> VMaxRatTensor dims (argExpr x) (argExpr y)
166
+ Just (ReduceAddRatTensor , [dims, x]) -> VReduceAddRatTensor dims (argExpr x)
167
+ Just (ReduceMulRatTensor , [dims, x]) -> VReduceMulRatTensor dims (argExpr x)
168
+ Just (ReduceMinRatTensor , [dims, x]) -> VReduceMinRatTensor dims (argExpr x)
169
+ Just (ReduceMaxRatTensor , [dims, x]) -> VReduceMaxRatTensor dims (argExpr x)
170
+ Just (SearchRatTensor , [dims, reduce, lower, upper, fn]) -> VSearchRatTensor dims (argExpr reduce) (argExpr lower) (argExpr upper) (argExpr fn)
167
171
Nothing -> case getDimensionDataOp expr of
168
- Just (ConstTensor, [x, y ]) -> VConstRatTensor (argExpr x) (argExpr y )
169
- Just (StackTensor n, args) -> VStackRatTensor n (fmap argExpr args)
172
+ Just (ConstTensor , [argExpr -> IRatElementType , x, dims ]) -> VConstRatTensor (argExpr x) (argExpr dims )
173
+ Just (StackTensor n, (argExpr -> IRatElementType ) : dims : args) -> VStackRatTensor n dims (fmap argExpr args)
170
174
_ -> developerError " ill-typed RatTensor expression"
171
175
_ -> developerError " ill-typed RatTensor expression"
172
- -}
176
+
173
177
--------------------------------------------------------------------------------
174
178
-- Other
175
179
0 commit comments