From b370d1131ec7d682823cb9c1fb0743a0ce6ae6c1 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Thu, 5 Sep 2019 14:25:02 -0400 Subject: [PATCH] Add matrix multiplication infix operator --- RELEASE-NOTES.md | 1 + docs/source/index.rst | 2 +- pymc3/model.py | 16 +++++++++++++--- pymc3/tests/test_model.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 5933018c24..d0427c33c9 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -9,6 +9,7 @@ - Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2. - Progressbar reports number of divergences in real time, when available [#3547](https://github.com/pymc-devs/pymc3/pull/3547). - Sampling from variational approximation now allows for alternative trace backends [#3550]. +- Infix `@` operator now works with random variables and deterministics [#3578](https://github.com/pymc-devs/pymc3/pull/3578). ### Maintenance - Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508) diff --git a/docs/source/index.rst b/docs/source/index.rst index 75c328f2a8..e7a540af02 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -26,7 +26,7 @@ weights = pm.Normal('weights', mu=0, sigma=1) noise = pm.Gamma('noise', alpha=2, beta=1) y_observed = pm.Normal('y_observed', - mu=X.dot(weights), + mu=X @ weights, sigma=noise, observed=y) diff --git a/pymc3/model.py b/pymc3/model.py index 083de57101..00423e5514 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -30,6 +30,16 @@ FlatView = collections.namedtuple('FlatView', 'input, replacements, view') +class PyMC3Variable(TensorVariable): + """Class to wrap Theano TensorVariable for custom behavior.""" + + # Implement matrix multiplication infix operator: X @ w + __matmul__ = tt.dot + + def __rmatmul__(self, other): + return tt.dot(other, self) + + class InstanceMethod: """Class for hiding references to instance methods so they can be pickled. @@ -1245,7 +1255,7 @@ def _get_scaling(total_size, shape, ndim): return tt.as_tensor(floatX(coef)) -class FreeRV(Factor, TensorVariable): +class FreeRV(Factor, PyMC3Variable): """Unobserved random variable that a model is specified in terms of.""" def __init__(self, type=None, owner=None, index=None, name=None, @@ -1354,7 +1364,7 @@ def as_tensor(data, name, model, distribution): return data -class ObservedRV(Factor, TensorVariable): +class ObservedRV(Factor, PyMC3Variable): """Observed random variable that a model is specified in terms of. Potentially partially observed. """ @@ -1525,7 +1535,7 @@ def Potential(name, var, model=None): return var -class TransformedRV(TensorVariable): +class TransformedRV(PyMC3Variable): """ Parameters ---------- diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index 8bb47aeec3..78ad1daf04 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -157,6 +157,34 @@ def test_nested(self): assert theano.config.compute_test_value == 'ignore' assert theano.config.compute_test_value == 'off' +def test_matrix_multiplication(): + # Check matrix multiplication works between RVs, transformed RVs, + # Deterministics, and numpy arrays + with pm.Model() as linear_model: + matrix = pm.Normal('matrix', shape=(2, 2)) + transformed = pm.Gamma('transformed', alpha=2, beta=1, shape=2) + rv_rv = pm.Deterministic('rv_rv', matrix @ transformed) + np_rv = pm.Deterministic('np_rv', np.ones((2, 2)) @ transformed) + rv_np = pm.Deterministic('rv_np', matrix @ np.ones(2)) + rv_det = pm.Deterministic('rv_det', matrix @ rv_rv) + det_rv = pm.Deterministic('det_rv', rv_rv @ transformed) + + posterior = pm.sample(10, + tune=0, + compute_convergence_checks=False, + progressbar=False) + for point in posterior.points(): + npt.assert_almost_equal(point['matrix'] @ point['transformed'], + point['rv_rv']) + npt.assert_almost_equal(np.ones((2, 2)) @ point['transformed'], + point['np_rv']) + npt.assert_almost_equal(point['matrix'] @ np.ones(2), + point['rv_np']) + npt.assert_almost_equal(point['matrix'] @ point['rv_rv'], + point['rv_det']) + npt.assert_almost_equal(point['rv_rv'] @ point['transformed'], + point['det_rv']) + def test_duplicate_vars(): with pytest.raises(ValueError) as err: