Skip to content

Commit 8f02bea

Browse files
authored
Marginalapprox fix (#6076)
* switch from using DensityDist to using Potential * increase tolerance on flaky tests, add test using find_MAP * refactor MarginalApprox tests * run precommit * address comments, pass approx arg correctly, improve docstrings * fix comments, make pass jitter through correctly, get rid of is_observed arg
1 parent 0b191ad commit 8f02bea

File tree

2 files changed

+60
-92
lines changed

2 files changed

+60
-92
lines changed

pymc/gp/gp.py

+4-44
Original file line numberDiff line numberDiff line change
@@ -685,18 +685,13 @@ def __init__(self, approx="VFE", *, mean_func=Zero(), cov_func=Constant(0.0)):
685685
super().__init__(mean_func=mean_func, cov_func=cov_func)
686686

687687
def __add__(self, other):
688-
# new_gp will default to FITC approx
689688
new_gp = super().__add__(other)
690-
# make sure new gp has correct approx
691689
if not self.approx == other.approx:
692690
raise TypeError("Cannot add GPs with different approximations")
693691
new_gp.approx = self.approx
694692
return new_gp
695693

696-
# Use y as first argument, so that we can use functools.partial
697-
# in marginal_likelihood instead of lambda. This makes pickling
698-
# possible.
699-
def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
694+
def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
700695
sigma2 = at.square(sigma)
701696
Kuu = self.cov_func(Xu)
702697
Kuf = self.cov_func(Xu, X)
@@ -725,9 +720,7 @@ def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
725720
quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c))
726721
return -1.0 * (constant + logdet + quadratic + trace)
727722

728-
def marginal_likelihood(
729-
self, name, X, Xu, y, noise=None, is_observed=True, jitter=JITTER_DEFAULT, **kwargs
730-
):
723+
def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT, **kwargs):
731724
R"""
732725
Returns the approximate marginal likelihood distribution, given the input
733726
locations `X`, inducing point locations `Xu`, data `y`, and white noise
@@ -747,9 +740,6 @@ def marginal_likelihood(
747740
noise. Must have shape `(n, )`.
748741
noise: scalar, Variable
749742
Standard deviation of the Gaussian noise.
750-
is_observed: bool
751-
Whether to set `y` as an `observed` variable in the `model`.
752-
Default is `True`.
753743
jitter: scalar
754744
A small correction added to the diagonal of positive semi-definite
755745
covariance matrices to ensure numerical stability.
@@ -767,38 +757,8 @@ def marginal_likelihood(
767757
else:
768758
self.sigma = noise
769759

770-
if is_observed:
771-
return pm.DensityDist(
772-
name,
773-
X,
774-
Xu,
775-
self.sigma,
776-
jitter,
777-
logp=self._build_marginal_likelihood_logp,
778-
observed=y,
779-
ndims_params=[2, 2, 0],
780-
size=X.shape[0],
781-
**kwargs,
782-
)
783-
else:
784-
warnings.warn(
785-
"The 'is_observed' argument has been deprecated. If the GP is "
786-
"unobserved use gp.Latent instead.",
787-
FutureWarning,
788-
)
789-
return pm.DensityDist(
790-
name,
791-
X,
792-
Xu,
793-
self.sigma,
794-
jitter,
795-
logp=self._build_marginal_likelihood_logp,
796-
observed=y,
797-
ndims_params=[2, 2, 0],
798-
# ndim_supp=1,
799-
size=X.shape[0],
800-
**kwargs,
801-
)
760+
approx_loglik = self._build_marginal_likelihood_loglik(y, X, Xu, noise, jitter)
761+
pm.Potential(f"marginalapprox_loglik_{name}", approx_loglik, **kwargs)
802762

803763
def _build_conditional(
804764
self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total, jitter

pymc/tests/test_gp.py

+56-48
Original file line numberDiff line numberDiff line change
@@ -846,63 +846,71 @@ def testLatent2(self):
846846

847847
class TestMarginalVsMarginalApprox:
848848
R"""
849-
Compare logp of models Marginal and MarginalApprox.
850-
Should be nearly equal when inducing points are same as inputs.
849+
Compare test fits of models Marginal and MarginalApprox.
851850
"""
852851

853852
def setup_method(self):
854-
X = np.random.randn(50, 3)
855-
y = np.random.randn(50)
856-
Xnew = np.random.randn(60, 3)
857-
pnew = np.random.randn(60)
858-
with pm.Model() as model:
859-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
860-
mean_func = pm.gp.mean.Constant(0.5)
861-
gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
862-
sigma = 0.1
863-
f = gp.marginal_likelihood("f", X, y, noise=sigma)
864-
p = gp.conditional("p", Xnew)
865-
self.logp = model.compile_logp()({"p": pnew})
866-
self.X = X
867-
self.Xnew = Xnew
868-
self.y = y
869-
self.sigma = sigma
870-
self.pnew = pnew
871-
self.gp = gp
853+
self.sigma = 0.1
854+
self.x = np.linspace(-5, 5, 30)
855+
self.y = np.random.normal(0.25 * self.x, self.sigma)
856+
with pm.Model() as model:
857+
cov_func = pm.gp.cov.Linear(1, c=0.0)
858+
c = pm.Normal("c", mu=20.0, sigma=100.0) # far from true value
859+
mean_func = pm.gp.mean.Constant(c)
860+
self.gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
861+
sigma = pm.HalfNormal("sigma", sigma=100)
862+
self.gp.marginal_likelihood("lik", self.x[:, None], self.y, sigma)
863+
self.map_full = pm.find_MAP(method="bfgs") # bfgs seems to work much better than lbfgsb
864+
865+
self.x_new = np.linspace(-6, 6, 20)
866+
867+
# Include additive Gaussian noise, return diagonal of predicted covariance matrix
868+
with model:
869+
self.pred_mu, self.pred_var = self.gp.predict(
870+
self.x_new[:, None], point=self.map_full, pred_noise=True, diag=True
871+
)
872872

873-
@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
874-
def testApproximations(self, approx):
875-
with pm.Model() as model:
876-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
877-
mean_func = pm.gp.mean.Constant(0.5)
878-
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
879-
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
880-
p = gp.conditional("p", self.Xnew)
881-
approx_logp = model.compile_logp()({"p": self.pnew})
882-
npt.assert_allclose(approx_logp, self.logp, atol=0, rtol=1e-2)
873+
# Dont include additive Gaussian noise, return full predicted covariance matrix
874+
with model:
875+
self.pred_mu, self.pred_covar = self.gp.predict(
876+
self.x_new[:, None], point=self.map_full, pred_noise=False, diag=False
877+
)
883878

884879
@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
885-
def testPredictVar(self, approx):
880+
def test_fits_and_preds(self, approx):
881+
"""Get MAP estimate for GP approximation, compare results and predictions to what's returned
882+
by an unapproximated GP. The tolerances are fairly wide, but narrow relative to initial
883+
values of the unknown parameters.
884+
"""
885+
886886
with pm.Model() as model:
887-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
888-
mean_func = pm.gp.mean.Constant(0.5)
887+
cov_func = pm.gp.cov.Linear(1, c=0.0)
888+
c = pm.Normal("c", mu=20.0, sigma=100.0, initval=-500.0)
889+
mean_func = pm.gp.mean.Constant(c)
889890
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
890-
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
891-
mu1, var1 = self.gp.predict(self.Xnew, diag=True)
892-
mu2, var2 = gp.predict(self.Xnew, diag=True)
893-
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
894-
npt.assert_allclose(var1, var2, atol=0, rtol=1e-3)
891+
sigma = pm.HalfNormal("sigma", sigma=100, initval=50.0)
892+
gp.marginal_likelihood("lik", self.x[:, None], self.x[:, None], self.y, sigma)
893+
map_approx = pm.find_MAP(method="bfgs")
894+
895+
# Check MAP gets approximately correct result
896+
npt.assert_allclose(self.map_full["c"], map_approx["c"], atol=0.01, rtol=0.1)
897+
npt.assert_allclose(self.map_full["sigma"], map_approx["sigma"], atol=0.01, rtol=0.1)
898+
899+
# Check that predict (and conditional) work, include noise, with diagonal non-full pred var.
900+
with model:
901+
pred_mu_approx, pred_var_approx = gp.predict(
902+
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
903+
)
904+
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
905+
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)
895906

896-
def testPredictCov(self):
897-
with pm.Model() as model:
898-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
899-
mean_func = pm.gp.mean.Constant(0.5)
900-
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx="DTC")
901-
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
902-
mu1, cov1 = self.gp.predict(self.Xnew, pred_noise=True)
903-
mu2, cov2 = gp.predict(self.Xnew, pred_noise=True)
904-
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
905-
npt.assert_allclose(cov1, cov2, atol=0, rtol=1e-3)
907+
# Check that predict (and conditional) work, no noise, full pred covariance.
908+
with model:
909+
pred_mu_approx, pred_var_approx = gp.predict(
910+
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
911+
)
912+
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
913+
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)
906914

907915

908916
class TestGPAdditive:

0 commit comments

Comments
 (0)