50
50
MAX_ITER_INIT = 100
51
51
CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance.
52
52
DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates.
53
+ STOP_AFTER_SHARE_CONVERGED = 1.0 # We optimize multiple configurations at once
54
+ # in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
55
+ # optimizing after the given share has converged.
56
+ # Convergence is defined as the improvements of one discrete, followed by a scalar
57
+ # optimization yield less than `CONVERGENCE_TOL` improvements.
53
58
54
59
SUPPORTED_OPTIONS = {
55
60
"initialization_strategy" ,
63
68
"std_cont_perturbation" ,
64
69
"batch_limit" ,
65
70
"init_batch_limit" ,
71
+ "stop_after_share_converged" ,
66
72
}
67
73
SUPPORTED_INITIALIZATION = {"continuous_relaxation" , "equally_spaced" , "random" }
68
74
@@ -564,63 +570,99 @@ def discrete_step(
564
570
discrete_dims: A tensor of indices corresponding to binary and
565
571
integer parameters.
566
572
cat_dims: A tensor of indices corresponding to categorical parameters.
567
- current_x: Starting point. A tensor of shape `d`.
573
+ current_x: Starting point. A tensor of shape `d` or `b x d` .
568
574
569
575
Returns:
570
576
A tuple of two tensors: a (d)-dim tensor of optimized point
571
577
and a scalar tensor of correspondins acquisition value.
572
578
"""
579
+ batched_current_x = current_x .view (- 1 , current_x .shape [- 1 ]).clone ()
573
580
with torch .no_grad ():
574
- current_acqval = opt_inputs .acq_function (current_x .unsqueeze (0 ))
581
+ current_acqvals = opt_inputs .acq_function (batched_current_x .unsqueeze (1 ))
575
582
options = opt_inputs .options or {}
576
- for _ in range (
577
- assert_is_instance (options .get ("maxiter_discrete" , MAX_ITER_DISCRETE ), int )
578
- ):
579
- neighbors = []
580
- if discrete_dims .numel ():
581
- x_neighbors_discrete = get_nearest_neighbors (
582
- current_x = current_x .detach (),
583
- bounds = opt_inputs .bounds ,
584
- discrete_dims = discrete_dims ,
585
- )
586
- x_neighbors_discrete = _filter_infeasible (
587
- X = x_neighbors_discrete ,
588
- inequality_constraints = opt_inputs .inequality_constraints ,
589
- )
590
- neighbors .append (x_neighbors_discrete )
583
+ maxiter_discrete = options .get ("maxiter_discrete" , MAX_ITER_DISCRETE )
584
+ done = torch .zeros (len (batched_current_x ), dtype = torch .bool )
585
+ for _ in range (assert_is_instance (maxiter_discrete , int )):
586
+ # we don't batch this, as the number of x_neighbors can be different
587
+ # for each entry (as duplicates are removed), and the most expensive
588
+ # op is the acq_function, which is batched
589
+ # TODO one could try removing duplicate removal or use nested tensors
590
+ # to make this parallel, if this loop is parallelized the second loop
591
+ # in a few lines is also easy to parallelize
592
+ x_neighbors_list = []
593
+ for i in range (len (done )):
594
+ x_neighbors = None
595
+ if ~ done [i ]:
596
+ neighbors = []
597
+ if discrete_dims .numel ():
598
+ x_neighbors_discrete = get_nearest_neighbors (
599
+ current_x = batched_current_x [i ].detach (),
600
+ bounds = opt_inputs .bounds ,
601
+ discrete_dims = discrete_dims ,
602
+ )
603
+ x_neighbors_discrete = _filter_infeasible (
604
+ X = x_neighbors_discrete ,
605
+ inequality_constraints = opt_inputs .inequality_constraints ,
606
+ )
607
+ neighbors .append (x_neighbors_discrete )
591
608
592
- if cat_dims .numel ():
593
- x_neighbors_cat = get_categorical_neighbors (
594
- current_x = current_x .detach (),
595
- bounds = opt_inputs .bounds ,
596
- cat_dims = cat_dims ,
597
- )
598
- x_neighbors_cat = _filter_infeasible (
599
- X = x_neighbors_cat ,
600
- inequality_constraints = opt_inputs .inequality_constraints ,
601
- )
602
- neighbors .append (x_neighbors_cat )
609
+ if cat_dims .numel ():
610
+ x_neighbors_cat = get_categorical_neighbors (
611
+ current_x = batched_current_x [i ].detach (),
612
+ bounds = opt_inputs .bounds ,
613
+ cat_dims = cat_dims ,
614
+ )
615
+ x_neighbors_cat = _filter_infeasible (
616
+ X = x_neighbors_cat ,
617
+ inequality_constraints = opt_inputs .inequality_constraints ,
618
+ )
619
+ neighbors .append (x_neighbors_cat )
620
+
621
+ x_neighbors = torch .cat (neighbors , dim = 0 )
622
+ if x_neighbors .numel () == 0 :
623
+ # Exit gracefully with last point if no feasible neighbors left.
624
+ done [i ] = True
625
+ x_neighbors_list .append (x_neighbors )
603
626
604
- x_neighbors = torch .cat (neighbors , dim = 0 )
605
- if x_neighbors .numel () == 0 :
606
- # Exit gracefully with last point if there are no feasible neighbors.
627
+ if done .all ():
607
628
break
629
+
630
+ all_x_neighbors = torch .cat (
631
+ [
632
+ x_neighbors
633
+ for x_neighbors in x_neighbors_list
634
+ if x_neighbors is not None
635
+ ],
636
+ dim = 0 ,
637
+ )
608
638
with torch .no_grad ():
609
639
acq_vals = torch .cat (
610
640
[
611
641
opt_inputs .acq_function (X_ .unsqueeze (- 2 ))
612
- for X_ in x_neighbors .split (
642
+ for X_ in all_x_neighbors .split (
613
643
options .get ("init_batch_limit" , MAX_BATCH_SIZE )
614
644
)
615
645
]
616
646
)
617
- argmax = acq_vals .argmax ()
618
- improvement = acq_vals [argmax ] - current_acqval
619
- if improvement > 0 :
620
- current_acqval , current_x = acq_vals [argmax ], x_neighbors [argmax ]
621
- if improvement <= options .get ("tol" , CONVERGENCE_TOL ):
622
- break
623
- return current_x , current_acqval
647
+ offset = 0
648
+ for i in range (len (done )):
649
+ # assuming the offset incurred due to done samples is 0
650
+ if ~ done [i ]:
651
+ width = len (x_neighbors_list [i ])
652
+ x_neighbors = all_x_neighbors [offset : offset + width ]
653
+ max_acq , argmax = acq_vals [offset : offset + width ].max (dim = 0 )
654
+ improvement = acq_vals [offset + argmax ] - current_acqvals [i ]
655
+ if improvement > 0 :
656
+ current_acqvals [i ], batched_current_x [i ] = (
657
+ max_acq ,
658
+ x_neighbors [argmax ],
659
+ )
660
+ if improvement <= options .get ("tol" , CONVERGENCE_TOL ):
661
+ done [i ] = True
662
+
663
+ offset += width
664
+
665
+ return batched_current_x .view_as (current_x ), current_acqvals
624
666
625
667
626
668
def continuous_step (
@@ -638,36 +680,41 @@ def continuous_step(
638
680
discrete_dims: A tensor of indices corresponding to binary and
639
681
integer parameters.
640
682
cat_dims: A tensor of indices corresponding to categorical parameters.
641
- current_x: Starting point. A tensor of shape `d`.
683
+ current_x: Starting point. A tensor of shape `(b) x d`.
642
684
643
685
Returns:
644
686
A tuple of two tensors: a (1 x d)-dim tensor of optimized points
645
687
and a (1)-dim tensor of acquisition values.
646
688
"""
689
+ d = current_x .shape [- 1 ]
690
+ batched_current_x = current_x .view (- 1 , d )
691
+
647
692
options = opt_inputs .options or {}
648
693
non_cont_dims = torch .cat ((discrete_dims , cat_dims ), dim = 0 )
649
694
650
- if len (non_cont_dims ) == len ( current_x ) : # nothing continuous to optimize
695
+ if len (non_cont_dims ) == d : # nothing continuous to optimize
651
696
with torch .no_grad ():
652
- return current_x , opt_inputs .acq_function (current_x . unsqueeze ( 0 ) )
697
+ return current_x , opt_inputs .acq_function (batched_current_x )
653
698
654
699
updated_opt_inputs = dataclasses .replace (
655
700
opt_inputs ,
656
701
q = 1 ,
657
702
num_restarts = 1 ,
658
703
raw_samples = None ,
659
- batch_initial_conditions = current_x . unsqueeze ( 0 ) ,
704
+ batch_initial_conditions = batched_current_x [:, None ] ,
660
705
fixed_features = {
661
- ** dict ( zip ( non_cont_dims .tolist (), current_x [ non_cont_dims ])) ,
706
+ ** { d : batched_current_x [:, d ] for d in non_cont_dims .tolist ()} ,
662
707
** (opt_inputs .fixed_features or {}),
663
708
},
664
709
options = {
665
710
"maxiter" : options .get ("maxiter_continuous" , MAX_ITER_CONT ),
666
711
"tol" : options .get ("tol" , CONVERGENCE_TOL ),
667
712
"batch_limit" : options .get ("batch_limit" , MAX_BATCH_SIZE ),
713
+ "max_optimization_problem_aggregation_size" : 1 ,
668
714
},
669
715
)
670
- return _optimize_acqf (opt_inputs = updated_opt_inputs )
716
+ best_X , best_acq_values = _optimize_acqf (opt_inputs = updated_opt_inputs )
717
+ return best_X .view_as (current_x ), best_acq_values
671
718
672
719
673
720
def optimize_acqf_mixed_alternating (
@@ -730,6 +777,11 @@ def optimize_acqf_mixed_alternating(
730
777
in a `no_grad` context, which reduces memory usage. As a result,
731
778
`init_batch_limit` can be set to a larger value than `batch_limit`.
732
779
Defaults to `batch_limit`, if given.
780
+ - "stop_after_share_converged": We optimize multiple configurations at once
781
+ in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
782
+ optimizing after the given share has converged.
783
+ Convergence is defined as the improvements of one discrete, followed by a scalar
784
+ optimization yield less than `options["tol"]` improvements. Defaults to 1.
733
785
q: Number of candidates.
734
786
raw_samples: Number of initial candidates used to select starting points from.
735
787
Defaults to 1024.
@@ -761,6 +813,12 @@ def optimize_acqf_mixed_alternating(
761
813
762
814
fixed_features = fixed_features or {}
763
815
options = options or {}
816
+ if options .get ("max_optimization_problem_aggregation_size" , 1 ) != 1 :
817
+ raise UnsupportedError (
818
+ "optimize_acqf_mixed_alternating does not support "
819
+ "max_optimization_problem_aggregation_size != 1. "
820
+ "You might leave this option empty, though."
821
+ )
764
822
options .setdefault ("batch_limit" , MAX_BATCH_SIZE )
765
823
options .setdefault ("init_batch_limit" , options ["batch_limit" ])
766
824
if not (keys := set (options .keys ())).issubset (SUPPORTED_OPTIONS ):
@@ -793,11 +851,18 @@ def optimize_acqf_mixed_alternating(
793
851
fixed_features = fixed_features ,
794
852
post_processing_func = post_processing_func ,
795
853
batch_initial_conditions = None ,
796
- return_best_only = True ,
854
+ return_best_only = False , # We don't want to perform the cont. optimization
855
+ # step and only return best, but this function itself only returns best
797
856
gen_candidates = gen_candidates_scipy ,
798
- sequential = sequential ,
857
+ sequential = sequential , # only relevant if all dims are cont.
799
858
)
800
- _validate_sequential_inputs (opt_inputs = opt_inputs )
859
+ if sequential :
860
+ # Sequential optimization requires return_best_only to be True
861
+ # But we turn it off here, as we "manually" perform the seq.
862
+ # conditioning in the loop below
863
+ _validate_sequential_inputs (
864
+ opt_inputs = dataclasses .replace (opt_inputs , return_best_only = True )
865
+ )
801
866
802
867
base_X_pending = acq_function .X_pending if q > 1 else None
803
868
dim = bounds .shape [- 1 ]
@@ -808,7 +873,12 @@ def optimize_acqf_mixed_alternating(
808
873
non_cont_dims = [* discrete_dims , * cat_dims ]
809
874
if len (non_cont_dims ) == 0 :
810
875
# If the problem is fully continuous, fall back to standard optimization.
811
- return _optimize_acqf (opt_inputs = opt_inputs )
876
+ return _optimize_acqf (
877
+ opt_inputs = dataclasses .replace (
878
+ opt_inputs ,
879
+ return_best_only = True ,
880
+ )
881
+ )
812
882
if not (
813
883
isinstance (non_cont_dims , list )
814
884
and len (set (non_cont_dims )) == len (non_cont_dims )
@@ -842,26 +912,30 @@ def optimize_acqf_mixed_alternating(
842
912
cont_dims = cont_dims ,
843
913
)
844
914
845
- # TODO: Eliminate this for loop. Tensors being unequal sizes could potentially
846
- # be handled by concatenating them rather than stacking, and keeping a list
847
- # of indices.
848
- for i in range (num_restarts ):
849
- alternate_steps = 0
850
- while alternate_steps < options .get ("maxiter_alternating" , MAX_ITER_ALTER ):
851
- starting_acq_val = best_acq_val [i ].clone ()
852
- alternate_steps += 1
853
- for step in (discrete_step , continuous_step ):
854
- best_X [i ], best_acq_val [i ] = step (
855
- opt_inputs = opt_inputs ,
856
- discrete_dims = discrete_dims_t ,
857
- cat_dims = cat_dims_t ,
858
- current_x = best_X [i ],
859
- )
915
+ done = torch .zeros (len (best_X ), dtype = torch .bool , device = tkwargs ["device" ])
916
+ for _step in range (options .get ("maxiter_alternating" , MAX_ITER_ALTER )):
917
+ starting_acq_val = best_acq_val .clone ()
918
+ best_X , best_acq_val = discrete_step (
919
+ opt_inputs = opt_inputs ,
920
+ discrete_dims = discrete_dims_t ,
921
+ cat_dims = cat_dims_t ,
922
+ current_x = best_X ,
923
+ )
924
+
925
+ best_X [~ done ], best_acq_val [~ done ] = continuous_step (
926
+ opt_inputs = opt_inputs ,
927
+ discrete_dims = discrete_dims_t ,
928
+ cat_dims = cat_dims_t ,
929
+ current_x = best_X [~ done ],
930
+ )
860
931
861
- improvement = best_acq_val [i ] - starting_acq_val
862
- if improvement < options .get ("tol" , CONVERGENCE_TOL ):
863
- # Check for convergence
864
- break
932
+ improvement = best_acq_val - starting_acq_val
933
+ done_now = improvement < options .get ("tol" , CONVERGENCE_TOL )
934
+ done = done | done_now
935
+ if done .float ().mean () >= options .get (
936
+ "stop_after_share_converged" , STOP_AFTER_SHARE_CONVERGED
937
+ ):
938
+ break
865
939
866
940
new_candidate = best_X [torch .argmax (best_acq_val )].unsqueeze (0 )
867
941
candidates = torch .cat ([candidates , new_candidate ], dim = - 2 )
0 commit comments