|
31 | 31 | from aesara.tensor import gammaln, sigmoid
|
32 | 32 | from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
|
33 | 33 | from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
|
34 |
| -from aesara.tensor.random.op import RandomVariable, default_shape_from_params |
| 34 | +from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params |
35 | 35 | from aesara.tensor.random.utils import broadcast_params
|
36 | 36 | from aesara.tensor.slinalg import Cholesky
|
37 | 37 | from aesara.tensor.slinalg import solve_lower_triangular as solve_lower
|
@@ -295,8 +295,10 @@ def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
|
295 | 295 | cov = np.array([[1.0]], dtype=dtype)
|
296 | 296 | return super().__call__(nu, mu, cov, size=size, **kwargs)
|
297 | 297 |
|
298 |
| - def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
299 |
| - return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) |
| 298 | + def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
| 299 | + return default_supp_shape_from_params( |
| 300 | + self.ndim_supp, dist_params, rep_param_idx, param_shapes |
| 301 | + ) |
300 | 302 |
|
301 | 303 | @classmethod
|
302 | 304 | def rng_fn(cls, rng, nu, mu, cov, size):
|
@@ -607,8 +609,10 @@ class DirichletMultinomialRV(RandomVariable):
|
607 | 609 | dtype = "int64"
|
608 | 610 | _print_name = ("DirichletMN", "\\operatorname{DirichletMN}")
|
609 | 611 |
|
610 |
| - def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
611 |
| - return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) |
| 612 | + def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
| 613 | + return default_supp_shape_from_params( |
| 614 | + self.ndim_supp, dist_params, rep_param_idx, param_shapes |
| 615 | + ) |
612 | 616 |
|
613 | 617 | @classmethod
|
614 | 618 | def rng_fn(cls, rng, n, a, size):
|
@@ -903,7 +907,7 @@ class WishartRV(RandomVariable):
|
903 | 907 | dtype = "floatX"
|
904 | 908 | _print_name = ("Wishart", "\\operatorname{Wishart}")
|
905 | 909 |
|
906 |
| - def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
| 910 | + def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
907 | 911 | # The shape of second parameter `V` defines the shape of the output.
|
908 | 912 | return dist_params[1].shape[-2:]
|
909 | 913 |
|
@@ -1471,7 +1475,7 @@ def make_node(self, rng, size, dtype, n, eta):
|
1471 | 1475 |
|
1472 | 1476 | return super().make_node(rng, size, dtype, n, eta)
|
1473 | 1477 |
|
1474 |
| - def _shape_from_params(self, dist_params, **kwargs): |
| 1478 | + def _supp_shape_from_params(self, dist_params, **kwargs): |
1475 | 1479 | n = dist_params[0]
|
1476 | 1480 | dist_shape = ((n * (n - 1)) // 2,)
|
1477 | 1481 | return dist_shape
|
|
0 commit comments