Skip to content

Commit 5b3de17

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Skip tensor validation in condition_on_observations when fantasizing (#2378)
Summary: Pull Request resolved: #2378 Resolves #2376 When fantasizing, we expect `Y` to have an additional batch dimension, so it will fail the tensor shape checks. Since `Y` is produced by the model itself, it should always be a valid input. Disabling these checks to avoid misleading warnings. Reviewed By: Balandat Differential Revision: D58564318 fbshipit-source-id: 5d21784b116a7a72c93d3904dcdd36e172e785ad
1 parent 62c96f0 commit 5b3de17

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

botorch/models/gpytorch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
mod_batch_shape,
4040
multioutput_to_batch_mode_transform,
4141
)
42+
from botorch.models.utils.assorted import fantasize as fantasize_flag
4243
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
4344
from botorch.posteriors.gpytorch import GPyTorchPosterior
4445
from botorch.utils.multitask import separate_mtmvn
@@ -246,9 +247,11 @@ def condition_on_observations(
246247
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
247248
# `noise` is assumed to already be outcome-transformed.
248249
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar)
249-
# validate using strict=False, since we cannot tell if Y has an explicit
250-
# output dimension
251-
self._validate_tensor_args(X=X, Y=Y, Yvar=Yvar, strict=False)
250+
# Validate using strict=False, since we cannot tell if Y has an explicit
251+
# output dimension. Do not check shapes when fantasizing as they are
252+
# not expected to match.
253+
if fantasize_flag.off():
254+
self._validate_tensor_args(X=X, Y=Y, Yvar=Yvar, strict=False)
252255
if Y.size(-1) == 1:
253256
Y = Y.squeeze(-1)
254257
if Yvar is not None:
@@ -505,7 +508,9 @@ def condition_on_observations(
505508
# We need to apply transforms before shifting batch indices around.
506509
# `noise` is assumed to already be outcome-transformed.
507510
Y, _ = self.outcome_transform(Y)
508-
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
511+
# Do not check shapes when fantasizing as they are not expected to match.
512+
if fantasize_flag.off():
513+
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
509514
inputs = X
510515
if self._num_outputs > 1:
511516
inputs, targets, noise = multioutput_to_batch_mode_transform(

botorch/models/higher_order_gp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from botorch.models.transforms.input import InputTransform
2626
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
2727
from botorch.models.utils import gpt_posterior_settings
28+
from botorch.models.utils.assorted import fantasize as fantasize_flag
2829
from botorch.models.utils.gpytorch_modules import (
2930
get_gaussian_likelihood_with_gamma_prior,
3031
)
@@ -414,7 +415,9 @@ def condition_on_observations(
414415
if hasattr(self, "outcome_transform"):
415416
# we need to apply transforms before shifting batch indices around
416417
Y, noise = self.outcome_transform(Y=Y, Yvar=noise)
417-
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
418+
# Do not check shapes when fantasizing as they are not expected to match.
419+
if fantasize_flag.off():
420+
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
418421

419422
# we don't need to do un-squeezing because Y already is batched
420423
# we don't support fixed noise here yet

test/models/test_gpytorch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,20 @@ def test_validate_tensor_args(self) -> None:
250250
):
251251
GPyTorchModel._validate_tensor_args(X, Y, Yvar, strict=strict)
252252

253+
def test_condition_on_observations_tensor_validation(self) -> None:
254+
model = SimpleGPyTorchModel(torch.rand(5, 1), torch.randn(5, 1))
255+
model.posterior(torch.rand(2, 1)) # evaluate the model to form caches.
256+
# Outside of fantasize, the inputs are validated.
257+
with self.assertWarnsRegex(
258+
BotorchTensorDimensionWarning, "Non-strict enforcement of"
259+
):
260+
model.condition_on_observations(torch.randn(2, 1), torch.randn(5, 2, 1))
261+
# Inside of fantasize, the inputs are not validated.
262+
with fantasize(), warnings.catch_warnings(record=True) as ws:
263+
warnings.filterwarnings("always", category=BotorchTensorDimensionWarning)
264+
model.condition_on_observations(torch.randn(2, 1), torch.randn(5, 2, 1))
265+
self.assertFalse(any(w.category is BotorchTensorDimensionWarning for w in ws))
266+
253267
def test_fantasize_flag(self):
254268
train_X = torch.rand(5, 1)
255269
train_Y = torch.sin(train_X)

0 commit comments

Comments
 (0)