Skip to content

Commit 3d18ae3

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Allow batched fixed features in gen_candidates_scipy and gen_candidates_torch (#2893)
Summary: This is a PR that should enable batching for mixed optimization later. To enable it, we need to allow setting different fixed features for different initial conditions during optimizations. We do this by allowing passing tensors of shape [b] or [b,q] to `gen_candidates_scipy` and for `gen_candidates_torch` for compatibility. Reviewed By: saitcakmak Differential Revision: D77043260
1 parent a80b98f commit 3d18ae3

File tree

5 files changed

+206
-76
lines changed

5 files changed

+206
-76
lines changed

botorch/generation/gen.py

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import warnings
1515
from collections.abc import Callable
1616
from functools import partial
17-
from typing import Any, NoReturn
17+
from typing import Any, Mapping, NoReturn
1818

1919
import numpy as np
2020
import numpy.typing as npt
@@ -64,7 +64,7 @@ def gen_candidates_scipy(
6464
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
6565
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
6666
options: dict[str, Any] | None = None,
67-
fixed_features: dict[int, float | None] | None = None,
67+
fixed_features: Mapping[int, float | Tensor] | None = None,
6868
timeout_sec: float | None = None,
6969
use_parallel_mode: bool | None = None,
7070
) -> tuple[Tensor, Tensor]:
@@ -107,11 +107,11 @@ def gen_candidates_scipy(
107107
and SLSQP if inequality or equality constraints are present. If
108108
`with_grad=False`, then we use a two-point finite difference estimate
109109
of the gradient.
110-
fixed_features: This is a dictionary of feature indices to values, where
110+
fixed_features: Mapping[int, float | Tensor] | None,
111111
all generated candidates will have features fixed to these values.
112-
If the dictionary value is None, then that feature will just be
113-
fixed to the clamped value and not optimized. Assumes values to be
114-
compatible with lower_bounds and upper_bounds!
112+
If passing tensors as values, they should have either shape `b` or
113+
`b x q` to fix the same feature to different values in the batch.
114+
Assumes values to be compatible with lower_bounds and upper_bounds!
115115
timeout_sec: Timeout (in seconds) for `scipy.optimize.minimize` routine -
116116
if provided, optimization will stop after this many seconds and return
117117
the best solution found so far.
@@ -211,18 +211,17 @@ def f(x):
211211
timeout_sec=timeout_sec,
212212
)
213213

214+
f_np_wrapper = _get_f_np_wrapper(
215+
clamped_candidates.shape,
216+
initial_conditions.device,
217+
initial_conditions.dtype,
218+
with_grad,
219+
)
220+
214221
if not why_not_fast_path and use_parallel_mode is not False:
215222
if is_constrained:
216223
raise RuntimeWarning("Method L-BFGS-B cannot handle constraints.")
217224

218-
f_np_wrapper = _get_f_np_wrapper(
219-
clamped_candidates.shape,
220-
initial_conditions.device,
221-
initial_conditions.dtype,
222-
with_grad,
223-
batched=True,
224-
)
225-
226225
batched_x0 = _arrayify(clamped_candidates).reshape(len(clamped_candidates), -1)
227226

228227
l_bfgs_b_bounds = translate_bounds_for_lbfgsb(
@@ -242,6 +241,7 @@ def f(x):
242241
bounds=l_bfgs_b_bounds,
243242
# constraints=constraints,
244243
callback=options.get("callback", None),
244+
pass_batch_indices=True,
245245
**minimize_options,
246246
)
247247
for res in results:
@@ -264,21 +264,38 @@ def f(x):
264264
else:
265265
logger.debug(msg)
266266

267-
f_np_wrapper = _get_f_np_wrapper(
268-
clamped_candidates.shape,
269-
initial_conditions.device,
270-
initial_conditions.dtype,
271-
with_grad,
272-
)
267+
if (
268+
fixed_features
269+
and any(
270+
torch.is_tensor(ff) and ff.ndim > 0 for ff in fixed_features.values()
271+
)
272+
and max_optimization_problem_aggregation_size != 1
273+
):
274+
raise UnsupportedError(
275+
"Batch shaped fixed features are not "
276+
"supported, when optimizing more than one optimization "
277+
"problem at a time."
278+
)
273279

274280
all_xs = []
275281
split_candidates = clamped_candidates.split(
276282
max_optimization_problem_aggregation_size
277283
)
278-
for candidates_ in split_candidates:
279-
# We optimize the candidates at hand as a single problem
284+
for i, candidates_ in enumerate(split_candidates):
285+
if fixed_features:
286+
fixed_features_ = {
287+
k: ff[i : i + 1].item()
288+
# from the test above, we know that we only treat one candidate
289+
# at a time thus we can use index i
290+
if torch.is_tensor(ff) and ff.ndim > 0
291+
else ff
292+
for k, ff in fixed_features.items()
293+
}
294+
else:
295+
fixed_features_ = None
296+
280297
_no_fixed_features = _remove_fixed_features_from_optimization(
281-
fixed_features=fixed_features,
298+
fixed_features=fixed_features_,
282299
acquisition_function=acquisition_function,
283300
initial_conditions=None,
284301
d=initial_conditions_all_features.shape[-1],
@@ -296,7 +313,7 @@ def f(x):
296313

297314
f_np_wrapper_ = partial(
298315
f_np_wrapper,
299-
fixed_features=fixed_features,
316+
fixed_features=fixed_features_,
300317
)
301318

302319
x0 = candidates_.flatten()
@@ -363,13 +380,14 @@ def f(x):
363380
return clamped_candidates, batch_acquisition
364381

365382

366-
def _get_f_np_wrapper(shapeX, device, dtype, with_grad, batched=False):
383+
def _get_f_np_wrapper(shapeX, device, dtype, with_grad):
367384
if with_grad:
368385

369386
def f_np_wrapper(
370387
x: npt.NDArray,
371388
f: Callable,
372-
fixed_features: dict[int, float] | None,
389+
fixed_features: Mapping[int, float | Tensor] | None,
390+
batch_indices: list[int] | None = None,
373391
) -> tuple[float | np.NDArray, np.NDArray]:
374392
"""Given a torch callable, compute value + grad given a numpy array."""
375393
if np.isnan(x).any():
@@ -387,8 +405,21 @@ def f_np_wrapper(
387405
.contiguous()
388406
.requires_grad_(True)
389407
)
408+
if fixed_features is not None:
409+
if batch_indices is not None:
410+
this_fixed_features = {
411+
k: ff[batch_indices]
412+
if torch.is_tensor(ff) and ff.ndim > 0
413+
else ff
414+
for k, ff in fixed_features.items()
415+
}
416+
else:
417+
this_fixed_features = fixed_features
418+
else:
419+
this_fixed_features = None
420+
390421
X_fix = fix_features(
391-
X, fixed_features=fixed_features, replace_current_value=False
422+
X, fixed_features=this_fixed_features, replace_current_value=False
392423
)
393424
# we compute the loss on the whole batch, under the assumption that f
394425
# treats multiple inputs in the 0th dimension as independent
@@ -409,7 +440,7 @@ def f_np_wrapper(
409440
raise OptimizationGradientError(msg, current_x=x)
410441
fval = (
411442
losses.detach().view(-1).cpu().numpy()
412-
if batched
443+
if batch_indices is not None
413444
else loss.detach().item()
414445
) # the view(-1) seems necessary as f might return a single scalar
415446
return fval, gradf
@@ -485,7 +516,7 @@ def gen_candidates_torch(
485516
optimizer: type[Optimizer] = torch.optim.Adam,
486517
options: dict[str, float | str] | None = None,
487518
callback: Callable[[int, Tensor, Tensor], NoReturn] | None = None,
488-
fixed_features: dict[int, float | None] | None = None,
519+
fixed_features: Mapping[int, float | Tensor] | None = None,
489520
timeout_sec: float | None = None,
490521
) -> tuple[Tensor, Tensor]:
491522
r"""Generate a set of candidates using a `torch.optim` optimizer.
@@ -507,9 +538,10 @@ def gen_candidates_torch(
507538
the loss and gradients, but before calling the optimizer.
508539
fixed_features: This is a dictionary of feature indices to values, where
509540
all generated candidates will have features fixed to these values.
510-
If the dictionary value is None, then that feature will just be
511-
fixed to the clamped value and not optimized. Assumes values to be
512-
compatible with lower_bounds and upper_bounds!
541+
If a float is passed it is fixed across [b,q], if a tensor is passed:
542+
it might either be of shape [b,q] or [b], in which case the same value
543+
is used across the q dimension.
544+
Assumes values to be compatible with lower_bounds and upper_bounds!
513545
timeout_sec: Timeout (in seconds) for optimization. If provided,
514546
`gen_candidates_torch` will stop after this many seconds and return
515547
the best solution found so far.
@@ -533,46 +565,21 @@ def gen_candidates_torch(
533565
upper_bounds=bounds[1],
534566
)
535567
"""
536-
assert not fixed_features or not any(
537-
torch.is_tensor(v) for v in fixed_features.values()
538-
), "`gen_candidates_torch` does not support tensor-valued fixed features."
539-
540568
start_time = time.monotonic()
541569
options = options or {}
542-
543-
# if there are fixed features we may optimize over a domain of lower dimension
544-
if fixed_features:
545-
subproblem = _remove_fixed_features_from_optimization(
546-
fixed_features=fixed_features,
547-
acquisition_function=acquisition_function,
548-
initial_conditions=initial_conditions,
549-
d=initial_conditions.shape[-1],
550-
lower_bounds=lower_bounds,
551-
upper_bounds=upper_bounds,
552-
inequality_constraints=None,
553-
equality_constraints=None,
554-
nonlinear_inequality_constraints=None,
555-
)
556-
557-
# call the routine with no fixed_features
558-
elapsed = time.monotonic() - start_time
559-
clamped_candidates, batch_acquisition = gen_candidates_torch(
560-
initial_conditions=subproblem.initial_conditions,
561-
acquisition_function=subproblem.acquisition_function,
562-
lower_bounds=subproblem.lower_bounds,
563-
upper_bounds=subproblem.upper_bounds,
564-
optimizer=optimizer,
565-
options=options,
566-
callback=callback,
567-
fixed_features=None,
568-
timeout_sec=timeout_sec - elapsed if timeout_sec else None,
569-
)
570-
clamped_candidates = subproblem.acquisition_function._construct_X_full(
571-
clamped_candidates
572-
)
573-
return clamped_candidates, batch_acquisition
570+
# We remove max_optimization_problem_aggregation_size as it does not affect
571+
# the 1st order optimizers implemented in this method.
572+
# Here, it does not matter whether one combines multiple optimizations into
573+
# one or not.
574+
options.pop("max_optimization_problem_aggregation_size", None)
574575
_clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds)
575-
clamped_candidates = _clamp(initial_conditions).requires_grad_(True)
576+
clamped_candidates = _clamp(initial_conditions)
577+
if fixed_features:
578+
clamped_candidates = clamped_candidates[
579+
...,
580+
[i for i in range(clamped_candidates.shape[-1]) if i not in fixed_features],
581+
]
582+
clamped_candidates = clamped_candidates.requires_grad_(True)
576583
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))
577584

578585
i = 0
@@ -583,7 +590,7 @@ def gen_candidates_torch(
583590
with torch.no_grad():
584591
X = _clamp(clamped_candidates).requires_grad_(True)
585592

586-
loss = -acquisition_function(X).sum()
593+
loss = -acquisition_function(fix_features(X, fixed_features)).sum()
587594
grad = torch.autograd.grad(loss, X)[0]
588595
if callback:
589596
callback(i, loss, grad)
@@ -602,6 +609,7 @@ def assign_grad():
602609
logger.info(f"Optimization timed out after {runtime} seconds.")
603610

604611
clamped_candidates = _clamp(clamped_candidates)
612+
clamped_candidates = fix_features(clamped_candidates, fixed_features)
605613
with torch.no_grad():
606614
batch_acquisition = acquisition_function(clamped_candidates)
607615

botorch/optim/optimize.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,18 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
367367
def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
368368
batch_candidates_list: list[Tensor] = []
369369
batch_acq_values_list: list[Tensor] = []
370+
370371
batched_ics = batch_initial_conditions.split(batch_limit)
372+
if opt_inputs.fixed_features is None:
373+
batched_fixed_features = {}
374+
else:
375+
batched_fixed_features = {
376+
k: ff.split(batch_limit)
377+
if torch.is_tensor(ff) and ff.numel() > 1
378+
else [ff] * len(batched_ics)
379+
for k, ff in opt_inputs.fixed_features.items()
380+
}
381+
371382
opt_warnings = []
372383
timeout_sec = (
373384
opt_inputs.timeout_sec / len(batched_ics)
@@ -393,7 +404,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
393404
lower_bounds=lower_bounds,
394405
upper_bounds=upper_bounds,
395406
options=gen_options,
396-
fixed_features=opt_inputs.fixed_features,
407+
fixed_features={k: v[i] for k, v in batched_fixed_features.items()},
397408
timeout_sec=timeout_sec,
398409
**gen_kwargs,
399410
)

botorch/optim/utils/acquisition_utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515
from botorch.acquisition.acquisition import AcquisitionFunction
16-
from botorch.exceptions.errors import BotorchError
16+
from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError
1717
from botorch.exceptions.warnings import BotorchWarning
1818
from botorch.models.gpytorch import ModelListGPyTorchModel
1919
from torch import Tensor
@@ -58,14 +58,16 @@ def columnwise_clamp(
5858

5959
out = X.clamp(lower, upper)
6060
if raise_on_violation and not X.allclose(out):
61-
raise BotorchError("Original value(s) are out of bounds.")
61+
raise BotorchError(
62+
"Original value(s) are out of bounds: " f"{out=}, {X=}, {lower=}, {upper=}."
63+
)
6264

6365
return out
6466

6567

6668
def fix_features(
6769
X: Tensor,
68-
fixed_features: Mapping[int, float] | None = None,
70+
fixed_features: Mapping[int, float | Tensor] | None = None,
6971
replace_current_value: bool = True,
7072
) -> Tensor:
7173
r"""Fix feature values in a Tensor.
@@ -79,6 +81,8 @@ def fix_features(
7981
fixed_features: A mapping with keys as column indices and values
8082
equal to what the feature should be set to in `X`. Keys should be in the
8183
range `[0, p - 1]`.
84+
If a tensor is passed as value, it has to either have shape `b x q` or
85+
`b`, in which case the same value is used across the q dimension.
8286
replace_current_value: If True, replace the specified indexes, otherwise
8387
the indices are inserted.
8488
@@ -102,7 +106,18 @@ def fix_features(
102106
for index in range(new_X.shape[-1]):
103107
if index in fixed_features:
104108
value = fixed_features[index]
105-
value = torch.full_like(new_X[..., index], value)
109+
if torch.is_tensor(value) and value.ndim > 0:
110+
if X.ndim != 3:
111+
raise BotorchTensorDimensionError(
112+
"X must be a 3-dimensional tensor, as value is a tensor."
113+
f"X.shape = {X.shape}, value.shape = {value.shape}."
114+
)
115+
_b, q, _reduced_p = X.shape
116+
if value.ndim == 1:
117+
# Repeat values across the q dimension.
118+
value = value.unsqueeze(-1).repeat(1, q)
119+
else:
120+
value = torch.full_like(new_X[..., index], value)
106121
new_X[..., index] = value
107122
else:
108123
new_X[..., index] = X[..., filtered_index]

0 commit comments

Comments
 (0)