Skip to content

Commit

Permalink
Add EtR featurizer for training.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693477935
  • Loading branch information
xingyousong authored and copybara-github committed Nov 5, 2024
1 parent 1570f7b commit 64367f0
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 4 deletions.
14 changes: 12 additions & 2 deletions optformer/embed_then_regress/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,18 @@ def unwarp(self, ys: jt.Float[np.ndarray, 'L']) -> jt.Float[np.ndarray, 'L']:


def default_warper() -> StatefulWarper:
"""Warper used in original paper."""
return SequentialWarper([
HalfRankWarper(),
LinearScalingWarper(),
SigmoidDampenWarper(),
LinearScalingWarper(scale=0.5),
SigmoidDampenWarper(curvature=1.0, scale=1.0),
])


def new_warper() -> StatefulWarper:
"""New warper which may be more stable."""
return SequentialWarper([
HalfRankWarper(),
LinearScalingWarper(scale=1.0),
SigmoidDampenWarper(curvature=0.5, scale=1.0),
])
124 changes: 124 additions & 0 deletions optformer/embed_then_regress/vizier/featurizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Featurizer for Vizier; used for training only."""

import functools
from typing import Sequence

import attrs
import numpy as np
from optformer.common.data import featurizers
from optformer.common.data import filters
from optformer.embed_then_regress import normalization
from optformer.embed_then_regress.vizier import serializers
from optformer.vizier.data import augmenters
import tensorflow.google.compat.v2 as tf
from vizier import pyvizier as vz


VizierFilter = filters.Filter[vz.ProblemAndTrials]


@attrs.define(init=True, kw_only=True)
class ICLFeaturizer(featurizers.Featurizer[vz.ProblemAndTrials]):
"""Converts a Vizier study to strings suitable for ICL training."""

min_context: int = attrs.field(default=5) # 5
max_context: int = attrs.field(default=100)
max_trials: int = attrs.field(default=120)

warper: normalization.StatefulWarper = attrs.field(
factory=normalization.default_warper
)

_prefilters: Sequence[VizierFilter] = attrs.field(factory=list)
_augmenters: Sequence[augmenters.VizierAugmenter] = attrs.field(factory=list)
_postfilters: Sequence[VizierFilter] = attrs.field(factory=list)

@functools.cached_property
def element_spec(self) -> dict[str, tf.TensorSpec]:
return {
'x': tf.TensorSpec(shape=(None,), dtype=tf.string), # L
'y': tf.TensorSpec(shape=(None,), dtype=tf.float32), # L
'metadata': tf.TensorSpec(shape=(), dtype=tf.string), # Scalar
'mask': tf.TensorSpec(shape=(None,), dtype=tf.bool), # L
}

@functools.cached_property
def empty_output(self) -> dict[str, tf.Tensor]:
return {
'x': tf.constant([''], dtype=tf.string),
'y': tf.constant([0.0], dtype=tf.float32),
'metadata': tf.constant('', dtype=tf.string),
'mask': tf.constant([False], dtype=tf.bool),
}

def to_features(self, study: vz.ProblemAndTrials, /) -> dict[str, tf.Tensor]:
# pylint:disable=invalid-name
for study_filter in self._prefilters:
if not study_filter(study):
raise ValueError(f'{study_filter} rejected study.')

for study_augmenter in self._augmenters:
# NOTE: Study may be modified in-place rather than copied.
study = study_augmenter.augment_study(study)

for study_filter in self._postfilters:
if not study_filter(study):
raise ValueError(f'{study_filter} rejected study.')

if not study.trials:
raise ValueError('Study has no trials.')

# Limit maximum sequence length.
study.trials[:] = study.trials[: self.max_trials]

metadata = serializers.SearchSpaceSerializer().to_str(
study.problem.search_space
)
m_name = study.problem.metric_information.item().name
L = len(study.trials)

xs = []
ys = []
x_serializer = serializers.XSerializer(study.problem.search_space)
for trial in study.trials:
xs.append(x_serializer.to_str(trial))
ys.append(trial.final_measurement_or_die.metrics[m_name].value)

num_context = np.random.randint(self.min_context, self.max_context)

# Edit masking.
mask = np.ones(L, dtype=bool)
# Apply random permutation.
perm = np.random.permutation(L)
xs = [xs[i] for i in perm]
ys = [ys[i] for i in perm]
mask[num_context:] = False

# Warp y-values.
ys = np.array(ys)
self.warper.train(ys[:num_context])
ys = self.warper.warp(ys)

if np.isnan(ys).any():
raise ValueError(f'Y values contain NaN: {ys}')

return {
'x': tf.constant(xs, dtype=tf.string),
'y': tf.constant(ys, dtype=tf.float32),
'metadata': tf.constant(metadata, dtype=tf.string),
'mask': tf.constant(mask, dtype=tf.bool),
}
6 changes: 5 additions & 1 deletion optformer/embed_then_regress/vizier/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ def to_str(self, t: vz.TrialSuggestion, /) -> str:
for pc in self.search_space.parameters:
value = param_dict[pc.name]
if isinstance(value, (float, int)):
new_param_dict[pc.name] = format(value, '.2e') # Scientific notation.
# NOTE: This works best if the search space is normalized.
new_param_dict[pc.name] = format(value, '.2f')
else:
new_param_dict[pc.name] = value

metadata_str = vs_lib.MetadataSerializer().to_str(t.metadata)
return self.primitive_serializer.to_str(
{'params': new_param_dict, 'metadata': metadata_str}
)


SearchSpaceSerializer = vs_lib.SearchSpaceSerializer
3 changes: 2 additions & 1 deletion optformer/vizier/data/augmenters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def test_e2e(self):
ss = vz.SearchSpace()
ss.add(vz.ParameterConfig.factory('A', bounds=(0, 1)))
ss.add(vz.ParameterConfig.factory('B', bounds=(0, 1)))
ss.add(vz.ParameterConfig.factory('C', bounds=(0, 1)))
ss.add(vz.ParameterConfig.factory('C', feasible_values=['c1', 'c2', 'c3']))

new_ss = augmenters.SearchSpacePermuter(seed=0).augment(ss)
self.assertEqual([p.name for p in new_ss.parameters], ['A', 'C', 'B'])
self.assertEqual(new_ss.parameters[1].feasible_values, ['c3', 'c2', 'c1'])


class MetricsConfigPermuterTest(absltest.TestCase):
Expand Down

0 comments on commit 64367f0

Please sign in to comment.