Skip to content

Commit 054cae9

Browse files
authored
Merge: Integrate Hook Examples (#331)
Integrating the dev branch dealing with the new hooks mechanic and corresponding examples. Includes #322 #287
2 parents 063c5ca + c96f777 commit 054cae9

11 files changed

+12898
-27
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4848
`DiscreteExcludeConstraint`, `DiscreteLinkedParametersConstraint` and
4949
`DiscreteNoLabelDuplicatesConstraint`
5050
- Discrete search space Cartesian product can be created lazily via Polars
51+
- Examples demonstrating the `register_hooks` utility: basic registration mechanism,
52+
monitoring the probability of improvement, and automatic campaign stopping
5153

5254
### Changed
5355
- Passing an `Objective` to `Campaign` is now optional

baybe/simulation/scenarios.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def simulate_scenarios(
3131
initial_data: list[pd.DataFrame] | None = None,
3232
groupby: list[str] | None = None,
3333
n_mc_iterations: int = 1,
34+
random_seed: int | None = None,
3435
impute_mode: Literal[
3536
"error", "worst", "best", "mean", "random", "ignore"
3637
] = "error",
@@ -52,6 +53,9 @@ def simulate_scenarios(
5253
A separate simulation will be conducted for each partition, with the search
5354
restricted to that partition.
5455
n_mc_iterations: The number of Monte Carlo simulations to be used.
56+
random_seed: An optional integer specifying the random seed for the first Monte
57+
Carlo run. Each subsequent runs will increase this value by 1. If omitted,
58+
the current random seed is used.
5559
impute_mode: See :func:`baybe.simulation.core.simulate_experiment`.
5660
noise_percent: See :func:`baybe.simulation.core.simulate_experiment`.
5761
@@ -64,7 +68,10 @@ def simulate_scenarios(
6468
function:
6569
6670
* ``Scenario``: Specifies the scenario identifier of the respective simulation.
67-
* ``Random_Seed``: Specifies the random seed used for the respective simulation.
71+
* ``Monte_Carlo_Run``: Specifies the Monte Carlo repetition of the
72+
respective simulation.
73+
* Optional, if ``random_seed`` is provided: A column ``Random_Seed`` that
74+
specifies the random seed used for the respective simulation.
6875
* Optional, if ``initial_data`` is provided: A column ``Initial_Data`` that
6976
specifies the index of the initial data set used for the respective
7077
simulation.
@@ -90,24 +97,26 @@ def make_xyzpy_callable(result_variable: str) -> Callable:
9097
@xyzpy.label(var_names=[result_variable])
9198
def simulate(
9299
Scenario: str,
93-
Random_Seed=None,
100+
Monte_Carlo_Run: int,
94101
Initial_Data=None,
95102
):
96103
"""Callable for xyzpy simulation."""
97104
data = None if initial_data is None else initial_data[Initial_Data]
98-
return SimulationResult(
99-
_simulate_groupby(
100-
scenarios[Scenario],
101-
lookup,
102-
batch_size=batch_size,
103-
n_doe_iterations=n_doe_iterations,
104-
initial_data=data,
105-
groupby=groupby,
106-
random_seed=Random_Seed,
107-
impute_mode=impute_mode,
108-
noise_percent=noise_percent,
109-
)
105+
seed = None if random_seed is None else Monte_Carlo_Run + _DEFAULT_SEED
106+
result = _simulate_groupby(
107+
scenarios[Scenario],
108+
lookup,
109+
batch_size=batch_size,
110+
n_doe_iterations=n_doe_iterations,
111+
initial_data=data,
112+
groupby=groupby,
113+
random_seed=seed,
114+
impute_mode=impute_mode,
115+
noise_percent=noise_percent,
110116
)
117+
if random_seed is not None:
118+
result["Random_Seed"] = seed
119+
return SimulationResult(result)
111120

112121
return simulate
113122

@@ -130,7 +139,7 @@ def unpack_simulation_results(array: DataArray) -> pd.DataFrame:
130139

131140
# Collect the settings to be simulated
132141
combos = {"Scenario": scenarios.keys()}
133-
combos["Random_Seed"] = range(_DEFAULT_SEED, _DEFAULT_SEED + n_mc_iterations)
142+
combos["Monte_Carlo_Run"] = range(n_mc_iterations)
134143
if initial_data:
135144
combos["Initial_Data"] = range(len(initial_data))
136145

baybe/utils/botorch_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""A wrapper class for synthetic BoTorch test functions."""
22

3+
import torch
34
from botorch.test_functions import SyntheticTestFunction
4-
from torch import Tensor
55

66

77
def botorch_function_wrapper(test_function: SyntheticTestFunction):
@@ -19,7 +19,7 @@ def botorch_function_wrapper(test_function: SyntheticTestFunction):
1919

2020
def wrapper(*x: float) -> float:
2121
# Cast the provided list of floats to a tensor.
22-
x_tensor = Tensor(x)
22+
x_tensor = torch.tensor(x)
2323
result = test_function.forward(x_tensor)
2424
# We do not need to return a tuple here.
2525
return float(result)

baybe/utils/plotting.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
import matplotlib.pyplot as plt
1111
from matplotlib.axes import Axes
1212
from matplotlib.figure import Figure
13+
from mpl_toolkits.mplot3d import Axes3D
1314

1415

1516
def create_example_plots(
16-
ax: Axes,
17+
ax: Axes | Axes3D,
1718
base_name: str,
1819
path: Path = Path("."),
1920
) -> None:
@@ -31,6 +32,9 @@ def create_example_plots(
3132
ax: The Axes object containing the figure that should be plotted.
3233
base_name: The base name that is used for naming the output files.
3334
path: Optional path to the directory in which the plots should be saved.
35+
36+
Returns:
37+
The ``Figure`` containing ``ax``
3438
"""
3539
# Check whether we immediately return due to just running a SMOKE_TEST
3640
if "SMOKE_TEST" in os.environ:
@@ -96,6 +100,9 @@ def create_example_plots(
96100
ax.xaxis.label.set_fontsize(fontsize)
97101
ax.yaxis.label.set_color(color)
98102
ax.yaxis.label.set_fontsize(fontsize)
103+
if isinstance(ax, Axes3D):
104+
ax.zaxis.label.set_color(color)
105+
ax.zaxis.label.set_fontsize(fontsize)
99106

100107
# Adjust the size of the ax
101108
# mypy thinks that ax.figure might become None, hence the explicit ignore
@@ -105,18 +112,22 @@ def create_example_plots(
105112
warnings.warn("Could not adjust size of plot due to it not being a Figure.")
106113

107114
# Adjust the labels
108-
for label in ax.get_xticklabels() + ax.get_yticklabels():
115+
ticklabels = ax.get_xticklabels() + ax.get_yticklabels()
116+
if isinstance(ax, Axes3D):
117+
ticklabels += ax.get_zticklabels()
118+
for label in ticklabels:
109119
label.set_color(color)
110120
label.set_fontsize(fontsize)
111121

112-
# Adjust the legend
122+
# Adjust the legend if it exists
113123
legend = ax.get_legend()
114-
legend.get_frame().set_alpha(framealpha)
115-
legend.get_title().set_color(color)
116-
legend.get_title().set_fontsize(fontsize)
117-
for text in legend.get_texts():
118-
text.set_fontsize(fontsize)
119-
text.set_color(color)
124+
if legend:
125+
legend.get_frame().set_alpha(framealpha)
126+
legend.get_title().set_color(color)
127+
legend.get_title().set_fontsize(fontsize)
128+
for text in legend.get_texts():
129+
text.set_fontsize(fontsize)
130+
text.set_color(color)
120131

121132
output_path = Path(path, f"{base_name}_{theme_name}.svg")
122133
# mypy thinks that ax.figure might become None, hence the explicit ignore
@@ -128,4 +139,4 @@ def create_example_plots(
128139
)
129140
else:
130141
warnings.warn("Plots could not be saved.")
131-
plt.close()
142+
plt.close()

0 commit comments

Comments
 (0)