Skip to content

Commit db2e947

Browse files
esantorellafacebook-github-bot
authored andcommitted
Introduce fast_optimize context manager to speed up testing (#2563)
Summary: Pull Request resolved: #2563 Context: Many BoTorch tests take a while because they run optimization code. But it's good that this code runs rather than being avoided or mocked out, becuase the tests are ensuring that things work end-to-end. Borrowing a page from Ax's `fast_botorch_optimize`, this commit introduces the same thing to BoTorch, with the exception of `fit_fully_bayesian_model_nuts`. A future commit to Ax can remove that functionality from Ax in favor of importing it from BoTorch, but we might not want to do it right way because then Ax won't work with older versions of BoTorch. This PR: * Introduces `fast_optimize`, which is the same as Ax's `fast_botorch_optimize`, but with different import paths. * Applies it to a slow test, reducing runtime to 2s from 6s-10s. Reviewed By: sdaulton Differential Revision: D63838626 fbshipit-source-id: d2f8d6b496df1b50baf4f07327713a38473157d4
1 parent a0a2c05 commit db2e947

File tree

5 files changed

+264
-0
lines changed

5 files changed

+264
-0
lines changed

botorch/test_utils/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
test_utils has its own directory with 'botorch/' to avoid circular dependencies:
9+
Anything in 'tests/' can depend on anything in 'botorch/test_utils/', and
10+
anything in 'botorch/test_utils/' can depend on anything in the rest of
11+
'botorch/'.
12+
"""
13+
14+
from botorch.test_utils.mock import fast_optimize
15+
16+
__all__ = ["fast_optimize"]

botorch/test_utils/mock.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Utilities for speeding up optimization in tests.
9+
10+
"""
11+
from __future__ import annotations
12+
13+
from collections.abc import Generator
14+
from contextlib import contextmanager, ExitStack
15+
from functools import wraps
16+
from typing import Any, Callable
17+
from unittest import mock
18+
19+
from botorch.optim.initializers import (
20+
gen_batch_initial_conditions,
21+
gen_one_shot_kg_initial_conditions,
22+
)
23+
24+
from botorch.optim.utils.timeout import minimize_with_timeout
25+
from scipy.optimize import OptimizeResult
26+
from torch import Tensor
27+
28+
29+
@contextmanager
30+
def fast_optimize_context_manager(
31+
force: bool = False,
32+
) -> Generator[None, None, None]:
33+
"""A context manager to force botorch to speed up optimization. Currently, the
34+
primary tactic is to force the underlying scipy methods to stop after just one
35+
iteration.
36+
37+
force: If True will not raise an AssertionError if no mocks are called.
38+
USE RESPONSIBLY.
39+
"""
40+
41+
def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult:
42+
if kwargs["options"] is None:
43+
kwargs["options"] = {}
44+
45+
kwargs["options"]["maxiter"] = 1
46+
return minimize_with_timeout(*args, **kwargs)
47+
48+
def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor:
49+
kwargs["num_restarts"] = 2
50+
kwargs["raw_samples"] = 4
51+
52+
return gen_batch_initial_conditions(*args, **kwargs)
53+
54+
def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None:
55+
kwargs["num_restarts"] = 2
56+
kwargs["raw_samples"] = 4
57+
58+
return gen_one_shot_kg_initial_conditions(*args, **kwargs)
59+
60+
with ExitStack() as es:
61+
# Note this `minimize_with_timeout` is defined in optim.utils.timeout;
62+
# this mock only has an effect when calling a function used in
63+
# `botorch.generation.gen`, such as `gen_candidates_scipy`.
64+
mock_generation = es.enter_context(
65+
mock.patch(
66+
"botorch.generation.gen.minimize_with_timeout",
67+
wraps=one_iteration_minimize,
68+
)
69+
)
70+
71+
# Similarly, works when using calling a function defined in
72+
# `optim.core`, such as `scipy_minimize` and `torch_minimize`.
73+
mock_fit = es.enter_context(
74+
mock.patch(
75+
"botorch.optim.core.minimize_with_timeout",
76+
wraps=one_iteration_minimize,
77+
)
78+
)
79+
80+
# Works when calling a function in `optim.optimize` such as
81+
# `optimize_acqf`
82+
mock_gen_ics = es.enter_context(
83+
mock.patch(
84+
"botorch.optim.optimize.gen_batch_initial_conditions",
85+
wraps=minimal_gen_ics,
86+
)
87+
)
88+
89+
# Works when calling a function in `optim.optimize` such as
90+
# `optimize_acqf`
91+
mock_gen_os_ics = es.enter_context(
92+
mock.patch(
93+
"botorch.optim.optimize.gen_one_shot_kg_initial_conditions",
94+
wraps=minimal_gen_os_ics,
95+
)
96+
)
97+
98+
yield
99+
100+
if (not force) and all(
101+
mock_.call_count < 1
102+
for mock_ in [
103+
mock_generation,
104+
mock_fit,
105+
mock_gen_ics,
106+
mock_gen_os_ics,
107+
]
108+
):
109+
raise AssertionError(
110+
"No mocks were called in the context manager. Please remove unused "
111+
"fast_optimize_context_manager()."
112+
)
113+
114+
115+
def fast_optimize(f: Callable) -> Callable:
116+
"""Wraps f in the fast_botorch_optimize_context_manager for use as a decorator."""
117+
118+
@wraps(f)
119+
# pyre-fixme[3]: Return type must be annotated.
120+
def inner(*args: Any, **kwargs: Any):
121+
with fast_optimize_context_manager():
122+
return f(*args, **kwargs)
123+
124+
return inner

sphinx/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ BoTorch API Reference
2222
settings
2323
logging
2424
test_functions
25+
test_utils
2526
exceptions
2627
utils
2728

sphinx/source/test_utils.rst

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
5+
botorch.test_utils
6+
========================================================
7+
.. automodule:: botorch.test_utils
8+
9+
Mock
10+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
11+
.. automodule:: botorch.test_utils.mock
12+
:members:

test/test_utils/test_mock.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import warnings
9+
from unittest.mock import patch
10+
11+
import torch
12+
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
13+
from botorch.exceptions.warnings import BadInitialCandidatesWarning
14+
from botorch.generation.gen import gen_candidates_scipy
15+
from botorch.models.gp_regression import SingleTaskGP
16+
from botorch.optim.core import scipy_minimize
17+
from botorch.optim.initializers import gen_batch_initial_conditions, initialize_q_batch
18+
from botorch.optim.optimize import optimize_acqf
19+
20+
from botorch.test_utils.mock import fast_optimize, fast_optimize_context_manager
21+
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction
22+
23+
24+
class SinAcqusitionFunction(MockAcquisitionFunction):
25+
"""Simple acquisition function with known numerical properties."""
26+
27+
def __init__(self, *args, **kwargs): # noqa: D107
28+
return
29+
30+
def __call__(self, X):
31+
return torch.sin(X[..., 0].max(dim=-1).values)
32+
33+
34+
class TestMock(BotorchTestCase):
35+
def test_fast_optimize_context_manager(self):
36+
with self.subTest("gen_candidates_scipy"):
37+
with fast_optimize_context_manager():
38+
cand, value = gen_candidates_scipy(
39+
initial_conditions=torch.tensor([[0.0]]),
40+
acquisition_function=SinAcqusitionFunction(),
41+
)
42+
# When not using `fast_optimize`, the value is 1.0. With it, the value is
43+
# around 0.84
44+
self.assertLess(value.item(), 0.99)
45+
46+
with self.subTest("scipy_minimize"):
47+
x = torch.tensor([0.0])
48+
49+
def closure():
50+
return torch.sin(x), [torch.cos(x)]
51+
52+
with fast_optimize_context_manager():
53+
result = scipy_minimize(closure=closure, parameters={"x": x})
54+
self.assertEqual(
55+
result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT"
56+
)
57+
58+
with self.subTest("optimize_acqf"):
59+
with fast_optimize_context_manager():
60+
cand, value = optimize_acqf(
61+
acq_function=SinAcqusitionFunction(),
62+
bounds=torch.tensor([[-2.0], [2.0]]),
63+
q=1,
64+
num_restarts=32,
65+
batch_initial_conditions=torch.tensor([[0.0]]),
66+
)
67+
self.assertLess(value.item(), 0.99)
68+
69+
with self.subTest("gen_batch_initial_conditions"):
70+
with fast_optimize_context_manager(), patch(
71+
"botorch.optim.initializers.initialize_q_batch",
72+
wraps=initialize_q_batch,
73+
) as mock_init_q_batch:
74+
cand, value = optimize_acqf(
75+
acq_function=SinAcqusitionFunction(),
76+
bounds=torch.tensor([[-2.0], [2.0]]),
77+
q=1,
78+
num_restarts=32,
79+
raw_samples=16,
80+
)
81+
self.assertEqual(mock_init_q_batch.call_args[1]["n"], 2)
82+
83+
@fast_optimize
84+
def test_decorator(self) -> None:
85+
model = SingleTaskGP(
86+
train_X=torch.tensor([[0.0]], dtype=torch.double),
87+
train_Y=torch.tensor([[0.0]], dtype=torch.double),
88+
)
89+
acqf = qKnowledgeGradient(model=model, num_fantasies=64)
90+
# this is called within gen_one_shot_kg_initial_conditions
91+
with patch(
92+
"botorch.optim.initializers.gen_batch_initial_conditions",
93+
wraps=gen_batch_initial_conditions,
94+
) as mock_gen_batch_ics, warnings.catch_warnings():
95+
warnings.simplefilter("ignore", category=BadInitialCandidatesWarning)
96+
cand, value = optimize_acqf(
97+
acq_function=acqf,
98+
bounds=torch.tensor([[-2.0], [2.0]]),
99+
q=1,
100+
num_restarts=32,
101+
raw_samples=16,
102+
)
103+
104+
called_with = mock_gen_batch_ics.call_args[1]
105+
self.assertEqual(called_with["num_restarts"], 2)
106+
self.assertEqual(called_with["raw_samples"], 4)
107+
108+
def test_raises_when_unused(self) -> None:
109+
with self.assertRaisesRegex(AssertionError, "No mocks were called"):
110+
with fast_optimize_context_manager():
111+
pass

0 commit comments

Comments
 (0)