Skip to content

Commit b19cba7

Browse files
xingyousongcopybara-github
authored andcommitted
Internal change.
PiperOrigin-RevId: 691043078
1 parent 753f82a commit b19cba7

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

optformer/vizier/data/augmenters.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,18 +483,21 @@ def augment_study(self, study: vz.ProblemAndTrials, /) -> vz.ProblemAndTrials:
483483
return self.augment(study)
484484

485485

486+
@attrs.define(kw_only=True)
486487
class StandardizeSearchSpace(VizierIdempotentAugmenter[vz.ProblemAndTrials]):
487488
"""Standardizes the search space and corresponding trials.
488489
489490
DOUBLE, INTEGER, and DISCRETE parameters are scaled to [0,1] range, and
490491
corresponding ParameterConfigs are all DOUBLE.
491492
492-
CATEGORICAL parameters use standard ["0", "1", "2", ...] feasible values.
493+
CATEGORICAL params use ["0", "1", "2", ...] or ["a", "b", "c", ...].
493494
494495
This is useful mainly for serializations for string-based regressors, where we
495496
don't care about reversibility back into original space.
496497
"""
497498

499+
alpha_categorical: bool = attrs.field(default=False)
500+
498501
def augment(self, study: vz.ProblemAndTrials, /) -> vz.ProblemAndTrials:
499502
# Create a forward converter to map to normalized feature space.
500503
old_search_space = study.problem.search_space
@@ -508,10 +511,11 @@ def augment(self, study: vz.ProblemAndTrials, /) -> vz.ProblemAndTrials:
508511
new_search_space = vz.SearchSpace()
509512
for i, pc in enumerate(old_search_space.parameters):
510513
if pc.type == vz.ParameterType.CATEGORICAL:
511-
new_pc = vz.ParameterConfig.factory(
512-
f'x{i}',
513-
feasible_values=[str(j) for j in range(len(pc.feasible_values))],
514-
)
514+
if self.alpha_categorical:
515+
feasibles = [chr(j + 97) for j in range(len(pc.feasible_values))]
516+
else:
517+
feasibles = [str(j) for j in range(len(pc.feasible_values))]
518+
new_pc = vz.ParameterConfig.factory(f'x{i}', feasible_values=feasibles)
515519
else:
516520
new_pc = vz.ParameterConfig.factory(f'x{i}', bounds=(0.0, 1.0))
517521
new_search_space.add(new_pc)

0 commit comments

Comments
 (0)