Skip to content

Commit 87fd3a0

Browse files
committed
Generalize ordered transform
1 parent d34ed95 commit 87fd3a0

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

pymc/distributions/transforms.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,48 @@ def log_jac_det(self, value, *inputs):
8989

9090

9191
class Ordered(Transform):
92+
"""
93+
Transforms a vector of values into a vector of ordered values.
94+
95+
Parameters
96+
----------
97+
positive: If True, all values are positive. This has better geometry than just chaining with a log transform.
98+
ascending: If True, the values are in ascending order (default). If False, the values are in descending order.
99+
"""
100+
92101
name = "ordered"
93102

94-
def __init__(self, ndim_supp=None):
103+
def __init__(self, ndim_supp=None, positive=False, ascending=True):
95104
if ndim_supp is not None:
96105
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
106+
self.positive = positive
107+
self.ascending = ascending
97108

98109
def backward(self, value, *inputs):
99-
x = pt.zeros(value.shape)
100-
x = pt.set_subtensor(x[..., 0], value[..., 0])
101-
x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
102-
return pt.cumsum(x, axis=-1)
110+
if self.positive: # Transform both initial value and deltas to be positive
111+
x = pt.exp(value)
112+
else: # Transform only deltas to be positive
113+
x = pt.empty(value.shape)
114+
x = pt.set_subtensor(x[..., 0], value[..., 0])
115+
x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
116+
x = pt.cumsum(x, axis=-1) # Add deltas cumulatively to initial value
117+
if not self.ascending:
118+
x = x[..., ::-1]
119+
return x
103120

104121
def forward(self, value, *inputs):
105-
y = pt.zeros(value.shape)
106-
y = pt.set_subtensor(y[..., 0], value[..., 0])
122+
if not self.ascending:
123+
value = value[..., ::-1]
124+
y = pt.empty(value.shape)
125+
y = pt.set_subtensor(y[..., 0], pt.log(value[..., 0]) if self.positive else value[..., 0])
107126
y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
108127
return y
109128

110129
def log_jac_det(self, value, *inputs):
111-
return pt.sum(value[..., 1:], axis=-1)
130+
if self.positive:
131+
return pt.sum(value, axis=-1)
132+
else:
133+
return pt.sum(value[..., 1:], axis=-1)
112134

113135

114136
class SumTo1(Transform):
@@ -132,8 +154,7 @@ def forward(self, value, *inputs):
132154
return value[..., :-1]
133155

134156
def log_jac_det(self, value, *inputs):
135-
y = pt.zeros(value.shape)
136-
return pt.sum(y, axis=-1)
157+
return pt.zeros(value.shape[:-1])
137158

138159

139160
class CholeskyCovPacked(Transform):

tests/distributions/test_transform.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def check_jacobian_det(
103103
x = make_comparable(x)
104104

105105
if not elemwise:
106-
jac = pt.log(pt.nlinalg.det(jacobian(x, [y])))
106+
jac = pt.log(pt.abs(pt.nlinalg.det(jacobian(x, [y]))))
107107
else:
108108
jac = pt.log(pt.abs(pt.diag(jacobian(x, [y]))))
109109

@@ -115,7 +115,7 @@ def check_jacobian_det(
115115
)
116116

117117
for yval in domain.vals:
118-
assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol)
118+
assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol, atol=tol)
119119

120120

121121
def test_simplex():
@@ -281,6 +281,31 @@ def test_ordered():
281281
vals = get_values(tr.ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3)))
282282
assert_array_equal(np.diff(vals) >= 0, True)
283283

284+
# Check that positive=True creates positive and still ordered values
285+
vals = get_values(tr.Ordered(positive=True), Vector(R, 3), pt.vector, floatX(np.zeros(3)))
286+
assert_array_equal(vals > 0, True)
287+
assert_array_equal(np.diff(vals) >= 0, True)
288+
289+
# Check that positive=True and ascending=False creates descending values
290+
vals = get_values(
291+
tr.Ordered(positive=True, ascending=False), Vector(R, 3), pt.vector, floatX(np.zeros(3))
292+
)
293+
assert_array_equal(vals > 0, True)
294+
assert_array_equal(np.diff(vals) <= 0, True)
295+
296+
# Check that forward and backward are still inverses
297+
ord, vals = tr.Ordered(positive=True, ascending=False), np.array([0.3, 0.2, 0.1])
298+
assert_allclose(vals, ord.backward(ord.forward(vals)).eval())
299+
300+
# Check the jacobian for positive=True and ascending=False
301+
check_jacobian_det(
302+
tr.Ordered(positive=True, ascending=False),
303+
Vector(R, 2),
304+
pt.vector,
305+
floatX(np.array([1, 1])),
306+
elemwise=False,
307+
)
308+
284309

285310
def test_chain_values():
286311
chain_tranf = tr.Chain([tr.logodds, tr.ordered])

0 commit comments

Comments
 (0)