2727from aesara .tensor import gammaln
2828from aesara .tensor .nlinalg import det , eigh , matrix_inverse , trace
2929from aesara .tensor .random .basic import MultinomialRV , dirichlet , multivariate_normal
30+ from aesara .tensor .random .op import RandomVariable , default_shape_from_params
3031from aesara .tensor .random .utils import broadcast_params
3132from aesara .tensor .slinalg import (
3233 Cholesky ,
4142
4243from pymc3 .aesaraf import floatX , intX
4344from pymc3 .distributions import transforms
44- from pymc3 .distributions .continuous import ChiSquared , Normal
45+ from pymc3 .distributions .continuous import ChiSquared , Normal , assert_negative_support
4546from pymc3 .distributions .dist_math import bound , factln , logpow , multigammaln
4647from pymc3 .distributions .distribution import Continuous , Discrete
4748from pymc3 .math import kron_diag , kron_dot , kron_solve_lower , kronecker
@@ -248,6 +249,48 @@ def _distr_parameters_for_repr(self):
248249 return ["mu" , "cov" ]
249250
250251
252+ class MvStudentTRV (RandomVariable ):
253+ name = "multivariate_studentt"
254+ ndim_supp = 1
255+ ndims_params = [0 , 1 , 2 ]
256+ dtype = "floatX"
257+ _print_name = ("MvStudentT" , "\\ operatorname{MvStudentT}" )
258+
259+ def __call__ (self , nu , mu = None , cov = None , size = None , ** kwargs ):
260+
261+ dtype = aesara .config .floatX if self .dtype == "floatX" else self .dtype
262+
263+ if mu is None :
264+ mu = np .array ([0.0 ], dtype = dtype )
265+ if cov is None :
266+ cov = np .array ([[1.0 ]], dtype = dtype )
267+ return super ().__call__ (nu , mu , cov , size = size , ** kwargs )
268+
269+ def _shape_from_params (self , dist_params , rep_param_idx = 1 , param_shapes = None ):
270+ return default_shape_from_params (self .ndim_supp , dist_params , rep_param_idx , param_shapes )
271+
272+ @classmethod
273+ def rng_fn (cls , rng , nu , mu , cov , size ):
274+
275+ # Don't reassign broadcasted cov, since MvNormal expects two dimensional cov only.
276+ mu , _ = broadcast_params ([mu , cov ], cls .ndims_params [1 :])
277+
278+ chi2_samples = np .sqrt (rng .chisquare (nu , size = size ) / nu )
279+ # Add distribution shape to chi2 samples
280+ chi2_samples = chi2_samples .reshape (chi2_samples .shape + (1 ,) * len (mu .shape ))
281+
282+ mv_samples = multivariate_normal .rng_fn (rng = rng , mean = np .zeros_like (mu ), cov = cov , size = size )
283+
284+ size = tuple (size or ())
285+ if size :
286+ mu = np .broadcast_to (mu , size + mu .shape )
287+
288+ return (mv_samples / chi2_samples ) + mu
289+
290+
291+ mv_studentt = MvStudentTRV ()
292+
293+
251294class MvStudentT (Continuous ):
252295 r"""
253296 Multivariate Student-T log-likelihood.
@@ -273,8 +316,8 @@ class MvStudentT(Continuous):
273316
274317 Parameters
275318 ----------
276- nu: int
277- Degrees of freedom.
319+ nu: float
320+ Degrees of freedom, should be a positive scalar .
278321 Sigma: matrix
279322 Covariance matrix. Use `cov` in new code.
280323 mu: array
@@ -288,55 +331,21 @@ class MvStudentT(Continuous):
288331 lower: bool, default=True
289332 Whether the cholesky fatcor is given as a lower triangular matrix.
290333 """
334+ rv_op = mv_studentt
291335
292- def __init__ (
293- self , nu , Sigma = None , mu = None , cov = None , tau = None , chol = None , lower = True , * args , ** kwargs
294- ):
336+ @classmethod
337+ def dist (cls , nu , Sigma = None , mu = None , cov = None , tau = None , chol = None , lower = True , ** kwargs ):
295338 if Sigma is not None :
296339 if cov is not None :
297340 raise ValueError ("Specify only one of cov and Sigma" )
298341 cov = Sigma
299- super ().__init__ (mu = mu , cov = cov , tau = tau , chol = chol , lower = lower , * args , ** kwargs )
300- self .nu = nu = at .as_tensor_variable (nu )
301- self .mean = self .median = self .mode = self .mu = self .mu
302-
303- def random (self , point = None , size = None ):
304- """
305- Draw random values from Multivariate Student's T distribution.
306-
307- Parameters
308- ----------
309- point: dict, optional
310- Dict of variable values on which random values are to be
311- conditioned (uses default point if not specified).
312- size: int, optional
313- Desired size of random sample (returns one sample if not
314- specified).
315-
316- Returns
317- -------
318- array
319- """
320- # with _DrawValuesContext():
321- # nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
322- # if self._cov_type == "cov":
323- # (cov,) = draw_values([self.cov], point=point, size=size)
324- # dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
325- # elif self._cov_type == "tau":
326- # (tau,) = draw_values([self.tau], point=point, size=size)
327- # dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
328- # else:
329- # (chol,) = draw_values([self.chol_cov], point=point, size=size)
330- # dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)
331- #
332- # samples = dist.random(point, size)
333- #
334- # chi2_samples = np.random.chisquare(nu, size)
335- # # Add distribution shape to chi2 samples
336- # chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(self.shape))
337- # return (samples / np.sqrt(chi2_samples / nu)) + mu
342+ nu = at .as_tensor_variable (floatX (nu ))
343+ mu = at .as_tensor_variable (floatX (mu ))
344+ cov = quaddist_matrix (cov , chol , tau , lower )
345+ assert_negative_support (nu , "nu" , "MvStudentT" )
346+ return super ().dist ([nu , mu , cov ], ** kwargs )
338347
339- def logp (value , nu , cov ):
348+ def logp (value , nu , mu , cov ):
340349 """
341350 Calculate log-probability of Multivariate Student's T distribution
342351 at specified value.
@@ -350,15 +359,15 @@ def logp(value, nu, cov):
350359 -------
351360 TensorVariable
352361 """
353- quaddist , logdet , ok = quaddist_parse (value , nu , cov )
362+ quaddist , logdet , ok = quaddist_parse (value , mu , cov )
354363 k = floatX (value .shape [- 1 ])
355364
356- norm = gammaln ((nu + k ) / 2.0 ) - gammaln (nu / 2.0 ) - 0.5 * k * floatX ( np .log (nu * np .pi ) )
365+ norm = gammaln ((nu + k ) / 2.0 ) - gammaln (nu / 2.0 ) - 0.5 * k * at .log (nu * np .pi )
357366 inner = - (nu + k ) / 2.0 * at .log1p (quaddist / nu )
358367 return bound (norm + inner - logdet , ok )
359368
360369 def _distr_parameters_for_repr (self ):
361- return ["mu " , "nu " , "cov" ]
370+ return ["nu " , "mu " , "cov" ]
362371
363372
364373class Dirichlet (Continuous ):
0 commit comments