Skip to content

Commit 6600655

Browse files
esantorellafacebook-github-bot
authored andcommitted
Fix pathwise sampler bug (#2337)
Summary: Pull Request resolved: #2337 `_draw_matheron_paths_ExactGP` calls `_gaussian_update_ExactGP` (as the update_strategy argument). It was passing an incorrect argument `train_targets` in place of `target_values`. This led the `Y` values to being silently ignored. Reviewed By: saitcakmak Differential Revision: D57175013 fbshipit-source-id: f3a1c304ee221d1460284b51f423af894c7c6a9f
1 parent 07f12b7 commit 6600655

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

botorch/sampling/pathwise/posterior_samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _draw_matheron_paths_ExactGP(
139139
update_paths = update_strategy(
140140
model=model,
141141
sample_values=sample_values,
142-
train_targets=train_Y,
142+
target_values=train_Y,
143143
)
144144

145145
return MatheronPath(

botorch/sampling/pathwise/update_strategies.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def _gaussian_update_ExactGP(
112112
points: Optional[Tensor] = None,
113113
noise_covariance: Optional[Union[Tensor, LinearOperator]] = None,
114114
scale_tril: Optional[Union[Tensor, LinearOperator]] = None,
115-
**ignore: Any,
116115
) -> GeneralizedLinearPath:
117116
if points is None:
118117
(points,) = get_train_inputs(model, transformed=True)

test/sampling/pathwise/test_update_strategies.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from botorch.utils.testing import BotorchTestCase
2727
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
2828
from gpytorch.likelihoods import BernoulliLikelihood
29+
from gpytorch.models import ExactGP
2930
from linear_operator.operators import ZeroLinearOperator
3031
from linear_operator.utils.cholesky import psd_safe_cholesky
3132
from torch import Size
@@ -204,3 +205,20 @@ def _test_gaussian_updates(self, model):
204205
with patch.object(model, "likelihood", new=BernoulliLikelihood()):
205206
with self.assertRaises(NotImplementedError):
206207
gaussian_update(model=model, sample_values=sample_values)
208+
209+
with self.subTest("Exact models with `None` target_values"):
210+
assert isinstance(model, ExactGP)
211+
torch.manual_seed(0)
212+
path_none_target_values = gaussian_update(
213+
model=model,
214+
sample_values=sample_values,
215+
)
216+
torch.manual_seed(0)
217+
path_with_target_values = gaussian_update(
218+
model=model,
219+
sample_values=sample_values,
220+
target_values=get_train_targets(model, transformed=True),
221+
)
222+
self.assertAllClose(
223+
path_none_target_values.weight, path_with_target_values.weight
224+
)

0 commit comments

Comments
 (0)