50
50
MAX_ITER_INIT = 100
51
51
CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance.
52
52
DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates.
53
+ STOP_AFTER_SHARE_CONVERGED = 1.0 # We optimize multiple configurations at once
54
+ # in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
55
+ # optimizing after the given share has converged.
56
+ # Convergence is defined as the improvements of one discrete, followed by a scalar
57
+ # optimization yield less than `CONVERGENCE_TOL` improvements.
53
58
54
59
SUPPORTED_OPTIONS = {
55
60
"initialization_strategy" ,
@@ -564,63 +569,99 @@ def discrete_step(
564
569
discrete_dims: A tensor of indices corresponding to binary and
565
570
integer parameters.
566
571
cat_dims: A tensor of indices corresponding to categorical parameters.
567
- current_x: Starting point. A tensor of shape `d`.
572
+ current_x: Starting point. A tensor of shape `d` or `b x d` .
568
573
569
574
Returns:
570
575
A tuple of two tensors: a (d)-dim tensor of optimized point
571
576
and a scalar tensor of correspondins acquisition value.
572
577
"""
578
+ batched_current_x = current_x .view (- 1 , current_x .shape [- 1 ]).clone ()
573
579
with torch .no_grad ():
574
- current_acqval = opt_inputs .acq_function (current_x .unsqueeze (0 ))
580
+ current_acqvals = opt_inputs .acq_function (batched_current_x .unsqueeze (1 ))
575
581
options = opt_inputs .options or {}
576
- for _ in range (
577
- assert_is_instance (options .get ("maxiter_discrete" , MAX_ITER_DISCRETE ), int )
578
- ):
579
- neighbors = []
580
- if discrete_dims .numel ():
581
- x_neighbors_discrete = get_nearest_neighbors (
582
- current_x = current_x .detach (),
583
- bounds = opt_inputs .bounds ,
584
- discrete_dims = discrete_dims ,
585
- )
586
- x_neighbors_discrete = _filter_infeasible (
587
- X = x_neighbors_discrete ,
588
- inequality_constraints = opt_inputs .inequality_constraints ,
589
- )
590
- neighbors .append (x_neighbors_discrete )
582
+ maxiter_discrete = options .get ("maxiter_discrete" , MAX_ITER_DISCRETE )
583
+ done = torch .zeros (len (batched_current_x ), dtype = torch .bool )
584
+ for _ in range (assert_is_instance (maxiter_discrete , int )):
585
+ # we don't batch this, as the number of x_neighbors can be different
586
+ # for each entry (as duplicates are removed), and the most expensive
587
+ # op is the acq_function, which is batched
588
+ # TODO one could try removing duplicate removal or use nested tensors
589
+ # to make this parallel, if this loop is parallelized the second loop
590
+ # in a few lines is also easy to parallelize
591
+ x_neighbors_list = []
592
+ for i in range (len (done )):
593
+ x_neighbors = None
594
+ if ~ done [i ]:
595
+ neighbors = []
596
+ if discrete_dims .numel ():
597
+ x_neighbors_discrete = get_nearest_neighbors (
598
+ current_x = batched_current_x [i ].detach (),
599
+ bounds = opt_inputs .bounds ,
600
+ discrete_dims = discrete_dims ,
601
+ )
602
+ x_neighbors_discrete = _filter_infeasible (
603
+ X = x_neighbors_discrete ,
604
+ inequality_constraints = opt_inputs .inequality_constraints ,
605
+ )
606
+ neighbors .append (x_neighbors_discrete )
591
607
592
- if cat_dims .numel ():
593
- x_neighbors_cat = get_categorical_neighbors (
594
- current_x = current_x .detach (),
595
- bounds = opt_inputs .bounds ,
596
- cat_dims = cat_dims ,
597
- )
598
- x_neighbors_cat = _filter_infeasible (
599
- X = x_neighbors_cat ,
600
- inequality_constraints = opt_inputs .inequality_constraints ,
601
- )
602
- neighbors .append (x_neighbors_cat )
608
+ if cat_dims .numel ():
609
+ x_neighbors_cat = get_categorical_neighbors (
610
+ current_x = batched_current_x [i ].detach (),
611
+ bounds = opt_inputs .bounds ,
612
+ cat_dims = cat_dims ,
613
+ )
614
+ x_neighbors_cat = _filter_infeasible (
615
+ X = x_neighbors_cat ,
616
+ inequality_constraints = opt_inputs .inequality_constraints ,
617
+ )
618
+ neighbors .append (x_neighbors_cat )
619
+
620
+ x_neighbors = torch .cat (neighbors , dim = 0 )
621
+ if x_neighbors .numel () == 0 :
622
+ # Exit gracefully with last point if no feasible neighbors left.
623
+ done [i ] = True
624
+ x_neighbors_list .append (x_neighbors )
603
625
604
- x_neighbors = torch .cat (neighbors , dim = 0 )
605
- if x_neighbors .numel () == 0 :
606
- # Exit gracefully with last point if there are no feasible neighbors.
626
+ if done .all ():
607
627
break
628
+
629
+ all_x_neighbors = torch .cat (
630
+ [
631
+ x_neighbors
632
+ for x_neighbors in x_neighbors_list
633
+ if x_neighbors is not None
634
+ ],
635
+ dim = 0 ,
636
+ )
608
637
with torch .no_grad ():
609
638
acq_vals = torch .cat (
610
639
[
611
640
opt_inputs .acq_function (X_ .unsqueeze (- 2 ))
612
- for X_ in x_neighbors .split (
641
+ for X_ in all_x_neighbors .split (
613
642
options .get ("init_batch_limit" , MAX_BATCH_SIZE )
614
643
)
615
644
]
616
645
)
617
- argmax = acq_vals .argmax ()
618
- improvement = acq_vals [argmax ] - current_acqval
619
- if improvement > 0 :
620
- current_acqval , current_x = acq_vals [argmax ], x_neighbors [argmax ]
621
- if improvement <= options .get ("tol" , CONVERGENCE_TOL ):
622
- break
623
- return current_x , current_acqval
646
+ offset = 0
647
+ for i in range (len (done )):
648
+ # assuming the offset incurred due to done samples is 0
649
+ if ~ done [i ]:
650
+ width = len (x_neighbors_list [i ])
651
+ x_neighbors = all_x_neighbors [offset : offset + width ]
652
+ max_acq , argmax = acq_vals [offset : offset + width ].max (dim = 0 )
653
+ improvement = acq_vals [offset + argmax ] - current_acqvals [i ]
654
+ if improvement > 0 :
655
+ current_acqvals [i ], batched_current_x [i ] = (
656
+ max_acq ,
657
+ x_neighbors [argmax ],
658
+ )
659
+ if improvement <= options .get ("tol" , CONVERGENCE_TOL ):
660
+ done [i ] = True
661
+
662
+ offset += width
663
+
664
+ return batched_current_x .view_as (current_x ), current_acqvals
624
665
625
666
626
667
def continuous_step (
@@ -638,36 +679,41 @@ def continuous_step(
638
679
discrete_dims: A tensor of indices corresponding to binary and
639
680
integer parameters.
640
681
cat_dims: A tensor of indices corresponding to categorical parameters.
641
- current_x: Starting point. A tensor of shape `d`.
682
+ current_x: Starting point. A tensor of shape `(b) x d`.
642
683
643
684
Returns:
644
685
A tuple of two tensors: a (1 x d)-dim tensor of optimized points
645
686
and a (1)-dim tensor of acquisition values.
646
687
"""
688
+ d = current_x .shape [- 1 ]
689
+ batched_current_x = current_x .view (- 1 , d )
690
+
647
691
options = opt_inputs .options or {}
648
692
non_cont_dims = torch .cat ((discrete_dims , cat_dims ), dim = 0 )
649
693
650
- if len (non_cont_dims ) == len ( current_x ) : # nothing continuous to optimize
694
+ if len (non_cont_dims ) == d : # nothing continuous to optimize
651
695
with torch .no_grad ():
652
- return current_x , opt_inputs .acq_function (current_x . unsqueeze ( 0 ) )
696
+ return current_x , opt_inputs .acq_function (batched_current_x )
653
697
654
698
updated_opt_inputs = dataclasses .replace (
655
699
opt_inputs ,
656
700
q = 1 ,
657
701
num_restarts = 1 ,
658
702
raw_samples = None ,
659
- batch_initial_conditions = current_x . unsqueeze ( 0 ) ,
703
+ batch_initial_conditions = batched_current_x [:, None ] ,
660
704
fixed_features = {
661
- ** dict ( zip ( non_cont_dims .tolist (), current_x [ non_cont_dims ])) ,
705
+ ** { d : batched_current_x [:, d ] for d in non_cont_dims .tolist ()} ,
662
706
** (opt_inputs .fixed_features or {}),
663
707
},
664
708
options = {
665
709
"maxiter" : options .get ("maxiter_continuous" , MAX_ITER_CONT ),
666
710
"tol" : options .get ("tol" , CONVERGENCE_TOL ),
667
711
"batch_limit" : options .get ("batch_limit" , MAX_BATCH_SIZE ),
712
+ "max_optimization_problem_aggregation_size" : 1 ,
668
713
},
669
714
)
670
- return _optimize_acqf (opt_inputs = updated_opt_inputs )
715
+ best_X , best_acq_values = _optimize_acqf (opt_inputs = updated_opt_inputs )
716
+ return best_X .view_as (current_x ), best_acq_values
671
717
672
718
673
719
def optimize_acqf_mixed_alternating (
@@ -761,6 +807,12 @@ def optimize_acqf_mixed_alternating(
761
807
762
808
fixed_features = fixed_features or {}
763
809
options = options or {}
810
+ if options .get ("max_optimization_problem_aggregation_size" , 1 ) != 1 :
811
+ raise UnsupportedError (
812
+ "optimize_acqf_mixed_alternating does not support "
813
+ "max_optimization_problem_aggregation_size != 1. "
814
+ "You might leave this option empty, though."
815
+ )
764
816
options .setdefault ("batch_limit" , MAX_BATCH_SIZE )
765
817
options .setdefault ("init_batch_limit" , options ["batch_limit" ])
766
818
if not (keys := set (options .keys ())).issubset (SUPPORTED_OPTIONS ):
@@ -793,11 +845,18 @@ def optimize_acqf_mixed_alternating(
793
845
fixed_features = fixed_features ,
794
846
post_processing_func = post_processing_func ,
795
847
batch_initial_conditions = None ,
796
- return_best_only = True ,
848
+ return_best_only = False , # We don't want to perform the cont. optimization
849
+ # step and only return best, but this function itself only returns best
797
850
gen_candidates = gen_candidates_scipy ,
798
- sequential = sequential ,
851
+ sequential = sequential , # only relevant if all dims are cont.
799
852
)
800
- _validate_sequential_inputs (opt_inputs = opt_inputs )
853
+ if sequential :
854
+ # Sequential optimization requires return_best_only to be True
855
+ # But we turn it off here, as we "manually" perform the seq.
856
+ # conditioning in the loop below
857
+ _validate_sequential_inputs (
858
+ opt_inputs = dataclasses .replace (opt_inputs , return_best_only = True )
859
+ )
801
860
802
861
base_X_pending = acq_function .X_pending if q > 1 else None
803
862
dim = bounds .shape [- 1 ]
@@ -808,7 +867,12 @@ def optimize_acqf_mixed_alternating(
808
867
non_cont_dims = [* discrete_dims , * cat_dims ]
809
868
if len (non_cont_dims ) == 0 :
810
869
# If the problem is fully continuous, fall back to standard optimization.
811
- return _optimize_acqf (opt_inputs = opt_inputs )
870
+ return _optimize_acqf (
871
+ opt_inputs = dataclasses .replace (
872
+ opt_inputs ,
873
+ return_best_only = True ,
874
+ )
875
+ )
812
876
if not (
813
877
isinstance (non_cont_dims , list )
814
878
and len (set (non_cont_dims )) == len (non_cont_dims )
@@ -842,26 +906,28 @@ def optimize_acqf_mixed_alternating(
842
906
cont_dims = cont_dims ,
843
907
)
844
908
845
- # TODO: Eliminate this for loop. Tensors being unequal sizes could potentially
846
- # be handled by concatenating them rather than stacking, and keeping a list
847
- # of indices.
848
- for i in range (num_restarts ):
849
- alternate_steps = 0
850
- while alternate_steps < options .get ("maxiter_alternating" , MAX_ITER_ALTER ):
851
- starting_acq_val = best_acq_val [i ].clone ()
852
- alternate_steps += 1
853
- for step in (discrete_step , continuous_step ):
854
- best_X [i ], best_acq_val [i ] = step (
855
- opt_inputs = opt_inputs ,
856
- discrete_dims = discrete_dims_t ,
857
- cat_dims = cat_dims_t ,
858
- current_x = best_X [i ],
859
- )
909
+ done = torch .zeros (len (best_X ), dtype = torch .bool , device = tkwargs ["device" ])
910
+ for _step in range (options .get ("maxiter_alternating" , MAX_ITER_ALTER )):
911
+ starting_acq_val = best_acq_val .clone ()
912
+ best_X , best_acq_val = discrete_step (
913
+ opt_inputs = opt_inputs ,
914
+ discrete_dims = discrete_dims_t ,
915
+ cat_dims = cat_dims_t ,
916
+ current_x = best_X ,
917
+ )
918
+
919
+ best_X [~ done ], best_acq_val [~ done ] = continuous_step (
920
+ opt_inputs = opt_inputs ,
921
+ discrete_dims = discrete_dims_t ,
922
+ cat_dims = cat_dims_t ,
923
+ current_x = best_X [~ done ],
924
+ )
860
925
861
- improvement = best_acq_val [i ] - starting_acq_val
862
- if improvement < options .get ("tol" , CONVERGENCE_TOL ):
863
- # Check for convergence
864
- break
926
+ improvement = best_acq_val - starting_acq_val
927
+ done_now = improvement < options .get ("tol" , CONVERGENCE_TOL )
928
+ done = done | done_now
929
+ if done .float ().mean () >= STOP_AFTER_SHARE_CONVERGED :
930
+ break
865
931
866
932
new_candidate = best_X [torch .argmax (best_acq_val )].unsqueeze (0 )
867
933
candidates = torch .cat ([candidates , new_candidate ], dim = - 2 )
0 commit comments