Skip to content

Commit

Permalink
Merge: Improve GP Fit (#472)
Browse files Browse the repository at this point in the history
- scipy fit is used for all cases now
- instead, the MLL type is switched based on whether TL is active or not
  • Loading branch information
Scienfitz authored Jan 31, 2025
2 parents fb6e7d8 + 190bbe9 commit daaf4d1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.12.2] - 2025-01-31
### Changed
- More robust settings for the GP fitting

## [0.12.1] - 2025-01-29
### Changed
- Default of `allow_recommending_already_recommended` is changed back to `False`
Expand Down
16 changes: 10 additions & 6 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,19 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
covar_module=covar_module,
likelihood=likelihood,
)
mll = gpytorch.ExactMarginalLogLikelihood(self._model.likelihood, self._model)

# TODO: This is a simple temporary workaround to avoid model overfitting
# via early stopping in the presence of task parameters, which currently
# have no prior configured.
# TODO: This is still a temporary workaround to avoid overfitting seen in
# low-dimensional TL cases. More robust settings are being researched.
if context.n_task_dimensions > 0:
botorch.optim.fit.fit_gpytorch_mll_torch(mll, step_limit=200)
mll = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
self._model.likelihood, self._model
)
else:
botorch.fit.fit_gpytorch_mll(mll)
mll = gpytorch.ExactMarginalLogLikelihood(
self._model.likelihood, self._model
)

botorch.fit.fit_gpytorch_mll(mll)

@override
def __str__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion baybe/surrogates/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _estimate_moments(
import torch

# TODO: use target value bounds for covariance scaling when explicitly provided
mean = self._model * torch.ones([len(candidates_comp_scaled)])
mean = self._model * torch.ones([len(candidates_comp_scaled)]) # type: ignore[operator]
var = torch.ones(len(candidates_comp_scaled))
return mean, var

Expand Down

0 comments on commit daaf4d1

Please sign in to comment.