-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathexample_hpolib.py
66 lines (53 loc) · 2.04 KB
/
example_hpolib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Please install the dependencies first:
```shell
$ pip install optuna==4.0.0b0 optunahub
```
"""
from argparse import ArgumentParser
import pickle
import numpy as np
import optuna
import optunahub
parser = ArgumentParser()
parser.add_argument("--dataset_id", choices=list(range(4)), type=int, required=True)
args = parser.parse_args()
dataset_id = args.dataset_id
dataset_names = ["naval_propulsion", "parkinsons_telemonitoring", "protein_structure", "slice_localization"]
dataset_name = dataset_names[dataset_id]
dataset = pickle.load(open(f"examples/{dataset_name}.pkl", mode="rb"))
seed = 0
rng = np.random.RandomState(seed)
def objective(trial: optuna.Trial) -> float:
# These are the indices of each hyperparameter.
hyperparameter_indices = [
trial.suggest_categorical("activation_fn_1", list(range(2))),
trial.suggest_categorical("activation_fn_2", list(range(2))),
trial.suggest_int("batch_size", low=0, high=3),
trial.suggest_int("dropout_1", low=0, high=2),
trial.suggest_int("dropout_2", low=0, high=2),
trial.suggest_int("init_lr", low=0, high=5),
trial.suggest_categorical("lr_schedule", list(range(2))),
trial.suggest_int("n_units_1", low=0, high=5),
trial.suggest_int("n_units_2", low=0, high=5),
]
config_id = "".join([str(i) for i in hyperparameter_indices])
eval_seed = rng.randint(4)
return dataset[config_id][eval_seed]
module = optunahub.load_module(package="samplers/tpe_tutorial")
# NOTE: Please check https://hub.optuna.org/samplers/tpe_tutorial/ for the parameter descriptions.
tpe_config = {
"consider_prior": True,
"consider_magic_clip": True,
"multivariate": True,
"b_magic_exponent": 1.0,
"min_bandwidth_factor": 0.01,
"gamma_strategy": "linear",
"gamma_beta": 0.1,
"weight_strategy": "old-decay",
"bandwidth_strategy": "hyperopt",
"categorical_prior_weight": None,
}
sampler = module.CustomizableTPESampler(seed=seed)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=100)