Skip to content

Commit 930d435

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Batched Optimization For Mixed Optimization (#2895)
Summary: Pull Request resolved: #2895 So far, our optimization in mixed search spaces work on each restart separately and sequentially instead of batching them. Here, we change this to batch the restarts, based on the new l-bfgs-b implementation that supports this. This speeds up mixed search spaces a lot (depending on the problem around 3-4x speedups). Differential Revision: D76517454
1 parent ba2d731 commit 930d435

File tree

2 files changed

+262
-73
lines changed

2 files changed

+262
-73
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 142 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
MAX_ITER_INIT = 100
5151
CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance.
5252
DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates.
53+
STOP_AFTER_SHARE_CONVERGED = 1.0 # We optimize multiple configurations at once
54+
# in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
55+
# optimizing after the given share has converged.
56+
# Convergence is defined as the improvements of one discrete, followed by a scalar
57+
# optimization yield less than `CONVERGENCE_TOL` improvements.
5358

5459
SUPPORTED_OPTIONS = {
5560
"initialization_strategy",
@@ -63,6 +68,7 @@
6368
"std_cont_perturbation",
6469
"batch_limit",
6570
"init_batch_limit",
71+
"stop_after_share_converged",
6672
}
6773
SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"}
6874

@@ -564,63 +570,99 @@ def discrete_step(
564570
discrete_dims: A tensor of indices corresponding to binary and
565571
integer parameters.
566572
cat_dims: A tensor of indices corresponding to categorical parameters.
567-
current_x: Starting point. A tensor of shape `d`.
573+
current_x: Starting point. A tensor of shape `d` or `b x d`.
568574
569575
Returns:
570576
A tuple of two tensors: a (d)-dim tensor of optimized point
571577
and a scalar tensor of correspondins acquisition value.
572578
"""
579+
batched_current_x = current_x.view(-1, current_x.shape[-1]).clone()
573580
with torch.no_grad():
574-
current_acqval = opt_inputs.acq_function(current_x.unsqueeze(0))
581+
current_acqvals = opt_inputs.acq_function(batched_current_x.unsqueeze(1))
575582
options = opt_inputs.options or {}
576-
for _ in range(
577-
assert_is_instance(options.get("maxiter_discrete", MAX_ITER_DISCRETE), int)
578-
):
579-
neighbors = []
580-
if discrete_dims.numel():
581-
x_neighbors_discrete = get_nearest_neighbors(
582-
current_x=current_x.detach(),
583-
bounds=opt_inputs.bounds,
584-
discrete_dims=discrete_dims,
585-
)
586-
x_neighbors_discrete = _filter_infeasible(
587-
X=x_neighbors_discrete,
588-
inequality_constraints=opt_inputs.inequality_constraints,
589-
)
590-
neighbors.append(x_neighbors_discrete)
583+
maxiter_discrete = options.get("maxiter_discrete", MAX_ITER_DISCRETE)
584+
done = torch.zeros(len(batched_current_x), dtype=torch.bool)
585+
for _ in range(assert_is_instance(maxiter_discrete, int)):
586+
# we don't batch this, as the number of x_neighbors can be different
587+
# for each entry (as duplicates are removed), and the most expensive
588+
# op is the acq_function, which is batched
589+
# TODO one could try removing duplicate removal or use nested tensors
590+
# to make this parallel, if this loop is parallelized the second loop
591+
# in a few lines is also easy to parallelize
592+
x_neighbors_list = []
593+
for i in range(len(done)):
594+
x_neighbors = None
595+
if ~done[i]:
596+
neighbors = []
597+
if discrete_dims.numel():
598+
x_neighbors_discrete = get_nearest_neighbors(
599+
current_x=batched_current_x[i].detach(),
600+
bounds=opt_inputs.bounds,
601+
discrete_dims=discrete_dims,
602+
)
603+
x_neighbors_discrete = _filter_infeasible(
604+
X=x_neighbors_discrete,
605+
inequality_constraints=opt_inputs.inequality_constraints,
606+
)
607+
neighbors.append(x_neighbors_discrete)
591608

592-
if cat_dims.numel():
593-
x_neighbors_cat = get_categorical_neighbors(
594-
current_x=current_x.detach(),
595-
bounds=opt_inputs.bounds,
596-
cat_dims=cat_dims,
597-
)
598-
x_neighbors_cat = _filter_infeasible(
599-
X=x_neighbors_cat,
600-
inequality_constraints=opt_inputs.inequality_constraints,
601-
)
602-
neighbors.append(x_neighbors_cat)
609+
if cat_dims.numel():
610+
x_neighbors_cat = get_categorical_neighbors(
611+
current_x=batched_current_x[i].detach(),
612+
bounds=opt_inputs.bounds,
613+
cat_dims=cat_dims,
614+
)
615+
x_neighbors_cat = _filter_infeasible(
616+
X=x_neighbors_cat,
617+
inequality_constraints=opt_inputs.inequality_constraints,
618+
)
619+
neighbors.append(x_neighbors_cat)
620+
621+
x_neighbors = torch.cat(neighbors, dim=0)
622+
if x_neighbors.numel() == 0:
623+
# Exit gracefully with last point if no feasible neighbors left.
624+
done[i] = True
625+
x_neighbors_list.append(x_neighbors)
603626

604-
x_neighbors = torch.cat(neighbors, dim=0)
605-
if x_neighbors.numel() == 0:
606-
# Exit gracefully with last point if there are no feasible neighbors.
627+
if done.all():
607628
break
629+
630+
all_x_neighbors = torch.cat(
631+
[
632+
x_neighbors
633+
for x_neighbors in x_neighbors_list
634+
if x_neighbors is not None
635+
],
636+
dim=0,
637+
)
608638
with torch.no_grad():
609639
acq_vals = torch.cat(
610640
[
611641
opt_inputs.acq_function(X_.unsqueeze(-2))
612-
for X_ in x_neighbors.split(
642+
for X_ in all_x_neighbors.split(
613643
options.get("init_batch_limit", MAX_BATCH_SIZE)
614644
)
615645
]
616646
)
617-
argmax = acq_vals.argmax()
618-
improvement = acq_vals[argmax] - current_acqval
619-
if improvement > 0:
620-
current_acqval, current_x = acq_vals[argmax], x_neighbors[argmax]
621-
if improvement <= options.get("tol", CONVERGENCE_TOL):
622-
break
623-
return current_x, current_acqval
647+
offset = 0
648+
for i in range(len(done)):
649+
# assuming the offset incurred due to done samples is 0
650+
if ~done[i]:
651+
width = len(x_neighbors_list[i])
652+
x_neighbors = all_x_neighbors[offset : offset + width]
653+
max_acq, argmax = acq_vals[offset : offset + width].max(dim=0)
654+
improvement = acq_vals[offset + argmax] - current_acqvals[i]
655+
if improvement > 0:
656+
current_acqvals[i], batched_current_x[i] = (
657+
max_acq,
658+
x_neighbors[argmax],
659+
)
660+
if improvement <= options.get("tol", CONVERGENCE_TOL):
661+
done[i] = True
662+
663+
offset += width
664+
665+
return batched_current_x.view_as(current_x), current_acqvals
624666

625667

626668
def continuous_step(
@@ -638,36 +680,41 @@ def continuous_step(
638680
discrete_dims: A tensor of indices corresponding to binary and
639681
integer parameters.
640682
cat_dims: A tensor of indices corresponding to categorical parameters.
641-
current_x: Starting point. A tensor of shape `d`.
683+
current_x: Starting point. A tensor of shape `(b) x d`.
642684
643685
Returns:
644686
A tuple of two tensors: a (1 x d)-dim tensor of optimized points
645687
and a (1)-dim tensor of acquisition values.
646688
"""
689+
d = current_x.shape[-1]
690+
batched_current_x = current_x.view(-1, d)
691+
647692
options = opt_inputs.options or {}
648693
non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0)
649694

650-
if len(non_cont_dims) == len(current_x): # nothing continuous to optimize
695+
if len(non_cont_dims) == d: # nothing continuous to optimize
651696
with torch.no_grad():
652-
return current_x, opt_inputs.acq_function(current_x.unsqueeze(0))
697+
return current_x, opt_inputs.acq_function(batched_current_x)
653698

654699
updated_opt_inputs = dataclasses.replace(
655700
opt_inputs,
656701
q=1,
657702
num_restarts=1,
658703
raw_samples=None,
659-
batch_initial_conditions=current_x.unsqueeze(0),
704+
batch_initial_conditions=batched_current_x[:, None],
660705
fixed_features={
661-
**dict(zip(non_cont_dims.tolist(), current_x[non_cont_dims])),
706+
**{d: batched_current_x[:, d] for d in non_cont_dims.tolist()},
662707
**(opt_inputs.fixed_features or {}),
663708
},
664709
options={
665710
"maxiter": options.get("maxiter_continuous", MAX_ITER_CONT),
666711
"tol": options.get("tol", CONVERGENCE_TOL),
667712
"batch_limit": options.get("batch_limit", MAX_BATCH_SIZE),
713+
"max_optimization_problem_aggregation_size": 1,
668714
},
669715
)
670-
return _optimize_acqf(opt_inputs=updated_opt_inputs)
716+
best_X, best_acq_values = _optimize_acqf(opt_inputs=updated_opt_inputs)
717+
return best_X.view_as(current_x), best_acq_values
671718

672719

673720
def optimize_acqf_mixed_alternating(
@@ -730,6 +777,11 @@ def optimize_acqf_mixed_alternating(
730777
in a `no_grad` context, which reduces memory usage. As a result,
731778
`init_batch_limit` can be set to a larger value than `batch_limit`.
732779
Defaults to `batch_limit`, if given.
780+
- "stop_after_share_converged": We optimize multiple configurations at once
781+
in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
782+
optimizing after the given share has converged.
783+
Convergence is defined as the improvements of one discrete, followed by a scalar
784+
optimization yield less than `options["tol"]` improvements. Defaults to 1.
733785
q: Number of candidates.
734786
raw_samples: Number of initial candidates used to select starting points from.
735787
Defaults to 1024.
@@ -761,6 +813,12 @@ def optimize_acqf_mixed_alternating(
761813

762814
fixed_features = fixed_features or {}
763815
options = options or {}
816+
if options.get("max_optimization_problem_aggregation_size", 1) != 1:
817+
raise UnsupportedError(
818+
"optimize_acqf_mixed_alternating does not support "
819+
"max_optimization_problem_aggregation_size != 1. "
820+
"You might leave this option empty, though."
821+
)
764822
options.setdefault("batch_limit", MAX_BATCH_SIZE)
765823
options.setdefault("init_batch_limit", options["batch_limit"])
766824
if not (keys := set(options.keys())).issubset(SUPPORTED_OPTIONS):
@@ -793,11 +851,18 @@ def optimize_acqf_mixed_alternating(
793851
fixed_features=fixed_features,
794852
post_processing_func=post_processing_func,
795853
batch_initial_conditions=None,
796-
return_best_only=True,
854+
return_best_only=False, # We don't want to perform the cont. optimization
855+
# step and only return best, but this function itself only returns best
797856
gen_candidates=gen_candidates_scipy,
798-
sequential=sequential,
857+
sequential=sequential, # only relevant if all dims are cont.
799858
)
800-
_validate_sequential_inputs(opt_inputs=opt_inputs)
859+
if sequential:
860+
# Sequential optimization requires return_best_only to be True
861+
# But we turn it off here, as we "manually" perform the seq.
862+
# conditioning in the loop below
863+
_validate_sequential_inputs(
864+
opt_inputs=dataclasses.replace(opt_inputs, return_best_only=True)
865+
)
801866

802867
base_X_pending = acq_function.X_pending if q > 1 else None
803868
dim = bounds.shape[-1]
@@ -808,7 +873,12 @@ def optimize_acqf_mixed_alternating(
808873
non_cont_dims = [*discrete_dims, *cat_dims]
809874
if len(non_cont_dims) == 0:
810875
# If the problem is fully continuous, fall back to standard optimization.
811-
return _optimize_acqf(opt_inputs=opt_inputs)
876+
return _optimize_acqf(
877+
opt_inputs=dataclasses.replace(
878+
opt_inputs,
879+
return_best_only=True,
880+
)
881+
)
812882
if not (
813883
isinstance(non_cont_dims, list)
814884
and len(set(non_cont_dims)) == len(non_cont_dims)
@@ -842,26 +912,30 @@ def optimize_acqf_mixed_alternating(
842912
cont_dims=cont_dims,
843913
)
844914

845-
# TODO: Eliminate this for loop. Tensors being unequal sizes could potentially
846-
# be handled by concatenating them rather than stacking, and keeping a list
847-
# of indices.
848-
for i in range(num_restarts):
849-
alternate_steps = 0
850-
while alternate_steps < options.get("maxiter_alternating", MAX_ITER_ALTER):
851-
starting_acq_val = best_acq_val[i].clone()
852-
alternate_steps += 1
853-
for step in (discrete_step, continuous_step):
854-
best_X[i], best_acq_val[i] = step(
855-
opt_inputs=opt_inputs,
856-
discrete_dims=discrete_dims_t,
857-
cat_dims=cat_dims_t,
858-
current_x=best_X[i],
859-
)
915+
done = torch.zeros(len(best_X), dtype=torch.bool, device=tkwargs["device"])
916+
for _step in range(options.get("maxiter_alternating", MAX_ITER_ALTER)):
917+
starting_acq_val = best_acq_val.clone()
918+
best_X, best_acq_val = discrete_step(
919+
opt_inputs=opt_inputs,
920+
discrete_dims=discrete_dims_t,
921+
cat_dims=cat_dims_t,
922+
current_x=best_X,
923+
)
924+
925+
best_X[~done], best_acq_val[~done] = continuous_step(
926+
opt_inputs=opt_inputs,
927+
discrete_dims=discrete_dims_t,
928+
cat_dims=cat_dims_t,
929+
current_x=best_X[~done],
930+
)
860931

861-
improvement = best_acq_val[i] - starting_acq_val
862-
if improvement < options.get("tol", CONVERGENCE_TOL):
863-
# Check for convergence
864-
break
932+
improvement = best_acq_val - starting_acq_val
933+
done_now = improvement < options.get("tol", CONVERGENCE_TOL)
934+
done = done | done_now
935+
if done.float().mean() >= options.get(
936+
"stop_after_share_converged", STOP_AFTER_SHARE_CONVERGED
937+
):
938+
break
865939

866940
new_candidate = best_X[torch.argmax(best_acq_val)].unsqueeze(0)
867941
candidates = torch.cat([candidates, new_candidate], dim=-2)

0 commit comments

Comments
 (0)