Skip to content

Commit 73afed1

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Batched Optimization For Mixed Optimization (#2895)
Summary: So far, our optimization in mixed search spaces work on each restart separately and sequentially instead of batching them. Here, we change this to batch the restarts, based on the new l-bfgs-b implementation that supports this. This speeds up mixed search spaces a lot (depending on the problem around 3-4x speedups). Differential Revision: D76517454
1 parent 3d18ae3 commit 73afed1

File tree

2 files changed

+254
-73
lines changed

2 files changed

+254
-73
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 134 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
MAX_ITER_INIT = 100
5151
CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance.
5252
DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates.
53+
STOP_AFTER_SHARE_CONVERGED = 1.0 # We optimize multiple configurations at once
54+
# in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
55+
# optimizing after the given share has converged.
56+
# Convergence is defined as the improvements of one discrete, followed by a scalar
57+
# optimization yield less than `CONVERGENCE_TOL` improvements.
5358

5459
SUPPORTED_OPTIONS = {
5560
"initialization_strategy",
@@ -564,63 +569,99 @@ def discrete_step(
564569
discrete_dims: A tensor of indices corresponding to binary and
565570
integer parameters.
566571
cat_dims: A tensor of indices corresponding to categorical parameters.
567-
current_x: Starting point. A tensor of shape `d`.
572+
current_x: Starting point. A tensor of shape `d` or `b x d`.
568573
569574
Returns:
570575
A tuple of two tensors: a (d)-dim tensor of optimized point
571576
and a scalar tensor of correspondins acquisition value.
572577
"""
578+
batched_current_x = current_x.view(-1, current_x.shape[-1]).clone()
573579
with torch.no_grad():
574-
current_acqval = opt_inputs.acq_function(current_x.unsqueeze(0))
580+
current_acqvals = opt_inputs.acq_function(batched_current_x.unsqueeze(1))
575581
options = opt_inputs.options or {}
576-
for _ in range(
577-
assert_is_instance(options.get("maxiter_discrete", MAX_ITER_DISCRETE), int)
578-
):
579-
neighbors = []
580-
if discrete_dims.numel():
581-
x_neighbors_discrete = get_nearest_neighbors(
582-
current_x=current_x.detach(),
583-
bounds=opt_inputs.bounds,
584-
discrete_dims=discrete_dims,
585-
)
586-
x_neighbors_discrete = _filter_infeasible(
587-
X=x_neighbors_discrete,
588-
inequality_constraints=opt_inputs.inequality_constraints,
589-
)
590-
neighbors.append(x_neighbors_discrete)
582+
maxiter_discrete = options.get("maxiter_discrete", MAX_ITER_DISCRETE)
583+
done = torch.zeros(len(batched_current_x), dtype=torch.bool)
584+
for _ in range(assert_is_instance(maxiter_discrete, int)):
585+
# we don't batch this, as the number of x_neighbors can be different
586+
# for each entry (as duplicates are removed), and the most expensive
587+
# op is the acq_function, which is batched
588+
# TODO one could try removing duplicate removal or use nested tensors
589+
# to make this parallel, if this loop is parallelized the second loop
590+
# in a few lines is also easy to parallelize
591+
x_neighbors_list = []
592+
for i in range(len(done)):
593+
x_neighbors = None
594+
if ~done[i]:
595+
neighbors = []
596+
if discrete_dims.numel():
597+
x_neighbors_discrete = get_nearest_neighbors(
598+
current_x=batched_current_x[i].detach(),
599+
bounds=opt_inputs.bounds,
600+
discrete_dims=discrete_dims,
601+
)
602+
x_neighbors_discrete = _filter_infeasible(
603+
X=x_neighbors_discrete,
604+
inequality_constraints=opt_inputs.inequality_constraints,
605+
)
606+
neighbors.append(x_neighbors_discrete)
591607

592-
if cat_dims.numel():
593-
x_neighbors_cat = get_categorical_neighbors(
594-
current_x=current_x.detach(),
595-
bounds=opt_inputs.bounds,
596-
cat_dims=cat_dims,
597-
)
598-
x_neighbors_cat = _filter_infeasible(
599-
X=x_neighbors_cat,
600-
inequality_constraints=opt_inputs.inequality_constraints,
601-
)
602-
neighbors.append(x_neighbors_cat)
608+
if cat_dims.numel():
609+
x_neighbors_cat = get_categorical_neighbors(
610+
current_x=batched_current_x[i].detach(),
611+
bounds=opt_inputs.bounds,
612+
cat_dims=cat_dims,
613+
)
614+
x_neighbors_cat = _filter_infeasible(
615+
X=x_neighbors_cat,
616+
inequality_constraints=opt_inputs.inequality_constraints,
617+
)
618+
neighbors.append(x_neighbors_cat)
619+
620+
x_neighbors = torch.cat(neighbors, dim=0)
621+
if x_neighbors.numel() == 0:
622+
# Exit gracefully with last point if no feasible neighbors left.
623+
done[i] = True
624+
x_neighbors_list.append(x_neighbors)
603625

604-
x_neighbors = torch.cat(neighbors, dim=0)
605-
if x_neighbors.numel() == 0:
606-
# Exit gracefully with last point if there are no feasible neighbors.
626+
if done.all():
607627
break
628+
629+
all_x_neighbors = torch.cat(
630+
[
631+
x_neighbors
632+
for x_neighbors in x_neighbors_list
633+
if x_neighbors is not None
634+
],
635+
dim=0,
636+
)
608637
with torch.no_grad():
609638
acq_vals = torch.cat(
610639
[
611640
opt_inputs.acq_function(X_.unsqueeze(-2))
612-
for X_ in x_neighbors.split(
641+
for X_ in all_x_neighbors.split(
613642
options.get("init_batch_limit", MAX_BATCH_SIZE)
614643
)
615644
]
616645
)
617-
argmax = acq_vals.argmax()
618-
improvement = acq_vals[argmax] - current_acqval
619-
if improvement > 0:
620-
current_acqval, current_x = acq_vals[argmax], x_neighbors[argmax]
621-
if improvement <= options.get("tol", CONVERGENCE_TOL):
622-
break
623-
return current_x, current_acqval
646+
offset = 0
647+
for i in range(len(done)):
648+
# assuming the offset incurred due to done samples is 0
649+
if ~done[i]:
650+
width = len(x_neighbors_list[i])
651+
x_neighbors = all_x_neighbors[offset : offset + width]
652+
max_acq, argmax = acq_vals[offset : offset + width].max(dim=0)
653+
improvement = acq_vals[offset + argmax] - current_acqvals[i]
654+
if improvement > 0:
655+
current_acqvals[i], batched_current_x[i] = (
656+
max_acq,
657+
x_neighbors[argmax],
658+
)
659+
if improvement <= options.get("tol", CONVERGENCE_TOL):
660+
done[i] = True
661+
662+
offset += width
663+
664+
return batched_current_x.view_as(current_x), current_acqvals
624665

625666

626667
def continuous_step(
@@ -638,36 +679,41 @@ def continuous_step(
638679
discrete_dims: A tensor of indices corresponding to binary and
639680
integer parameters.
640681
cat_dims: A tensor of indices corresponding to categorical parameters.
641-
current_x: Starting point. A tensor of shape `d`.
682+
current_x: Starting point. A tensor of shape `(b) x d`.
642683
643684
Returns:
644685
A tuple of two tensors: a (1 x d)-dim tensor of optimized points
645686
and a (1)-dim tensor of acquisition values.
646687
"""
688+
d = current_x.shape[-1]
689+
batched_current_x = current_x.view(-1, d)
690+
647691
options = opt_inputs.options or {}
648692
non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0)
649693

650-
if len(non_cont_dims) == len(current_x): # nothing continuous to optimize
694+
if len(non_cont_dims) == d: # nothing continuous to optimize
651695
with torch.no_grad():
652-
return current_x, opt_inputs.acq_function(current_x.unsqueeze(0))
696+
return current_x, opt_inputs.acq_function(batched_current_x)
653697

654698
updated_opt_inputs = dataclasses.replace(
655699
opt_inputs,
656700
q=1,
657701
num_restarts=1,
658702
raw_samples=None,
659-
batch_initial_conditions=current_x.unsqueeze(0),
703+
batch_initial_conditions=batched_current_x[:, None],
660704
fixed_features={
661-
**dict(zip(non_cont_dims.tolist(), current_x[non_cont_dims])),
705+
**{d: batched_current_x[:, d] for d in non_cont_dims.tolist()},
662706
**(opt_inputs.fixed_features or {}),
663707
},
664708
options={
665709
"maxiter": options.get("maxiter_continuous", MAX_ITER_CONT),
666710
"tol": options.get("tol", CONVERGENCE_TOL),
667711
"batch_limit": options.get("batch_limit", MAX_BATCH_SIZE),
712+
"max_optimization_problem_aggregation_size": 1,
668713
},
669714
)
670-
return _optimize_acqf(opt_inputs=updated_opt_inputs)
715+
best_X, best_acq_values = _optimize_acqf(opt_inputs=updated_opt_inputs)
716+
return best_X.view_as(current_x), best_acq_values
671717

672718

673719
def optimize_acqf_mixed_alternating(
@@ -761,6 +807,12 @@ def optimize_acqf_mixed_alternating(
761807

762808
fixed_features = fixed_features or {}
763809
options = options or {}
810+
if options.get("max_optimization_problem_aggregation_size", 1) != 1:
811+
raise UnsupportedError(
812+
"optimize_acqf_mixed_alternating does not support "
813+
"max_optimization_problem_aggregation_size != 1. "
814+
"You might leave this option empty, though."
815+
)
764816
options.setdefault("batch_limit", MAX_BATCH_SIZE)
765817
options.setdefault("init_batch_limit", options["batch_limit"])
766818
if not (keys := set(options.keys())).issubset(SUPPORTED_OPTIONS):
@@ -793,11 +845,18 @@ def optimize_acqf_mixed_alternating(
793845
fixed_features=fixed_features,
794846
post_processing_func=post_processing_func,
795847
batch_initial_conditions=None,
796-
return_best_only=True,
848+
return_best_only=False, # We don't want to perform the cont. optimization
849+
# step and only return best, but this function itself only returns best
797850
gen_candidates=gen_candidates_scipy,
798-
sequential=sequential,
851+
sequential=sequential, # only relevant if all dims are cont.
799852
)
800-
_validate_sequential_inputs(opt_inputs=opt_inputs)
853+
if sequential:
854+
# Sequential optimization requires return_best_only to be True
855+
# But we turn it off here, as we "manually" perform the seq.
856+
# conditioning in the loop below
857+
_validate_sequential_inputs(
858+
opt_inputs=dataclasses.replace(opt_inputs, return_best_only=True)
859+
)
801860

802861
base_X_pending = acq_function.X_pending if q > 1 else None
803862
dim = bounds.shape[-1]
@@ -808,7 +867,12 @@ def optimize_acqf_mixed_alternating(
808867
non_cont_dims = [*discrete_dims, *cat_dims]
809868
if len(non_cont_dims) == 0:
810869
# If the problem is fully continuous, fall back to standard optimization.
811-
return _optimize_acqf(opt_inputs=opt_inputs)
870+
return _optimize_acqf(
871+
opt_inputs=dataclasses.replace(
872+
opt_inputs,
873+
return_best_only=True,
874+
)
875+
)
812876
if not (
813877
isinstance(non_cont_dims, list)
814878
and len(set(non_cont_dims)) == len(non_cont_dims)
@@ -842,26 +906,28 @@ def optimize_acqf_mixed_alternating(
842906
cont_dims=cont_dims,
843907
)
844908

845-
# TODO: Eliminate this for loop. Tensors being unequal sizes could potentially
846-
# be handled by concatenating them rather than stacking, and keeping a list
847-
# of indices.
848-
for i in range(num_restarts):
849-
alternate_steps = 0
850-
while alternate_steps < options.get("maxiter_alternating", MAX_ITER_ALTER):
851-
starting_acq_val = best_acq_val[i].clone()
852-
alternate_steps += 1
853-
for step in (discrete_step, continuous_step):
854-
best_X[i], best_acq_val[i] = step(
855-
opt_inputs=opt_inputs,
856-
discrete_dims=discrete_dims_t,
857-
cat_dims=cat_dims_t,
858-
current_x=best_X[i],
859-
)
909+
done = torch.zeros(len(best_X), dtype=torch.bool, device=tkwargs["device"])
910+
for _step in range(options.get("maxiter_alternating", MAX_ITER_ALTER)):
911+
starting_acq_val = best_acq_val.clone()
912+
best_X, best_acq_val = discrete_step(
913+
opt_inputs=opt_inputs,
914+
discrete_dims=discrete_dims_t,
915+
cat_dims=cat_dims_t,
916+
current_x=best_X,
917+
)
918+
919+
best_X[~done], best_acq_val[~done] = continuous_step(
920+
opt_inputs=opt_inputs,
921+
discrete_dims=discrete_dims_t,
922+
cat_dims=cat_dims_t,
923+
current_x=best_X[~done],
924+
)
860925

861-
improvement = best_acq_val[i] - starting_acq_val
862-
if improvement < options.get("tol", CONVERGENCE_TOL):
863-
# Check for convergence
864-
break
926+
improvement = best_acq_val - starting_acq_val
927+
done_now = improvement < options.get("tol", CONVERGENCE_TOL)
928+
done = done | done_now
929+
if done.float().mean() >= STOP_AFTER_SHARE_CONVERGED:
930+
break
865931

866932
new_candidate = best_X[torch.argmax(best_acq_val)].unsqueeze(0)
867933
candidates = torch.cat([candidates, new_candidate], dim=-2)

0 commit comments

Comments
 (0)