55from typing import Any , ClassVar
66
77import pandas as pd
8- from attr .converters import optional
98from attrs import define , field
9+ from attrs .converters import optional as optional_c
10+ from attrs .validators import gt , instance_of
1011
1112from baybe .acquisition .acqfs import qThompsonSampling
1213from baybe .exceptions import (
@@ -50,12 +51,12 @@ class BotorchRecommender(BayesianRecommender):
5051 # Object variables
5152 sequential_continuous : bool = field (default = False )
5253 """Flag defining whether to apply sequential greedy or batch optimization in
53- **continuous** search spaces. ( In discrete/hybrid spaces, sequential greedy
54- optimization is applied automatically.)
54+ **continuous** search spaces. In discrete/hybrid spaces, sequential greedy
55+ optimization is applied automatically.
5556 """
5657
5758 hybrid_sampler : DiscreteSamplingMethod | None = field (
58- converter = optional (DiscreteSamplingMethod ), default = None
59+ converter = optional_c (DiscreteSamplingMethod ), default = None
5960 )
6061 """Strategy used for sampling the discrete subspace when performing hybrid search
6162 space optimization."""
@@ -64,6 +65,16 @@ class BotorchRecommender(BayesianRecommender):
6465 """Percentage of discrete search space that is sampled when performing hybrid search
6566 space optimization. Ignored when ``hybrid_sampler="None"``."""
6667
68+ n_restarts : int = field (validator = [instance_of (int ), gt (0 )], default = 10 )
69+ """Number of times gradient-based optimization is restarted from different initial
70+ points. **Does not affect purely discrete optimization**.
71+ """
72+
73+ n_raw_samples : int = field (validator = [instance_of (int ), gt (0 )], default = 64 )
74+ """Number of raw samples drawn for the initialization heuristic in gradient-based
75+ optimization. **Does not affect purely discrete optimization**.
76+ """
77+
6778 @sampling_percentage .validator
6879 def _validate_percentage ( # noqa: DOC101, DOC103
6980 self , _ : Any , value : float
@@ -168,8 +179,8 @@ def _recommend_continuous(
168179 acq_function = self ._botorch_acqf ,
169180 bounds = torch .from_numpy (subspace_continuous .comp_rep_bounds .values ),
170181 q = batch_size ,
171- num_restarts = 5 , # TODO make choice for num_restarts
172- raw_samples = 10 , # TODO make choice for raw_samples
182+ num_restarts = self . n_restarts ,
183+ raw_samples = self . n_raw_samples ,
173184 equality_constraints = [
174185 c .to_botorch (subspace_continuous .parameters )
175186 for c in subspace_continuous .constraints_lin_eq
@@ -252,8 +263,8 @@ def _recommend_hybrid(
252263 acq_function = self ._botorch_acqf ,
253264 bounds = torch .from_numpy (searchspace .comp_rep_bounds .values ),
254265 q = batch_size ,
255- num_restarts = 5 , # TODO make choice for num_restarts
256- raw_samples = 10 , # TODO make choice for raw_samples
266+ num_restarts = self . n_restarts ,
267+ raw_samples = self . n_raw_samples ,
257268 fixed_features_list = fixed_features_list ,
258269 equality_constraints = [
259270 c .to_botorch (
0 commit comments