5
5
from typing import Any , ClassVar
6
6
7
7
import pandas as pd
8
- from attr .converters import optional
9
8
from attrs import define , field
9
+ from attrs .converters import optional as optional_c
10
+ from attrs .validators import gt , instance_of
10
11
11
12
from baybe .acquisition .acqfs import qThompsonSampling
12
13
from baybe .exceptions import (
@@ -50,12 +51,12 @@ class BotorchRecommender(BayesianRecommender):
50
51
# Object variables
51
52
sequential_continuous : bool = field (default = False )
52
53
"""Flag defining whether to apply sequential greedy or batch optimization in
53
- **continuous** search spaces. ( In discrete/hybrid spaces, sequential greedy
54
- optimization is applied automatically.)
54
+ **continuous** search spaces. In discrete/hybrid spaces, sequential greedy
55
+ optimization is applied automatically.
55
56
"""
56
57
57
58
hybrid_sampler : DiscreteSamplingMethod | None = field (
58
- converter = optional (DiscreteSamplingMethod ), default = None
59
+ converter = optional_c (DiscreteSamplingMethod ), default = None
59
60
)
60
61
"""Strategy used for sampling the discrete subspace when performing hybrid search
61
62
space optimization."""
@@ -64,6 +65,16 @@ class BotorchRecommender(BayesianRecommender):
64
65
"""Percentage of discrete search space that is sampled when performing hybrid search
65
66
space optimization. Ignored when ``hybrid_sampler="None"``."""
66
67
68
+ n_restarts : int = field (validator = [instance_of (int ), gt (0 )], default = 10 )
69
+ """Number of times gradient-based optimization is restarted from different initial
70
+ points. **Does not affect purely discrete optimization**.
71
+ """
72
+
73
+ n_raw_samples : int = field (validator = [instance_of (int ), gt (0 )], default = 64 )
74
+ """Number of raw samples drawn for the initialization heuristic in gradient-based
75
+ optimization. **Does not affect purely discrete optimization**.
76
+ """
77
+
67
78
@sampling_percentage .validator
68
79
def _validate_percentage ( # noqa: DOC101, DOC103
69
80
self , _ : Any , value : float
@@ -168,8 +179,8 @@ def _recommend_continuous(
168
179
acq_function = self ._botorch_acqf ,
169
180
bounds = torch .from_numpy (subspace_continuous .comp_rep_bounds .values ),
170
181
q = batch_size ,
171
- num_restarts = 5 , # TODO make choice for num_restarts
172
- raw_samples = 10 , # TODO make choice for raw_samples
182
+ num_restarts = self . n_restarts ,
183
+ raw_samples = self . n_raw_samples ,
173
184
equality_constraints = [
174
185
c .to_botorch (subspace_continuous .parameters )
175
186
for c in subspace_continuous .constraints_lin_eq
@@ -252,8 +263,8 @@ def _recommend_hybrid(
252
263
acq_function = self ._botorch_acqf ,
253
264
bounds = torch .from_numpy (searchspace .comp_rep_bounds .values ),
254
265
q = batch_size ,
255
- num_restarts = 5 , # TODO make choice for num_restarts
256
- raw_samples = 10 , # TODO make choice for raw_samples
266
+ num_restarts = self . n_restarts ,
267
+ raw_samples = self . n_raw_samples ,
257
268
fixed_features_list = fixed_features_list ,
258
269
equality_constraints = [
259
270
c .to_botorch (
0 commit comments