6
6
7
7
import torch
8
8
from botorch .acquisition .analytic import AcquisitionFunction
9
- from botorch .acquisition .objective import PosteriorTransform
9
+ from botorch .acquisition .objective import (
10
+ IdentityMCObjective ,
11
+ MCAcquisitionObjective ,
12
+ PosteriorTransform ,
13
+ )
14
+ from botorch .exceptions .errors import UnsupportedError
15
+ from botorch .models .deterministic import GenericDeterministicModel
10
16
from botorch .models .model import Model
11
17
from botorch .sampling .pathwise .posterior_samplers import get_matheron_path_model
12
- from botorch .utils .transforms import t_batch_mode_transform
18
+ from botorch .utils .transforms import is_ensemble , t_batch_mode_transform
13
19
from torch import Tensor
14
20
15
21
@@ -32,55 +38,151 @@ class PathwiseThompsonSampling(AcquisitionFunction):
32
38
def __init__ (
33
39
self ,
34
40
model : Model ,
41
+ objective : MCAcquisitionObjective | None = None ,
35
42
posterior_transform : PosteriorTransform | None = None ,
36
43
) -> None :
37
44
r"""Single-outcome TS.
38
45
46
+ If using a multi-output `model`, the acquisition function requires either an
47
+ `objective` or a `posterior_transform` that transforms the multi-output
48
+ posterior samples to single-output posterior samples.
49
+
39
50
Args:
40
51
model: A fitted GP model.
41
- posterior_transform: A PosteriorTransform. If using a multi-output model,
42
- a PosteriorTransform that transforms the multi-output posterior into a
43
- single-output posterior is required .
52
+ objective: The MCAcquisitionObjective under which the samples are
53
+ evaluated. Defaults to `IdentityMCObjective()`.
54
+ posterior_transform: An optional PosteriorTransform .
44
55
"""
45
- if model ._is_fully_bayesian :
46
- raise NotImplementedError (
47
- "PathwiseThompsonSampling is not supported for fully Bayesian models" ,
48
- )
49
56
50
57
super ().__init__ (model = model )
51
58
self .batch_size : int | None = None
52
-
53
- def redraw (self ) -> None :
59
+ self .samples : GenericDeterministicModel | None = None
60
+ self .ensemble_indices : Tensor | None = None
61
+
62
+ # NOTE: This conditional block is copied from MCAcquisitionFunction, we should
63
+ # consider inherting from it and e.g. getting the X_pending logic as well.
64
+ if objective is None and model .num_outputs != 1 :
65
+ if posterior_transform is None :
66
+ raise UnsupportedError (
67
+ "Must specify an objective or a posterior transform when using "
68
+ "a multi-output model."
69
+ )
70
+ elif not posterior_transform .scalarize :
71
+ raise UnsupportedError (
72
+ "If using a multi-output model without an objective, "
73
+ "posterior_transform must scalarize the output."
74
+ )
75
+ if objective is None :
76
+ objective = IdentityMCObjective ()
77
+ self .objective = objective
78
+ self .posterior_transform = posterior_transform
79
+
80
+ def redraw (self , batch_size : int ) -> None :
81
+ sample_shape = (batch_size ,)
54
82
self .samples = get_matheron_path_model (
55
- model = self .model , sample_shape = torch .Size ([ self . batch_size ] )
83
+ model = self .model , sample_shape = torch .Size (sample_shape )
56
84
)
85
+ if is_ensemble (self .model ):
86
+ # the ensembling dimension is assumed to be part of the batch shape
87
+ model_batch_shape = self .model .batch_shape
88
+ if len (model_batch_shape ) > 1 :
89
+ raise NotImplementedError (
90
+ "Ensemble models with more than one ensemble dimension are not "
91
+ "yet supported."
92
+ )
93
+ num_ensemble = model_batch_shape [0 ]
94
+ # ensemble_indices is cached here to ensure that the acquisition function
95
+ # becomes deterministic for the same input and can be optimized with LBFGS.
96
+ # ensemble_indices is used in select_from_ensemble_models.
97
+ self .ensemble_indices = torch .randint (
98
+ 0 ,
99
+ num_ensemble ,
100
+ (* sample_shape , 1 , self .model .num_outputs ),
101
+ )
57
102
58
103
@t_batch_mode_transform ()
59
104
def forward (self , X : Tensor ) -> Tensor :
60
105
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
61
106
62
107
Args:
63
- X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
108
+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
64
109
65
110
Returns:
66
- A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
67
- evaluations on the posterior sample draws .
111
+ A `batch_shape`-dim tensor of evaluations on the posterior sample draws,
112
+ where the samples are summed over the q-batch dimension .
68
113
"""
69
- batch_size = X .shape [- 2 ]
70
- q_dim = - 2
114
+ objective_values = self ._pathwise_forward (X ) # batch_shape x q
115
+ # NOTE: The current implementation sums over the q-batch dimension, which means
116
+ # that we are optimizing the sum of independent Thompson samples. In the future,
117
+ # we can leverage *batched* L-BFGS optimization, rather than summing over the q
118
+ # dimension, which will guarantee descent steps for all members of the batch
119
+ # through batch-member-specific learning rate selection.
120
+ return objective_values .sum (- 1 ) # batch_shape
71
121
122
+ def _pathwise_forward (self , X : Tensor ) -> Tensor :
123
+ """Evaluate the pathwise posterior sample draws on the candidate set X.
124
+
125
+ Args:
126
+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
127
+
128
+ Returns:
129
+ A `batch_shape x q`-dim tensor of evaluations on the posterior sample draws.
130
+ """
131
+ batch_size = X .shape [- 2 ]
72
132
# batch_shape x q x 1 x d
73
133
X = X .unsqueeze (- 2 )
74
- if self .batch_size is None :
134
+ if self .samples is None :
75
135
self .batch_size = batch_size
76
- self .redraw ()
77
- elif self .batch_size != batch_size :
136
+ self .redraw (batch_size = batch_size )
137
+
138
+ if self .batch_size != batch_size :
78
139
raise ValueError (
79
140
BATCH_SIZE_CHANGE_ERROR .format (self .batch_size , batch_size )
80
141
)
142
+ # batch_shape x q [x num_ensembles] x 1 x m
143
+ posterior_values = self .samples (X )
144
+ # batch_shape x q [x num_ensembles] x m
145
+ posterior_values = posterior_values .squeeze (- 2 )
81
146
82
- # posterior_values.shape post-squeeze:
83
147
# batch_shape x q x m
84
- posterior_values = self .samples (X ).squeeze (- 2 )
85
- # sum over batch dim and squeeze num_objectives dim (-1)
86
- return posterior_values .sum (q_dim ).squeeze (- 1 )
148
+ posterior_values = self .select_from_ensemble_models (values = posterior_values )
149
+
150
+ if self .posterior_transform :
151
+ posterior_values = self .posterior_transform .evaluate (posterior_values )
152
+ # objective removes the `m` dimension
153
+ objective_values = self .objective (posterior_values ) # batch_shape x q
154
+ return objective_values
155
+
156
+ def select_from_ensemble_models (self , values : Tensor ):
157
+ """Subselecting a value associated with a single sample in the ensemble for each
158
+ element of samples that is not associated with an ensemble dimension.
159
+
160
+ NOTE: 1) uses `self.model` and `is_ensemble` to determine whether or not an
161
+ ensembling dimension is present. 2) uses `self.ensemble_indices` to select the
162
+ value associated with a single sample in the ensemble. `ensemble_indices`
163
+ contains uniformly randomly sample indices for each element of the ensemble, but
164
+ is cached to make the evaluation of the acquisition function deterministic.
165
+
166
+ Args:
167
+ values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
168
+
169
+ Returns:
170
+ A`batch_shape x num_draws x q x m`-dim where each element is contains a
171
+ single sample from the ensemble, selected with `self.ensemble_indices`.
172
+ """
173
+ if not is_ensemble (self .model ):
174
+ return values
175
+
176
+ ensemble_dim = - 2
177
+ # `ensemble_indices` are fixed so that the acquisition function becomes
178
+ # deterministic for the same input and can be optimized with LBFGS.
179
+ # ensemble indices have shape num_paths x 1 x m
180
+ self .ensemble_indices = self .ensemble_indices .to (device = values .device )
181
+ index = self .ensemble_indices
182
+ input_batch_shape = values .shape [:- 3 ]
183
+ index = index .expand (* input_batch_shape , * index .shape )
184
+ # samples is batch_shape x q x num_ensemble x m
185
+ values_wo_ensemble = torch .gather (values , dim = ensemble_dim , index = index )
186
+ return values_wo_ensemble .squeeze (
187
+ ensemble_dim
188
+ ) # removing the ensemble dimension
0 commit comments