47
47
from pymc3 .distributions .continuous import ChiSquared , Normal , assert_negative_support
48
48
from pymc3 .distributions .dist_math import bound , factln , logpow , multigammaln
49
49
from pymc3 .distributions .distribution import Continuous , Discrete
50
+ from pymc3 .distributions .shape_utils import broadcast_dist_samples_to , to_tuple
50
51
from pymc3 .math import kron_diag , kron_dot
51
52
52
53
__all__ = [
@@ -739,6 +740,26 @@ def __str__(self):
739
740
matrix_pos_def = PosDefMatrix ()
740
741
741
742
743
+ class WishartRV (RandomVariable ):
744
+ name = "wishart"
745
+ ndim_supp = 2
746
+ ndims_params = [0 , 2 ]
747
+ dtype = "floatX"
748
+ _print_name = ("Wishart" , "\\ operatorname{Wishart}" )
749
+
750
+ def _shape_from_params (self , dist_params , rep_param_idx = 1 , param_shapes = None ):
751
+ # The shape of second parameter `V` defines the shape of the output.
752
+ return dist_params [1 ].shape
753
+
754
+ @classmethod
755
+ def rng_fn (cls , rng , nu , V , size = None ):
756
+ size = size if size else 1 # Default size for Scipy's wishart.rvs is 1
757
+ return stats .wishart .rvs (np .int (nu ), V , size = size , random_state = rng )
758
+
759
+
760
+ wishart = WishartRV ()
761
+
762
+
742
763
class Wishart (Continuous ):
743
764
r"""
744
765
Wishart log-likelihood.
@@ -775,9 +796,13 @@ class Wishart(Continuous):
775
796
This distribution is unusable in a PyMC3 model. You should instead
776
797
use LKJCholeskyCov or LKJCorr.
777
798
"""
799
+ rv_op = wishart
800
+
801
+ @classmethod
802
+ def dist (cls , nu , V , * args , ** kwargs ):
803
+ nu = at .as_tensor_variable (intX (nu ))
804
+ V = at .as_tensor_variable (floatX (V ))
778
805
779
- def __init__ (self , nu , V , * args , ** kwargs ):
780
- super ().__init__ (* args , ** kwargs )
781
806
warnings .warn (
782
807
"The Wishart distribution can currently not be used "
783
808
"for MCMC sampling. The probability of sampling a "
@@ -787,34 +812,13 @@ def __init__(self, nu, V, *args, **kwargs):
787
812
"https://github.com/pymc-devs/pymc3/issues/538." ,
788
813
UserWarning ,
789
814
)
790
- self .nu = nu = at .as_tensor_variable (nu )
791
- self .p = p = at .as_tensor_variable (V .shape [0 ])
792
- self .V = V = at .as_tensor_variable (V )
793
- self .mean = nu * V
794
- self .mode = at .switch (at .ge (nu , p + 1 ), (nu - p - 1 ) * V , np .nan )
795
815
796
- def random (self , point = None , size = None ):
797
- """
798
- Draw random values from Wishart distribution.
816
+ # mean = nu * V
817
+ # p = V.shape[0]
818
+ # mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan)
819
+ return super ().dist ([nu , V ], * args , ** kwargs )
799
820
800
- Parameters
801
- ----------
802
- point: dict, optional
803
- Dict of variable values on which random values are to be
804
- conditioned (uses default point if not specified).
805
- size: int, optional
806
- Desired size of random sample (returns one sample if not
807
- specified).
808
-
809
- Returns
810
- -------
811
- array
812
- """
813
- # nu, V = draw_values([self.nu, self.V], point=point, size=size)
814
- # size = 1 if size is None else size
815
- # return generate_samples(stats.wishart.rvs, nu.item(), V, broadcast_shape=(size,))
816
-
817
- def logp (self , X ):
821
+ def logp (X , nu , V ):
818
822
"""
819
823
Calculate log-probability of Wishart distribution
820
824
at specified value.
@@ -828,9 +832,8 @@ def logp(self, X):
828
832
-------
829
833
TensorVariable
830
834
"""
831
- nu = self .nu
832
- p = self .p
833
- V = self .V
835
+
836
+ p = V .shape [0 ]
834
837
835
838
IVI = det (V )
836
839
IXI = det (X )
@@ -1445,6 +1448,36 @@ def _distr_parameters_for_repr(self):
1445
1448
return ["eta" , "n" ]
1446
1449
1447
1450
1451
+ class MatrixNormalRV (RandomVariable ):
1452
+ name = "matrixnormal"
1453
+ ndim_supp = 2
1454
+ ndims_params = [2 , 2 , 2 ]
1455
+ dtype = "floatX"
1456
+ _print_name = ("MatrixNormal" , "\\ operatorname{MatrixNormal}" )
1457
+
1458
+ @classmethod
1459
+ def rng_fn (cls , rng , mu , rowchol , colchol , size = None ):
1460
+
1461
+ size = to_tuple (size )
1462
+ dist_shape = to_tuple ([rowchol .shape [0 ], colchol .shape [0 ]])
1463
+ output_shape = size + dist_shape
1464
+
1465
+ # Broadcasting all parameters
1466
+ (mu ,) = broadcast_dist_samples_to (to_shape = output_shape , samples = [mu ], size = size )
1467
+ rowchol = np .broadcast_to (rowchol , shape = size + rowchol .shape [- 2 :])
1468
+
1469
+ colchol = np .broadcast_to (colchol , shape = size + colchol .shape [- 2 :])
1470
+ colchol = np .swapaxes (colchol , - 1 , - 2 ) # Take transpose
1471
+
1472
+ standard_normal = rng .standard_normal (output_shape )
1473
+ samples = mu + np .matmul (rowchol , np .matmul (standard_normal , colchol ))
1474
+
1475
+ return samples
1476
+
1477
+
1478
+ matrixnormal = MatrixNormalRV ()
1479
+
1480
+
1448
1481
class MatrixNormal (Continuous ):
1449
1482
r"""
1450
1483
Matrix-valued normal log-likelihood.
@@ -1533,175 +1566,101 @@ class MatrixNormal(Continuous):
1533
1566
vals = pm.MatrixNormal('vals', mu=mu, colchol=colchol, rowcov=rowcov,
1534
1567
observed=data, shape=(m, n))
1535
1568
"""
1569
+ rv_op = matrixnormal
1536
1570
1537
- def __init__ (
1538
- self ,
1539
- mu = 0 ,
1571
+ @classmethod
1572
+ def dist (
1573
+ cls ,
1574
+ mu ,
1540
1575
rowcov = None ,
1541
1576
rowchol = None ,
1542
- rowtau = None ,
1543
1577
colcov = None ,
1544
1578
colchol = None ,
1545
- coltau = None ,
1546
1579
shape = None ,
1547
1580
* args ,
1548
1581
** kwargs ,
1549
1582
):
1550
- self ._setup_matrices (colcov , colchol , coltau , rowcov , rowchol , rowtau )
1551
- if shape is None :
1552
- raise TypeError ("shape is a required argument" )
1553
- assert len (shape ) == 2 , "shape must have length 2: mxn"
1554
- self .shape = shape
1555
- super ().__init__ (shape = shape , * args , ** kwargs )
1556
- self .mu = at .as_tensor_variable (mu )
1557
- self .mean = self .median = self .mode = self .mu
1558
- self .solve_lower = solve_lower_triangular
1559
- self .solve_upper = solve_upper_triangular
1560
-
1561
- def _setup_matrices (self , colcov , colchol , coltau , rowcov , rowchol , rowtau ):
1583
+
1562
1584
cholesky = Cholesky (lower = True , on_error = "raise" )
1563
1585
1586
+ if mu .ndim == 1 :
1587
+ raise ValueError (
1588
+ "1x1 Matrix was provided. Please use Normal distribution for such cases."
1589
+ )
1590
+
1564
1591
# Among-row matrices
1565
- if len ([i for i in [rowtau , rowcov , rowchol ] if i is not None ]) != 1 :
1592
+ if len ([i for i in [rowcov , rowchol ] if i is not None ]) != 1 :
1566
1593
raise ValueError (
1567
- "Incompatible parameterization. "
1568
- "Specify exactly one of rowtau, rowcov, "
1569
- "or rowchol."
1594
+ "Incompatible parameterization. Specify exactly one of rowcov, or rowchol."
1570
1595
)
1571
1596
if rowcov is not None :
1572
- self .m = rowcov .shape [0 ]
1573
- self ._rowcov_type = "cov"
1574
- rowcov = at .as_tensor_variable (rowcov )
1575
1597
if rowcov .ndim != 2 :
1576
1598
raise ValueError ("rowcov must be two dimensional." )
1577
- self .rowchol_cov = cholesky (rowcov )
1578
- self .rowcov = rowcov
1579
- elif rowtau is not None :
1580
- raise ValueError ("rowtau not supported at this time" )
1581
- self .m = rowtau .shape [0 ]
1582
- self ._rowcov_type = "tau"
1583
- rowtau = at .as_tensor_variable (rowtau )
1584
- if rowtau .ndim != 2 :
1585
- raise ValueError ("rowtau must be two dimensional." )
1586
- self .rowchol_tau = cholesky (rowtau )
1587
- self .rowtau = rowtau
1599
+ rowchol_cov = cholesky (rowcov )
1588
1600
else :
1589
- self .m = rowchol .shape [0 ]
1590
- self ._rowcov_type = "chol"
1591
1601
if rowchol .ndim != 2 :
1592
1602
raise ValueError ("rowchol must be two dimensional." )
1593
- self . rowchol_cov = at .as_tensor_variable (rowchol )
1603
+ rowchol_cov = at .as_tensor_variable (rowchol )
1594
1604
1595
1605
# Among-column matrices
1596
- if len ([i for i in [coltau , colcov , colchol ] if i is not None ]) != 1 :
1606
+ if len ([i for i in [colcov , colchol ] if i is not None ]) != 1 :
1597
1607
raise ValueError (
1598
- "Incompatible parameterization. "
1599
- "Specify exactly one of coltau, colcov, "
1600
- "or colchol."
1608
+ "Incompatible parameterization. Specify exactly one of colcov, or colchol."
1601
1609
)
1602
1610
if colcov is not None :
1603
- self .n = colcov .shape [0 ]
1604
- self ._colcov_type = "cov"
1605
1611
colcov = at .as_tensor_variable (colcov )
1606
1612
if colcov .ndim != 2 :
1607
1613
raise ValueError ("colcov must be two dimensional." )
1608
- self .colchol_cov = cholesky (colcov )
1609
- self .colcov = colcov
1610
- elif coltau is not None :
1611
- raise ValueError ("coltau not supported at this time" )
1612
- self .n = coltau .shape [0 ]
1613
- self ._colcov_type = "tau"
1614
- coltau = at .as_tensor_variable (coltau )
1615
- if coltau .ndim != 2 :
1616
- raise ValueError ("coltau must be two dimensional." )
1617
- self .colchol_tau = cholesky (coltau )
1618
- self .coltau = coltau
1614
+ colchol_cov = cholesky (colcov )
1619
1615
else :
1620
- self .n = colchol .shape [0 ]
1621
- self ._colcov_type = "chol"
1622
1616
if colchol .ndim != 2 :
1623
1617
raise ValueError ("colchol must be two dimensional." )
1624
- self . colchol_cov = at .as_tensor_variable (colchol )
1618
+ colchol_cov = at .as_tensor_variable (colchol )
1625
1619
1626
- def random (self , point = None , size = None ):
1620
+ mu = at .as_tensor_variable (floatX (mu ))
1621
+ # mean = median = mode = mu
1622
+
1623
+ return super ().dist ([mu , rowchol_cov , colchol_cov ], ** kwargs )
1624
+
1625
+ def logp (value , mu , rowchol , colchol ):
1627
1626
"""
1628
- Draw random values from Matrix-valued Normal distribution.
1627
+ Calculate log-probability of Matrix-valued Normal distribution
1628
+ at specified value.
1629
1629
1630
1630
Parameters
1631
1631
----------
1632
- point: dict, optional
1633
- Dict of variable values on which random values are to be
1634
- conditioned (uses default point if not specified).
1635
- size: int, optional
1636
- Desired size of random sample (returns one sample if not
1637
- specified).
1632
+ value: numeric
1633
+ Value for which log-probability is calculated.
1638
1634
1639
1635
Returns
1640
1636
-------
1641
- array
1637
+ TensorVariable
1642
1638
"""
1643
- # mu, colchol, rowchol = draw_values(
1644
- # [self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size
1645
- # )
1646
- # size = to_tuple(size)
1647
- # dist_shape = to_tuple(self.shape)
1648
- # output_shape = size + dist_shape
1649
- #
1650
- # # Broadcasting all parameters
1651
- # (mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
1652
- # rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
1653
- #
1654
- # colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
1655
- # colchol = np.swapaxes(colchol, -1, -2) # Take transpose
1656
- #
1657
- # standard_normal = np.random.standard_normal(output_shape)
1658
- # samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))
1659
- # return samples
1660
-
1661
- def _trquaddist (self , value ):
1662
- """Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
1663
- the logdet of colcov and rowcov."""
1664
1639
1665
- delta = value - self . mu
1666
- rowchol_cov = self . rowchol_cov
1667
- colchol_cov = self . colchol_cov
1640
+ # Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
1641
+ # the logdet of colcov and rowcov.
1642
+ delta = value - mu
1668
1643
1669
1644
# Find exponent piece by piece
1670
- right_quaddist = self . solve_lower ( rowchol_cov , delta )
1645
+ right_quaddist = solve_lower_triangular ( rowchol , delta )
1671
1646
quaddist = at .nlinalg .matrix_dot (right_quaddist .T , right_quaddist )
1672
- quaddist = self . solve_lower ( colchol_cov , quaddist )
1673
- quaddist = self . solve_upper ( colchol_cov .T , quaddist )
1647
+ quaddist = solve_lower_triangular ( colchol , quaddist )
1648
+ quaddist = solve_upper_triangular ( colchol .T , quaddist )
1674
1649
trquaddist = at .nlinalg .trace (quaddist )
1675
1650
1676
- coldiag = at .diag (colchol_cov )
1677
- rowdiag = at .diag (rowchol_cov )
1651
+ coldiag = at .diag (colchol )
1652
+ rowdiag = at .diag (rowchol )
1678
1653
half_collogdet = at .sum (at .log (coldiag )) # logdet(M) = 2*Tr(log(L))
1679
1654
half_rowlogdet = at .sum (at .log (rowdiag )) # Using Cholesky: M = L L^T
1680
- return trquaddist , half_collogdet , half_rowlogdet
1681
-
1682
- def logp (self , value ):
1683
- """
1684
- Calculate log-probability of Matrix-valued Normal distribution
1685
- at specified value.
1686
1655
1687
- Parameters
1688
- ----------
1689
- value: numeric
1690
- Value for which log-probability is calculated.
1656
+ m = rowchol .shape [0 ]
1657
+ n = colchol .shape [0 ]
1691
1658
1692
- Returns
1693
- -------
1694
- TensorVariable
1695
- """
1696
- trquaddist , half_collogdet , half_rowlogdet = self ._trquaddist (value )
1697
- m = self .m
1698
- n = self .n
1699
1659
norm = - 0.5 * m * n * pm .floatX (np .log (2 * np .pi ))
1700
1660
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
1701
1661
1702
1662
def _distr_parameters_for_repr (self ):
1703
- mapping = {"tau" : "tau" , "cov" : "cov" , "chol" : "chol_cov" }
1704
- return ["mu" , "row" + mapping [self ._rowcov_type ], "col" + mapping [self ._colcov_type ]]
1663
+ return ["mu" ]
1705
1664
1706
1665
1707
1666
class KroneckerNormalRV (RandomVariable ):
0 commit comments