Skip to content

Commit

Permalink
Add EtR featurizer for training and switch default warper to original.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693488306
  • Loading branch information
xingyousong authored and copybara-github committed Nov 5, 2024
1 parent 1570f7b commit 840260a
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 6 deletions.
14 changes: 12 additions & 2 deletions optformer/embed_then_regress/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,18 @@ def unwarp(self, ys: jt.Float[np.ndarray, 'L']) -> jt.Float[np.ndarray, 'L']:


def default_warper() -> StatefulWarper:
"""Warper used in original paper."""
return SequentialWarper([
HalfRankWarper(),
LinearScalingWarper(),
SigmoidDampenWarper(),
LinearScalingWarper(scale=0.5),
SigmoidDampenWarper(curvature=1.0, scale=1.0),
])


def new_warper() -> StatefulWarper:
"""New warper which may be more stable."""
return SequentialWarper([
HalfRankWarper(),
LinearScalingWarper(scale=1.0),
SigmoidDampenWarper(curvature=0.5, scale=1.0),
])
123 changes: 123 additions & 0 deletions optformer/embed_then_regress/vizier/featurizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Featurizer for Vizier; used for training only."""

import functools
from typing import Sequence

import attrs
import numpy as np
from optformer.common.data import featurizers
from optformer.common.data import filters
from optformer.embed_then_regress import normalization
from optformer.embed_then_regress.vizier import serializers
from optformer.vizier.data import augmenters
import tensorflow.google.compat.v2 as tf
from vizier import pyvizier as vz


VizierFilter = filters.Filter[vz.ProblemAndTrials]


@attrs.define(init=True, kw_only=True)
class ICLFeaturizer(featurizers.Featurizer[vz.ProblemAndTrials]):
"""Converts a Vizier study to strings suitable for ICL training."""

min_context: int = attrs.field(default=5)
max_context: int = attrs.field(default=100)
max_trials: int = attrs.field(default=120)

warper: normalization.StatefulWarper = attrs.field(
factory=normalization.default_warper
)

_prefilters: Sequence[VizierFilter] = attrs.field(factory=list)
_augmenters: Sequence[augmenters.VizierAugmenter] = attrs.field(factory=list)
_postfilters: Sequence[VizierFilter] = attrs.field(factory=list)

@functools.cached_property
def element_spec(self) -> dict[str, tf.TensorSpec]:
return {
'x': tf.TensorSpec(shape=(None,), dtype=tf.string), # L
'y': tf.TensorSpec(shape=(None,), dtype=tf.float32), # L
'metadata': tf.TensorSpec(shape=(), dtype=tf.string), # Scalar
'mask': tf.TensorSpec(shape=(None,), dtype=tf.bool), # L
}

@functools.cached_property
def empty_output(self) -> dict[str, tf.Tensor]:
return {
'x': tf.constant([''], dtype=tf.string),
'y': tf.constant([0.0], dtype=tf.float32),
'metadata': tf.constant('', dtype=tf.string),
'mask': tf.constant([False], dtype=tf.bool),
}

def to_features(self, study: vz.ProblemAndTrials, /) -> dict[str, tf.Tensor]:
# pylint:disable=invalid-name
for study_filter in self._prefilters:
if not study_filter(study):
raise ValueError(f'{study_filter} rejected study.')

for study_augmenter in self._augmenters:
# NOTE: Study may be modified in-place rather than copied.
study = study_augmenter.augment_study(study)

for study_filter in self._postfilters:
if not study_filter(study):
raise ValueError(f'{study_filter} rejected study.')

if not study.trials:
raise ValueError('Study has no trials.')

# Limit maximum sequence length.
study.trials[:] = study.trials[: self.max_trials]

problem = study.problem
ss_str = serializers.SearchSpaceSerializer().to_str(problem.search_space)
m_name = problem.metric_information.item().name
L = len(study.trials)

xs = []
ys = []
x_serializer = serializers.XSerializer(study.problem.search_space)
for trial in study.trials:
xs.append(x_serializer.to_str(trial))
ys.append(trial.final_measurement_or_die.metrics[m_name].value)

num_context = np.random.randint(self.min_context, self.max_context)

# Edit masking.
mask = np.ones(L, dtype=bool)
# Apply random permutation.
perm = np.random.permutation(L)
xs = [xs[i] for i in perm]
ys = [ys[i] for i in perm]
mask[num_context:] = False

# Warp y-values.
ys = np.array(ys)
self.warper.train(ys[:num_context])
ys = self.warper.warp(ys)

if np.isnan(ys).any():
raise ValueError(f'Y values contain NaN: {ys}')

return {
'x': tf.constant(xs, dtype=tf.string),
'y': tf.constant(ys, dtype=tf.float32),
'metadata': tf.constant(ss_str, dtype=tf.string),
'mask': tf.constant(mask, dtype=tf.bool),
}
10 changes: 9 additions & 1 deletion optformer/embed_then_regress/vizier/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,26 @@ class XSerializer(SuggestionSerializer):
factory=s_lib.PrimitiveSerializer, kw_only=True
)

# Use decimal or scientific notation for numeric param values.
# Decimal works best only if search space is standardized.
use_scientific: bool = attrs.field(default=False, kw_only=True)

def to_str(self, t: vz.TrialSuggestion, /) -> str:
param_dict = t.parameters.as_dict()

new_param_dict = dict()
for pc in self.search_space.parameters:
value = param_dict[pc.name]
if isinstance(value, (float, int)):
new_param_dict[pc.name] = format(value, '.2e') # Scientific notation.
float_format = '.2e' if self.use_scientific else '.2f'
new_param_dict[pc.name] = format(value, float_format)
else:
new_param_dict[pc.name] = value

metadata_str = vs_lib.MetadataSerializer().to_str(t.metadata)
return self.primitive_serializer.to_str(
{'params': new_param_dict, 'metadata': metadata_str}
)


SearchSpaceSerializer = vs_lib.SearchSpaceSerializer
12 changes: 10 additions & 2 deletions optformer/vizier/data/augmenters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,26 @@ class VizierIdempotentAugmenter(VizierAugmenter[_T]):

@attrs.define
class SearchSpacePermuter(VizierAugmenter[vz.SearchSpace]):
"""Permutes the search space's parameters."""
"""Permutes the search space's parameters and feasible values."""

seed: Optional[int] = attrs.field(init=True, kw_only=True, default=None)

def augment(self, search_space: vz.SearchSpace, /) -> vz.SearchSpace:
"""Logic below reduces expensive object-copying as much as possible."""
# pylint: disable=protected-access
rng = random.Random(self.seed)

# Pop out all parameter configs, shuffle, then put back in.
# Pop out all parameter configs and shuffle ordering.
parameter_names = list(search_space.parameter_names)
p_configs = [search_space.pop(name) for name in parameter_names]
rng.shuffle(p_configs)

# Shuffle feasible values within each parameter config.
for p_config in p_configs:
if p_config._feasible_values:
rng.shuffle(p_config._feasible_values)

# Then put back in.
for p_config in p_configs:
search_space.add(p_config)

Expand Down
3 changes: 2 additions & 1 deletion optformer/vizier/data/augmenters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def test_e2e(self):
ss = vz.SearchSpace()
ss.add(vz.ParameterConfig.factory('A', bounds=(0, 1)))
ss.add(vz.ParameterConfig.factory('B', bounds=(0, 1)))
ss.add(vz.ParameterConfig.factory('C', bounds=(0, 1)))
ss.add(vz.ParameterConfig.factory('C', feasible_values=['c1', 'c2', 'c3']))

new_ss = augmenters.SearchSpacePermuter(seed=0).augment(ss)
self.assertEqual([p.name for p in new_ss.parameters], ['A', 'C', 'B'])
self.assertEqual(new_ss.parameters[1].feasible_values, ['c3', 'c2', 'c1'])


class MetricsConfigPermuterTest(absltest.TestCase):
Expand Down

0 comments on commit 840260a

Please sign in to comment.