1212"""
1313
1414import warnings
15- from dataclasses import replace
15+ from dataclasses import dataclass , replace
1616from typing import Literal
1717
1818import numpy as np
@@ -85,8 +85,8 @@ def run_multistart_optimization(
8585
8686 scheduled_steps = scheduled_steps [1 :]
8787
88- sorted_sample = exploration_res [ " sorted_sample" ]
89- sorted_values = exploration_res [ " sorted_values" ]
88+ sorted_sample = exploration_res . sorted_sample
89+ sorted_values = exploration_res . sorted_values
9090
9191 stopping_maxopt = options .stopping_maxopt
9292 if stopping_maxopt > len (sorted_sample ):
@@ -172,7 +172,7 @@ def single_optimization(x0, step_id):
172172 "start_parameters" : state ["start_history" ],
173173 "local_optima" : state ["result_history" ],
174174 "exploration_sample" : sorted_sample ,
175- "exploration_results" : exploration_res [ " sorted_values" ] ,
175+ "exploration_results" : sorted_values ,
176176 }
177177
178178 raw_res = state ["best_res" ]
@@ -288,12 +288,27 @@ def _draw_exploration_sample(
288288 return sample_scaled
289289
290290
291+ @dataclass (frozen = True )
292+ class _InternalExplorationResult :
293+ """Exploration result of the multistart optimization.
294+
295+ Attributes:
296+ sorted_values: List of sorted function values.
297+ sorted_sample: 2d numpy array where each row is the internal parameter
298+ vector corresponding to the sorted function values.
299+
300+ """
301+
302+ sorted_values : list [float ]
303+ sorted_sample : NDArray [np .float64 ]
304+
305+
291306def run_explorations (
292307 internal_problem : InternalOptimizationProblem ,
293308 sample : NDArray [np .float64 ],
294309 n_cores : int ,
295310 step_id : int ,
296- ) -> dict [ str , NDArray [ np . float64 ]] :
311+ ) -> _InternalExplorationResult :
297312 """Do the function evaluations for the exploration phase.
298313
299314 Args:
@@ -305,11 +320,11 @@ def run_explorations(
305320 step_id: The identifier of the exploration step.
306321
307322 Returns:
308- dict: A dictionary with the the following entries:
309- " sorted_values": 1d numpy array with sorted function values. Invalid
310- function values are excluded.
311- " sorted_sample" : 2d numpy array with corresponding internal parameter
312- vectors .
323+ A data object containing
324+ - sorted_values: List of sorted function values. Invalid function values are
325+ excluded.
326+ - sorted_sample: 2d numpy array where each row is the internal parameter
327+ vector corresponding to the sorted function values .
313328
314329 """
315330 internal_problem = internal_problem .with_step_id (step_id )
@@ -334,10 +349,10 @@ def run_explorations(
334349 # of the sign switch.
335350 sorting_indices = np .argsort (valid_values )
336351
337- out = {
338- " sorted_values" : valid_values [sorting_indices ],
339- " sorted_sample" : valid_sample [sorting_indices ],
340- }
352+ out = _InternalExplorationResult (
353+ sorted_values = valid_values [sorting_indices ]. tolist () ,
354+ sorted_sample = valid_sample [sorting_indices ],
355+ )
341356
342357 return out
343358
0 commit comments