Skip to content

Commit 1d2f840

Browse files
esantorellafacebook-github-bot
authored andcommitted
Remove _filter_kwargs (#2341)
Summary: Pull Request resolved: #2341 Now when incorrect arguments are passed to `optimize_acqf` in `gen_kwargs`, they will raise an exception (as incorrect arguments generally do) rather than being silently ignored. There was only one usage of `_filter_kwargs`, and it would not be triggered by correct usage after a prior change that stopped creating unused arguments. Reviewed By: Balandat Differential Revision: D57250387 fbshipit-source-id: b2090154151d88a7a34b978f1ca067222b44da59
1 parent e2d90ac commit 1d2f840

File tree

6 files changed

+63
-166
lines changed

6 files changed

+63
-166
lines changed

botorch/optim/optimize.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
TGenInitialConditions,
3737
)
3838
from botorch.optim.stopping import ExpMAStoppingCriterion
39-
from botorch.optim.utils import _filter_kwargs
4039
from torch import Tensor
4140

4241
INIT_OPTION_KEYS = {
@@ -314,8 +313,6 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
314313
"timeout_sec": timeout_sec,
315314
}
316315

317-
# only add parameter constraints to gen_kwargs if they are specified
318-
# to avoid unnecessary warnings in _filter_kwargs
319316
for constraint_name in [
320317
"inequality_constraints",
321318
"equality_constraints",
@@ -324,8 +321,6 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
324321
if (constraint := getattr(opt_inputs, constraint_name)) is not None:
325322
gen_kwargs[constraint_name] = constraint
326323

327-
filtered_gen_kwargs = _filter_kwargs(opt_inputs.gen_candidates, **gen_kwargs)
328-
329324
for i, batched_ics_ in enumerate(batched_ics):
330325
# optimize using random restart optimization
331326
with warnings.catch_warnings(record=True) as ws:
@@ -334,7 +329,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
334329
batch_candidates_curr,
335330
batch_acq_values_curr,
336331
) = opt_inputs.gen_candidates(
337-
batched_ics_, opt_inputs.acq_function, **filtered_gen_kwargs
332+
batched_ics_, opt_inputs.acq_function, **gen_kwargs
338333
)
339334
opt_warnings += ws
340335
batch_candidates_list.append(batch_candidates_curr)

botorch/optim/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
get_X_baseline,
1111
)
1212
from botorch.optim.utils.common import (
13-
_filter_kwargs,
1413
_handle_numerical_errors,
1514
_warning_handler_template,
1615
)
@@ -31,7 +30,6 @@
3130
from botorch.optim.utils.timeout import minimize_with_timeout
3231

3332
__all__ = [
34-
"_filter_kwargs",
3533
"_handle_numerical_errors",
3634
"_warning_handler_template",
3735
"as_ndarray",

botorch/optim/utils/common.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,14 @@
88

99
from __future__ import annotations
1010

11-
from inspect import signature
1211
from logging import debug as logging_debug
13-
from typing import Any, Callable, Optional, Tuple
14-
from warnings import warn, warn_explicit, WarningMessage
12+
from typing import Callable, Optional, Tuple
13+
from warnings import warn_explicit, WarningMessage
1514

1615
import numpy as np
1716
from linear_operator.utils.errors import NanError, NotPSDError
1817

1918

20-
def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:
21-
r"""Filter out kwargs that are not applicable for a given function.
22-
Return a copy of given kwargs dict with only the required kwargs."""
23-
allowed_params = signature(function).parameters
24-
removed = {k for k in kwargs.keys() if k not in allowed_params}
25-
if len(removed) > 0:
26-
fn_descriptor = (
27-
f" for function {function.__name__}"
28-
if hasattr(function, "__name__")
29-
else ""
30-
)
31-
warn(
32-
f"Keyword arguments {list(removed)} will be ignored because they are"
33-
f" not allowed parameters{fn_descriptor}. Allowed "
34-
f"parameters are {list(allowed_params.keys())}."
35-
)
36-
return {k: v for k, v in kwargs.items() if k not in removed}
37-
38-
3919
def _handle_numerical_errors(
4020
error: RuntimeError, x: np.ndarray, dtype: Optional[np.dtype] = None
4121
) -> Tuple[np.ndarray, np.ndarray]:

test/acquisition/test_knowledge_gradient.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@
2424
)
2525
from botorch.acquisition.utils import project_to_sample_points
2626
from botorch.exceptions.errors import UnsupportedError
27-
from botorch.generation.gen import gen_candidates_scipy
2827
from botorch.models import SingleTaskGP
2928
from botorch.optim.optimize import optimize_acqf
30-
from botorch.optim.utils import _filter_kwargs
3129
from botorch.posteriors.gpytorch import GPyTorchPosterior
3230
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
3331
from botorch.utils.test_helpers import DummyNonScalarizingPosteriorTransform
@@ -588,11 +586,6 @@ def test_optimize_w_posterior_transform(self):
588586
torch.zeros(2, n_f + 1, 2, **tkwargs),
589587
torch.zeros(2, **tkwargs),
590588
),
591-
), mock.patch(
592-
f"{optimize_acqf.__module__}._filter_kwargs",
593-
wraps=lambda f, **kwargs: _filter_kwargs(
594-
function=gen_candidates_scipy, **kwargs
595-
),
596589
):
597590

598591
candidate, value = optimize_acqf(

test/optim/test_optimize.py

Lines changed: 59 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import itertools
88
import warnings
9-
from inspect import signature
109
from itertools import product
1110
from unittest import mock
1211

@@ -114,10 +113,8 @@ class TestOptimizeAcqf(BotorchTestCase):
114113
@mock.patch("botorch.generation.gen.gen_candidates_torch")
115114
@mock.patch("botorch.optim.optimize.gen_batch_initial_conditions")
116115
@mock.patch("botorch.optim.optimize.gen_candidates_scipy")
117-
@mock.patch("botorch.optim.utils.common.signature")
118116
def test_optimize_acqf_joint(
119117
self,
120-
mock_signature,
121118
mock_gen_candidates_scipy,
122119
mock_gen_batch_initial_conditions,
123120
mock_gen_candidates_torch,
@@ -134,10 +131,6 @@ def test_optimize_acqf_joint(
134131
mock_gen_candidates_scipy,
135132
mock_gen_candidates_torch,
136133
):
137-
if mock_gen_candidates == mock_gen_candidates_torch:
138-
mock_signature.return_value = signature(gen_candidates_torch)
139-
else:
140-
mock_signature.return_value = signature(gen_candidates_scipy)
141134

142135
mock_gen_batch_initial_conditions.return_value = torch.zeros(
143136
num_restarts, q, 3, device=self.device, dtype=dtype
@@ -264,12 +257,14 @@ def test_optimize_acqf_joint(
264257
)
265258

266259
@mock.patch("botorch.optim.optimize.gen_batch_initial_conditions")
267-
@mock.patch("botorch.optim.optimize.gen_candidates_scipy")
268-
@mock.patch("botorch.generation.gen.gen_candidates_torch")
269-
@mock.patch("botorch.optim.utils.common.signature")
260+
@mock.patch(
261+
"botorch.optim.optimize.gen_candidates_scipy", wraps=gen_candidates_scipy
262+
)
263+
@mock.patch(
264+
"botorch.generation.gen.gen_candidates_torch", wraps=gen_candidates_torch
265+
)
270266
def test_optimize_acqf_sequential(
271267
self,
272-
mock_signature,
273268
mock_gen_candidates_torch,
274269
mock_gen_candidates_scipy,
275270
mock_gen_batch_initial_conditions,
@@ -278,11 +273,6 @@ def test_optimize_acqf_sequential(
278273
for mock_gen_candidates, timeout_sec in product(
279274
[mock_gen_candidates_scipy, mock_gen_candidates_torch], [None, 1e-4]
280275
):
281-
if mock_gen_candidates == mock_gen_candidates_torch:
282-
mock_signature.return_value = signature(gen_candidates_torch)
283-
else:
284-
mock_signature.return_value = signature(gen_candidates_scipy)
285-
mock_gen_candidates.__name__ = "gen_candidates"
286276
q = 3
287277
num_restarts = 2
288278
raw_samples = 10
@@ -1019,16 +1009,12 @@ def nlc4(x):
10191009
raw_samples=16,
10201010
)
10211011

1022-
@mock.patch("botorch.generation.gen.gen_candidates_torch")
10231012
@mock.patch("botorch.optim.optimize.gen_batch_initial_conditions")
10241013
@mock.patch("botorch.optim.optimize.gen_candidates_scipy")
1025-
@mock.patch("botorch.optim.utils.common.signature")
10261014
def test_optimize_acqf_non_linear_constraints_sequential(
10271015
self,
1028-
mock_signature,
10291016
mock_gen_candidates_scipy,
10301017
mock_gen_batch_initial_conditions,
1031-
mock_gen_candidates_torch,
10321018
):
10331019
def nlc(x):
10341020
return 4 * x[..., 2] - 5
@@ -1037,90 +1023,63 @@ def nlc(x):
10371023
num_restarts = 2
10381024
raw_samples = 10
10391025
options = {}
1040-
for mock_gen_candidates in (
1041-
mock_gen_candidates_torch,
1042-
mock_gen_candidates_scipy,
1043-
):
1044-
if mock_gen_candidates == mock_gen_candidates_torch:
1045-
mock_signature.return_value = signature(gen_candidates_torch)
1046-
else:
1047-
mock_signature.return_value = signature(gen_candidates_scipy)
1048-
for dtype in (torch.float, torch.double):
1049-
mock_acq_function = MockAcquisitionFunction()
1050-
mock_gen_batch_initial_conditions.side_effect = [
1051-
torch.zeros(num_restarts, 1, 3, device=self.device, dtype=dtype)
1052-
for _ in range(q)
1053-
]
1054-
gcs_return_vals = [
1055-
(
1056-
torch.tensor(
1057-
[[[1.0, 2.0, 3.0]]], device=self.device, dtype=dtype
1058-
),
1059-
torch.tensor([i], device=self.device, dtype=dtype),
1060-
)
1061-
# for nonlinear inequality constraints the batch_limit variable is
1062-
# currently set to 1 by default and hence gen_candidates_scipy is
1063-
# called num_restarts*q times
1064-
for i in range(num_restarts * q)
1065-
]
1066-
mock_gen_candidates.side_effect = gcs_return_vals
1067-
expected_candidates = torch.cat(
1068-
[cands[0] for cands, _ in gcs_return_vals[::num_restarts]], dim=-2
1026+
1027+
for dtype in (torch.float, torch.double):
1028+
mock_acq_function = MockAcquisitionFunction()
1029+
mock_gen_batch_initial_conditions.side_effect = [
1030+
torch.zeros(num_restarts, 1, 3, device=self.device, dtype=dtype)
1031+
for _ in range(q)
1032+
]
1033+
gcs_return_vals = [
1034+
(
1035+
torch.tensor([[[1.0, 2.0, 3.0]]], device=self.device, dtype=dtype),
1036+
torch.tensor([i], device=self.device, dtype=dtype),
10691037
)
1070-
bounds = torch.stack(
1071-
[
1072-
torch.zeros(3, device=self.device, dtype=dtype),
1073-
4 * torch.ones(3, device=self.device, dtype=dtype),
1074-
]
1038+
# for nonlinear inequality constraints the batch_limit variable is
1039+
# currently set to 1 by default and hence gen_candidates_scipy is
1040+
# called num_restarts*q times
1041+
for i in range(num_restarts * q)
1042+
]
1043+
mock_gen_candidates_scipy.side_effect = gcs_return_vals
1044+
expected_candidates = torch.cat(
1045+
[cands[0] for cands, _ in gcs_return_vals[::num_restarts]], dim=-2
1046+
)
1047+
bounds = torch.stack(
1048+
[
1049+
torch.zeros(3, device=self.device, dtype=dtype),
1050+
4 * torch.ones(3, device=self.device, dtype=dtype),
1051+
]
1052+
)
1053+
with warnings.catch_warnings(record=True) as ws:
1054+
candidates, acq_value = optimize_acqf(
1055+
acq_function=mock_acq_function,
1056+
bounds=bounds,
1057+
q=q,
1058+
num_restarts=num_restarts,
1059+
raw_samples=raw_samples,
1060+
options=options,
1061+
nonlinear_inequality_constraints=[nlc],
1062+
sequential=True,
1063+
ic_generator=mock_gen_batch_initial_conditions,
1064+
gen_candidates=mock_gen_candidates_scipy,
10751065
)
1076-
with warnings.catch_warnings(record=True) as ws:
1077-
candidates, acq_value = optimize_acqf(
1078-
acq_function=mock_acq_function,
1079-
bounds=bounds,
1080-
q=q,
1081-
num_restarts=num_restarts,
1082-
raw_samples=raw_samples,
1083-
options=options,
1084-
nonlinear_inequality_constraints=[nlc],
1085-
sequential=True,
1086-
ic_generator=mock_gen_batch_initial_conditions,
1087-
gen_candidates=mock_gen_candidates,
1088-
)
1089-
if mock_gen_candidates == mock_gen_candidates_torch:
1090-
self.assertEqual(len(ws), 3)
1091-
message = (
1092-
"Keyword arguments ['nonlinear_inequality_constraints']"
1093-
" will be ignored because they are not allowed parameters for"
1094-
" function gen_candidates. Allowed parameters are "
1095-
" ['initial_conditions', 'acquisition_function', "
1096-
"'lower_bounds', 'upper_bounds', 'optimizer', 'options',"
1097-
" 'callback', 'fixed_features', 'timeout_sec']."
1098-
)
1099-
expected_warning_raised = (
1100-
issubclass(w.category, UserWarning)
1101-
and message == str(w.message)
1102-
for w in ws
1103-
)
1104-
self.assertTrue(expected_warning_raised)
1105-
# check message
1106-
else:
1107-
self.assertEqual(len(ws), 0)
1108-
self.assertTrue(torch.equal(candidates, expected_candidates))
1109-
# Extract the relevant entries from gcs_return_vals to
1110-
# perform comparison with.
1111-
self.assertTrue(
1112-
torch.equal(
1113-
acq_value,
1114-
torch.cat(
1115-
[
1116-
expected_acq_value
1117-
for _, expected_acq_value in gcs_return_vals[
1118-
num_restarts - 1 :: num_restarts
1119-
]
1066+
self.assertEqual(len(ws), 0)
1067+
self.assertTrue(torch.equal(candidates, expected_candidates))
1068+
# Extract the relevant entries from gcs_return_vals to
1069+
# perform comparison with.
1070+
self.assertTrue(
1071+
torch.equal(
1072+
acq_value,
1073+
torch.cat(
1074+
[
1075+
expected_acq_value
1076+
for _, expected_acq_value in gcs_return_vals[
1077+
num_restarts - 1 :: num_restarts
11201078
]
1121-
),
1079+
]
11221080
),
1123-
)
1081+
),
1082+
)
11241083

11251084
def test_constraint_caching(self):
11261085
def nlc(x):

test/optim/utils/test_common.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,12 @@
1010
from warnings import catch_warnings, warn
1111

1212
import numpy as np
13-
from botorch.optim.utils import (
14-
_filter_kwargs,
15-
_handle_numerical_errors,
16-
_warning_handler_template,
17-
)
13+
from botorch.optim.utils import _handle_numerical_errors, _warning_handler_template
1814
from botorch.utils.testing import BotorchTestCase
1915
from linear_operator.utils.errors import NanError, NotPSDError
2016

2117

2218
class TestUtilsCommon(BotorchTestCase):
23-
def test__filter_kwargs(self) -> None:
24-
def mock_adam(params, lr: float = 0.001) -> None:
25-
return # pragma: nocover
26-
27-
kwargs = {"lr": 0.01, "maxiter": 3000}
28-
expected_msg = (
29-
r"Keyword arguments \['maxiter'\] will be ignored because they are "
30-
r"not allowed parameters for function mock_adam. Allowed parameters "
31-
r"are \['params', 'lr'\]."
32-
)
33-
34-
with self.assertWarnsRegex(Warning, expected_msg):
35-
valid_kwargs = _filter_kwargs(mock_adam, **kwargs)
36-
self.assertEqual(set(valid_kwargs.keys()), {"lr"})
37-
38-
mock_partial = partial(mock_adam, lr=2.0)
39-
expected_msg = (
40-
r"Keyword arguments \['maxiter'\] will be ignored because they are "
41-
r"not allowed parameters. Allowed parameters are \['params', 'lr'\]."
42-
)
43-
with self.assertWarnsRegex(Warning, expected_msg):
44-
valid_kwargs = _filter_kwargs(mock_partial, **kwargs)
45-
self.assertEqual(set(valid_kwargs.keys()), {"lr"})
46-
4719
def test_handle_numerical_errors(self):
4820
x = np.zeros(1, dtype=np.float64)
4921

0 commit comments

Comments
 (0)