1
1
import abc
2
2
import warnings
3
+ from typing import Literal
3
4
4
5
import numpy as np
5
6
import scipy .stats as stats
6
7
from numpy import broadcast_shapes as np_broadcast_shapes
7
8
from numpy import einsum as np_einsum
9
+ from numpy import sqrt as np_sqrt
8
10
from numpy .linalg import cholesky as np_cholesky
11
+ from numpy .linalg import eigh as np_eigh
12
+ from numpy .linalg import svd as np_svd
9
13
10
- import pytensor
11
14
from pytensor .tensor import get_vector_length , specify_shape
12
15
from pytensor .tensor .basic import as_tensor_variable
13
16
from pytensor .tensor .math import sqrt
@@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable):
852
855
signature = "(n),(n,n)->(n)"
853
856
dtype = "floatX"
854
857
_print_name = ("MultivariateNormal" , "\\ operatorname{MultivariateNormal}" )
858
+ __props__ = ("name" , "signature" , "dtype" , "inplace" , "method" )
855
859
856
- def __call__ (self , mean = None , cov = None , size = None , ** kwargs ):
860
+ def __init__ (self , * args , method : Literal ["cholesky" , "svd" , "eigh" ], ** kwargs ):
861
+ super ().__init__ (* args , ** kwargs )
862
+ if method not in ("cholesky" , "svd" , "eigh" ):
863
+ raise ValueError (
864
+ f"Unknown method { method } . The method must be one of 'cholesky', 'svd', or 'eigh'."
865
+ )
866
+ self .method = method
867
+
868
+ def __call__ (self , mean , cov , size = None , ** kwargs ):
857
869
r""" "Draw samples from a multivariate normal distribution.
858
870
859
871
Signature
@@ -876,33 +888,34 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
876
888
is specified, a single `N`-dimensional sample is returned.
877
889
878
890
"""
879
- dtype = pytensor .config .floatX if self .dtype == "floatX" else self .dtype
880
-
881
- if mean is None :
882
- mean = np .array ([0.0 ], dtype = dtype )
883
- if cov is None :
884
- cov = np .array ([[1.0 ]], dtype = dtype )
885
891
return super ().__call__ (mean , cov , size = size , ** kwargs )
886
892
887
- @classmethod
888
- def rng_fn (cls , rng , mean , cov , size ):
893
+ def rng_fn (self , rng , mean , cov , size ):
889
894
if size is None :
890
895
size = np_broadcast_shapes (mean .shape [:- 1 ], cov .shape [:- 2 ])
891
896
892
- chol = np_cholesky (cov )
897
+ if self .method == "cholesky" :
898
+ A = np_cholesky (cov )
899
+ elif self .method == "svd" :
900
+ A , s , _ = np_svd (cov )
901
+ A *= np_sqrt (s , out = s )[..., None , :]
902
+ else :
903
+ w , A = np_eigh (cov )
904
+ A *= np_sqrt (w , out = w )[..., None , :]
905
+
893
906
out = rng .normal (size = (* size , mean .shape [- 1 ]))
894
907
np_einsum (
895
908
"...ij,...j->...i" , # numpy doesn't have a batch matrix-vector product
896
- chol ,
909
+ A ,
897
910
out ,
898
- out = out ,
899
911
optimize = False , # Nothing to optimize with two operands, skip costly setup
912
+ out = out ,
900
913
)
901
914
out += mean
902
915
return out
903
916
904
917
905
- multivariate_normal = MvNormalRV ()
918
+ multivariate_normal = MvNormalRV (method = "cholesky" )
906
919
907
920
908
921
class DirichletRV (RandomVariable ):
0 commit comments