Skip to content

Commit 9e6c761

Browse files
committed
fix: return hyper->value dict from hyperparam search
1 parent 2e66794 commit 9e6c761

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

rlevaluation/hypers/api.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,17 @@ def select_best_hypers(
5757

5858
rng = np.random.default_rng(0)
5959
out = bootstrap_hyper_selection(rng, score_per_seed, statistic.value, prefer.value, threshold)
60+
config = {
61+
col: df[col][out.best_idx] for col in cols
62+
}
6063

6164
return HyperSelectionResult(
62-
best_configuration=df.row(out.best_idx),
65+
best_configuration=config,
6366
best_score=out.best_score,
6467

6568
uncertainty_set_configurations=[
66-
df.row(idx) for idx in out.uncertainty_set_idxs
69+
{col: df[col][int(idx)] for col in cols}
70+
for idx in out.uncertainty_set_idxs
6771
],
6872
uncertainty_set_probs=out.uncertainty_set_probs,
6973
sample_stat=out.sample_stat,
@@ -72,10 +76,10 @@ def select_best_hypers(
7276
)
7377

7478
class HyperSelectionResult(NamedTuple):
75-
best_configuration: tuple[Any, ...]
79+
best_configuration: dict[str, Any]
7680
best_score: float
7781

78-
uncertainty_set_configurations: list[tuple[Any, ...]]
82+
uncertainty_set_configurations: list[dict[str, Any]]
7983
uncertainty_set_probs: np.ndarray
8084
sample_stat: float
8185
ci: tuple[float, float]

rlevaluation/hypers/reporting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ def pretty_print_hyper_selection_result(result: HyperSelectionResult, d: DataDef
3535
if len(result.uncertainty_set_probs) > 1:
3636
out += 'Possible best configurations:\n'
3737
out += '-----------------------------\n'
38-
for i, hyper in enumerate(cols):
39-
hyper_val = result.uncertainty_set_configurations[0][i]
38+
for hyper in cols:
39+
hyper_val = result.uncertainty_set_configurations[0][hyper]
4040
if isinstance(hyper_val, float) and np.isnan(hyper_val): continue
4141
ws = 4 + col_len - len(hyper)
4242
out += f'{hyper}:' + ' ' * ws
4343

4444
for config in result.uncertainty_set_configurations:
45-
out += f'{config[i]} '
45+
out += f'{config[hyper]} '
4646

4747
out += '\n'
4848

tests/test_hypers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ def test_select_best_hypers():
1313
d = data_definition(hyper_cols=['alpha'])
1414

1515
best = select_best_hypers(test_data, 'result', Preference.high, data_definition=d)
16-
assert best.best_configuration[0] == 0.01
16+
assert best.best_configuration['alpha'] == 0.01

0 commit comments

Comments
 (0)