Skip to content

Commit 48b67c2

Browse files
vizier-teamcopybara-github
vizier-team
authored andcommitted
Internal
PiperOrigin-RevId: 712550496
1 parent e0d923e commit 48b67c2

File tree

2 files changed

+252
-47
lines changed

2 files changed

+252
-47
lines changed

vizier/_src/algorithms/designers/gp_ucb_pe.py

+185-26
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import copy
2222
import datetime
23+
import enum
2324
import random
2425
from typing import Any, Callable, Mapping, Optional, Sequence, Union
2526

@@ -35,6 +36,7 @@
3536
from vizier import algorithms as vza
3637
from vizier import pyvizier as vz
3738
from vizier._src.algorithms.designers import quasi_random
39+
from vizier._src.algorithms.designers import scalarization
3840
from vizier._src.algorithms.designers.gp import acquisitions
3941
from vizier._src.algorithms.designers.gp import output_warpers
4042
from vizier._src.algorithms.optimizers import eagle_strategy as es
@@ -51,6 +53,23 @@
5153
tfd = tfp.distributions
5254

5355

56+
class MultimetricPromisingRegionPenaltyType(enum.Enum):
57+
"""The type of penalty to apply to the points outside the promising region.
58+
59+
Configures the penalty term in `PEScoreFunction` for multimetric problems.
60+
"""
61+
62+
# The penalty is applied to the points outside the union of the promising
63+
# regions of all metrics.
64+
UNION = 'union'
65+
# The penalty is applied to the points outside the intersection of the
66+
# promising regions of all metrics.
67+
INTERSECTION = 'intersection'
68+
# The penalty applied to a point in the search space is the average of
69+
# the penalties with respect to the promising regions of all metrics.
70+
AVERAGE = 'average'
71+
72+
5473
class UCBPEConfig(eqx.Module):
5574
"""UCB-PE config parameters."""
5675

@@ -92,6 +111,13 @@ class UCBPEConfig(eqx.Module):
92111
optimize_set_acquisition_for_exploration: bool = eqx.field(
93112
default=False, static=True
94113
)
114+
# The type of penalty to apply to the points outside the promising region for
115+
# multimetric problems.
116+
multimetric_promising_region_penalty_type: (
117+
MultimetricPromisingRegionPenaltyType
118+
) = eqx.field(
119+
default=MultimetricPromisingRegionPenaltyType.AVERAGE, static=True
120+
)
95121

96122
def __repr__(self):
97123
return eqx.tree_pformat(self, short_arrays=False)
@@ -155,10 +181,28 @@ def _compute_ucb_threshold(
155181
The predicted mean of the feature array with the maximum UCB among `xs`.
156182
"""
157183
pred_mean = gprm.mean()
158-
ucb_values = jnp.where(
159-
is_missing, -jnp.inf, pred_mean + ucb_coefficient * gprm.stddev()
160-
)
161-
return pred_mean[jnp.argmax(ucb_values)]
184+
if pred_mean.ndim > 1:
185+
# In the multimetric case, the predicted mean and stddev are of shape
186+
# [num_points, num_metrics].
187+
ucb_values = jnp.where(
188+
jnp.tile(is_missing[:, jnp.newaxis], (1, pred_mean.shape[-1])),
189+
-jnp.inf,
190+
pred_mean + ucb_coefficient * gprm.stddev(),
191+
)
192+
# The indices of the points with the maximum UCB values for each metric.
193+
best_ucb_indices = jnp.argmax(ucb_values, axis=0)
194+
return jax.vmap(
195+
lambda pred_mean, best_ucb_idx: pred_mean[best_ucb_idx],
196+
in_axes=-1,
197+
out_axes=-1,
198+
)(pred_mean, best_ucb_indices)
199+
else:
200+
# In the single metric case, the predicted mean and stddev are of shape
201+
# [num_points].
202+
ucb_values = jnp.where(
203+
is_missing, -jnp.inf, pred_mean + ucb_coefficient * gprm.stddev()
204+
)
205+
return pred_mean[jnp.argmax(ucb_values)]
162206

163207

164208
# TODO: Use acquisitions.TrustRegion instead.
@@ -238,12 +282,45 @@ class UCBScoreFunction(eqx.Module):
238282
on completed and pending trials.
239283
ucb_coefficient: The UCB coefficient.
240284
trust_region: Trust region.
285+
scalarization_weights_rng: Random key for scalarization.
286+
labels: Labels, shaped as [num_index_points, num_metrics].
287+
num_scalarizations: Number of scalarizations.
241288
"""
242289

243290
predictive: sp.UniformEnsemblePredictive
244291
predictive_all_features: sp.UniformEnsemblePredictive
245292
ucb_coefficient: jt.Float[jt.Array, '']
246293
trust_region: Optional[acquisitions.TrustRegion]
294+
labels: types.PaddedArray
295+
scalarizer: scalarization.Scalarization
296+
297+
def __init__(
298+
self,
299+
predictive: sp.UniformEnsemblePredictive,
300+
predictive_all_features: sp.UniformEnsemblePredictive,
301+
ucb_coefficient: jt.Float[jt.Array, ''],
302+
trust_region: Optional[acquisitions.TrustRegion],
303+
scalarization_weights_rng: jax.Array,
304+
labels: types.PaddedArray,
305+
num_scalarizations: int = 1000,
306+
):
307+
self.predictive = predictive
308+
self.predictive_all_features = predictive_all_features
309+
self.ucb_coefficient = ucb_coefficient
310+
self.trust_region = trust_region
311+
self.labels = labels
312+
weights = jax.random.normal(
313+
scalarization_weights_rng,
314+
shape=(num_scalarizations, self.labels.shape[1]),
315+
)
316+
weights = jnp.abs(weights)
317+
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
318+
ref_point = (
319+
acquisitions.get_reference_point(self.labels, scale=0.01)
320+
if self.labels.shape[0] > 0
321+
else None
322+
)
323+
self.scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point)
247324

248325
def score(
249326
self, xs: types.ModelInput, seed: Optional[jax.Array] = None
@@ -264,9 +341,26 @@ def score_with_aux(
264341
mean = gprm.mean()
265342
stddev_from_all = gprm_all_features.stddev()
266343
acq_values = mean + self.ucb_coefficient * stddev_from_all
344+
# `self.labels` is of shape [num_index_points, num_metrics].
345+
if self.labels.shape[1] > 1:
346+
scalarized = self.scalarizer(acq_values)
347+
padded_labels = self.labels.replace_fill_value(-np.inf).padded_array
348+
if padded_labels.shape[0] > 0:
349+
# Broadcast max_scalarized to the same shape as scalarized and take max.
350+
max_scalarized = jnp.max(self.scalarizer(padded_labels), axis=-1)
351+
shape_mismatch = len(scalarized.shape) - len(max_scalarized.shape)
352+
expand_max = jnp.expand_dims(
353+
max_scalarized, axis=range(-shape_mismatch, 0)
354+
)
355+
scalarized = jnp.maximum(scalarized, expand_max)
356+
scalarized_acq_values = jnp.mean(scalarized, axis=0)
357+
else:
358+
scalarized_acq_values = acq_values
267359
if self.trust_region is not None:
268-
acq_values = _apply_trust_region(self.trust_region, xs, acq_values)
269-
return acq_values, {
360+
scalarized_acq_values = _apply_trust_region(
361+
self.trust_region, xs, scalarized_acq_values
362+
)
363+
return scalarized_acq_values, {
270364
'mean': mean,
271365
'stddev': gprm.stddev(),
272366
'stddev_from_all': stddev_from_all,
@@ -303,6 +397,9 @@ class PEScoreFunction(eqx.Module):
303397
explore_ucb_coefficient: jt.Float[jt.Array, '']
304398
penalty_coefficient: jt.Float[jt.Array, '']
305399
trust_region: Optional[acquisitions.TrustRegion]
400+
multimetric_promising_region_penalty_type: (
401+
MultimetricPromisingRegionPenaltyType
402+
)
306403

307404
def score(
308405
self, xs: types.ModelInput, seed: Optional[jax.Array] = None
@@ -333,10 +430,34 @@ def score_with_aux(
333430

334431
gprm_all = self.predictive_all_features.predict(xs)
335432
stddev_from_all = gprm_all.stddev()
336-
acq_values = stddev_from_all + self.penalty_coefficient * jnp.minimum(
433+
penalty = self.penalty_coefficient * jnp.minimum(
337434
explore_ucb - threshold,
338435
0.0,
339436
)
437+
# `stddev_from_all` and `penalty` are of shape
438+
# [num_index_points, num_metrics] for multi-metric problems or
439+
# [num_index_points] for single-metric problems.
440+
if stddev_from_all.ndim > 1:
441+
if self.multimetric_promising_region_penalty_type == (
442+
MultimetricPromisingRegionPenaltyType.UNION
443+
):
444+
scalarized_penalty = jnp.max(penalty, axis=-1)
445+
elif self.multimetric_promising_region_penalty_type == (
446+
MultimetricPromisingRegionPenaltyType.INTERSECTION
447+
):
448+
scalarized_penalty = jnp.min(penalty, axis=-1)
449+
elif self.multimetric_promising_region_penalty_type == (
450+
MultimetricPromisingRegionPenaltyType.AVERAGE
451+
):
452+
scalarized_penalty = jnp.mean(penalty, axis=-1)
453+
else:
454+
raise ValueError(
455+
'Unsupported multimetric promising region penalty type:'
456+
f' {self.multimetric_promising_region_penalty_type}'
457+
)
458+
acq_values = jnp.mean(stddev_from_all, axis=-1) + scalarized_penalty
459+
else:
460+
acq_values = stddev_from_all + penalty
340461
if self.trust_region is not None:
341462
acq_values = _apply_trust_region(self.trust_region, xs, acq_values)
342463
return acq_values, {
@@ -537,8 +658,14 @@ def __attrs_post_init__(self):
537658
# Extra validations
538659
if self._problem.search_space.is_conditional:
539660
raise ValueError(f'{type(self)} does not support conditional search.')
540-
elif len(self._problem.metric_information) != 1:
541-
raise ValueError(f'{type(self)} works with exactly one metric.')
661+
elif (
662+
len(self._problem.metric_information) != 1
663+
and self._config.optimize_set_acquisition_for_exploration
664+
):
665+
raise ValueError(
666+
f'{type(self)} works with exactly one metric with'
667+
' `optimize_set_acquisition_for_exploration` enabled.'
668+
)
542669

543670
# Extra initializations.
544671
# Discrete parameters are continuified to account for their actual values.
@@ -554,7 +681,7 @@ def __attrs_post_init__(self):
554681
self._problem.search_space,
555682
seed=int(jax.random.randint(qrs_seed, [], 0, 2**16)),
556683
)
557-
self._output_warper = None
684+
self._output_warpers: list[output_warpers.OutputWarper] = []
558685

559686
def update(
560687
self, completed: vza.CompletedTrials, all_active: vza.ActiveTrials
@@ -717,10 +844,15 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
717844
data.labels.shape,
718845
_get_features_shape(data.features),
719846
)
720-
self._output_warper = output_warpers.create_default_warper()
721-
warped_labels = self._output_warper.warp(np.array(data.labels.unpad()))
847+
unpadded_labels = np.asarray(data.labels.unpad())
848+
warped_labels = []
849+
self._output_warpers = []
850+
for i in range(data.labels.shape[1]):
851+
output_warper = output_warpers.create_default_warper()
852+
warped_labels.append(output_warper.warp(unpadded_labels[:, i : i + 1]))
853+
self._output_warpers.append(output_warper)
722854
labels = types.PaddedArray.from_array(
723-
warped_labels,
855+
np.concatenate(warped_labels, axis=-1),
724856
data.labels.padded_array.shape,
725857
fill_value=data.labels.fill_value,
726858
)
@@ -773,7 +905,10 @@ def _get_predictive_all_features(
773905
# Pending features are only used to predict standard deviation, so their
774906
# labels do not matter, and we simply set them to 0.
775907
dummy_labels = jnp.zeros(
776-
shape=(pending_features.continuous.unpad().shape[0], 1),
908+
shape=(
909+
pending_features.continuous.unpad().shape[0],
910+
data.labels.shape[-1],
911+
),
777912
dtype=data.labels.padded_array.dtype,
778913
)
779914
all_labels = jnp.concatenate([data.labels.unpad(), dummy_labels], axis=0)
@@ -840,11 +975,14 @@ def _suggest_one(
840975
# When `use_ucb` is true, the acquisition function computes the UCB
841976
# values. Otherwise, it computes the Pure-Exploration acquisition values.
842977
if use_ucb:
978+
scalarization_weights_rng, self._rng = jax.random.split(self._rng)
843979
scoring_fn = UCBScoreFunction(
844980
predictive,
845981
predictive_all_features,
846982
ucb_coefficient=self._config.ucb_coefficient,
847983
trust_region=tr if self._use_trust_region else None,
984+
scalarization_weights_rng=scalarization_weights_rng,
985+
labels=data.labels,
848986
)
849987
else:
850988
scoring_fn = PEScoreFunction(
@@ -854,6 +992,9 @@ def _suggest_one(
854992
ucb_coefficient=self._config.ucb_coefficient,
855993
explore_ucb_coefficient=self._config.explore_region_ucb_coefficient,
856994
trust_region=tr if self._use_trust_region else None,
995+
multimetric_promising_region_penalty_type=(
996+
self._config.multimetric_promising_region_penalty_type
997+
),
857998
)
858999

8591000
if isinstance(acquisition_optimizer, vb.VectorizedOptimizer):
@@ -910,9 +1051,11 @@ def _suggest_one(
9101051
# debugging needs.
9111052
metadata = best_candidate.metadata.ns(self._metadata_ns)
9121053
metadata.ns('prediction_in_warped_y_space').update({
913-
'mean': f'{predict_mean[0]}',
914-
'stddev': f'{predict_stddev[0]}',
915-
'stddev_from_all': f'{predict_stddev_from_all[0]}',
1054+
'mean': np.array2string(np.asarray(predict_mean[0]), separator=','),
1055+
'stddev': np.array2string(np.asarray(predict_stddev[0]), separator=','),
1056+
'stddev_from_all': np.array2string(
1057+
np.asarray(predict_stddev_from_all[0]), separator=','
1058+
),
9161059
'acquisition': f'{acquisition}',
9171060
'use_ucb': f'{use_ucb}',
9181061
'trust_radius': f'{tr.trust_radius}',
@@ -1060,20 +1203,36 @@ def sample(
10601203
)
10611204
samples = eqx.filter_jit(acquisitions.sample_from_predictive)(
10621205
predictive, xs, num_samples, key=rng
1063-
) # (num_samples, num_trials)
1064-
# Scope the samples to non-padded only (there's a single padded dimension).
1206+
)
1207+
# Scope `samples` to non-padded only (there's a single padded dimension).
1208+
# `samples` has shape: [num_samples, num_trials] for single metric or
1209+
# [num_samples, num_trials, num_metrics] for multi-metric problems.
1210+
if samples.ndim == 2:
1211+
samples = jnp.expand_dims(samples, axis=-1)
10651212
samples = samples[
1066-
:, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0])
1213+
:, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0]), :
10671214
]
10681215
# TODO: vectorize output warping.
1069-
if self._output_warper is not None:
1070-
return np.vstack([
1071-
self._output_warper.unwarp(samples[i][..., np.newaxis]).reshape(-1)
1072-
for i in range(samples.shape[0])
1073-
])
1216+
if self._output_warpers:
1217+
unwarped_samples = []
1218+
for metric_idx, output_warper in enumerate(self._output_warpers):
1219+
unwarped_samples.append(
1220+
np.vstack([
1221+
output_warper.unwarp(
1222+
samples[i][:, metric_idx : metric_idx + 1]
1223+
).reshape(-1)
1224+
for i in range(samples.shape[0])
1225+
])
1226+
)
1227+
unwarped_samples = np.stack(unwarped_samples, axis=-1)
1228+
if unwarped_samples.shape[-1] > 1:
1229+
return unwarped_samples
1230+
else:
1231+
return np.squeeze(unwarped_samples, axis=-1)
10741232
else:
10751233
raise TypeError(
1076-
'Output warper is expected to be set, but found to be None.'
1234+
'Output warpers are expected to be set, but found to be'
1235+
f' {self._output_warpers}.'
10771236
)
10781237

10791238
@profiler.record_runtime

0 commit comments

Comments
 (0)