@@ -89,26 +89,48 @@ def log_jac_det(self, value, *inputs):
89
89
90
90
91
91
class Ordered (Transform ):
92
+ """
93
+ Transforms a vector of values into a vector of ordered values.
94
+
95
+ Parameters
96
+ ----------
97
+ positive: If True, all values are positive. This has better geometry than just chaining with a log transform.
98
+ ascending: If True, the values are in ascending order (default). If False, the values are in descending order.
99
+ """
100
+
92
101
name = "ordered"
93
102
94
- def __init__ (self , ndim_supp = None ):
103
+ def __init__ (self , ndim_supp = None , positive = False , ascending = True ):
95
104
if ndim_supp is not None :
96
105
warnings .warn ("ndim_supp argument is deprecated and has no effect" , FutureWarning )
106
+ self .positive = positive
107
+ self .ascending = ascending
97
108
98
109
def backward (self , value , * inputs ):
99
- x = pt .zeros (value .shape )
100
- x = pt .set_subtensor (x [..., 0 ], value [..., 0 ])
101
- x = pt .set_subtensor (x [..., 1 :], pt .exp (value [..., 1 :]))
102
- return pt .cumsum (x , axis = - 1 )
110
+ if self .positive : # Transform both initial value and deltas to be positive
111
+ x = pt .exp (value )
112
+ else : # Transform only deltas to be positive
113
+ x = pt .empty (value .shape )
114
+ x = pt .set_subtensor (x [..., 0 ], value [..., 0 ])
115
+ x = pt .set_subtensor (x [..., 1 :], pt .exp (value [..., 1 :]))
116
+ x = pt .cumsum (x , axis = - 1 ) # Add deltas cumulatively to initial value
117
+ if not self .ascending :
118
+ x = x [..., ::- 1 ]
119
+ return x
103
120
104
121
def forward (self , value , * inputs ):
105
- y = pt .zeros (value .shape )
106
- y = pt .set_subtensor (y [..., 0 ], value [..., 0 ])
122
+ if not self .ascending :
123
+ value = value [..., ::- 1 ]
124
+ y = pt .empty (value .shape )
125
+ y = pt .set_subtensor (y [..., 0 ], pt .log (value [..., 0 ]) if self .positive else value [..., 0 ])
107
126
y = pt .set_subtensor (y [..., 1 :], pt .log (value [..., 1 :] - value [..., :- 1 ]))
108
127
return y
109
128
110
129
def log_jac_det (self , value , * inputs ):
111
- return pt .sum (value [..., 1 :], axis = - 1 )
130
+ if self .positive :
131
+ return pt .sum (value , axis = - 1 )
132
+ else :
133
+ return pt .sum (value [..., 1 :], axis = - 1 )
112
134
113
135
114
136
class SumTo1 (Transform ):
@@ -132,8 +154,7 @@ def forward(self, value, *inputs):
132
154
return value [..., :- 1 ]
133
155
134
156
def log_jac_det (self , value , * inputs ):
135
- y = pt .zeros (value .shape )
136
- return pt .sum (y , axis = - 1 )
157
+ return pt .zeros (value .shape [:- 1 ])
137
158
138
159
139
160
class CholeskyCovPacked (Transform ):
0 commit comments