Skip to content

Commit 2aecb95

Browse files
committed
Allow decomposition methods in MvNormal
1 parent 2823dfc commit 2aecb95

File tree

6 files changed

+113
-18
lines changed

6 files changed

+113
-18
lines changed

pytensor/link/jax/dispatch/random.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def jax_sample_fn(op, node):
128128
@jax_sample_fn.register(ptr.BetaRV)
129129
@jax_sample_fn.register(ptr.DirichletRV)
130130
@jax_sample_fn.register(ptr.PoissonRV)
131-
@jax_sample_fn.register(ptr.MvNormalRV)
132131
def jax_sample_fn_generic(op, node):
133132
"""Generic JAX implementation of random variables."""
134133
name = op.name
@@ -173,6 +172,20 @@ def sample_fn(rng, size, dtype, *parameters):
173172
return sample_fn
174173

175174

175+
@jax_sample_fn.register(ptr.MvNormalRV)
176+
def jax_sample_mvnormal(op, node):
177+
def sample_fn(rng, size, dtype, mean, cov):
178+
rng_key = rng["jax_state"]
179+
rng_key, sampling_key = jax.random.split(rng_key, 2)
180+
sample = jax.random.multivariate_normal(
181+
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
182+
)
183+
rng["jax_state"] = rng_key
184+
return (rng, sample)
185+
186+
return sample_fn
187+
188+
176189
@jax_sample_fn.register(ptr.BernoulliRV)
177190
def jax_sample_fn_bernoulli(op, node):
178191
"""JAX implementation of `BernoulliRV`."""

pytensor/link/numba/dispatch/random.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,24 @@ def random_fn(rng, p):
144144

145145
@numba_core_rv_funcify.register(ptr.MvNormalRV)
146146
def core_MvNormalRV(op, node):
147+
method = op.method
148+
147149
@numba_basic.numba_njit
148150
def random_fn(rng, mean, cov):
149-
chol = np.linalg.cholesky(cov)
150-
stdnorm = rng.normal(size=cov.shape[-1])
151-
return np.dot(chol, stdnorm) + mean
151+
if method == "cholesky":
152+
A = np.linalg.cholesky(cov)
153+
elif method == "svd":
154+
A, s, _ = np.linalg.svd(cov)
155+
A *= np.sqrt(s)[None, :]
156+
else:
157+
w, A = np.linalg.eigh(cov)
158+
A *= np.sqrt(w)[None, :]
159+
160+
out = rng.normal(size=cov.shape[-1])
161+
# out argument not working correctly: https://github.com/numba/numba/issues/9924
162+
out[:] = np.dot(A, out)
163+
out += mean
164+
return out
152165

153166
random_fn.handles_out = True
154167
return random_fn

pytensor/tensor/random/basic.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import abc
22
import warnings
3+
from typing import Literal
34

45
import numpy as np
56
import scipy.stats as stats
67
from numpy import broadcast_shapes as np_broadcast_shapes
78
from numpy import einsum as np_einsum
9+
from numpy import sqrt as np_sqrt
810
from numpy.linalg import cholesky as np_cholesky
11+
from numpy.linalg import eigh as np_eigh
12+
from numpy.linalg import svd as np_svd
913

10-
import pytensor
1114
from pytensor.tensor import get_vector_length, specify_shape
1215
from pytensor.tensor.basic import as_tensor_variable
1316
from pytensor.tensor.math import sqrt
@@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable):
852855
signature = "(n),(n,n)->(n)"
853856
dtype = "floatX"
854857
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
858+
__props__ = ("name", "signature", "dtype", "inplace", "method")
855859

856-
def __call__(self, mean=None, cov=None, size=None, **kwargs):
860+
def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs):
861+
super().__init__(*args, **kwargs)
862+
if method not in ("cholesky", "svd", "eigh"):
863+
raise ValueError(
864+
f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'."
865+
)
866+
self.method = method
867+
868+
def __call__(self, mean, cov, size=None, **kwargs):
857869
r""" "Draw samples from a multivariate normal distribution.
858870
859871
Signature
@@ -876,33 +888,34 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
876888
is specified, a single `N`-dimensional sample is returned.
877889
878890
"""
879-
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype
880-
881-
if mean is None:
882-
mean = np.array([0.0], dtype=dtype)
883-
if cov is None:
884-
cov = np.array([[1.0]], dtype=dtype)
885891
return super().__call__(mean, cov, size=size, **kwargs)
886892

887-
@classmethod
888-
def rng_fn(cls, rng, mean, cov, size):
893+
def rng_fn(self, rng, mean, cov, size):
889894
if size is None:
890895
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
891896

892-
chol = np_cholesky(cov)
897+
if self.method == "cholesky":
898+
A = np_cholesky(cov)
899+
elif self.method == "svd":
900+
A, s, _ = np_svd(cov)
901+
A *= np_sqrt(s, out=s)[..., None, :]
902+
else:
903+
w, A = np_eigh(cov)
904+
A *= np_sqrt(w, out=w)[..., None, :]
905+
893906
out = rng.normal(size=(*size, mean.shape[-1]))
894907
np_einsum(
895908
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
896-
chol,
909+
A,
897910
out,
898-
out=out,
899911
optimize=False, # Nothing to optimize with two operands, skip costly setup
912+
out=out,
900913
)
901914
out += mean
902915
return out
903916

904917

905-
multivariate_normal = MvNormalRV()
918+
multivariate_normal = MvNormalRV(method="cholesky")
906919

907920

908921
class DirichletRV(RandomVariable):

tests/link/jax/test_random.py

+6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
batched_permutation_tester,
1919
batched_unweighted_choice_without_replacement_tester,
2020
batched_weighted_choice_without_replacement_tester,
21+
create_mvnormal_cov_decomposition_method_test,
2122
)
2223

2324

@@ -547,6 +548,11 @@ def test_random_mvnormal():
547548
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
548549

549550

551+
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
552+
"JAX"
553+
)
554+
555+
550556
@pytest.mark.parametrize(
551557
"parameter, size",
552558
[

tests/link/numba/test_random.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
batched_permutation_tester,
2323
batched_unweighted_choice_without_replacement_tester,
2424
batched_weighted_choice_without_replacement_tester,
25+
create_mvnormal_cov_decomposition_method_test,
2526
)
2627

2728

@@ -147,6 +148,11 @@ def test_multivariate_normal():
147148
)
148149

149150

151+
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
152+
"NUMBA"
153+
)
154+
155+
150156
@pytest.mark.parametrize(
151157
"rv_op, dist_args, size",
152158
[

tests/tensor/random/test_basic.py

+44
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytensor.tensor import ones, stack
2020
from pytensor.tensor.random.basic import (
2121
ChoiceWithoutReplacement,
22+
MvNormalRV,
2223
PermutationRV,
2324
_gamma,
2425
bernoulli,
@@ -686,6 +687,49 @@ def test_mvnormal_ShapeFeature():
686687
assert s4.get_test_value() == 3
687688

688689

690+
def create_mvnormal_cov_decomposition_method_test(mode):
691+
@pytest.mark.parametrize("psd", (True, False))
692+
@pytest.mark.parametrize("method", ("cholesky", "svd", "eigh"))
693+
def test_mvnormal_cov_decomposition_method(method, psd):
694+
mean = 2 ** np.arange(3)
695+
if psd:
696+
cov = [
697+
[1, 0.5, -1],
698+
[0.5, 2, 0],
699+
[-1, 0, 3],
700+
]
701+
else:
702+
cov = [
703+
[1, 0.5, 0],
704+
[0.5, 2, 0],
705+
[0, 0, 0],
706+
]
707+
rng = shared(np.random.default_rng(675))
708+
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,))
709+
assert draws.owner.op.method == method
710+
711+
# JAX doesn't raise errors at runtime
712+
if not psd and method == "cholesky":
713+
if mode == "JAX":
714+
# JAX doesn't raise errors at runtime, instead it returns nan
715+
np.isnan(draws.eval(mode=mode)).all()
716+
else:
717+
with pytest.raises(np.linalg.LinAlgError):
718+
draws.eval(mode=mode)
719+
720+
else:
721+
draws_eval = draws.eval(mode=mode)
722+
np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.02)
723+
np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1)
724+
725+
return test_mvnormal_cov_decomposition_method
726+
727+
728+
test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
729+
None
730+
)
731+
732+
689733
@pytest.mark.parametrize(
690734
"alphas, size",
691735
[

0 commit comments

Comments
 (0)