1616from ax .core .optimization_config import OptimizationConfig
1717from ax .core .objective import Objective as AxObjective
1818from ax .runners import SyntheticRunner
19- from ax .modelbridge .factory import get_sobol
20- from ax .modelbridge .torch import TorchModelBridge
19+ from ax .adapter .factory import get_sobol
20+ from ax .adapter .torch import TorchAdapter
2121from ax .core .observation import ObservationFeatures
2222from ax .core .generator_run import GeneratorRun
2323from ax .storage .json_store .save import save_experiment
2424from ax .storage .metric_registry import register_metrics
2525
26- from ax .modelbridge .registry import Models , ST_MTGP_trans
27-
28- try :
29- # For Ax >= 0.5.0
30- from ax .modelbridge .transforms .derelativize import Derelativize
31- from ax .modelbridge .transforms .convert_metric_names import (
32- ConvertMetricNames ,
33- )
34- from ax .modelbridge .transforms .trial_as_task import TrialAsTask
35- from ax .modelbridge .transforms .stratified_standardize_y import (
36- StratifiedStandardizeY ,
37- )
38- from ax .modelbridge .transforms .task_encode import TaskChoiceToIntTaskChoice
39- from ax .modelbridge .registry import MBM_X_trans
40-
41- MT_MTGP_trans = MBM_X_trans + [
42- Derelativize ,
43- ConvertMetricNames ,
44- TrialAsTask ,
45- StratifiedStandardizeY ,
46- TaskChoiceToIntTaskChoice ,
47- ]
26+ from ax .adapter .registry import Generators , MBM_X_trans
27+ from ax .adapter .transforms .derelativize import Derelativize
28+ from ax .adapter .transforms .metrics_as_task import MetricsAsTask
29+ from ax .adapter .transforms .trial_as_task import TrialAsTask
30+ from ax .adapter .transforms .stratified_standardize_y import (
31+ StratifiedStandardizeY ,
32+ )
33+ from ax .adapter .transforms .task_encode import TaskChoiceToIntTaskChoice
34+ from ax .adapter .registry import ST_MTGP_trans
4835
49- except ImportError :
50- # For Ax < 0.5.0
51- from ax .modelbridge .registry import MT_MTGP_trans
36+ MT_MTGP_trans = MBM_X_trans + [
37+ Derelativize ,
38+ MetricsAsTask ,
39+ TrialAsTask ,
40+ StratifiedStandardizeY ,
41+ TaskChoiceToIntTaskChoice ,
42+ ]
5243
5344from ax .core .experiment import Experiment
5445from ax .core .data import Data
55- from ax .modelbridge .transforms .convert_metric_names import (
56- tconfig_from_mt_experiment ,
57- )
5846
5947from optimas .generators .ax .base import AxGenerator
6048from optimas .core import (
@@ -81,7 +69,7 @@ def get_MTGP(
8169 trial_index : Optional [int ] = None ,
8270 device : torch .device = torch .device ("cpu" ),
8371 dtype : torch .dtype = torch .double ,
84- ) -> TorchModelBridge :
72+ ) -> TorchAdapter :
8573 """Instantiate a Multi-task Gaussian Process (MTGP) model.
8674
8775 Points are generated with EI (Expected Improvement).
@@ -94,11 +82,21 @@ def get_MTGP(
9482 t .index : t .trial_type for t in experiment .trials .values ()
9583 }
9684 transforms = MT_MTGP_trans
85+
86+ # Build MetricsAsTask config manually (replaces tconfig_from_mt_experiment)
87+ canonical = experiment ._metric_to_canonical_name
88+ metric_task_map = {}
89+ for metric_name , canonical_name in canonical .items ():
90+ if metric_name != canonical_name :
91+ if canonical_name not in metric_task_map :
92+ metric_task_map [canonical_name ] = []
93+ metric_task_map [canonical_name ].append (metric_name )
94+
9795 transform_configs = {
9896 "TrialAsTask" : {
9997 "trial_level_map" : {"trial_type" : trial_index_to_type }
10098 },
101- "ConvertMetricNames " : tconfig_from_mt_experiment ( experiment ) ,
99+ "MetricsAsTask " : { "metric_task_map" : metric_task_map } ,
102100 }
103101 else :
104102 # Set transforms for a Single-type MTGP model.
@@ -125,7 +123,7 @@ def get_MTGP(
125123 )
126124
127125 return assert_is_instance (
128- Models .ST_MTGP (
126+ Generators .ST_MTGP (
129127 experiment = experiment ,
130128 search_space = search_space or experiment .search_space ,
131129 data = data ,
@@ -135,7 +133,7 @@ def get_MTGP(
135133 torch_device = device ,
136134 status_quo_features = status_quo_features ,
137135 ),
138- TorchModelBridge ,
136+ TorchAdapter ,
139137 )
140138
141139
@@ -354,7 +352,6 @@ def _incorporate_external_data(self, trials: List[Trial]) -> None:
354352 arms .append (
355353 Arm (parameters = params , name = param_to_name [arm .signature ])
356354 )
357- # self._next_id += 1
358355
359356 # Create new batch trial.
360357 gr = GeneratorRun (arms = arms , weights = [1.0 ] * len (arms ))
@@ -597,7 +594,7 @@ def _save_model_to_file(self) -> None:
597594
598595
599596def max_utility_from_GP (
600- n : int , m : TorchModelBridge , gr : GeneratorRun , hifi_task : str
597+ n : int , m : TorchAdapter , gr : GeneratorRun , hifi_task : str
601598) -> GeneratorRun :
602599 """Select the max utility points according to the MTGP predictions.
603600
0 commit comments