Skip to content

Commit 0fc5c94

Browse files
authored
Merge pull request #7 from cics-nd/predict-f-samples
Predict f samples
2 parents e24b3b8 + dc55c25 commit 0fc5c94

File tree

5 files changed

+73
-2
lines changed

5 files changed

+73
-2
lines changed

gptorch/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,20 @@ def predict_y(self, input_new, diag=True):
495495
else:
496496
return self.likelihood.predict_mean_covariance(mean_f, cov_f)
497497

498+
@input_as_tensor
499+
def predict_f_samples(self, input_new, n_samples=1):
500+
"""
501+
Return [n_samp x n_test x d_y] matrix of samples
502+
:param input_new:
503+
:param n_samples:
504+
:return:
505+
"""
506+
mu, sigma = self.predict_f(input_new, diag=False)
507+
chol_s = cholesky(sigma)
508+
samp = mu + torch.stack([torch.mm(chol_s, Variable(torch.Tensor(r)))
509+
for r in np.random.randn(n_samples, *mu.size())])
510+
return samp
511+
498512
@input_as_tensor
499513
def predict_y_samples(self, input_new, n_samples=1):
500514
"""

test/test_likelihoods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_logp(self):
5555
assert isinstance(logp, Variable)
5656

5757
# Value
58-
assert logp.data.numpy() == expected_logp
58+
assert logp.data.numpy() == pytest.approx(expected_logp)
5959

6060
def test_predict_mean_variance(self):
6161
"""

test/test_model.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import torch as th
1313

1414
from gptorch.model import Param, Model, GPModel
15+
from gptorch.kernels import Rbf
16+
from gptorch.models import GPR
1517

1618

1719
class TestParam(object):
@@ -58,4 +60,42 @@ class TestGPModel(object):
5860
"""
5961
Tests for the GPModel class
6062
"""
61-
pass
63+
def test_predict_f_samples(self):
64+
# TODO mock a GPModel? Using GPR for the moment.
65+
n, dx, dy = 5, 3, 2
66+
x, y = np.random.randn(n, dx), np.random.randn(n, dy)
67+
kern = Rbf(dx, ARD=True)
68+
gp = GPR(y, x, kern)
69+
70+
n_test = 5
71+
x_test = np.random.randn(n_test, dx)
72+
f_samples = gp.predict_f_samples(x_test)
73+
assert isinstance(f_samples, th.Tensor)
74+
assert f_samples.ndimension() == 3 # [sample x n_test x dy]
75+
assert f_samples.shape == (1, n_test, dy)
76+
77+
n_samples = 3
78+
f_samples_2 = gp.predict_f_samples(x_test, n_samples=n_samples)
79+
assert isinstance(f_samples_2, th.Tensor)
80+
assert f_samples_2.ndimension() == 3 # [sample x n_test x dy]
81+
assert f_samples_2.shape == (n_samples, n_test, dy)
82+
83+
def test_predict_y_samples(self):
84+
# TODO mock a GPModel? Using GPR for the moment.
85+
n, dx, dy = 5, 3, 2
86+
x, y = np.random.randn(n, dx), np.random.randn(n, dy)
87+
kern = Rbf(dx, ARD=True)
88+
gp = GPR(y, x, kern)
89+
90+
n_test = 5
91+
x_test = np.random.randn(n_test, dx)
92+
y_samples = gp.predict_y_samples(x_test)
93+
assert isinstance(y_samples, th.Tensor)
94+
assert y_samples.ndimension() == 3 # [sample x n_test x dy]
95+
assert y_samples.shape == (1, n_test, dy)
96+
97+
n_samples = 3
98+
y_samples_2 = gp.predict_y_samples(x_test, n_samples=n_samples)
99+
assert isinstance(y_samples_2, th.Tensor)
100+
assert y_samples_2.ndimension() == 3 # [sample x n_test x dy]
101+
assert y_samples_2.shape == (n_samples, n_test, dy)

test/test_models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# File: __init__.py
2+
# File Created: Sunday, 21st July 2019 8:06:08 pm
3+
# Author: Steven Atkinson ([email protected])

test/test_models/test_gpr.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# File: test_gpr.py
2+
# File Created: Sunday, 21st July 2019 7:53:26 pm
3+
# Author: Steven Atkinson ([email protected])
4+
5+
"""
6+
Tests for the GPR class
7+
"""
8+
9+
import pytest
10+
11+
12+
def TestGPR(object):
13+
# TODO
14+
pass

0 commit comments

Comments
 (0)