Skip to content
Merged

Ax1.0 #297

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .github/workflows/unix-noax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', 3.11, 3.12]
python-version: [3.11, 3.12]

steps:
- uses: actions/checkout@v4
Expand All @@ -28,12 +28,11 @@ jobs:
- shell: bash -l {0}
name: Install dependencies
run: |
conda install -c conda-forge pytorch-cpu
conda install -c pytorch numpy pandas
conda install -c conda-forge mpi4py mpich
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install git+https://github.com/campa-consortium/gest-api.git
pip install git+https://github.com/xopt-org/xopt.git
pip install gest-api
pip install xopt
pip uninstall --yes ax-platform # Run without Ax
- shell: bash -l {0}
name: Run unit tests without Ax
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/unix-openmpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', 3.11, 3.12]
python-version: [3.11, 3.12]

steps:
- uses: actions/checkout@v4
Expand All @@ -28,12 +28,11 @@ jobs:
- shell: bash -l {0}
name: Install dependencies
run: |
conda install -c conda-forge "numpy<2.4" "pandas<3"
conda install -c conda-forge pytorch-cpu
conda install -c conda-forge mpi4py openmpi=5.*
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install git+https://github.com/campa-consortium/gest-api.git
pip install --upgrade-strategy=only-if-needed git+https://github.com/xopt-org/xopt.git
pip install gest-api
pip install xopt
- shell: bash -l {0}
name: Run unit tests with openMPI
run: |
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/unix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', 3.11, 3.12]
python-version: [3.11, 3.12]

steps:
- uses: actions/checkout@v4
Expand All @@ -28,12 +28,11 @@ jobs:
- shell: bash -l {0}
name: Install dependencies
run: |
conda install -c conda-forge "numpy<2.4" "pandas<3"
conda install -c conda-forge pytorch-cpu
conda install -c conda-forge mpi4py mpich
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install git+https://github.com/campa-consortium/gest-api.git
pip install --upgrade-strategy=only-if-needed git+https://github.com/xopt-org/xopt.git
pip install gest-api
pip install xopt
- shell: bash -l {0}
name: Run unit tests with MPICH
run: |
Expand Down
4 changes: 2 additions & 2 deletions doc/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ dependencies:
- pip
- pip:
- -e ..
- ax-platform >= 0.5.0, < 1.0.0
- ax-platform >=1.0.0
- autodoc_pydantic >= 2.0.1
- ipykernel
- matplotlib
- nbsphinx
- numpydoc
- git+https://github.com/campa-consortium/gest-api.git
- gest-api
- pydata-sphinx-theme
- sphinx-copybutton
- sphinx-design
Expand Down
2 changes: 1 addition & 1 deletion optimas/generators/ax/developer/ax_metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Contains the definition of the Ax metric used for multitask optimization."""

import pandas as pd
from ax import Metric
from ax.core.metric import Metric
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.utils.common.result import Ok
Expand Down
69 changes: 33 additions & 36 deletions optimas/generators/ax/developer/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,33 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.objective import Objective as AxObjective
from ax.runners import SyntheticRunner
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.torch import TorchModelBridge
from ax.adapter.factory import get_sobol
from ax.adapter.torch import TorchAdapter
from ax.core.observation import ObservationFeatures
from ax.core.generator_run import GeneratorRun
from ax.storage.json_store.save import save_experiment
from ax.storage.metric_registry import register_metrics

from ax.modelbridge.registry import Models, ST_MTGP_trans

try:
# For Ax >= 0.5.0
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.convert_metric_names import (
ConvertMetricNames,
)
from ax.modelbridge.transforms.trial_as_task import TrialAsTask
from ax.modelbridge.transforms.stratified_standardize_y import (
StratifiedStandardizeY,
)
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
from ax.modelbridge.registry import MBM_X_trans

MT_MTGP_trans = MBM_X_trans + [
Derelativize,
ConvertMetricNames,
TrialAsTask,
StratifiedStandardizeY,
TaskChoiceToIntTaskChoice,
]
from ax.adapter.registry import Generators, MBM_X_trans
from ax.adapter.transforms.derelativize import Derelativize
from ax.adapter.transforms.metrics_as_task import MetricsAsTask
from ax.adapter.transforms.trial_as_task import TrialAsTask
from ax.adapter.transforms.stratified_standardize_y import (
StratifiedStandardizeY,
)
from ax.adapter.transforms.task_encode import TaskChoiceToIntTaskChoice
from ax.adapter.registry import ST_MTGP_trans

except ImportError:
# For Ax < 0.5.0
from ax.modelbridge.registry import MT_MTGP_trans
MT_MTGP_trans = MBM_X_trans + [
Derelativize,
MetricsAsTask,
TrialAsTask,
StratifiedStandardizeY,
TaskChoiceToIntTaskChoice,
]

from ax.core.experiment import Experiment
from ax.core.data import Data
from ax.modelbridge.transforms.convert_metric_names import (
tconfig_from_mt_experiment,
)

from optimas.generators.ax.base import AxGenerator
from optimas.core import (
Expand All @@ -81,7 +69,7 @@ def get_MTGP(
trial_index: Optional[int] = None,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.double,
) -> TorchModelBridge:
) -> TorchAdapter:
"""Instantiate a Multi-task Gaussian Process (MTGP) model.

Points are generated with EI (Expected Improvement).
Expand All @@ -94,11 +82,21 @@ def get_MTGP(
t.index: t.trial_type for t in experiment.trials.values()
}
transforms = MT_MTGP_trans

# Build MetricsAsTask config manually (replaces tconfig_from_mt_experiment)
canonical = experiment._metric_to_canonical_name
metric_task_map = {}
for metric_name, canonical_name in canonical.items():
if metric_name != canonical_name:
if canonical_name not in metric_task_map:
metric_task_map[canonical_name] = []
metric_task_map[canonical_name].append(metric_name)

transform_configs = {
"TrialAsTask": {
"trial_level_map": {"trial_type": trial_index_to_type}
},
"ConvertMetricNames": tconfig_from_mt_experiment(experiment),
"MetricsAsTask": {"metric_task_map": metric_task_map},
}
else:
# Set transforms for a Single-type MTGP model.
Expand All @@ -125,7 +123,7 @@ def get_MTGP(
)

return assert_is_instance(
Models.ST_MTGP(
Generators.ST_MTGP(
experiment=experiment,
search_space=search_space or experiment.search_space,
data=data,
Expand All @@ -135,7 +133,7 @@ def get_MTGP(
torch_device=device,
status_quo_features=status_quo_features,
),
TorchModelBridge,
TorchAdapter,
)


Expand Down Expand Up @@ -354,7 +352,6 @@ def _incorporate_external_data(self, trials: List[Trial]) -> None:
arms.append(
Arm(parameters=params, name=param_to_name[arm.signature])
)
# self._next_id += 1

# Create new batch trial.
gr = GeneratorRun(arms=arms, weights=[1.0] * len(arms))
Expand Down Expand Up @@ -597,7 +594,7 @@ def _save_model_to_file(self) -> None:


def max_utility_from_GP(
n: int, m: TorchModelBridge, gr: GeneratorRun, hifi_task: str
n: int, m: TorchAdapter, gr: GeneratorRun, hifi_task: str
) -> GeneratorRun:
"""Select the max utility points according to the MTGP predictions.

Expand Down
2 changes: 1 addition & 1 deletion optimas/generators/ax/import_error_dummy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def __init__(self, *args, **kwargs) -> None:
raise RuntimeError(
"You need to install ax-platform, in order "
"to use Ax-based generators in optimas.\n"
"e.g. with `pip install 'ax-platform<1.0.0'`"
"e.g. with `pip install 'ax-platform'`"
)
9 changes: 5 additions & 4 deletions optimas/generators/ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def _create_ax_client(self) -> AxClient:

def _use_cuda(self, ax_client: AxClient):
"""Determine whether the AxClient uses CUDA."""
for step in ax_client.generation_strategy._steps:
if "torch_device" in step.model_kwargs:
if step.model_kwargs["torch_device"] == "cuda":
return True
for node in ax_client.generation_strategy._nodes:
for gs in node.generator_specs:
if "torch_device" in gs.generator_kwargs:
if gs.generator_kwargs["torch_device"] == "cuda":
return True
return False
36 changes: 18 additions & 18 deletions optimas/generators/ax/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
ObjectiveProperties,
FixedFeatures,
)
from ax.modelbridge.registry import Models
from ax.modelbridge.generation_strategy import (
GenerationStep,
GenerationStrategy,
)
from ax.modelbridge.transition_criterion import MaxTrials, MinTrials
from ax import Arm
from ax.adapter.registry import Generators
from ax.generation_strategy.generation_node import GenerationStep
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.generation_strategy.transition_criterion import MinTrials
from ax.core.arm import Arm

from optimas.core import (
Trial,
Expand Down Expand Up @@ -262,15 +260,19 @@ def _insert_unknown_trial(self, trial: Trial) -> None:
# initialization trials, but only if they have not failed.
if trial.completed and not self._enforce_n_init:
generation_strategy = self._ax_client.generation_strategy
current_step = generation_strategy.current_step
current_node = generation_strategy.current_node
# Reduce only if there are still Sobol trials left.
if current_step.model == Models.SOBOL:
for tc in current_step.transition_criteria:
# Looping over all criterial makes sure we reduce
is_sobol = any(
gs.generator_enum == Generators.SOBOL
for gs in current_node.generator_specs
)
if is_sobol:
for tc in current_node.transition_criteria:
# Looping over all criteria makes sure we reduce
# the transition thresholds due to `_n_init`
# (i.e., max trials) and `min_trials_observed=1` (
# i.e., min trials).
if isinstance(tc, (MinTrials, MaxTrials)):
if isinstance(tc, MinTrials):
tc.threshold -= 1
generation_strategy._maybe_transition_to_next_node()
return ax_trial
Expand All @@ -296,13 +298,11 @@ def _complete_trial(self, ax_trial_index: int, trial: Trial) -> None:
def _create_ax_client(self) -> AxClient:
"""Create Ax client."""
bo_model_kwargs = {
"torch_dtype": torch.double,
"torch_device": torch.device(self.torch_device),
"fit_out_of_design": self._fit_out_of_design,
}
ax_client = AxClient(
generation_strategy=GenerationStrategy(
self._create_generation_steps(bo_model_kwargs)
nodes=self._create_generation_steps(bo_model_kwargs)
),
verbose_logging=False,
)
Expand Down Expand Up @@ -339,7 +339,7 @@ def _create_sobol_step(self) -> GenerationStep:
# This also allows the generator to work well when
# `sim_workers` > `n_init`.
return GenerationStep(
model=Models.SOBOL,
generator=Generators.SOBOL,
num_trials=self._n_init,
min_trials_observed=1,
enforce_num_trials=False,
Expand All @@ -366,8 +366,8 @@ def _update_parameter(self, parameter):
# Delete the fitted model from the generation strategy, otherwise
# the parameter won't be updated.
generation_strategy = self._ax_client.generation_strategy
if generation_strategy._model is not None:
del generation_strategy._curr.model_spec._fitted_model
if generation_strategy._curr.generator_spec._fitted_adapter is not None:
generation_strategy._curr.generator_spec._fitted_adapter = None
parameters = self._create_ax_parameters()
new_search_space = InstantiationBase.make_search_space(parameters, None)
self._ax_client.experiment.search_space.update_parameter(
Expand Down
Loading
Loading