Skip to content

Commit cf2b60e

Browse files
committed
Added a PartialOrder transform
1 parent 7d62c53 commit cf2b60e

File tree

3 files changed

+226
-0
lines changed

3 files changed

+226
-0
lines changed

Diff for: pymc_extras/distributions/transforms/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from pymc_extras.distributions.transforms.partial_order import PartialOrder
2+
3+
__all__ = ["PartialOrder"]
4+
+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pymc.logprob.transforms import Transform
15+
import pytensor.tensor as pt
16+
import numpy as np
17+
__all__ = [
18+
"PartialOrder"
19+
]
20+
21+
# Find the minimum value for a given dtype
22+
def dtype_minval(dtype):
23+
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
24+
25+
# A padded version of np.where
26+
def padded_where(x,to_len,padval=-1):
27+
w = np.where(x)
28+
return np.concatenate([w[0],np.full(to_len-len(w[0]),padval)])
29+
30+
# Partial order transform
31+
class PartialOrder(Transform):
32+
name = "partial_order"
33+
34+
def __init__(self, adj_mat):
35+
"""Create a PartialOrder transform
36+
37+
This is a more flexible version of the pymc ordered transform that
38+
allows specifying a (strict) partial order on the elements.
39+
40+
It works in O(N*D) in runtime, but takes O(N^3) in initialization,
41+
where N is the number of nodes in the dag and
42+
D is the maximum in-degree of a node in the transitive reduction.
43+
44+
Parameters
45+
----------
46+
adj_mat: adjacency matrix for the DAG that generates the partial order,
47+
where adj_mat[i][j] = 1 denotes i<j.
48+
Note this also accepts multiple DAGs if RV is multidimensional
49+
"""
50+
51+
# Basic input checks
52+
if adj_mat.ndim < 2: raise ValueError("Adjacency matrix must have at least 2 dimensions")
53+
if adj_mat.shape[-2] != adj_mat.shape[-1]: raise ValueError("Adjacency matrix is not square")
54+
if adj_mat.min()!=0 or adj_mat.max()!=1: raise ValueError("Adjacency matrix must contain only 0s and 1s")
55+
56+
# Create index over the first ellipsis dimensions
57+
idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
58+
59+
# Transitive closure using Floyd-Warshall
60+
tc = adj_mat.astype(bool)
61+
for k in range(tc.shape[-1]):
62+
tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
63+
64+
# Check if the dag is acyclic
65+
if np.any(tc.diagonal(axis1=-2,axis2=-1)): raise ValueError("Partial order contains equalities")
66+
67+
# Transitive reduction using the closure
68+
# This gives the minimum description of the partial order
69+
# This is to minmax the input degree
70+
adj_mat = tc * (1-np.matmul(tc,tc))
71+
72+
# Find the maximum in-degree of the reduced dag
73+
dag_idim = adj_mat.sum(axis=-2).max()
74+
75+
# Topological sort
76+
ts_inds = np.zeros(adj_mat.shape[:-1],dtype=int)
77+
dm = adj_mat.copy()
78+
for i in range(adj_mat.shape[1]):
79+
assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
80+
nind = np.argmin(dm.sum(axis=-2),axis=-1)
81+
dm[idx+(slice(None),nind)] = 1 # Make nind not show up again
82+
dm[idx+(nind,slice(None))] = 0 # Allow it's children to show
83+
ts_inds[idx + (i,)] = nind
84+
self.ts_inds = ts_inds
85+
86+
# Change the dag to adjacency lists (with -1 for NA)
87+
dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat,
88+
padval=-1, to_len=dag_idim)
89+
self.dag = np.swapaxes(dag_T, -2, -1)
90+
self.is_start = np.all(self.dag[...,:,:]==-1,axis=-1)
91+
92+
def initvals(self, lower=-1, upper=1):
93+
vals = np.linspace(lower,upper,self.dag.shape[-2])
94+
inds = np.argsort(self.ts_inds,axis=-1)
95+
return vals[inds]
96+
97+
def backward(self, value, *inputs):
98+
minv = dtype_minval(value.dtype)
99+
x = pt.concatenate([pt.zeros_like(value),
100+
pt.full(value.shape[:-1],minv)[...,None]],axis=-1)
101+
102+
# Indices to allow broadcasting the max over the last dimension
103+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
104+
idx2 = tuple( np.tile(i[:,None],self.dag.shape[-1]) for i in idx )
105+
106+
# Has to be done stepwise as next steps depend on previous values
107+
# Also has to be done in topological order, hence the ts_inds
108+
print(self.ts_inds)
109+
for i in range(self.dag.shape[-2]):
110+
tsi = self.ts_inds[...,i]
111+
if len(tsi.shape)==0: tsi = int(tsi) # if shape 0, it's a scalar
112+
ni = idx + (tsi,) # i-th node in topological order
113+
eni = (Ellipsis,) + ni
114+
ist = self.is_start[ni]
115+
116+
mval = pt.max(x[(Ellipsis,) + idx2 + (self.dag[ni],)],axis=-1)
117+
x = pt.set_subtensor(x[eni], ist*value[eni] +
118+
(1-ist)*(mval + pt.exp(value[eni])))
119+
return x[...,:-1]
120+
121+
def forward(self, value, *inputs):
122+
y = pt.zeros_like(value)
123+
124+
minv = dtype_minval(value.dtype)
125+
vx = pt.concatenate([value,
126+
pt.full(value.shape[:-1],minv)[...,None]],axis=-1)
127+
128+
# Indices to allow broadcasting the max over the last dimension
129+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
130+
idx = tuple( np.tile(i[:,None,None],self.dag.shape[-2:]) for i in idx )
131+
132+
y = self.is_start*value + (1-self.is_start)*(pt.log(value -
133+
pt.max(vx[(Ellipsis,) + idx + (self.dag[...,:],)],axis=-1)))
134+
135+
return y
136+
137+
def log_jac_det(self, value, *inputs):
138+
return pt.sum(value*(1-self.is_start), axis=-1)

Diff for: tests/distributions/test_transform.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pymc as pm
16+
import pytensor
17+
import pytensor.tensor as pt
18+
import pytest
19+
import numpy as np
20+
21+
from pymc_extras.distributions.transforms import (
22+
PartialOrder
23+
)
24+
25+
class TestPartialOrder:
26+
27+
adj_mats = np.array([
28+
29+
# 0 < {1, 2} < 3
30+
[[0, 1, 1, 0],
31+
[0, 0, 0, 1],
32+
[0, 0, 0, 1],
33+
[0, 0, 0, 0]],
34+
35+
# 1 < 0 < 3 < 2
36+
[[0, 0, 0, 1],
37+
[1, 0, 0, 0],
38+
[0, 0, 0, 0],
39+
[0, 0, 1, 0]]
40+
41+
])
42+
43+
valid_values = np.array([
44+
[0, 2, 1, 3],
45+
[1, 0, 3, 2]
46+
], dtype=float)
47+
48+
# Test that forward and backward are inverses of eachother
49+
# And that it works when extra dimensions are added in data
50+
def test_forward_backward_dimensionality(self):
51+
po = PartialOrder(self.adj_mats)
52+
po0 = PartialOrder(self.adj_mats[0])
53+
vv = self.valid_values
54+
vv0 = self.valid_values[0]
55+
56+
testsets = [
57+
(vv,po),
58+
(po.initvals(),po),
59+
(vv0,po0),
60+
(po0.initvals(),po0),
61+
(np.tile(vv0,(2,1)),po0),
62+
(np.tile(vv0,(2,3,2,1)),po0),
63+
(np.tile(vv,(2,3,2,1,1)),po)
64+
]
65+
66+
for (vv,po) in testsets:
67+
fw = po.forward(vv)
68+
bw = po.backward(fw)
69+
np.testing.assert_allclose(bw.eval(), vv)
70+
71+
def test_sample_model(self):
72+
po = PartialOrder(self.adj_mats)
73+
with pm.Model() as model:
74+
x = pm.Normal('x',
75+
size=(2,4), transform=po,
76+
initval=po.initvals(-1,1))
77+
idata = pm.sample()
78+
79+
# Check that the order constraints are satisfied
80+
xvs = idata.posterior.x.values.transpose(2,3,0,1)
81+
x0 = xvs[0] # 0 < {1, 2} < 3
82+
assert (x0[0]<x0[1]).all() and (x0[0]<x0[2]).all() and (x0[1]<x0[3]).all() and (x0[2]<x0[3]).all()
83+
x1 = xvs[1] # 1 < 0 < 3 < 2
84+
assert (x1[1]<x1[0]).all() and (x1[0]<x1[3]).all() and (x1[3]<x1[2]).all()

0 commit comments

Comments
 (0)