Skip to content

Commit d113b7c

Browse files
authored
Merge pull request #297 from optimas-org/ax1.0
Ax1.0
2 parents 4b201cc + b13796d commit d113b7c

File tree

14 files changed

+148
-138
lines changed

14 files changed

+148
-138
lines changed

.github/workflows/unix-noax.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ['3.10', 3.11, 3.12]
14+
python-version: [3.11, 3.12]
1515

1616
steps:
1717
- uses: actions/checkout@v4
@@ -28,12 +28,11 @@ jobs:
2828
- shell: bash -l {0}
2929
name: Install dependencies
3030
run: |
31-
conda install -c conda-forge pytorch-cpu
32-
conda install -c pytorch numpy pandas
3331
conda install -c conda-forge mpi4py mpich
32+
pip install torch --index-url https://download.pytorch.org/whl/cpu
3433
pip install .[test]
35-
pip install git+https://github.com/campa-consortium/gest-api.git
36-
pip install git+https://github.com/xopt-org/xopt.git
34+
pip install gest-api
35+
pip install xopt
3736
pip uninstall --yes ax-platform # Run without Ax
3837
- shell: bash -l {0}
3938
name: Run unit tests without Ax

.github/workflows/unix-openmpi.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ['3.10', 3.11, 3.12]
14+
python-version: [3.11, 3.12]
1515

1616
steps:
1717
- uses: actions/checkout@v4
@@ -28,12 +28,11 @@ jobs:
2828
- shell: bash -l {0}
2929
name: Install dependencies
3030
run: |
31-
conda install -c conda-forge "numpy<2.4" "pandas<3"
32-
conda install -c conda-forge pytorch-cpu
3331
conda install -c conda-forge mpi4py openmpi=5.*
32+
pip install torch --index-url https://download.pytorch.org/whl/cpu
3433
pip install .[test]
35-
pip install git+https://github.com/campa-consortium/gest-api.git
36-
pip install --upgrade-strategy=only-if-needed git+https://github.com/xopt-org/xopt.git
34+
pip install gest-api
35+
pip install xopt
3736
- shell: bash -l {0}
3837
name: Run unit tests with openMPI
3938
run: |

.github/workflows/unix.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ['3.10', 3.11, 3.12]
14+
python-version: [3.11, 3.12]
1515

1616
steps:
1717
- uses: actions/checkout@v4
@@ -28,12 +28,11 @@ jobs:
2828
- shell: bash -l {0}
2929
name: Install dependencies
3030
run: |
31-
conda install -c conda-forge "numpy<2.4" "pandas<3"
32-
conda install -c conda-forge pytorch-cpu
3331
conda install -c conda-forge mpi4py mpich
32+
pip install torch --index-url https://download.pytorch.org/whl/cpu
3433
pip install .[test]
35-
pip install git+https://github.com/campa-consortium/gest-api.git
36-
pip install --upgrade-strategy=only-if-needed git+https://github.com/xopt-org/xopt.git
34+
pip install gest-api
35+
pip install xopt
3736
- shell: bash -l {0}
3837
name: Run unit tests with MPICH
3938
run: |

doc/environment.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ dependencies:
55
- pip
66
- pip:
77
- -e ..
8-
- ax-platform >= 0.5.0, < 1.0.0
8+
- ax-platform >=1.0.0
99
- autodoc_pydantic >= 2.0.1
1010
- ipykernel
1111
- matplotlib
1212
- nbsphinx
1313
- numpydoc
14-
- git+https://github.com/campa-consortium/gest-api.git
14+
- gest-api
1515
- pydata-sphinx-theme
1616
- sphinx-copybutton
1717
- sphinx-design

optimas/generators/ax/developer/ax_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Contains the definition of the Ax metric used for multitask optimization."""
22

33
import pandas as pd
4-
from ax import Metric
4+
from ax.core.metric import Metric
55
from ax.core.batch_trial import BatchTrial
66
from ax.core.data import Data
77
from ax.utils.common.result import Ok

optimas/generators/ax/developer/multitask.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,45 +16,33 @@
1616
from ax.core.optimization_config import OptimizationConfig
1717
from ax.core.objective import Objective as AxObjective
1818
from ax.runners import SyntheticRunner
19-
from ax.modelbridge.factory import get_sobol
20-
from ax.modelbridge.torch import TorchModelBridge
19+
from ax.adapter.factory import get_sobol
20+
from ax.adapter.torch import TorchAdapter
2121
from ax.core.observation import ObservationFeatures
2222
from ax.core.generator_run import GeneratorRun
2323
from ax.storage.json_store.save import save_experiment
2424
from ax.storage.metric_registry import register_metrics
2525

26-
from ax.modelbridge.registry import Models, ST_MTGP_trans
27-
28-
try:
29-
# For Ax >= 0.5.0
30-
from ax.modelbridge.transforms.derelativize import Derelativize
31-
from ax.modelbridge.transforms.convert_metric_names import (
32-
ConvertMetricNames,
33-
)
34-
from ax.modelbridge.transforms.trial_as_task import TrialAsTask
35-
from ax.modelbridge.transforms.stratified_standardize_y import (
36-
StratifiedStandardizeY,
37-
)
38-
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
39-
from ax.modelbridge.registry import MBM_X_trans
40-
41-
MT_MTGP_trans = MBM_X_trans + [
42-
Derelativize,
43-
ConvertMetricNames,
44-
TrialAsTask,
45-
StratifiedStandardizeY,
46-
TaskChoiceToIntTaskChoice,
47-
]
26+
from ax.adapter.registry import Generators, MBM_X_trans
27+
from ax.adapter.transforms.derelativize import Derelativize
28+
from ax.adapter.transforms.metrics_as_task import MetricsAsTask
29+
from ax.adapter.transforms.trial_as_task import TrialAsTask
30+
from ax.adapter.transforms.stratified_standardize_y import (
31+
StratifiedStandardizeY,
32+
)
33+
from ax.adapter.transforms.task_encode import TaskChoiceToIntTaskChoice
34+
from ax.adapter.registry import ST_MTGP_trans
4835

49-
except ImportError:
50-
# For Ax < 0.5.0
51-
from ax.modelbridge.registry import MT_MTGP_trans
36+
MT_MTGP_trans = MBM_X_trans + [
37+
Derelativize,
38+
MetricsAsTask,
39+
TrialAsTask,
40+
StratifiedStandardizeY,
41+
TaskChoiceToIntTaskChoice,
42+
]
5243

5344
from ax.core.experiment import Experiment
5445
from ax.core.data import Data
55-
from ax.modelbridge.transforms.convert_metric_names import (
56-
tconfig_from_mt_experiment,
57-
)
5846

5947
from optimas.generators.ax.base import AxGenerator
6048
from optimas.core import (
@@ -81,7 +69,7 @@ def get_MTGP(
8169
trial_index: Optional[int] = None,
8270
device: torch.device = torch.device("cpu"),
8371
dtype: torch.dtype = torch.double,
84-
) -> TorchModelBridge:
72+
) -> TorchAdapter:
8573
"""Instantiate a Multi-task Gaussian Process (MTGP) model.
8674
8775
Points are generated with EI (Expected Improvement).
@@ -94,11 +82,21 @@ def get_MTGP(
9482
t.index: t.trial_type for t in experiment.trials.values()
9583
}
9684
transforms = MT_MTGP_trans
85+
86+
# Build MetricsAsTask config manually (replaces tconfig_from_mt_experiment)
87+
canonical = experiment._metric_to_canonical_name
88+
metric_task_map = {}
89+
for metric_name, canonical_name in canonical.items():
90+
if metric_name != canonical_name:
91+
if canonical_name not in metric_task_map:
92+
metric_task_map[canonical_name] = []
93+
metric_task_map[canonical_name].append(metric_name)
94+
9795
transform_configs = {
9896
"TrialAsTask": {
9997
"trial_level_map": {"trial_type": trial_index_to_type}
10098
},
101-
"ConvertMetricNames": tconfig_from_mt_experiment(experiment),
99+
"MetricsAsTask": {"metric_task_map": metric_task_map},
102100
}
103101
else:
104102
# Set transforms for a Single-type MTGP model.
@@ -125,7 +123,7 @@ def get_MTGP(
125123
)
126124

127125
return assert_is_instance(
128-
Models.ST_MTGP(
126+
Generators.ST_MTGP(
129127
experiment=experiment,
130128
search_space=search_space or experiment.search_space,
131129
data=data,
@@ -135,7 +133,7 @@ def get_MTGP(
135133
torch_device=device,
136134
status_quo_features=status_quo_features,
137135
),
138-
TorchModelBridge,
136+
TorchAdapter,
139137
)
140138

141139

@@ -354,7 +352,6 @@ def _incorporate_external_data(self, trials: List[Trial]) -> None:
354352
arms.append(
355353
Arm(parameters=params, name=param_to_name[arm.signature])
356354
)
357-
# self._next_id += 1
358355

359356
# Create new batch trial.
360357
gr = GeneratorRun(arms=arms, weights=[1.0] * len(arms))
@@ -597,7 +594,7 @@ def _save_model_to_file(self) -> None:
597594

598595

599596
def max_utility_from_GP(
600-
n: int, m: TorchModelBridge, gr: GeneratorRun, hifi_task: str
597+
n: int, m: TorchAdapter, gr: GeneratorRun, hifi_task: str
601598
) -> GeneratorRun:
602599
"""Select the max utility points according to the MTGP predictions.
603600

optimas/generators/ax/import_error_dummy_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ def __init__(self, *args, **kwargs) -> None:
1212
raise RuntimeError(
1313
"You need to install ax-platform, in order "
1414
"to use Ax-based generators in optimas.\n"
15-
"e.g. with `pip install 'ax-platform<1.0.0'`"
15+
"e.g. with `pip install 'ax-platform'`"
1616
)

optimas/generators/ax/service/ax_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ def _create_ax_client(self) -> AxClient:
134134

135135
def _use_cuda(self, ax_client: AxClient):
136136
"""Determine whether the AxClient uses CUDA."""
137-
for step in ax_client.generation_strategy._steps:
138-
if "torch_device" in step.model_kwargs:
139-
if step.model_kwargs["torch_device"] == "cuda":
140-
return True
137+
for node in ax_client.generation_strategy._nodes:
138+
for gs in node.generator_specs:
139+
if "torch_device" in gs.generator_kwargs:
140+
if gs.generator_kwargs["torch_device"] == "cuda":
141+
return True
141142
return False

optimas/generators/ax/service/base.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@
1010
ObjectiveProperties,
1111
FixedFeatures,
1212
)
13-
from ax.modelbridge.registry import Models
14-
from ax.modelbridge.generation_strategy import (
15-
GenerationStep,
16-
GenerationStrategy,
17-
)
18-
from ax.modelbridge.transition_criterion import MaxTrials, MinTrials
19-
from ax import Arm
13+
from ax.adapter.registry import Generators
14+
from ax.generation_strategy.generation_node import GenerationStep
15+
from ax.generation_strategy.generation_strategy import GenerationStrategy
16+
from ax.generation_strategy.transition_criterion import MinTrials
17+
from ax.core.arm import Arm
2018

2119
from optimas.core import (
2220
Trial,
@@ -262,15 +260,19 @@ def _insert_unknown_trial(self, trial: Trial) -> None:
262260
# initialization trials, but only if they have not failed.
263261
if trial.completed and not self._enforce_n_init:
264262
generation_strategy = self._ax_client.generation_strategy
265-
current_step = generation_strategy.current_step
263+
current_node = generation_strategy.current_node
266264
# Reduce only if there are still Sobol trials left.
267-
if current_step.model == Models.SOBOL:
268-
for tc in current_step.transition_criteria:
269-
# Looping over all criterial makes sure we reduce
265+
is_sobol = any(
266+
gs.generator_enum == Generators.SOBOL
267+
for gs in current_node.generator_specs
268+
)
269+
if is_sobol:
270+
for tc in current_node.transition_criteria:
271+
# Looping over all criteria makes sure we reduce
270272
# the transition thresholds due to `_n_init`
271273
# (i.e., max trials) and `min_trials_observed=1` (
272274
# i.e., min trials).
273-
if isinstance(tc, (MinTrials, MaxTrials)):
275+
if isinstance(tc, MinTrials):
274276
tc.threshold -= 1
275277
generation_strategy._maybe_transition_to_next_node()
276278
return ax_trial
@@ -296,13 +298,11 @@ def _complete_trial(self, ax_trial_index: int, trial: Trial) -> None:
296298
def _create_ax_client(self) -> AxClient:
297299
"""Create Ax client."""
298300
bo_model_kwargs = {
299-
"torch_dtype": torch.double,
300301
"torch_device": torch.device(self.torch_device),
301-
"fit_out_of_design": self._fit_out_of_design,
302302
}
303303
ax_client = AxClient(
304304
generation_strategy=GenerationStrategy(
305-
self._create_generation_steps(bo_model_kwargs)
305+
nodes=self._create_generation_steps(bo_model_kwargs)
306306
),
307307
verbose_logging=False,
308308
)
@@ -339,7 +339,7 @@ def _create_sobol_step(self) -> GenerationStep:
339339
# This also allows the generator to work well when
340340
# `sim_workers` > `n_init`.
341341
return GenerationStep(
342-
model=Models.SOBOL,
342+
generator=Generators.SOBOL,
343343
num_trials=self._n_init,
344344
min_trials_observed=1,
345345
enforce_num_trials=False,
@@ -366,8 +366,8 @@ def _update_parameter(self, parameter):
366366
# Delete the fitted model from the generation strategy, otherwise
367367
# the parameter won't be updated.
368368
generation_strategy = self._ax_client.generation_strategy
369-
if generation_strategy._model is not None:
370-
del generation_strategy._curr.model_spec._fitted_model
369+
if generation_strategy._curr.generator_spec._fitted_adapter is not None:
370+
generation_strategy._curr.generator_spec._fitted_adapter = None
371371
parameters = self._create_ax_parameters()
372372
new_search_space = InstantiationBase.make_search_space(parameters, None)
373373
self._ax_client.experiment.search_space.update_parameter(

0 commit comments

Comments
 (0)