1212import torch as th
1313
1414from gptorch .model import Param , Model , GPModel
15+ from gptorch .kernels import Rbf
16+ from gptorch .models import GPR
1517
1618
1719class TestParam (object ):
@@ -58,4 +60,42 @@ class TestGPModel(object):
5860 """
5961 Tests for the GPModel class
6062 """
61- pass
63+ def test_predict_f_samples (self ):
64+ # TODO mock a GPModel? Using GPR for the moment.
65+ n , dx , dy = 5 , 3 , 2
66+ x , y = np .random .randn (n , dx ), np .random .randn (n , dy )
67+ kern = Rbf (dx , ARD = True )
68+ gp = GPR (y , x , kern )
69+
70+ n_test = 5
71+ x_test = np .random .randn (n_test , dx )
72+ f_samples = gp .predict_f_samples (x_test )
73+ assert isinstance (f_samples , th .Tensor )
74+ assert f_samples .ndimension () == 3 # [sample x n_test x dy]
75+ assert f_samples .shape == (1 , n_test , dy )
76+
77+ n_samples = 3
78+ f_samples_2 = gp .predict_f_samples (x_test , n_samples = n_samples )
79+ assert isinstance (f_samples_2 , th .Tensor )
80+ assert f_samples_2 .ndimension () == 3 # [sample x n_test x dy]
81+ assert f_samples_2 .shape == (n_samples , n_test , dy )
82+
83+ def test_predict_y_samples (self ):
84+ # TODO mock a GPModel? Using GPR for the moment.
85+ n , dx , dy = 5 , 3 , 2
86+ x , y = np .random .randn (n , dx ), np .random .randn (n , dy )
87+ kern = Rbf (dx , ARD = True )
88+ gp = GPR (y , x , kern )
89+
90+ n_test = 5
91+ x_test = np .random .randn (n_test , dx )
92+ y_samples = gp .predict_y_samples (x_test )
93+ assert isinstance (y_samples , th .Tensor )
94+ assert y_samples .ndimension () == 3 # [sample x n_test x dy]
95+ assert y_samples .shape == (1 , n_test , dy )
96+
97+ n_samples = 3
98+ y_samples_2 = gp .predict_y_samples (x_test , n_samples = n_samples )
99+ assert isinstance (y_samples_2 , th .Tensor )
100+ assert y_samples_2 .ndimension () == 3 # [sample x n_test x dy]
101+ assert y_samples_2 .shape == (n_samples , n_test , dy )
0 commit comments