Skip to content

Commit 898c92f

Browse files
committed
Added a PartialOrder transform
1 parent 7d62c53 commit 898c92f

File tree

5 files changed

+234
-0
lines changed

5 files changed

+234
-0
lines changed

Diff for: docs/api_reference.rst

+10
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ Distributions
4444
histogram_approximation
4545

4646

47+
Transforms
48+
==========
49+
50+
.. currentmodule:: pymc_extras.distributions.transforms
51+
.. autosummary::
52+
:toctree: generated/
53+
54+
PartialOrder
55+
56+
4757
Utils
4858
=====
4959

Diff for: pymc_extras/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc_extras.distributions.histogram_utils import histogram_approximation
2727
from pymc_extras.distributions.multivariate import R2D2M2CP
2828
from pymc_extras.distributions.timeseries import DiscreteMarkovChain
29+
from pymc_extras.distributions.transforms import PartialOrder
2930

3031
__all__ = [
3132
"Chi",
@@ -37,4 +38,5 @@
3738
"R2D2M2CP",
3839
"Skellam",
3940
"histogram_approximation",
41+
"PartialOrder",
4042
]

Diff for: pymc_extras/distributions/transforms/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pymc_extras.distributions.transforms.partial_order import PartialOrder
2+
3+
__all__ = ["PartialOrder"]
+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor.tensor as pt
16+
17+
from pymc.logprob.transforms import Transform
18+
19+
__all__ = ["PartialOrder"]
20+
21+
22+
# Find the minimum value for a given dtype
23+
def dtype_minval(dtype):
24+
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
25+
26+
27+
# A padded version of np.where
28+
def padded_where(x, to_len, padval=-1):
29+
w = np.where(x)
30+
return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
31+
32+
33+
# Partial order transform
34+
class PartialOrder(Transform):
35+
"""Create a PartialOrder transform
36+
37+
This is a more flexible version of the pymc ordered transform that
38+
allows specifying a (strict) partial order on the elements.
39+
40+
It works in O(N*D) in runtime, but takes O(N^3) in initialization,
41+
where N is the number of nodes in the dag and
42+
D is the maximum in-degree of a node in the transitive reduction.
43+
44+
"""
45+
46+
name = "partial_order"
47+
48+
def __init__(self, adj_mat):
49+
"""
50+
Parameters
51+
----------
52+
adj_mat: ndarray
53+
adjacency matrix for the DAG that generates the partial order,
54+
where ``adj_mat[i][j] = 1`` denotes ``i < j``.
55+
Note this also accepts multiple DAGs if RV is multidimensional
56+
"""
57+
58+
# Basic input checks
59+
if adj_mat.ndim < 2:
60+
raise ValueError("Adjacency matrix must have at least 2 dimensions")
61+
if adj_mat.shape[-2] != adj_mat.shape[-1]:
62+
raise ValueError("Adjacency matrix is not square")
63+
if adj_mat.min() != 0 or adj_mat.max() != 1:
64+
raise ValueError("Adjacency matrix must contain only 0s and 1s")
65+
66+
# Create index over the first ellipsis dimensions
67+
idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
68+
69+
# Transitive closure using Floyd-Warshall
70+
tc = adj_mat.astype(bool)
71+
for k in range(tc.shape[-1]):
72+
tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
73+
74+
# Check if the dag is acyclic
75+
if np.any(tc.diagonal(axis1=-2, axis2=-1)):
76+
raise ValueError("Partial order contains equalities")
77+
78+
# Transitive reduction using the closure
79+
# This gives the minimum description of the partial order
80+
# This is to minmax the input degree
81+
adj_mat = tc * (1 - np.matmul(tc, tc))
82+
83+
# Find the maximum in-degree of the reduced dag
84+
dag_idim = adj_mat.sum(axis=-2).max()
85+
86+
# Topological sort
87+
ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int)
88+
dm = adj_mat.copy()
89+
for i in range(adj_mat.shape[1]):
90+
assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
91+
nind = np.argmin(dm.sum(axis=-2), axis=-1)
92+
dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again
93+
dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show
94+
ts_inds[(*idx, i)] = nind
95+
self.ts_inds = ts_inds
96+
97+
# Change the dag to adjacency lists (with -1 for NA)
98+
dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim)
99+
self.dag = np.swapaxes(dag_T, -2, -1)
100+
self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
101+
102+
def initvals(self, lower=-1, upper=1):
103+
vals = np.linspace(lower, upper, self.dag.shape[-2])
104+
inds = np.argsort(self.ts_inds, axis=-1)
105+
return vals[inds]
106+
107+
def backward(self, value, *inputs):
108+
minv = dtype_minval(value.dtype)
109+
x = pt.concatenate(
110+
[pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1
111+
)
112+
113+
# Indices to allow broadcasting the max over the last dimension
114+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
115+
idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx)
116+
117+
# Has to be done stepwise as next steps depend on previous values
118+
# Also has to be done in topological order, hence the ts_inds
119+
for i in range(self.dag.shape[-2]):
120+
tsi = self.ts_inds[..., i]
121+
if len(tsi.shape) == 0:
122+
tsi = int(tsi) # if shape 0, it's a scalar
123+
ni = (*idx, tsi) # i-th node in topological order
124+
eni = (Ellipsis, *ni)
125+
ist = self.is_start[ni]
126+
127+
mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1)
128+
x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni])))
129+
return x[..., :-1]
130+
131+
def forward(self, value, *inputs):
132+
y = pt.zeros_like(value)
133+
134+
minv = dtype_minval(value.dtype)
135+
vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1)
136+
137+
# Indices to allow broadcasting the max over the last dimension
138+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
139+
idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx)
140+
141+
y = self.is_start * value + (1 - self.is_start) * (
142+
pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1))
143+
)
144+
145+
return y
146+
147+
def log_jac_det(self, value, *inputs):
148+
return pt.sum(value * (1 - self.is_start), axis=-1)

Diff for: tests/distributions/test_transform.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pymc as pm
16+
17+
from pymc_extras.distributions.transforms import PartialOrder
18+
19+
20+
class TestPartialOrder:
21+
adj_mats = np.array(
22+
[
23+
# 0 < {1, 2} < 3
24+
[[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]],
25+
# 1 < 0 < 3 < 2
26+
[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]],
27+
]
28+
)
29+
30+
valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float)
31+
32+
# Test that forward and backward are inverses of eachother
33+
# And that it works when extra dimensions are added in data
34+
def test_forward_backward_dimensionality(self):
35+
po = PartialOrder(self.adj_mats)
36+
po0 = PartialOrder(self.adj_mats[0])
37+
vv = self.valid_values
38+
vv0 = self.valid_values[0]
39+
40+
testsets = [
41+
(vv, po),
42+
(po.initvals(), po),
43+
(vv0, po0),
44+
(po0.initvals(), po0),
45+
(np.tile(vv0, (2, 1)), po0),
46+
(np.tile(vv0, (2, 3, 2, 1)), po0),
47+
(np.tile(vv, (2, 3, 2, 1, 1)), po),
48+
]
49+
50+
for vv, po in testsets:
51+
fw = po.forward(vv)
52+
bw = po.backward(fw)
53+
np.testing.assert_allclose(bw.eval(), vv)
54+
55+
def test_sample_model(self):
56+
po = PartialOrder(self.adj_mats)
57+
with pm.Model() as model:
58+
x = pm.Normal("x", size=(2, 4), transform=po, initval=po.initvals(-1, 1))
59+
idata = pm.sample()
60+
61+
# Check that the order constraints are satisfied
62+
xvs = idata.posterior.x.values.transpose(2, 3, 0, 1)
63+
x0 = xvs[0] # 0 < {1, 2} < 3
64+
assert (
65+
(x0[0] < x0[1]).all()
66+
and (x0[0] < x0[2]).all()
67+
and (x0[1] < x0[3]).all()
68+
and (x0[2] < x0[3]).all()
69+
)
70+
x1 = xvs[1] # 1 < 0 < 3 < 2
71+
assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()

0 commit comments

Comments
 (0)