Skip to content

Commit 86b5659

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Batched Optimization For Mixed Optimization (#2895)
Summary: So far, our optimization in mixed search spaces work on each restart separately and sequentially instead of batching them. Here, we change this to batch the restarts, based on the new l-bfgs-b implementation that supports this. This speeds up mixed search spaces a lot (depending on the problem around 3-4x speedups). Differential Revision: D76517454
1 parent b4b8586 commit 86b5659

File tree

2 files changed

+263
-73
lines changed

2 files changed

+263
-73
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 143 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
MAX_ITER_INIT = 100
5151
CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance.
5252
DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates.
53+
STOP_AFTER_SHARE_CONVERGED = 1.0 # We optimize multiple configurations at once
54+
# in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
55+
# optimizing after the given share has converged.
56+
# Convergence is defined as the improvements of one discrete, followed by a scalar
57+
# optimization yield less than `CONVERGENCE_TOL` improvements.
5358

5459
SUPPORTED_OPTIONS = {
5560
"initialization_strategy",
@@ -63,6 +68,7 @@
6368
"std_cont_perturbation",
6469
"batch_limit",
6570
"init_batch_limit",
71+
"stop_after_share_converged",
6672
}
6773
SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"}
6874

@@ -564,63 +570,99 @@ def discrete_step(
564570
discrete_dims: A tensor of indices corresponding to binary and
565571
integer parameters.
566572
cat_dims: A tensor of indices corresponding to categorical parameters.
567-
current_x: Starting point. A tensor of shape `d`.
573+
current_x: Starting point. A tensor of shape `d` or `b x d`.
568574
569575
Returns:
570576
A tuple of two tensors: a (d)-dim tensor of optimized point
571577
and a scalar tensor of correspondins acquisition value.
572578
"""
579+
batched_current_x = current_x.view(-1, current_x.shape[-1]).clone()
573580
with torch.no_grad():
574-
current_acqval = opt_inputs.acq_function(current_x.unsqueeze(0))
581+
current_acqvals = opt_inputs.acq_function(batched_current_x.unsqueeze(1))
575582
options = opt_inputs.options or {}
576-
for _ in range(
577-
assert_is_instance(options.get("maxiter_discrete", MAX_ITER_DISCRETE), int)
578-
):
579-
neighbors = []
580-
if discrete_dims.numel():
581-
x_neighbors_discrete = get_nearest_neighbors(
582-
current_x=current_x.detach(),
583-
bounds=opt_inputs.bounds,
584-
discrete_dims=discrete_dims,
585-
)
586-
x_neighbors_discrete = _filter_infeasible(
587-
X=x_neighbors_discrete,
588-
inequality_constraints=opt_inputs.inequality_constraints,
589-
)
590-
neighbors.append(x_neighbors_discrete)
583+
maxiter_discrete = options.get("maxiter_discrete", MAX_ITER_DISCRETE)
584+
done = torch.zeros(len(batched_current_x), dtype=torch.bool)
585+
for _ in range(assert_is_instance(maxiter_discrete, int)):
586+
# we don't batch this, as the number of x_neighbors can be different
587+
# for each entry (as duplicates are removed), and the most expensive
588+
# op is the acq_function, which is batched
589+
# TODO one could try removing duplicate removal or use nested tensors
590+
# to make this parallel, if this loop is parallelized the second loop
591+
# in a few lines is also easy to parallelize
592+
x_neighbors_list = []
593+
for i in range(len(done)):
594+
x_neighbors = None
595+
if ~done[i]:
596+
neighbors = []
597+
if discrete_dims.numel():
598+
x_neighbors_discrete = get_nearest_neighbors(
599+
current_x=batched_current_x[i].detach(),
600+
bounds=opt_inputs.bounds,
601+
discrete_dims=discrete_dims,
602+
)
603+
x_neighbors_discrete = _filter_infeasible(
604+
X=x_neighbors_discrete,
605+
inequality_constraints=opt_inputs.inequality_constraints,
606+
)
607+
neighbors.append(x_neighbors_discrete)
591608

592-
if cat_dims.numel():
593-
x_neighbors_cat = get_categorical_neighbors(
594-
current_x=current_x.detach(),
595-
bounds=opt_inputs.bounds,
596-
cat_dims=cat_dims,
597-
)
598-
x_neighbors_cat = _filter_infeasible(
599-
X=x_neighbors_cat,
600-
inequality_constraints=opt_inputs.inequality_constraints,
601-
)
602-
neighbors.append(x_neighbors_cat)
609+
if cat_dims.numel():
610+
x_neighbors_cat = get_categorical_neighbors(
611+
current_x=batched_current_x[i].detach(),
612+
bounds=opt_inputs.bounds,
613+
cat_dims=cat_dims,
614+
)
615+
x_neighbors_cat = _filter_infeasible(
616+
X=x_neighbors_cat,
617+
inequality_constraints=opt_inputs.inequality_constraints,
618+
)
619+
neighbors.append(x_neighbors_cat)
620+
621+
x_neighbors = torch.cat(neighbors, dim=0)
622+
if x_neighbors.numel() == 0:
623+
# Exit gracefully with last point if no feasible neighbors left.
624+
done[i] = True
625+
x_neighbors_list.append(x_neighbors)
603626

604-
x_neighbors = torch.cat(neighbors, dim=0)
605-
if x_neighbors.numel() == 0:
606-
# Exit gracefully with last point if there are no feasible neighbors.
627+
if done.all():
607628
break
629+
630+
all_x_neighbors = torch.cat(
631+
[
632+
x_neighbors
633+
for x_neighbors in x_neighbors_list
634+
if x_neighbors is not None
635+
],
636+
dim=0,
637+
)
608638
with torch.no_grad():
609639
acq_vals = torch.cat(
610640
[
611641
opt_inputs.acq_function(X_.unsqueeze(-2))
612-
for X_ in x_neighbors.split(
642+
for X_ in all_x_neighbors.split(
613643
options.get("init_batch_limit", MAX_BATCH_SIZE)
614644
)
615645
]
616646
)
617-
argmax = acq_vals.argmax()
618-
improvement = acq_vals[argmax] - current_acqval
619-
if improvement > 0:
620-
current_acqval, current_x = acq_vals[argmax], x_neighbors[argmax]
621-
if improvement <= options.get("tol", CONVERGENCE_TOL):
622-
break
623-
return current_x, current_acqval
647+
offset = 0
648+
for i in range(len(done)):
649+
# assuming the offset incurred due to done samples is 0
650+
if ~done[i]:
651+
width = len(x_neighbors_list[i])
652+
x_neighbors = all_x_neighbors[offset : offset + width]
653+
max_acq, argmax = acq_vals[offset : offset + width].max(dim=0)
654+
improvement = acq_vals[offset + argmax] - current_acqvals[i]
655+
if improvement > 0:
656+
current_acqvals[i], batched_current_x[i] = (
657+
max_acq,
658+
x_neighbors[argmax],
659+
)
660+
if improvement <= options.get("tol", CONVERGENCE_TOL):
661+
done[i] = True
662+
663+
offset += width
664+
665+
return batched_current_x.view_as(current_x), current_acqvals
624666

625667

626668
def continuous_step(
@@ -638,36 +680,41 @@ def continuous_step(
638680
discrete_dims: A tensor of indices corresponding to binary and
639681
integer parameters.
640682
cat_dims: A tensor of indices corresponding to categorical parameters.
641-
current_x: Starting point. A tensor of shape `d`.
683+
current_x: Starting point. A tensor of shape `(b) x d`.
642684
643685
Returns:
644686
A tuple of two tensors: a (1 x d)-dim tensor of optimized points
645687
and a (1)-dim tensor of acquisition values.
646688
"""
689+
d = current_x.shape[-1]
690+
batched_current_x = current_x.view(-1, d)
691+
647692
options = opt_inputs.options or {}
648693
non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0)
649694

650-
if len(non_cont_dims) == len(current_x): # nothing continuous to optimize
695+
if len(non_cont_dims) == d: # nothing continuous to optimize
651696
with torch.no_grad():
652-
return current_x, opt_inputs.acq_function(current_x.unsqueeze(0))
697+
return current_x, opt_inputs.acq_function(batched_current_x)
653698

654699
updated_opt_inputs = dataclasses.replace(
655700
opt_inputs,
656701
q=1,
657702
num_restarts=1,
658703
raw_samples=None,
659-
batch_initial_conditions=current_x.unsqueeze(0),
704+
batch_initial_conditions=batched_current_x[:, None],
660705
fixed_features={
661-
**dict(zip(non_cont_dims.tolist(), current_x[non_cont_dims])),
706+
**{d: batched_current_x[:, d] for d in non_cont_dims.tolist()},
662707
**(opt_inputs.fixed_features or {}),
663708
},
664709
options={
665710
"maxiter": options.get("maxiter_continuous", MAX_ITER_CONT),
666711
"tol": options.get("tol", CONVERGENCE_TOL),
667712
"batch_limit": options.get("batch_limit", MAX_BATCH_SIZE),
713+
"max_optimization_problem_aggregation_size": 1,
668714
},
669715
)
670-
return _optimize_acqf(opt_inputs=updated_opt_inputs)
716+
best_X, best_acq_values = _optimize_acqf(opt_inputs=updated_opt_inputs)
717+
return best_X.view_as(current_x), best_acq_values
671718

672719

673720
def optimize_acqf_mixed_alternating(
@@ -730,6 +777,12 @@ def optimize_acqf_mixed_alternating(
730777
in a `no_grad` context, which reduces memory usage. As a result,
731778
`init_batch_limit` can be set to a larger value than `batch_limit`.
732779
Defaults to `batch_limit`, if given.
780+
- "stop_after_share_converged": We optimize multiple configurations at once
781+
in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
782+
optimizing after the given share has converged.
783+
Convergence is defined as the improvements of one discrete, followed by a
784+
scalar optimization yield less than `options["tol"]` improvements.
785+
Defaults to 1.
733786
q: Number of candidates.
734787
raw_samples: Number of initial candidates used to select starting points from.
735788
Defaults to 1024.
@@ -761,6 +814,12 @@ def optimize_acqf_mixed_alternating(
761814

762815
fixed_features = fixed_features or {}
763816
options = options or {}
817+
if options.get("max_optimization_problem_aggregation_size", 1) != 1:
818+
raise UnsupportedError(
819+
"optimize_acqf_mixed_alternating does not support "
820+
"max_optimization_problem_aggregation_size != 1. "
821+
"You might leave this option empty, though."
822+
)
764823
options.setdefault("batch_limit", MAX_BATCH_SIZE)
765824
options.setdefault("init_batch_limit", options["batch_limit"])
766825
if not (keys := set(options.keys())).issubset(SUPPORTED_OPTIONS):
@@ -793,11 +852,18 @@ def optimize_acqf_mixed_alternating(
793852
fixed_features=fixed_features,
794853
post_processing_func=post_processing_func,
795854
batch_initial_conditions=None,
796-
return_best_only=True,
855+
return_best_only=False, # We don't want to perform the cont. optimization
856+
# step and only return best, but this function itself only returns best
797857
gen_candidates=gen_candidates_scipy,
798-
sequential=sequential,
858+
sequential=sequential, # only relevant if all dims are cont.
799859
)
800-
_validate_sequential_inputs(opt_inputs=opt_inputs)
860+
if sequential:
861+
# Sequential optimization requires return_best_only to be True
862+
# But we turn it off here, as we "manually" perform the seq.
863+
# conditioning in the loop below
864+
_validate_sequential_inputs(
865+
opt_inputs=dataclasses.replace(opt_inputs, return_best_only=True)
866+
)
801867

802868
base_X_pending = acq_function.X_pending if q > 1 else None
803869
dim = bounds.shape[-1]
@@ -808,7 +874,12 @@ def optimize_acqf_mixed_alternating(
808874
non_cont_dims = [*discrete_dims, *cat_dims]
809875
if len(non_cont_dims) == 0:
810876
# If the problem is fully continuous, fall back to standard optimization.
811-
return _optimize_acqf(opt_inputs=opt_inputs)
877+
return _optimize_acqf(
878+
opt_inputs=dataclasses.replace(
879+
opt_inputs,
880+
return_best_only=True,
881+
)
882+
)
812883
if not (
813884
isinstance(non_cont_dims, list)
814885
and len(set(non_cont_dims)) == len(non_cont_dims)
@@ -842,26 +913,30 @@ def optimize_acqf_mixed_alternating(
842913
cont_dims=cont_dims,
843914
)
844915

845-
# TODO: Eliminate this for loop. Tensors being unequal sizes could potentially
846-
# be handled by concatenating them rather than stacking, and keeping a list
847-
# of indices.
848-
for i in range(num_restarts):
849-
alternate_steps = 0
850-
while alternate_steps < options.get("maxiter_alternating", MAX_ITER_ALTER):
851-
starting_acq_val = best_acq_val[i].clone()
852-
alternate_steps += 1
853-
for step in (discrete_step, continuous_step):
854-
best_X[i], best_acq_val[i] = step(
855-
opt_inputs=opt_inputs,
856-
discrete_dims=discrete_dims_t,
857-
cat_dims=cat_dims_t,
858-
current_x=best_X[i],
859-
)
916+
done = torch.zeros(len(best_X), dtype=torch.bool, device=tkwargs["device"])
917+
for _step in range(options.get("maxiter_alternating", MAX_ITER_ALTER)):
918+
starting_acq_val = best_acq_val.clone()
919+
best_X, best_acq_val = discrete_step(
920+
opt_inputs=opt_inputs,
921+
discrete_dims=discrete_dims_t,
922+
cat_dims=cat_dims_t,
923+
current_x=best_X,
924+
)
925+
926+
best_X[~done], best_acq_val[~done] = continuous_step(
927+
opt_inputs=opt_inputs,
928+
discrete_dims=discrete_dims_t,
929+
cat_dims=cat_dims_t,
930+
current_x=best_X[~done],
931+
)
860932

861-
improvement = best_acq_val[i] - starting_acq_val
862-
if improvement < options.get("tol", CONVERGENCE_TOL):
863-
# Check for convergence
864-
break
933+
improvement = best_acq_val - starting_acq_val
934+
done_now = improvement < options.get("tol", CONVERGENCE_TOL)
935+
done = done | done_now
936+
if done.float().mean() >= options.get(
937+
"stop_after_share_converged", STOP_AFTER_SHARE_CONVERGED
938+
):
939+
break
865940

866941
new_candidate = best_X[torch.argmax(best_acq_val)].unsqueeze(0)
867942
candidates = torch.cat([candidates, new_candidate], dim=-2)

0 commit comments

Comments
 (0)