Skip to content

Commit

Permalink
Merge pull request #7 from cics-nd/predict-f-samples
Browse files Browse the repository at this point in the history
Predict f samples
  • Loading branch information
sdatkinson authored Jul 22, 2019
2 parents e24b3b8 + dc55c25 commit 0fc5c94
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 2 deletions.
14 changes: 14 additions & 0 deletions gptorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,20 @@ def predict_y(self, input_new, diag=True):
else:
return self.likelihood.predict_mean_covariance(mean_f, cov_f)

@input_as_tensor
def predict_f_samples(self, input_new, n_samples=1):
"""
Return [n_samp x n_test x d_y] matrix of samples
:param input_new:
:param n_samples:
:return:
"""
mu, sigma = self.predict_f(input_new, diag=False)
chol_s = cholesky(sigma)
samp = mu + torch.stack([torch.mm(chol_s, Variable(torch.Tensor(r)))
for r in np.random.randn(n_samples, *mu.size())])
return samp

@input_as_tensor
def predict_y_samples(self, input_new, n_samples=1):
"""
Expand Down
2 changes: 1 addition & 1 deletion test/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_logp(self):
assert isinstance(logp, Variable)

# Value
assert logp.data.numpy() == expected_logp
assert logp.data.numpy() == pytest.approx(expected_logp)

def test_predict_mean_variance(self):
"""
Expand Down
42 changes: 41 additions & 1 deletion test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch as th

from gptorch.model import Param, Model, GPModel
from gptorch.kernels import Rbf
from gptorch.models import GPR


class TestParam(object):
Expand Down Expand Up @@ -58,4 +60,42 @@ class TestGPModel(object):
"""
Tests for the GPModel class
"""
pass
def test_predict_f_samples(self):
# TODO mock a GPModel? Using GPR for the moment.
n, dx, dy = 5, 3, 2
x, y = np.random.randn(n, dx), np.random.randn(n, dy)
kern = Rbf(dx, ARD=True)
gp = GPR(y, x, kern)

n_test = 5
x_test = np.random.randn(n_test, dx)
f_samples = gp.predict_f_samples(x_test)
assert isinstance(f_samples, th.Tensor)
assert f_samples.ndimension() == 3 # [sample x n_test x dy]
assert f_samples.shape == (1, n_test, dy)

n_samples = 3
f_samples_2 = gp.predict_f_samples(x_test, n_samples=n_samples)
assert isinstance(f_samples_2, th.Tensor)
assert f_samples_2.ndimension() == 3 # [sample x n_test x dy]
assert f_samples_2.shape == (n_samples, n_test, dy)

def test_predict_y_samples(self):
# TODO mock a GPModel? Using GPR for the moment.
n, dx, dy = 5, 3, 2
x, y = np.random.randn(n, dx), np.random.randn(n, dy)
kern = Rbf(dx, ARD=True)
gp = GPR(y, x, kern)

n_test = 5
x_test = np.random.randn(n_test, dx)
y_samples = gp.predict_y_samples(x_test)
assert isinstance(y_samples, th.Tensor)
assert y_samples.ndimension() == 3 # [sample x n_test x dy]
assert y_samples.shape == (1, n_test, dy)

n_samples = 3
y_samples_2 = gp.predict_y_samples(x_test, n_samples=n_samples)
assert isinstance(y_samples_2, th.Tensor)
assert y_samples_2.ndimension() == 3 # [sample x n_test x dy]
assert y_samples_2.shape == (n_samples, n_test, dy)
3 changes: 3 additions & 0 deletions test/test_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# File: __init__.py
# File Created: Sunday, 21st July 2019 8:06:08 pm
# Author: Steven Atkinson ([email protected])
14 changes: 14 additions & 0 deletions test/test_models/test_gpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# File: test_gpr.py
# File Created: Sunday, 21st July 2019 7:53:26 pm
# Author: Steven Atkinson ([email protected])

"""
Tests for the GPR class
"""

import pytest


def TestGPR(object):
# TODO
pass

0 comments on commit 0fc5c94

Please sign in to comment.