Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712550496
  • Loading branch information
vizier-team authored and copybara-github committed Jan 6, 2025
1 parent 33e91e1 commit 1b07fd2
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 47 deletions.
211 changes: 185 additions & 26 deletions vizier/_src/algorithms/designers/gp_ucb_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import copy
import datetime
import enum
import random
from typing import Any, Callable, Mapping, Optional, Sequence, Union

Expand All @@ -35,6 +36,7 @@
from vizier import algorithms as vza
from vizier import pyvizier as vz
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.designers import scalarization
from vizier._src.algorithms.designers.gp import acquisitions
from vizier._src.algorithms.designers.gp import output_warpers
from vizier._src.algorithms.optimizers import eagle_strategy as es
Expand All @@ -51,6 +53,23 @@
tfd = tfp.distributions


class MultimetricPromisingRegionPenaltyType(enum.Enum):
"""The type of penalty to apply to the points outside the promising region.
Configures the penalty term in `PEScoreFunction` for multimetric problems.
"""

# The penalty is applied to the points outside the union of the promising
# regions of all metrics.
UNION = 'union'
# The penalty is applied to the points outside the intersection of the
# promising regions of all metrics.
INTERSECTION = 'intersection'
# The penalty applied to a point in the search space is the average of
# the penalties with respect to the promising regions of all metrics.
AVERAGE = 'average'


class UCBPEConfig(eqx.Module):
"""UCB-PE config parameters."""

Expand Down Expand Up @@ -92,6 +111,13 @@ class UCBPEConfig(eqx.Module):
optimize_set_acquisition_for_exploration: bool = eqx.field(
default=False, static=True
)
# The type of penalty to apply to the points outside the promising region for
# multimetric problems.
multimetric_promising_region_penalty_type: (
MultimetricPromisingRegionPenaltyType
) = eqx.field(
default=MultimetricPromisingRegionPenaltyType.AVERAGE, static=True
)

def __repr__(self):
return eqx.tree_pformat(self, short_arrays=False)
Expand Down Expand Up @@ -155,10 +181,28 @@ def _compute_ucb_threshold(
The predicted mean of the feature array with the maximum UCB among `xs`.
"""
pred_mean = gprm.mean()
ucb_values = jnp.where(
is_missing, -jnp.inf, pred_mean + ucb_coefficient * gprm.stddev()
)
return pred_mean[jnp.argmax(ucb_values)]
if pred_mean.ndim > 1:
# In the multimetric case, the predicted mean and stddev are of shape
# [num_points, num_metrics].
ucb_values = jnp.where(
jnp.tile(is_missing[:, jnp.newaxis], (1, pred_mean.shape[-1])),
-jnp.inf,
pred_mean + ucb_coefficient * gprm.stddev(),
)
# The indices of the points with the maximum UCB values for each metric.
best_ucb_indices = jnp.argmax(ucb_values, axis=0)
return jax.vmap(
lambda pred_mean, best_ucb_idx: pred_mean[best_ucb_idx],
in_axes=-1,
out_axes=-1,
)(pred_mean, best_ucb_indices)
else:
# In the single metric case, the predicted mean and stddev are of shape
# [num_points].
ucb_values = jnp.where(
is_missing, -jnp.inf, pred_mean + ucb_coefficient * gprm.stddev()
)
return pred_mean[jnp.argmax(ucb_values)]


# TODO: Use acquisitions.TrustRegion instead.
Expand Down Expand Up @@ -238,12 +282,45 @@ class UCBScoreFunction(eqx.Module):
on completed and pending trials.
ucb_coefficient: The UCB coefficient.
trust_region: Trust region.
scalarization_weights_rng: Random key for scalarization.
labels: Labels, shaped as [num_index_points, num_metrics].
num_scalarizations: Number of scalarizations.
"""

predictive: sp.UniformEnsemblePredictive
predictive_all_features: sp.UniformEnsemblePredictive
ucb_coefficient: jt.Float[jt.Array, '']
trust_region: Optional[acquisitions.TrustRegion]
labels: types.PaddedArray
scalarizer: scalarization.Scalarization

def __init__(
self,
predictive: sp.UniformEnsemblePredictive,
predictive_all_features: sp.UniformEnsemblePredictive,
ucb_coefficient: jt.Float[jt.Array, ''],
trust_region: Optional[acquisitions.TrustRegion],
scalarization_weights_rng: jax.Array,
labels: types.PaddedArray,
num_scalarizations: int = 1000,
):
self.predictive = predictive
self.predictive_all_features = predictive_all_features
self.ucb_coefficient = ucb_coefficient
self.trust_region = trust_region
self.labels = labels
weights = jax.random.normal(
scalarization_weights_rng,
shape=(num_scalarizations, self.labels.shape[1]),
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
ref_point = (
acquisitions.get_reference_point(self.labels, scale=0.01)
if self.labels.shape[0] > 0
else None
)
self.scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point)

def score(
self, xs: types.ModelInput, seed: Optional[jax.Array] = None
Expand All @@ -264,9 +341,26 @@ def score_with_aux(
mean = gprm.mean()
stddev_from_all = gprm_all_features.stddev()
acq_values = mean + self.ucb_coefficient * stddev_from_all
# `self.labels` is of shape [num_index_points, num_metrics].
if self.labels.shape[1] > 1:
scalarized = self.scalarizer(acq_values)
padded_labels = self.labels.replace_fill_value(-np.inf).padded_array
if padded_labels.shape[0] > 0:
# Broadcast max_scalarized to the same shape as scalarized and take max.
max_scalarized = jnp.max(self.scalarizer(padded_labels), axis=-1)
shape_mismatch = len(scalarized.shape) - len(max_scalarized.shape)
expand_max = jnp.expand_dims(
max_scalarized, axis=range(-shape_mismatch, 0)
)
scalarized = jnp.maximum(scalarized, expand_max)
scalarized_acq_values = jnp.mean(scalarized, axis=0)
else:
scalarized_acq_values = acq_values
if self.trust_region is not None:
acq_values = _apply_trust_region(self.trust_region, xs, acq_values)
return acq_values, {
scalarized_acq_values = _apply_trust_region(
self.trust_region, xs, scalarized_acq_values
)
return scalarized_acq_values, {
'mean': mean,
'stddev': gprm.stddev(),
'stddev_from_all': stddev_from_all,
Expand Down Expand Up @@ -303,6 +397,9 @@ class PEScoreFunction(eqx.Module):
explore_ucb_coefficient: jt.Float[jt.Array, '']
penalty_coefficient: jt.Float[jt.Array, '']
trust_region: Optional[acquisitions.TrustRegion]
multimetric_promising_region_penalty_type: (
MultimetricPromisingRegionPenaltyType
)

def score(
self, xs: types.ModelInput, seed: Optional[jax.Array] = None
Expand Down Expand Up @@ -333,10 +430,34 @@ def score_with_aux(

gprm_all = self.predictive_all_features.predict(xs)
stddev_from_all = gprm_all.stddev()
acq_values = stddev_from_all + self.penalty_coefficient * jnp.minimum(
penalty = self.penalty_coefficient * jnp.minimum(
explore_ucb - threshold,
0.0,
)
# `stddev_from_all` and `penalty` are of shape
# [num_index_points, num_metrics] for multi-metric problems or
# [num_index_points] for single-metric problems.
if stddev_from_all.ndim > 1:
if self.multimetric_promising_region_penalty_type == (
MultimetricPromisingRegionPenaltyType.UNION
):
scalarized_penalty = jnp.max(penalty, axis=-1)
elif self.multimetric_promising_region_penalty_type == (
MultimetricPromisingRegionPenaltyType.INTERSECTION
):
scalarized_penalty = jnp.min(penalty, axis=-1)
elif self.multimetric_promising_region_penalty_type == (
MultimetricPromisingRegionPenaltyType.AVERAGE
):
scalarized_penalty = jnp.mean(penalty, axis=-1)
else:
raise ValueError(
'Unsupported multimetric promising region penalty type:'
f' {self.multimetric_promising_region_penalty_type}'
)
acq_values = jnp.mean(stddev_from_all, axis=-1) + scalarized_penalty
else:
acq_values = stddev_from_all + penalty
if self.trust_region is not None:
acq_values = _apply_trust_region(self.trust_region, xs, acq_values)
return acq_values, {
Expand Down Expand Up @@ -537,8 +658,14 @@ def __attrs_post_init__(self):
# Extra validations
if self._problem.search_space.is_conditional:
raise ValueError(f'{type(self)} does not support conditional search.')
elif len(self._problem.metric_information) != 1:
raise ValueError(f'{type(self)} works with exactly one metric.')
elif (
len(self._problem.metric_information) != 1
and self._config.optimize_set_acquisition_for_exploration
):
raise ValueError(
f'{type(self)} works with exactly one metric with'
' `optimize_set_acquisition_for_exploration` enabled.'
)

# Extra initializations.
# Discrete parameters are continuified to account for their actual values.
Expand All @@ -554,7 +681,7 @@ def __attrs_post_init__(self):
self._problem.search_space,
seed=int(jax.random.randint(qrs_seed, [], 0, 2**16)),
)
self._output_warper = None
self._output_warpers: list[output_warpers.OutputWarper] = []

def update(
self, completed: vza.CompletedTrials, all_active: vza.ActiveTrials
Expand Down Expand Up @@ -717,10 +844,15 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
data.labels.shape,
_get_features_shape(data.features),
)
self._output_warper = output_warpers.create_default_warper()
warped_labels = self._output_warper.warp(np.array(data.labels.unpad()))
unpadded_labels = np.asarray(data.labels.unpad())
warped_labels = []
self._output_warpers = []
for i in range(data.labels.shape[1]):
output_warper = output_warpers.create_default_warper()
warped_labels.append(output_warper.warp(unpadded_labels[:, i : i + 1]))
self._output_warpers.append(output_warper)
labels = types.PaddedArray.from_array(
warped_labels,
np.concatenate(warped_labels, axis=-1),
data.labels.padded_array.shape,
fill_value=data.labels.fill_value,
)
Expand Down Expand Up @@ -773,7 +905,10 @@ def _get_predictive_all_features(
# Pending features are only used to predict standard deviation, so their
# labels do not matter, and we simply set them to 0.
dummy_labels = jnp.zeros(
shape=(pending_features.continuous.unpad().shape[0], 1),
shape=(
pending_features.continuous.unpad().shape[0],
data.labels.shape[-1],
),
dtype=data.labels.padded_array.dtype,
)
all_labels = jnp.concatenate([data.labels.unpad(), dummy_labels], axis=0)
Expand Down Expand Up @@ -840,11 +975,14 @@ def _suggest_one(
# When `use_ucb` is true, the acquisition function computes the UCB
# values. Otherwise, it computes the Pure-Exploration acquisition values.
if use_ucb:
scalarization_weights_rng, self._rng = jax.random.split(self._rng)
scoring_fn = UCBScoreFunction(
predictive,
predictive_all_features,
ucb_coefficient=self._config.ucb_coefficient,
trust_region=tr if self._use_trust_region else None,
scalarization_weights_rng=scalarization_weights_rng,
labels=data.labels,
)
else:
scoring_fn = PEScoreFunction(
Expand All @@ -854,6 +992,9 @@ def _suggest_one(
ucb_coefficient=self._config.ucb_coefficient,
explore_ucb_coefficient=self._config.explore_region_ucb_coefficient,
trust_region=tr if self._use_trust_region else None,
multimetric_promising_region_penalty_type=(
self._config.multimetric_promising_region_penalty_type
),
)

if isinstance(acquisition_optimizer, vb.VectorizedOptimizer):
Expand Down Expand Up @@ -910,9 +1051,11 @@ def _suggest_one(
# debugging needs.
metadata = best_candidate.metadata.ns(self._metadata_ns)
metadata.ns('prediction_in_warped_y_space').update({
'mean': f'{predict_mean[0]}',
'stddev': f'{predict_stddev[0]}',
'stddev_from_all': f'{predict_stddev_from_all[0]}',
'mean': np.array2string(np.asarray(predict_mean[0]), separator=','),
'stddev': np.array2string(np.asarray(predict_stddev[0]), separator=','),
'stddev_from_all': np.array2string(
np.asarray(predict_stddev_from_all[0]), separator=','
),
'acquisition': f'{acquisition}',
'use_ucb': f'{use_ucb}',
'trust_radius': f'{tr.trust_radius}',
Expand Down Expand Up @@ -1060,20 +1203,36 @@ def sample(
)
samples = eqx.filter_jit(acquisitions.sample_from_predictive)(
predictive, xs, num_samples, key=rng
) # (num_samples, num_trials)
# Scope the samples to non-padded only (there's a single padded dimension).
)
# Scope `samples` to non-padded only (there's a single padded dimension).
# `samples` has shape: [num_samples, num_trials] for single metric or
# [num_samples, num_trials, num_metrics] for multi-metric problems.
if samples.ndim == 2:
samples = jnp.expand_dims(samples, axis=-1)
samples = samples[
:, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0])
:, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0]), :
]
# TODO: vectorize output warping.
if self._output_warper is not None:
return np.vstack([
self._output_warper.unwarp(samples[i][..., np.newaxis]).reshape(-1)
for i in range(samples.shape[0])
])
if self._output_warpers:
unwarped_samples = []
for metric_idx, output_warper in enumerate(self._output_warpers):
unwarped_samples.append(
np.vstack([
output_warper.unwarp(
samples[i][:, metric_idx : metric_idx + 1]
).reshape(-1)
for i in range(samples.shape[0])
])
)
unwarped_samples = np.stack(unwarped_samples, axis=-1)
if unwarped_samples.shape[-1] > 1:
return unwarped_samples
else:
return np.squeeze(unwarped_samples, axis=-1)
else:
raise TypeError(
'Output warper is expected to be set, but found to be None.'
'Output warpers are expected to be set, but found to be'
f' {self._output_warpers}.'
)

@profiler.record_runtime
Expand Down
Loading

0 comments on commit 1b07fd2

Please sign in to comment.