20
20
21
21
import copy
22
22
import datetime
23
+ import enum
23
24
import random
24
25
from typing import Any , Callable , Mapping , Optional , Sequence , Union
25
26
35
36
from vizier import algorithms as vza
36
37
from vizier import pyvizier as vz
37
38
from vizier ._src .algorithms .designers import quasi_random
39
+ from vizier ._src .algorithms .designers import scalarization
38
40
from vizier ._src .algorithms .designers .gp import acquisitions
39
41
from vizier ._src .algorithms .designers .gp import output_warpers
40
42
from vizier ._src .algorithms .optimizers import eagle_strategy as es
51
53
tfd = tfp .distributions
52
54
53
55
56
+ class MultimetricPromisingRegionPenaltyType (enum .Enum ):
57
+ """The type of penalty to apply to the points outside the promising region.
58
+
59
+ Configures the penalty term in `PEScoreFunction` for multimetric problems.
60
+ """
61
+
62
+ # The penalty is applied to the points outside the union of the promising
63
+ # regions of all metrics.
64
+ UNION = 'union'
65
+ # The penalty is applied to the points outside the intersection of the
66
+ # promising regions of all metrics.
67
+ INTERSECTION = 'intersection'
68
+ # The penalty applied to a point in the search space is the average of
69
+ # the penalties with respect to the promising regions of all metrics.
70
+ AVERAGE = 'average'
71
+
72
+
54
73
class UCBPEConfig (eqx .Module ):
55
74
"""UCB-PE config parameters."""
56
75
@@ -92,6 +111,13 @@ class UCBPEConfig(eqx.Module):
92
111
optimize_set_acquisition_for_exploration : bool = eqx .field (
93
112
default = False , static = True
94
113
)
114
+ # The type of penalty to apply to the points outside the promising region for
115
+ # multimetric problems.
116
+ multimetric_promising_region_penalty_type : (
117
+ MultimetricPromisingRegionPenaltyType
118
+ ) = eqx .field (
119
+ default = MultimetricPromisingRegionPenaltyType .AVERAGE , static = True
120
+ )
95
121
96
122
def __repr__ (self ):
97
123
return eqx .tree_pformat (self , short_arrays = False )
@@ -155,10 +181,28 @@ def _compute_ucb_threshold(
155
181
The predicted mean of the feature array with the maximum UCB among `xs`.
156
182
"""
157
183
pred_mean = gprm .mean ()
158
- ucb_values = jnp .where (
159
- is_missing , - jnp .inf , pred_mean + ucb_coefficient * gprm .stddev ()
160
- )
161
- return pred_mean [jnp .argmax (ucb_values )]
184
+ if pred_mean .ndim > 1 :
185
+ # In the multimetric case, the predicted mean and stddev are of shape
186
+ # [num_points, num_metrics].
187
+ ucb_values = jnp .where (
188
+ jnp .tile (is_missing [:, jnp .newaxis ], (1 , pred_mean .shape [- 1 ])),
189
+ - jnp .inf ,
190
+ pred_mean + ucb_coefficient * gprm .stddev (),
191
+ )
192
+ # The indices of the points with the maximum UCB values for each metric.
193
+ best_ucb_indices = jnp .argmax (ucb_values , axis = 0 )
194
+ return jax .vmap (
195
+ lambda pred_mean , best_ucb_idx : pred_mean [best_ucb_idx ],
196
+ in_axes = - 1 ,
197
+ out_axes = - 1 ,
198
+ )(pred_mean , best_ucb_indices )
199
+ else :
200
+ # In the single metric case, the predicted mean and stddev are of shape
201
+ # [num_points].
202
+ ucb_values = jnp .where (
203
+ is_missing , - jnp .inf , pred_mean + ucb_coefficient * gprm .stddev ()
204
+ )
205
+ return pred_mean [jnp .argmax (ucb_values )]
162
206
163
207
164
208
# TODO: Use acquisitions.TrustRegion instead.
@@ -238,12 +282,45 @@ class UCBScoreFunction(eqx.Module):
238
282
on completed and pending trials.
239
283
ucb_coefficient: The UCB coefficient.
240
284
trust_region: Trust region.
285
+ scalarization_weights_rng: Random key for scalarization.
286
+ labels: Labels, shaped as [num_index_points, num_metrics].
287
+ num_scalarizations: Number of scalarizations.
241
288
"""
242
289
243
290
predictive : sp .UniformEnsemblePredictive
244
291
predictive_all_features : sp .UniformEnsemblePredictive
245
292
ucb_coefficient : jt .Float [jt .Array , '' ]
246
293
trust_region : Optional [acquisitions .TrustRegion ]
294
+ labels : types .PaddedArray
295
+ scalarizer : scalarization .Scalarization
296
+
297
+ def __init__ (
298
+ self ,
299
+ predictive : sp .UniformEnsemblePredictive ,
300
+ predictive_all_features : sp .UniformEnsemblePredictive ,
301
+ ucb_coefficient : jt .Float [jt .Array , '' ],
302
+ trust_region : Optional [acquisitions .TrustRegion ],
303
+ scalarization_weights_rng : jax .Array ,
304
+ labels : types .PaddedArray ,
305
+ num_scalarizations : int = 1000 ,
306
+ ):
307
+ self .predictive = predictive
308
+ self .predictive_all_features = predictive_all_features
309
+ self .ucb_coefficient = ucb_coefficient
310
+ self .trust_region = trust_region
311
+ self .labels = labels
312
+ weights = jax .random .normal (
313
+ scalarization_weights_rng ,
314
+ shape = (num_scalarizations , self .labels .shape [1 ]),
315
+ )
316
+ weights = jnp .abs (weights )
317
+ weights = weights / jnp .linalg .norm (weights , axis = - 1 , keepdims = True )
318
+ ref_point = (
319
+ acquisitions .get_reference_point (self .labels , scale = 0.01 )
320
+ if self .labels .shape [0 ] > 0
321
+ else None
322
+ )
323
+ self .scalarizer = scalarization .HyperVolumeScalarization (weights , ref_point )
247
324
248
325
def score (
249
326
self , xs : types .ModelInput , seed : Optional [jax .Array ] = None
@@ -264,9 +341,26 @@ def score_with_aux(
264
341
mean = gprm .mean ()
265
342
stddev_from_all = gprm_all_features .stddev ()
266
343
acq_values = mean + self .ucb_coefficient * stddev_from_all
344
+ # `self.labels` is of shape [num_index_points, num_metrics].
345
+ if self .labels .shape [1 ] > 1 :
346
+ scalarized = self .scalarizer (acq_values )
347
+ padded_labels = self .labels .replace_fill_value (- np .inf ).padded_array
348
+ if padded_labels .shape [0 ] > 0 :
349
+ # Broadcast max_scalarized to the same shape as scalarized and take max.
350
+ max_scalarized = jnp .max (self .scalarizer (padded_labels ), axis = - 1 )
351
+ shape_mismatch = len (scalarized .shape ) - len (max_scalarized .shape )
352
+ expand_max = jnp .expand_dims (
353
+ max_scalarized , axis = range (- shape_mismatch , 0 )
354
+ )
355
+ scalarized = jnp .maximum (scalarized , expand_max )
356
+ scalarized_acq_values = jnp .mean (scalarized , axis = 0 )
357
+ else :
358
+ scalarized_acq_values = acq_values
267
359
if self .trust_region is not None :
268
- acq_values = _apply_trust_region (self .trust_region , xs , acq_values )
269
- return acq_values , {
360
+ scalarized_acq_values = _apply_trust_region (
361
+ self .trust_region , xs , scalarized_acq_values
362
+ )
363
+ return scalarized_acq_values , {
270
364
'mean' : mean ,
271
365
'stddev' : gprm .stddev (),
272
366
'stddev_from_all' : stddev_from_all ,
@@ -303,6 +397,9 @@ class PEScoreFunction(eqx.Module):
303
397
explore_ucb_coefficient : jt .Float [jt .Array , '' ]
304
398
penalty_coefficient : jt .Float [jt .Array , '' ]
305
399
trust_region : Optional [acquisitions .TrustRegion ]
400
+ multimetric_promising_region_penalty_type : (
401
+ MultimetricPromisingRegionPenaltyType
402
+ )
306
403
307
404
def score (
308
405
self , xs : types .ModelInput , seed : Optional [jax .Array ] = None
@@ -333,10 +430,34 @@ def score_with_aux(
333
430
334
431
gprm_all = self .predictive_all_features .predict (xs )
335
432
stddev_from_all = gprm_all .stddev ()
336
- acq_values = stddev_from_all + self .penalty_coefficient * jnp .minimum (
433
+ penalty = self .penalty_coefficient * jnp .minimum (
337
434
explore_ucb - threshold ,
338
435
0.0 ,
339
436
)
437
+ # `stddev_from_all` and `penalty` are of shape
438
+ # [num_index_points, num_metrics] for multi-metric problems or
439
+ # [num_index_points] for single-metric problems.
440
+ if stddev_from_all .ndim > 1 :
441
+ if self .multimetric_promising_region_penalty_type == (
442
+ MultimetricPromisingRegionPenaltyType .UNION
443
+ ):
444
+ scalarized_penalty = jnp .max (penalty , axis = - 1 )
445
+ elif self .multimetric_promising_region_penalty_type == (
446
+ MultimetricPromisingRegionPenaltyType .INTERSECTION
447
+ ):
448
+ scalarized_penalty = jnp .min (penalty , axis = - 1 )
449
+ elif self .multimetric_promising_region_penalty_type == (
450
+ MultimetricPromisingRegionPenaltyType .AVERAGE
451
+ ):
452
+ scalarized_penalty = jnp .mean (penalty , axis = - 1 )
453
+ else :
454
+ raise ValueError (
455
+ 'Unsupported multimetric promising region penalty type:'
456
+ f' { self .multimetric_promising_region_penalty_type } '
457
+ )
458
+ acq_values = jnp .mean (stddev_from_all , axis = - 1 ) + scalarized_penalty
459
+ else :
460
+ acq_values = stddev_from_all + penalty
340
461
if self .trust_region is not None :
341
462
acq_values = _apply_trust_region (self .trust_region , xs , acq_values )
342
463
return acq_values , {
@@ -537,8 +658,14 @@ def __attrs_post_init__(self):
537
658
# Extra validations
538
659
if self ._problem .search_space .is_conditional :
539
660
raise ValueError (f'{ type (self )} does not support conditional search.' )
540
- elif len (self ._problem .metric_information ) != 1 :
541
- raise ValueError (f'{ type (self )} works with exactly one metric.' )
661
+ elif (
662
+ len (self ._problem .metric_information ) != 1
663
+ and self ._config .optimize_set_acquisition_for_exploration
664
+ ):
665
+ raise ValueError (
666
+ f'{ type (self )} works with exactly one metric with'
667
+ ' `optimize_set_acquisition_for_exploration` enabled.'
668
+ )
542
669
543
670
# Extra initializations.
544
671
# Discrete parameters are continuified to account for their actual values.
@@ -554,7 +681,7 @@ def __attrs_post_init__(self):
554
681
self ._problem .search_space ,
555
682
seed = int (jax .random .randint (qrs_seed , [], 0 , 2 ** 16 )),
556
683
)
557
- self ._output_warper = None
684
+ self ._output_warpers : list [ output_warpers . OutputWarper ] = []
558
685
559
686
def update (
560
687
self , completed : vza .CompletedTrials , all_active : vza .ActiveTrials
@@ -717,10 +844,15 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
717
844
data .labels .shape ,
718
845
_get_features_shape (data .features ),
719
846
)
720
- self ._output_warper = output_warpers .create_default_warper ()
721
- warped_labels = self ._output_warper .warp (np .array (data .labels .unpad ()))
847
+ unpadded_labels = np .asarray (data .labels .unpad ())
848
+ warped_labels = []
849
+ self ._output_warpers = []
850
+ for i in range (data .labels .shape [1 ]):
851
+ output_warper = output_warpers .create_default_warper ()
852
+ warped_labels .append (output_warper .warp (unpadded_labels [:, i : i + 1 ]))
853
+ self ._output_warpers .append (output_warper )
722
854
labels = types .PaddedArray .from_array (
723
- warped_labels ,
855
+ np . concatenate ( warped_labels , axis = - 1 ) ,
724
856
data .labels .padded_array .shape ,
725
857
fill_value = data .labels .fill_value ,
726
858
)
@@ -773,7 +905,10 @@ def _get_predictive_all_features(
773
905
# Pending features are only used to predict standard deviation, so their
774
906
# labels do not matter, and we simply set them to 0.
775
907
dummy_labels = jnp .zeros (
776
- shape = (pending_features .continuous .unpad ().shape [0 ], 1 ),
908
+ shape = (
909
+ pending_features .continuous .unpad ().shape [0 ],
910
+ data .labels .shape [- 1 ],
911
+ ),
777
912
dtype = data .labels .padded_array .dtype ,
778
913
)
779
914
all_labels = jnp .concatenate ([data .labels .unpad (), dummy_labels ], axis = 0 )
@@ -840,11 +975,14 @@ def _suggest_one(
840
975
# When `use_ucb` is true, the acquisition function computes the UCB
841
976
# values. Otherwise, it computes the Pure-Exploration acquisition values.
842
977
if use_ucb :
978
+ scalarization_weights_rng , self ._rng = jax .random .split (self ._rng )
843
979
scoring_fn = UCBScoreFunction (
844
980
predictive ,
845
981
predictive_all_features ,
846
982
ucb_coefficient = self ._config .ucb_coefficient ,
847
983
trust_region = tr if self ._use_trust_region else None ,
984
+ scalarization_weights_rng = scalarization_weights_rng ,
985
+ labels = data .labels ,
848
986
)
849
987
else :
850
988
scoring_fn = PEScoreFunction (
@@ -854,6 +992,9 @@ def _suggest_one(
854
992
ucb_coefficient = self ._config .ucb_coefficient ,
855
993
explore_ucb_coefficient = self ._config .explore_region_ucb_coefficient ,
856
994
trust_region = tr if self ._use_trust_region else None ,
995
+ multimetric_promising_region_penalty_type = (
996
+ self ._config .multimetric_promising_region_penalty_type
997
+ ),
857
998
)
858
999
859
1000
if isinstance (acquisition_optimizer , vb .VectorizedOptimizer ):
@@ -910,9 +1051,11 @@ def _suggest_one(
910
1051
# debugging needs.
911
1052
metadata = best_candidate .metadata .ns (self ._metadata_ns )
912
1053
metadata .ns ('prediction_in_warped_y_space' ).update ({
913
- 'mean' : f'{ predict_mean [0 ]} ' ,
914
- 'stddev' : f'{ predict_stddev [0 ]} ' ,
915
- 'stddev_from_all' : f'{ predict_stddev_from_all [0 ]} ' ,
1054
+ 'mean' : np .array2string (np .asarray (predict_mean [0 ]), separator = ',' ),
1055
+ 'stddev' : np .array2string (np .asarray (predict_stddev [0 ]), separator = ',' ),
1056
+ 'stddev_from_all' : np .array2string (
1057
+ np .asarray (predict_stddev_from_all [0 ]), separator = ','
1058
+ ),
916
1059
'acquisition' : f'{ acquisition } ' ,
917
1060
'use_ucb' : f'{ use_ucb } ' ,
918
1061
'trust_radius' : f'{ tr .trust_radius } ' ,
@@ -1060,20 +1203,36 @@ def sample(
1060
1203
)
1061
1204
samples = eqx .filter_jit (acquisitions .sample_from_predictive )(
1062
1205
predictive , xs , num_samples , key = rng
1063
- ) # (num_samples, num_trials)
1064
- # Scope the samples to non-padded only (there's a single padded dimension).
1206
+ )
1207
+ # Scope `samples` to non-padded only (there's a single padded dimension).
1208
+ # `samples` has shape: [num_samples, num_trials] for single metric or
1209
+ # [num_samples, num_trials, num_metrics] for multi-metric problems.
1210
+ if samples .ndim == 2 :
1211
+ samples = jnp .expand_dims (samples , axis = - 1 )
1065
1212
samples = samples [
1066
- :, ~ (xs .continuous .is_missing [0 ] | xs .categorical .is_missing [0 ])
1213
+ :, ~ (xs .continuous .is_missing [0 ] | xs .categorical .is_missing [0 ]), :
1067
1214
]
1068
1215
# TODO: vectorize output warping.
1069
- if self ._output_warper is not None :
1070
- return np .vstack ([
1071
- self ._output_warper .unwarp (samples [i ][..., np .newaxis ]).reshape (- 1 )
1072
- for i in range (samples .shape [0 ])
1073
- ])
1216
+ if self ._output_warpers :
1217
+ unwarped_samples = []
1218
+ for metric_idx , output_warper in enumerate (self ._output_warpers ):
1219
+ unwarped_samples .append (
1220
+ np .vstack ([
1221
+ output_warper .unwarp (
1222
+ samples [i ][:, metric_idx : metric_idx + 1 ]
1223
+ ).reshape (- 1 )
1224
+ for i in range (samples .shape [0 ])
1225
+ ])
1226
+ )
1227
+ unwarped_samples = np .stack (unwarped_samples , axis = - 1 )
1228
+ if unwarped_samples .shape [- 1 ] > 1 :
1229
+ return unwarped_samples
1230
+ else :
1231
+ return np .squeeze (unwarped_samples , axis = - 1 )
1074
1232
else :
1075
1233
raise TypeError (
1076
- 'Output warper is expected to be set, but found to be None.'
1234
+ 'Output warpers are expected to be set, but found to be'
1235
+ f' { self ._output_warpers } .'
1077
1236
)
1078
1237
1079
1238
@profiler .record_runtime
0 commit comments