@@ -846,63 +846,71 @@ def testLatent2(self):
846
846
847
847
class TestMarginalVsMarginalApprox :
848
848
R"""
849
- Compare logp of models Marginal and MarginalApprox.
850
- Should be nearly equal when inducing points are same as inputs.
849
+ Compare test fits of models Marginal and MarginalApprox.
851
850
"""
852
851
853
852
def setup_method (self ):
854
- X = np .random .randn (50 , 3 )
855
- y = np .random .randn (50 )
856
- Xnew = np .random .randn (60 , 3 )
857
- pnew = np .random .randn (60 )
858
- with pm .Model () as model :
859
- cov_func = pm .gp .cov .ExpQuad (3 , [0.1 , 0.2 , 0.3 ])
860
- mean_func = pm .gp .mean .Constant (0.5 )
861
- gp = pm .gp .Marginal (mean_func = mean_func , cov_func = cov_func )
862
- sigma = 0.1
863
- f = gp .marginal_likelihood ("f" , X , y , noise = sigma )
864
- p = gp .conditional ("p" , Xnew )
865
- self .logp = model .compile_logp ()({"p" : pnew })
866
- self .X = X
867
- self .Xnew = Xnew
868
- self .y = y
869
- self .sigma = sigma
870
- self .pnew = pnew
871
- self .gp = gp
853
+ self .sigma = 0.1
854
+ self .x = np .linspace (- 5 , 5 , 30 )
855
+ self .y = np .random .normal (0.25 * self .x , self .sigma )
856
+ with pm .Model () as model :
857
+ cov_func = pm .gp .cov .Linear (1 , c = 0.0 )
858
+ c = pm .Normal ("c" , mu = 20.0 , sigma = 100.0 ) # far from true value
859
+ mean_func = pm .gp .mean .Constant (c )
860
+ self .gp = pm .gp .Marginal (mean_func = mean_func , cov_func = cov_func )
861
+ sigma = pm .HalfNormal ("sigma" , sigma = 100 )
862
+ self .gp .marginal_likelihood ("lik" , self .x [:, None ], self .y , sigma )
863
+ self .map_full = pm .find_MAP (method = "bfgs" ) # bfgs seems to work much better than lbfgsb
864
+
865
+ self .x_new = np .linspace (- 6 , 6 , 20 )
866
+
867
+ # Include additive Gaussian noise, return diagonal of predicted covariance matrix
868
+ with model :
869
+ self .pred_mu , self .pred_var = self .gp .predict (
870
+ self .x_new [:, None ], point = self .map_full , pred_noise = True , diag = True
871
+ )
872
872
873
- @pytest .mark .parametrize ("approx" , ["FITC" , "VFE" , "DTC" ])
874
- def testApproximations (self , approx ):
875
- with pm .Model () as model :
876
- cov_func = pm .gp .cov .ExpQuad (3 , [0.1 , 0.2 , 0.3 ])
877
- mean_func = pm .gp .mean .Constant (0.5 )
878
- gp = pm .gp .MarginalApprox (mean_func = mean_func , cov_func = cov_func , approx = approx )
879
- f = gp .marginal_likelihood ("f" , self .X , self .X , self .y , self .sigma )
880
- p = gp .conditional ("p" , self .Xnew )
881
- approx_logp = model .compile_logp ()({"p" : self .pnew })
882
- npt .assert_allclose (approx_logp , self .logp , atol = 0 , rtol = 1e-2 )
873
+ # Dont include additive Gaussian noise, return full predicted covariance matrix
874
+ with model :
875
+ self .pred_mu , self .pred_covar = self .gp .predict (
876
+ self .x_new [:, None ], point = self .map_full , pred_noise = False , diag = False
877
+ )
883
878
884
879
@pytest .mark .parametrize ("approx" , ["FITC" , "VFE" , "DTC" ])
885
- def testPredictVar (self , approx ):
880
+ def test_fits_and_preds (self , approx ):
881
+ """Get MAP estimate for GP approximation, compare results and predictions to what's returned
882
+ by an unapproximated GP. The tolerances are fairly wide, but narrow relative to initial
883
+ values of the unknown parameters.
884
+ """
885
+
886
886
with pm .Model () as model :
887
- cov_func = pm .gp .cov .ExpQuad (3 , [0.1 , 0.2 , 0.3 ])
888
- mean_func = pm .gp .mean .Constant (0.5 )
887
+ cov_func = pm .gp .cov .Linear (1 , c = 0.0 )
888
+ c = pm .Normal ("c" , mu = 20.0 , sigma = 100.0 , initval = - 500.0 )
889
+ mean_func = pm .gp .mean .Constant (c )
889
890
gp = pm .gp .MarginalApprox (mean_func = mean_func , cov_func = cov_func , approx = approx )
890
- f = gp .marginal_likelihood ("f" , self .X , self .X , self .y , self .sigma )
891
- mu1 , var1 = self .gp .predict (self .Xnew , diag = True )
892
- mu2 , var2 = gp .predict (self .Xnew , diag = True )
893
- npt .assert_allclose (mu1 , mu2 , atol = 0 , rtol = 1e-3 )
894
- npt .assert_allclose (var1 , var2 , atol = 0 , rtol = 1e-3 )
891
+ sigma = pm .HalfNormal ("sigma" , sigma = 100 , initval = 50.0 )
892
+ gp .marginal_likelihood ("lik" , self .x [:, None ], self .x [:, None ], self .y , sigma )
893
+ map_approx = pm .find_MAP (method = "bfgs" )
894
+
895
+ # Check MAP gets approximately correct result
896
+ npt .assert_allclose (self .map_full ["c" ], map_approx ["c" ], atol = 0.01 , rtol = 0.1 )
897
+ npt .assert_allclose (self .map_full ["sigma" ], map_approx ["sigma" ], atol = 0.01 , rtol = 0.1 )
898
+
899
+ # Check that predict (and conditional) work, include noise, with diagonal non-full pred var.
900
+ with model :
901
+ pred_mu_approx , pred_var_approx = gp .predict (
902
+ self .x_new [:, None ], point = map_approx , pred_noise = True , diag = True
903
+ )
904
+ npt .assert_allclose (self .pred_mu , pred_mu_approx , atol = 0.0 , rtol = 0.1 )
905
+ npt .assert_allclose (self .pred_var , pred_var_approx , atol = 0.0 , rtol = 0.1 )
895
906
896
- def testPredictCov (self ):
897
- with pm .Model () as model :
898
- cov_func = pm .gp .cov .ExpQuad (3 , [0.1 , 0.2 , 0.3 ])
899
- mean_func = pm .gp .mean .Constant (0.5 )
900
- gp = pm .gp .MarginalApprox (mean_func = mean_func , cov_func = cov_func , approx = "DTC" )
901
- f = gp .marginal_likelihood ("f" , self .X , self .X , self .y , self .sigma )
902
- mu1 , cov1 = self .gp .predict (self .Xnew , pred_noise = True )
903
- mu2 , cov2 = gp .predict (self .Xnew , pred_noise = True )
904
- npt .assert_allclose (mu1 , mu2 , atol = 0 , rtol = 1e-3 )
905
- npt .assert_allclose (cov1 , cov2 , atol = 0 , rtol = 1e-3 )
907
+ # Check that predict (and conditional) work, no noise, full pred covariance.
908
+ with model :
909
+ pred_mu_approx , pred_var_approx = gp .predict (
910
+ self .x_new [:, None ], point = map_approx , pred_noise = True , diag = True
911
+ )
912
+ npt .assert_allclose (self .pred_mu , pred_mu_approx , atol = 0.0 , rtol = 0.1 )
913
+ npt .assert_allclose (self .pred_var , pred_var_approx , atol = 0.0 , rtol = 0.1 )
906
914
907
915
908
916
class TestGPAdditive :
0 commit comments