Skip to content

Commit aca2a96

Browse files
authored
Merge branch 'main' into align-mypy-version
2 parents 0527932 + 4e58df8 commit aca2a96

5 files changed

Lines changed: 36 additions & 17 deletions

File tree

src/optimagic/optimization/multistart.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
import warnings
15-
from dataclasses import replace
15+
from dataclasses import dataclass, replace
1616
from typing import Literal
1717

1818
import numpy as np
@@ -85,8 +85,8 @@ def run_multistart_optimization(
8585

8686
scheduled_steps = scheduled_steps[1:]
8787

88-
sorted_sample = exploration_res["sorted_sample"]
89-
sorted_values = exploration_res["sorted_values"]
88+
sorted_sample = exploration_res.sorted_sample
89+
sorted_values = exploration_res.sorted_values
9090

9191
stopping_maxopt = options.stopping_maxopt
9292
if stopping_maxopt > len(sorted_sample):
@@ -172,7 +172,7 @@ def single_optimization(x0, step_id):
172172
"start_parameters": state["start_history"],
173173
"local_optima": state["result_history"],
174174
"exploration_sample": sorted_sample,
175-
"exploration_results": exploration_res["sorted_values"],
175+
"exploration_results": sorted_values,
176176
}
177177

178178
raw_res = state["best_res"]
@@ -288,12 +288,27 @@ def _draw_exploration_sample(
288288
return sample_scaled
289289

290290

291+
@dataclass(frozen=True)
292+
class _InternalExplorationResult:
293+
"""Exploration result of the multistart optimization.
294+
295+
Attributes:
296+
sorted_values: List of sorted function values.
297+
sorted_sample: 2d numpy array where each row is the internal parameter
298+
vector corresponding to the sorted function values.
299+
300+
"""
301+
302+
sorted_values: list[float]
303+
sorted_sample: NDArray[np.float64]
304+
305+
291306
def run_explorations(
292307
internal_problem: InternalOptimizationProblem,
293308
sample: NDArray[np.float64],
294309
n_cores: int,
295310
step_id: int,
296-
) -> dict[str, NDArray[np.float64]]:
311+
) -> _InternalExplorationResult:
297312
"""Do the function evaluations for the exploration phase.
298313
299314
Args:
@@ -305,11 +320,11 @@ def run_explorations(
305320
step_id: The identifier of the exploration step.
306321
307322
Returns:
308-
dict: A dictionary with the the following entries:
309-
"sorted_values": 1d numpy array with sorted function values. Invalid
310-
function values are excluded.
311-
"sorted_sample": 2d numpy array with corresponding internal parameter
312-
vectors.
323+
A data object containing
324+
- sorted_values: List of sorted function values. Invalid function values are
325+
excluded.
326+
- sorted_sample: 2d numpy array where each row is the internal parameter
327+
vector corresponding to the sorted function values.
313328
314329
"""
315330
internal_problem = internal_problem.with_step_id(step_id)
@@ -334,10 +349,10 @@ def run_explorations(
334349
# of the sign switch.
335350
sorting_indices = np.argsort(valid_values)
336351

337-
out = {
338-
"sorted_values": valid_values[sorting_indices],
339-
"sorted_sample": valid_sample[sorting_indices],
340-
}
352+
out = _InternalExplorationResult(
353+
sorted_values=valid_values[sorting_indices].tolist(),
354+
sorted_sample=valid_sample[sorting_indices],
355+
)
341356

342357
return out
343358

src/optimagic/optimization/process_results.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def _process_multistart_info(
137137
solver_type: AggregationLevel,
138138
extra_fields: ExtraResultFields,
139139
) -> MultistartInfo:
140+
# The `info` dictionary is obtained from the `multistart_info` field of the
141+
# InternalOptimizeResult returned by `run_multistart_optimization` function.
142+
140143
starts = [converter.params_from_internal(x) for x in info["start_parameters"]]
141144

142145
optima = []

src/optimagic/visualization/history_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def _extract_plotting_data_from_results_object(
339339
if stack_multistart and local_histories is not None:
340340
stacked = _get_stacked_local_histories(local_histories, res.direction)
341341
if show_exploration:
342-
fun = res.multistart_info.exploration_results.tolist()[::-1] + stacked.fun
342+
fun = res.multistart_info.exploration_results[::-1] + stacked.fun
343343
params = res.multistart_info.exploration_sample[::-1] + stacked.params
344344

345345
stacked = History(

tests/optimagic/optimization/test_multistart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def with_step_id(self, step_id):
8484
exp_values = np.array([-9, -1])
8585
exp_sample = np.array([[4, 5], [0, 1]])
8686

87-
aaae(calculated["sorted_sample"], exp_sample)
88-
aaae(calculated["sorted_values"], exp_values)
87+
aaae(calculated.sorted_sample, exp_sample)
88+
aaae(calculated.sorted_values, exp_values)
8989

9090

9191
def test_get_batched_optimization_sample():

tests/optimagic/optimization/test_with_multistart.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def test_multistart_optimization_with_sum_of_squares_at_defaults(
8080
assert hasattr(res, "multistart_info")
8181
ms_info = res.multistart_info
8282
assert len(ms_info.exploration_sample) == 400
83+
assert isinstance(ms_info.exploration_results, list)
8384
assert len(ms_info.exploration_results) == 400
8485
assert all(isinstance(entry, float) for entry in ms_info.exploration_results)
8586
assert all(isinstance(entry, OptimizeResult) for entry in ms_info.local_optima)

0 commit comments

Comments
 (0)