Skip to content

Commit c05f725

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Pathwise Thomspon sampling for ensemble models (#2877)
Summary: Pull Request resolved: #2877 This commit adds support for pathwise Thompson sampling for ensemble models, including fully Bayesian SAAS models. Reviewed By: saitcakmak Differential Revision: D75990595 fbshipit-source-id: d80cb26658057cf63c50fd9e9b950553f95ea1be
1 parent 7c5cc76 commit c05f725

File tree

14 files changed

+438
-120
lines changed

14 files changed

+438
-120
lines changed

botorch/acquisition/thompson_sampling.py

Lines changed: 126 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
import torch
88
from botorch.acquisition.analytic import AcquisitionFunction
9-
from botorch.acquisition.objective import PosteriorTransform
9+
from botorch.acquisition.objective import (
10+
IdentityMCObjective,
11+
MCAcquisitionObjective,
12+
PosteriorTransform,
13+
)
14+
from botorch.exceptions.errors import UnsupportedError
15+
from botorch.models.deterministic import GenericDeterministicModel
1016
from botorch.models.model import Model
1117
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
12-
from botorch.utils.transforms import t_batch_mode_transform
18+
from botorch.utils.transforms import is_ensemble, t_batch_mode_transform
1319
from torch import Tensor
1420

1521

@@ -32,55 +38,151 @@ class PathwiseThompsonSampling(AcquisitionFunction):
3238
def __init__(
3339
self,
3440
model: Model,
41+
objective: MCAcquisitionObjective | None = None,
3542
posterior_transform: PosteriorTransform | None = None,
3643
) -> None:
3744
r"""Single-outcome TS.
3845
46+
If using a multi-output `model`, the acquisition function requires either an
47+
`objective` or a `posterior_transform` that transforms the multi-output
48+
posterior samples to single-output posterior samples.
49+
3950
Args:
4051
model: A fitted GP model.
41-
posterior_transform: A PosteriorTransform. If using a multi-output model,
42-
a PosteriorTransform that transforms the multi-output posterior into a
43-
single-output posterior is required.
52+
objective: The MCAcquisitionObjective under which the samples are
53+
evaluated. Defaults to `IdentityMCObjective()`.
54+
posterior_transform: An optional PosteriorTransform.
4455
"""
45-
if model._is_fully_bayesian:
46-
raise NotImplementedError(
47-
"PathwiseThompsonSampling is not supported for fully Bayesian models",
48-
)
4956

5057
super().__init__(model=model)
5158
self.batch_size: int | None = None
52-
53-
def redraw(self) -> None:
59+
self.samples: GenericDeterministicModel | None = None
60+
self.ensemble_indices: Tensor | None = None
61+
62+
# NOTE: This conditional block is copied from MCAcquisitionFunction, we should
63+
# consider inherting from it and e.g. getting the X_pending logic as well.
64+
if objective is None and model.num_outputs != 1:
65+
if posterior_transform is None:
66+
raise UnsupportedError(
67+
"Must specify an objective or a posterior transform when using "
68+
"a multi-output model."
69+
)
70+
elif not posterior_transform.scalarize:
71+
raise UnsupportedError(
72+
"If using a multi-output model without an objective, "
73+
"posterior_transform must scalarize the output."
74+
)
75+
if objective is None:
76+
objective = IdentityMCObjective()
77+
self.objective = objective
78+
self.posterior_transform = posterior_transform
79+
80+
def redraw(self, batch_size: int) -> None:
81+
sample_shape = (batch_size,)
5482
self.samples = get_matheron_path_model(
55-
model=self.model, sample_shape=torch.Size([self.batch_size])
83+
model=self.model, sample_shape=torch.Size(sample_shape)
5684
)
85+
if is_ensemble(self.model):
86+
# the ensembling dimension is assumed to be part of the batch shape
87+
model_batch_shape = self.model.batch_shape
88+
if len(model_batch_shape) > 1:
89+
raise NotImplementedError(
90+
"Ensemble models with more than one ensemble dimension are not "
91+
"yet supported."
92+
)
93+
num_ensemble = model_batch_shape[0]
94+
# ensemble_indices is cached here to ensure that the acquisition function
95+
# becomes deterministic for the same input and can be optimized with LBFGS.
96+
# ensemble_indices is used in select_from_ensemble_models.
97+
self.ensemble_indices = torch.randint(
98+
0,
99+
num_ensemble,
100+
(*sample_shape, 1, self.model.num_outputs),
101+
)
57102

58103
@t_batch_mode_transform()
59104
def forward(self, X: Tensor) -> Tensor:
60105
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
61106
62107
Args:
63-
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
108+
X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
64109
65110
Returns:
66-
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
67-
evaluations on the posterior sample draws.
111+
A `batch_shape`-dim tensor of evaluations on the posterior sample draws,
112+
where the samples are summed over the q-batch dimension.
68113
"""
69-
batch_size = X.shape[-2]
70-
q_dim = -2
114+
objective_values = self._pathwise_forward(X) # batch_shape x q
115+
# NOTE: The current implementation sums over the q-batch dimension, which means
116+
# that we are optimizing the sum of independent Thompson samples. In the future,
117+
# we can leverage *batched* L-BFGS optimization, rather than summing over the q
118+
# dimension, which will guarantee descent steps for all members of the batch
119+
# through batch-member-specific learning rate selection.
120+
return objective_values.sum(-1) # batch_shape
71121

122+
def _pathwise_forward(self, X: Tensor) -> Tensor:
123+
"""Evaluate the pathwise posterior sample draws on the candidate set X.
124+
125+
Args:
126+
X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
127+
128+
Returns:
129+
A `batch_shape x q`-dim tensor of evaluations on the posterior sample draws.
130+
"""
131+
batch_size = X.shape[-2]
72132
# batch_shape x q x 1 x d
73133
X = X.unsqueeze(-2)
74-
if self.batch_size is None:
134+
if self.samples is None:
75135
self.batch_size = batch_size
76-
self.redraw()
77-
elif self.batch_size != batch_size:
136+
self.redraw(batch_size=batch_size)
137+
138+
if self.batch_size != batch_size:
78139
raise ValueError(
79140
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size)
80141
)
142+
# batch_shape x q [x num_ensembles] x 1 x m
143+
posterior_values = self.samples(X)
144+
# batch_shape x q [x num_ensembles] x m
145+
posterior_values = posterior_values.squeeze(-2)
81146

82-
# posterior_values.shape post-squeeze:
83147
# batch_shape x q x m
84-
posterior_values = self.samples(X).squeeze(-2)
85-
# sum over batch dim and squeeze num_objectives dim (-1)
86-
return posterior_values.sum(q_dim).squeeze(-1)
148+
posterior_values = self.select_from_ensemble_models(values=posterior_values)
149+
150+
if self.posterior_transform:
151+
posterior_values = self.posterior_transform.evaluate(posterior_values)
152+
# objective removes the `m` dimension
153+
objective_values = self.objective(posterior_values) # batch_shape x q
154+
return objective_values
155+
156+
def select_from_ensemble_models(self, values: Tensor):
157+
"""Subselecting a value associated with a single sample in the ensemble for each
158+
element of samples that is not associated with an ensemble dimension.
159+
160+
NOTE: 1) uses `self.model` and `is_ensemble` to determine whether or not an
161+
ensembling dimension is present. 2) uses `self.ensemble_indices` to select the
162+
value associated with a single sample in the ensemble. `ensemble_indices`
163+
contains uniformly randomly sample indices for each element of the ensemble, but
164+
is cached to make the evaluation of the acquisition function deterministic.
165+
166+
Args:
167+
values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
168+
169+
Returns:
170+
A`batch_shape x num_draws x q x m`-dim where each element is contains a
171+
single sample from the ensemble, selected with `self.ensemble_indices`.
172+
"""
173+
if not is_ensemble(self.model):
174+
return values
175+
176+
ensemble_dim = -2
177+
# `ensemble_indices` are fixed so that the acquisition function becomes
178+
# deterministic for the same input and can be optimized with LBFGS.
179+
# ensemble indices have shape num_paths x 1 x m
180+
self.ensemble_indices = self.ensemble_indices.to(device=values.device)
181+
index = self.ensemble_indices
182+
input_batch_shape = values.shape[:-3]
183+
index = index.expand(*input_batch_shape, *index.shape)
184+
# samples is batch_shape x q x num_ensemble x m
185+
values_wo_ensemble = torch.gather(values, dim=ensemble_dim, index=index)
186+
return values_wo_ensemble.squeeze(
187+
ensemble_dim
188+
) # removing the ensemble dimension

botorch/acquisition/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,11 @@ def get_optimal_samples(
575575
else:
576576
sample_transform = None
577577

578-
paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
578+
paths = get_matheron_path_model(
579+
model=model,
580+
sample_shape=torch.Size([num_optima]),
581+
ensemble_as_batch=True,
582+
)
579583
optimal_inputs, optimal_outputs = optimize_posterior_samples(
580584
paths=paths,
581585
bounds=bounds,

botorch/models/deterministic.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@ class GenericDeterministicModel(DeterministicModel):
6464
>>> model = GenericDeterministicModel(f)
6565
"""
6666

67-
def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
67+
def __init__(
68+
self,
69+
f: Callable[[Tensor], Tensor],
70+
num_outputs: int = 1,
71+
batch_shape: torch.Size | None = None,
72+
) -> None:
6873
r"""
6974
Args:
7075
f: A callable mapping a `batch_shape x n x d`-dim input tensor `X`
@@ -75,6 +80,12 @@ def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
7580
super().__init__()
7681
self._f = f
7782
self._num_outputs = num_outputs
83+
self._batch_shape = batch_shape
84+
85+
@property
86+
def batch_shape(self) -> torch.Size | None:
87+
r"""The batch shape of the model."""
88+
return self._batch_shape
7889

7990
def subset_output(self, idcs: list[int]) -> GenericDeterministicModel:
8091
r"""Subset the model along the output dimension.
@@ -100,7 +111,19 @@ def forward(self, X: Tensor) -> Tensor:
100111
Returns:
101112
A `batch_shape x n x m`-dimensional output tensor.
102113
"""
103-
return self._f(X)
114+
Y = self._f(X)
115+
batch_shape = Y.shape[:-2]
116+
# allowing for old behavior of not specifying the batch_shape
117+
if self.batch_shape is not None:
118+
try:
119+
torch.broadcast_shapes(self.batch_shape, batch_shape)
120+
except RuntimeError:
121+
raise ValueError(
122+
"GenericDeterministicModel was initialized with batch_shape="
123+
f"{self.batch_shape=} but the output of f has a batch_shape="
124+
f"{batch_shape=} that is not broadcastable with it."
125+
)
126+
return Y
104127

105128

106129
class AffineDeterministicModel(DeterministicModel):

botorch/sampling/pathwise/paths.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from __future__ import annotations
88

9-
from abc import ABC
9+
from abc import ABC, abstractmethod
1010
from collections.abc import Callable, Iterable, Iterator, Mapping
1111
from typing import Any
1212

@@ -24,6 +24,16 @@
2424
class SamplePath(ABC, TransformedModuleMixin, Module):
2525
r"""Abstract base class for Botorch sample paths."""
2626

27+
@abstractmethod
28+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
29+
"""Sets whether the ensemble dimension is considered as a batch dimension.
30+
31+
Args:
32+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
33+
dimension or not.
34+
"""
35+
pass # pragma: no cover
36+
2737

2838
class PathDict(SamplePath):
2939
r"""A dictionary of SamplePaths."""
@@ -84,6 +94,16 @@ def __getitem__(self, key: str) -> SamplePath:
8494
def __setitem__(self, key: str, val: SamplePath) -> None:
8595
self.paths[key] = val
8696

97+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
98+
"""Sets whether the ensemble dimension is considered as a batch dimension.
99+
100+
Args:
101+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
102+
dimension or not.
103+
"""
104+
for path in self.paths.values():
105+
path.set_ensemble_as_batch(ensemble_as_batch)
106+
87107

88108
class PathList(SamplePath):
89109
r"""A list of SamplePaths."""
@@ -136,6 +156,16 @@ def __getitem__(self, key: int) -> SamplePath:
136156
def __setitem__(self, key: int, val: SamplePath) -> None:
137157
self.paths[key] = val
138158

159+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
160+
"""Sets whether the ensemble dimension is considered as a batch dimension.
161+
162+
Args:
163+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
164+
dimension or not.
165+
"""
166+
for path in self.paths:
167+
path.set_ensemble_as_batch(ensemble_as_batch)
168+
139169

140170
class GeneralizedLinearPath(SamplePath):
141171
r"""A sample path in the form of a generalized linear model."""
@@ -147,6 +177,8 @@ def __init__(
147177
bias_module: Module | None = None,
148178
input_transform: TInputTransform | None = None,
149179
output_transform: TOutputTransform | None = None,
180+
is_ensemble: bool = False,
181+
ensemble_as_batch: bool = False,
150182
):
151183
r"""Initializes a GeneralizedLinearPath instance.
152184
@@ -157,10 +189,17 @@ def __init__(
157189
158190
Args:
159191
feature_map: A map used to featurize the module's inputs.
160-
weight: A tensor of weights used to combine input features.
192+
weight: A tensor of weights used to combine input features. When generated
193+
with `draw_kernel_feature_paths`, `weight` is a Tensor with the shape
194+
`sample_shape x batch_shape x num_outputs`.
161195
bias_module: An optional module used to define additive offsets.
162196
input_transform: An optional input transform for the module.
163197
output_transform: An optional output transform for the module.
198+
is_ensemble: Whether the associated model is an ensemble model or not.
199+
ensemble_as_batch: Whether the ensemble dimension is added as a batch
200+
dimension or not. If `True`, the ensemble dimension is treated as a
201+
batch dimension, which allows for the joint optimization of all members
202+
of the ensemble.
164203
"""
165204
super().__init__()
166205
self.feature_map = feature_map
@@ -170,8 +209,36 @@ def __init__(
170209
self.bias_module = bias_module
171210
self.input_transform = input_transform
172211
self.output_transform = output_transform
212+
self.is_ensemble = is_ensemble
213+
self.ensemble_as_batch = ensemble_as_batch
173214

174215
def forward(self, x: Tensor, **kwargs) -> Tensor:
216+
"""Evaluates the path.
217+
218+
Args:
219+
x: The input tensor of shape `batch_shape x [num_ensemble x] q x d`, where
220+
`num_ensemble` is the number of ensemble members and is required to
221+
*only* be included if `is_ensemble=True` and `ensemble_as_batch=True`.
222+
kwargs: Additional keyword arguments passed to the feature map.
223+
224+
Returns:
225+
A tensor of shape `batch_shape x [num_ensemble x] q x m`, where `m` is the
226+
number of outputs, where `num_ensemble` is only included if `is_ensemble`
227+
is `True`, and regardless of whether `ensemble_as_batch` is `True` or not.
228+
"""
229+
if self.is_ensemble and not self.ensemble_as_batch:
230+
# assuming that the ensembling dimension is added after (n, d), but
231+
# before the other batch dimensions, starting from the left.
232+
x = x.unsqueeze(-3)
175233
feat = self.feature_map(x, **kwargs)
176234
out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1)
177235
return out if self.bias_module is None else out + self.bias_module(x)
236+
237+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
238+
"""Sets whether the ensemble dimension is considered as a batch dimension.
239+
240+
Args:
241+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
242+
dimension or not.
243+
"""
244+
self.ensemble_as_batch = ensemble_as_batch

0 commit comments

Comments
 (0)