Skip to content

Commit faf8d8d

Browse files
esantorellafacebook-github-bot
authored andcommitted
More removal of unused ** arguments (#2336)
Summary: Pull Request resolved: #2336 I am auditing cases in which BoTorch functions admit ** arguments and then don't use them. In many cases, it was easier to fix the issue than write it up. These are the easy cases. Note on inheritance: Type-checkers say that if a method accepts an argument, so must methods that override it. * I removed `**kwargs` from abstract methods to remove the implication that their subclasses must also support `**kwargs` even if they don't use them. * In cases where the base method admits `**kwargs` and is GPyTorch, I chose to "inconsistent override" GPyTorch rather than also having ignored `**kwargs` in BoTorch. This was the case for overriding `Module.forward` (in an ExactGP), `Kernel.forward`, and `Likelihood.forward`. Changes: * Small correctness fix in error-catching * Some small typing fixes * Removed `**kwargs` in many cases. Reviewed By: saitcakmak Differential Revision: D56849296 fbshipit-source-id: b059148e018608fac9691ee255cd73118d2e52b1
1 parent 6600655 commit faf8d8d

File tree

14 files changed

+92
-62
lines changed

14 files changed

+92
-62
lines changed

botorch/acquisition/multi_objective/predictive_entropy_search.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def _safe_update_omega(
891891
check_no_nans(omega_f_nat_cov_new)
892892
return omega_f_nat_mean_new, omega_f_nat_cov_new
893893

894-
except RuntimeError or InputDataError:
894+
except (RuntimeError, InputDataError):
895895
return omega_f_nat_mean, omega_f_nat_cov
896896

897897

@@ -1070,7 +1070,7 @@ def _update_damping_when_converged(
10701070
damping_factor: Tensor,
10711071
iteration: Tensor,
10721072
threshold: float = 1e-3,
1073-
) -> Tensor:
1073+
) -> Tuple[Tensor, Tensor, Tensor]:
10741074
r"""Set the damping factor to 0 once converged. Convergence is determined by the
10751075
relative change in the entries of the mean and covariance matrix.
10761076
@@ -1087,8 +1087,10 @@ def _update_damping_when_converged(
10871087
damping_factor: A `batch_shape`-dim Tensor containing the damping factor.
10881088
10891089
Returns:
1090-
A `batch_shape x param_shape`-dim Tensor containing the updated damping
1090+
- A `batch_shape x param_shape`-dim Tensor containing the updated damping
10911091
factor.
1092+
- Difference between `mean_new` and `mean_old`
1093+
- Difference between `cov_new` and `cov_old`
10921094
"""
10931095
df = damping_factor.clone()
10941096
delta_mean = mean_new - mean_old

botorch/models/fully_bayesian.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def sample(self) -> None:
133133

134134
@abstractmethod
135135
def postprocess_mcmc_samples(
136-
self, mcmc_samples: Dict[str, Tensor], **kwargs: Any
136+
self,
137+
mcmc_samples: Dict[str, Tensor],
137138
) -> Dict[str, Tensor]:
138139
"""Post-process the final MCMC samples."""
139140
pass # pragma: no cover

botorch/models/gpytorch.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def posterior(
204204
return posterior_transform(posterior)
205205
return posterior
206206

207-
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
207+
def condition_on_observations(
208+
self, X: Tensor, Y: Tensor, noise: Optional[Tensor] = None, **kwargs: Any
209+
) -> Model:
208210
r"""Condition the model on new observations.
209211
210212
Args:
@@ -219,6 +221,9 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
219221
standard broadcasting semantics. If `Y` has fewer batch dimensions
220222
than `X`, its is assumed that the missing batch dimensions are
221223
the same for all `Y`.
224+
noise: If not `None`, a tensor of the same shape as `Y` representing
225+
the associated noise variance.
226+
kwargs: Passed to `self.get_fantasy_model`.
222227
223228
Returns:
224229
A `Model` object of the same type, representing the original model
@@ -233,14 +238,14 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
233238
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
234239
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
235240
"""
236-
Yvar = kwargs.pop("noise", None)
241+
Yvar = noise
237242

238243
if hasattr(self, "outcome_transform"):
239244
# pass the transformed data to get_fantasy_model below
240245
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
241246
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
242247
# `noise` is assumed to already be outcome-transformed.
243-
Y, _ = self.outcome_transform(Y, Yvar)
248+
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar)
244249
# validate using strict=False, since we cannot tell if Y has an explicit
245250
# output dimension
246251
self._validate_tensor_args(X=X, Y=Y, Yvar=Yvar, strict=False)
@@ -356,7 +361,6 @@ def posterior(
356361
output_indices: Optional[List[int]] = None,
357362
observation_noise: Union[bool, Tensor] = False,
358363
posterior_transform: Optional[PosteriorTransform] = None,
359-
**kwargs: Any,
360364
) -> Union[GPyTorchPosterior, TransformedPosterior]:
361365
r"""Computes the posterior over model outputs at the provided points.
362366
@@ -609,7 +613,6 @@ def posterior(
609613
output_indices: Optional[List[int]] = None,
610614
observation_noise: Union[bool, Tensor] = False,
611615
posterior_transform: Optional[PosteriorTransform] = None,
612-
**kwargs: Any,
613616
) -> Union[GPyTorchPosterior, PosteriorList]:
614617
r"""Computes the posterior over model outputs at the provided points.
615618
If any model returns a MultitaskMultivariateNormal posterior, then that
@@ -661,7 +664,6 @@ def posterior(
661664
X=X,
662665
output_indices=output_indices,
663666
observation_noise=observation_noise,
664-
**kwargs,
665667
)
666668
if not returns_untransformed:
667669
mvns = [p.distribution for p in posterior.posteriors]
@@ -756,7 +758,6 @@ def posterior(
756758
output_indices: Optional[List[int]] = None,
757759
observation_noise: Union[bool, Tensor] = False,
758760
posterior_transform: Optional[PosteriorTransform] = None,
759-
**kwargs: Any,
760761
) -> Union[GPyTorchPosterior, TransformedPosterior]:
761762
r"""Computes the posterior over model outputs at the provided points.
762763

botorch/models/higher_order_gp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class FlattenedStandardize(Standardize):
6161
def __init__(
6262
self,
6363
output_shape: torch.Size,
64-
batch_shape: torch.Size = None,
64+
batch_shape: Optional[torch.Size] = None,
6565
min_stdv: float = 1e-8,
6666
):
6767
r"""
@@ -385,7 +385,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
385385
return super().get_fantasy_model(inputs, reshaped_targets, **kwargs)
386386

387387
def condition_on_observations(
388-
self, X: Tensor, Y: Tensor, **kwargs: Any
388+
self, X: Tensor, Y: Tensor, noise: Optional[torch.Tensor] = None, **kwargs: Any
389389
) -> HigherOrderGP:
390390
r"""Condition the model on new observations.
391391
@@ -401,17 +401,19 @@ def condition_on_observations(
401401
standard broadcasting semantics. If `Y` has fewer batch dimensions
402402
than `X`, its is assumed that the missing batch dimensions are
403403
the same for all `Y`.
404+
noise: If not None, a tensor of the same shape as `Y` representing
405+
the noise variance associated with each observation.
406+
kwargs: Passed to `condition_on_observations`.
404407
405408
Returns:
406409
A `BatchedMultiOutputGPyTorchModel` object of the same type with
407410
`n + n'` training examples, representing the original model
408411
conditioned on the new observations `(X, Y)` (and possibly noise
409412
observations passed in via kwargs).
410413
"""
411-
noise = kwargs.get("noise")
412414
if hasattr(self, "outcome_transform"):
413415
# we need to apply transforms before shifting batch indices around
414-
Y, noise = self.outcome_transform(Y, noise)
416+
Y, noise = self.outcome_transform(Y=Y, Yvar=noise)
415417
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
416418

417419
# we don't need to do un-squeezing because Y already is batched
@@ -420,7 +422,7 @@ def condition_on_observations(
420422
# kwargs.update({"noise": noise})
421423
fantasy_model = super(
422424
BatchedMultiOutputGPyTorchModel, self
423-
).condition_on_observations(X=X, Y=Y, **kwargs)
425+
).condition_on_observations(X=X, Y=Y, noise=noise, **kwargs)
424426
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
425427
: (-1 if self._num_outputs == 1 else -2)
426428
]
@@ -433,7 +435,6 @@ def posterior(
433435
output_indices: Optional[List[int]] = None,
434436
observation_noise: Union[bool, Tensor] = False,
435437
posterior_transform: Optional[PosteriorTransform] = None,
436-
**kwargs: Any,
437438
) -> GPyTorchPosterior:
438439
self.eval() # make sure we're calling a posterior
439440

botorch/models/kernels/categorical.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def forward(
2828
x2: Tensor,
2929
diag: bool = False,
3030
last_dim_is_batch: bool = False,
31-
**kwargs,
3231
) -> Tensor:
3332
delta = x1.unsqueeze(-2) != x2.unsqueeze(-3)
3433
dists = delta / self.lengthscale.unsqueeze(-2)

botorch/models/likelihoods/pairwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import math
1414
from abc import ABC, abstractmethod
15-
from typing import Any, Tuple
15+
from typing import Tuple
1616

1717
import torch
1818
from botorch.utils.probability.utils import (
@@ -41,7 +41,7 @@ def __init__(self, max_plate_nesting: int = 1):
4141
"""
4242
super().__init__(max_plate_nesting)
4343

44-
def forward(self, utility: Tensor, D: Tensor, **kwargs: Any) -> Bernoulli:
44+
def forward(self, utility: Tensor, D: Tensor) -> Bernoulli:
4545
"""Given the difference in (estimated) utility util_diff = f(v) - f(u),
4646
return a Bernoulli distribution object representing the likelihood of
4747
the user prefer v over u.

botorch/models/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def posterior(
9595
output_indices: Optional[List[int]] = None,
9696
observation_noise: Union[bool, Tensor] = False,
9797
posterior_transform: Optional[PosteriorTransform] = None,
98-
**kwargs: Any,
9998
) -> Posterior:
10099
r"""Computes the posterior over model outputs at the provided points.
101100
@@ -301,7 +300,9 @@ def __init__(self, args):
301300

302301
@abstractmethod
303302
def condition_on_observations(
304-
self: TFantasizeMixin, X: Tensor, Y: Tensor, **kwargs: Any
303+
self: TFantasizeMixin,
304+
X: Tensor,
305+
Y: Tensor,
305306
) -> TFantasizeMixin:
306307
"""
307308
Classes that inherit from `FantasizeMixin` must implement
@@ -314,7 +315,6 @@ def posterior(
314315
X: Tensor,
315316
*args,
316317
observation_noise: bool = False,
317-
**kwargs: Any,
318318
) -> Posterior:
319319
"""
320320
Classes that inherit from `FantasizeMixin` must implement
@@ -474,7 +474,6 @@ def posterior(
474474
output_indices: Optional[List[int]] = None,
475475
observation_noise: Union[bool, Tensor] = False,
476476
posterior_transform: Optional[Callable[[PosteriorList], Posterior]] = None,
477-
**kwargs: Any,
478477
) -> Posterior:
479478
r"""Computes the posterior over model outputs at the provided points.
480479

botorch/models/multitask.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,6 @@ def posterior(
572572
output_indices: Optional[List[int]] = None,
573573
observation_noise: Union[bool, Tensor] = False,
574574
posterior_transform: Optional[PosteriorTransform] = None,
575-
**kwargs: Any,
576575
) -> MultitaskGPPosterior:
577576
self.eval()
578577

botorch/models/pairwise_gp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,6 @@ def posterior(
10701070
output_indices: Optional[List[int]] = None,
10711071
observation_noise: bool = False,
10721072
posterior_transform: Optional[PosteriorTransform] = None,
1073-
**kwargs: Any,
10741073
) -> Posterior:
10751074
r"""Computes the posterior over model outputs at the provided points.
10761075
@@ -1100,11 +1099,11 @@ def posterior(
11001099
return posterior_transform(posterior)
11011100
return posterior
11021101

1103-
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
1102+
def condition_on_observations(self, X: Tensor, Y: Tensor) -> Model:
11041103
r"""Condition the model on new observations.
11051104
11061105
Note that unlike other BoTorch models, PairwiseGP requires Y to be
1107-
pairwise comparisons
1106+
pairwise comparisons.
11081107
11091108
Args:
11101109
X: A `batch_shape x n x d` dimension tensor X

botorch/optim/closures/model_closures.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def get_loss_closure_with_grads(
104104
@GetLossClosureWithGrads.register(object, object, object, object)
105105
def _get_loss_closure_with_grads_fallback(
106106
mll: MarginalLogLikelihood,
107-
_: object,
108-
__: object,
107+
_likelihood_type: object,
108+
_model_type: object,
109109
data_loader: Optional[DataLoader],
110110
parameters: Dict[str, Tensor],
111111
reducer: Callable[[Tensor], Tensor] = Tensor.sum,
@@ -127,8 +127,8 @@ def _get_loss_closure_with_grads_fallback(
127127
@GetLossClosure.register(MarginalLogLikelihood, object, object, DataLoader)
128128
def _get_loss_closure_fallback_external(
129129
mll: MarginalLogLikelihood,
130-
_: object,
131-
__: object,
130+
_likelihood_type: object,
131+
_model_type: object,
132132
data_loader: DataLoader,
133133
**ignore: Any,
134134
) -> Callable[[], Tensor]:
@@ -153,7 +153,7 @@ def closure(**kwargs: Any) -> Tensor:
153153

154154
@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType)
155155
def _get_loss_closure_fallback_internal(
156-
mll: MarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
156+
mll: MarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
157157
) -> Callable[[], Tensor]:
158158
r"""Fallback loss closure with internally managed data."""
159159

@@ -167,7 +167,7 @@ def closure(**kwargs: Any) -> Tensor:
167167

168168
@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType)
169169
def _get_loss_closure_exact_internal(
170-
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
170+
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
171171
) -> Callable[[], Tensor]:
172172
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""
173173

@@ -183,7 +183,7 @@ def closure(**kwargs: Any) -> Tensor:
183183

184184
@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType)
185185
def _get_loss_closure_sum_internal(
186-
mll: SumMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
186+
mll: SumMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
187187
) -> Callable[[], Tensor]:
188188
r"""SumMarginalLogLikelihood loss closure with internally managed data."""
189189

0 commit comments

Comments
 (0)