Skip to content

Commit ad5ff03

Browse files
michaeldeistlerjanfb
authored andcommitted
Match API between BNRE and NRE-C
1 parent 0ae6126 commit ad5ff03

File tree

6 files changed

+102
-38
lines changed

6 files changed

+102
-38
lines changed

sbi/inference/snre/bnre.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Callable, Dict, Optional, Union
22

33
import torch
44
from torch import Tensor, nn, ones
@@ -18,14 +18,14 @@ def __init__(
1818
logging_level: Union[int, str] = "warning",
1919
summary_writer: Optional[TensorboardSummaryWriter] = None,
2020
show_progress_bars: bool = True,
21-
regularization_strength: float = 100.0,
2221
):
2322

24-
r"""Balanced neural ratio estimation (BNRE)[1]. BNRE is a variation of NRE aiming to
25-
produce more conservative posterior approximations
23+
r"""Balanced neural ratio estimation (BNRE)[1]. BNRE is a variation of NRE
24+
aiming to produce more conservative posterior approximations
2625
2726
[1] Delaunoy, A., Hermans, J., Rozet, F., Wehenkel, A., & Louppe, G..
28-
Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation.
27+
Towards Reliable Simulation-Based Inference with Balanced Neural Ratio
28+
Estimation.
2929
NeurIPS 2022. https://arxiv.org/abs/2208.13624
3030
3131
Args:
@@ -36,27 +36,78 @@ def __init__(
3636
a string, use a pre-configured network of the provided type (one of
3737
linear, mlp, resnet). Alternatively, a function that builds a custom
3838
neural network can be provided. The function will be called with the
39-
first batch of simulations $(\theta, x)$, which can thus be used for shape
40-
inference and potentially for z-scoring. It needs to return a PyTorch
41-
`nn.Module` implementing the classifier.
39+
first batch of simulations $(\theta, x)$, which can thus be used for
40+
shape inference and potentially for z-scoring. It needs to return a
41+
PyTorch `nn.Module` implementing the classifier.
4242
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
4343
logging_level: Minimum severity of messages to log. One of the strings
4444
INFO, WARNING, DEBUG, ERROR and CRITICAL.
4545
summary_writer: A tensorboard `SummaryWriter` to control, among others, log
4646
file location (default is `<current working directory>/logs`.)
4747
show_progress_bars: Whether to show a progressbar during simulation and
4848
sampling.
49-
regularization_strength: The multiplicative coefficient applied to the
50-
balancing regularizer ($\lambda$)
5149
"""
5250

53-
self.regularization_strength = regularization_strength
54-
kwargs = del_entries(
55-
locals(), entries=("self", "__class__", "regularization_strength")
56-
)
51+
kwargs = del_entries(locals(), entries=("self", "__class__"))
5752
super().__init__(**kwargs)
5853

59-
def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
54+
def train(
55+
self,
56+
regularization_strength: float = 100.0,
57+
training_batch_size: int = 50,
58+
learning_rate: float = 5e-4,
59+
validation_fraction: float = 0.1,
60+
stop_after_epochs: int = 20,
61+
max_num_epochs: int = 2**31 - 1,
62+
clip_max_norm: Optional[float] = 5.0,
63+
resume_training: bool = False,
64+
discard_prior_samples: bool = False,
65+
retrain_from_scratch: bool = False,
66+
show_train_summary: bool = False,
67+
dataloader_kwargs: Optional[Dict] = None,
68+
) -> nn.Module:
69+
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
70+
Args:
71+
72+
regularization_strength: The multiplicative coefficient applied to the
73+
balancing regularizer ($\lambda$).
74+
training_batch_size: Training batch size.
75+
learning_rate: Learning rate for Adam optimizer.
76+
validation_fraction: The fraction of data to use for validation.
77+
stop_after_epochs: The number of epochs to wait for improvement on the
78+
validation set before terminating training.
79+
max_num_epochs: Maximum number of epochs to run. If reached, we stop
80+
training even when the validation loss is still decreasing. Otherwise,
81+
we train until validation loss increases (see also `stop_after_epochs`).
82+
clip_max_norm: Value at which to clip the total gradient norm in order to
83+
prevent exploding gradients. Use None for no clipping.
84+
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
85+
during training. Expect errors, silent or explicit, when `False`.
86+
resume_training: Can be used in case training time is limited, e.g. on a
87+
cluster. If `True`, the split between train and validation set, the
88+
optimizer, the number of epochs, and the best validation log-prob will
89+
be restored from the last time `.train()` was called.
90+
discard_prior_samples: Whether to discard samples simulated in round 1, i.e.
91+
from the prior. Training may be sped up by ignoring such less targeted
92+
samples.
93+
retrain_from_scratch: Whether to retrain the conditional density
94+
estimator for the posterior from scratch each round.
95+
show_train_summary: Whether to print the number of epochs and validation
96+
loss and leakage after the training.
97+
dataloader_kwargs: Additional or updated kwargs to be passed to the training
98+
and validation dataloaders (like, e.g., a collate_fn)
99+
Returns:
100+
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
101+
"""
102+
kwargs = del_entries(locals(), entries=("self", "__class__"))
103+
kwargs["loss_kwargs"] = {
104+
"regularization_strength": kwargs.pop("regularization_strength")
105+
}
106+
return super().train(**kwargs)
107+
108+
def _loss(
109+
self, theta: Tensor, x: Tensor, num_atoms: int, regularization_strength: float
110+
) -> Tensor:
60111
"""Returns the binary cross-entropy loss for the trained classifier.
61112
62113
The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
@@ -87,4 +138,4 @@ def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
87138
.square()
88139
)
89140

90-
return bce + self.regularization_strength * regularizer
141+
return bce + regularization_strength * regularizer

sbi/inference/snre/snre_a.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Dict, Optional, Union
1+
from typing import Any, Callable, Dict, Optional, Union
22

33
import torch
44
from torch import Tensor, nn, ones
@@ -60,6 +60,7 @@ def train(
6060
retrain_from_scratch: bool = False,
6161
show_train_summary: bool = False,
6262
dataloader_kwargs: Optional[Dict] = None,
63+
loss_kwargs: Dict[str, Any] = {},
6364
) -> nn.Module:
6465
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
6566
@@ -87,6 +88,7 @@ def train(
8788
loss and leakage after the training.
8889
dataloader_kwargs: Additional or updated kwargs to be passed to the training
8990
and validation dataloaders (like, e.g., a collate_fn)
91+
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.
9092
9193
Returns:
9294
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.

sbi/inference/snre/snre_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch import Tensor, eye, nn, ones, optim
77
from torch.distributions import Distribution
88
from torch.nn.utils.clip_grad import clip_grad_norm_
9-
from torch.utils import data
109
from torch.utils.tensorboard.writer import SummaryWriter
1110

1211
from sbi import utils as utils

sbi/inference/snre/snre_c.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from sbi.inference.snre.snre_base import RatioEstimator
88
from sbi.types import TensorboardSummaryWriter
9-
from sbi.utils import del_entries, repeat_rows
9+
from sbi.utils import del_entries
1010

1111

1212
class SNRE_C(RatioEstimator):
@@ -85,9 +85,10 @@ def train(
8585
`num_atoms` for SNRE_B except SNRE_C has an additional independently
8686
drawn sample. The total number of alternative parameters `NRE-C` "sees"
8787
is $2K-1$ or `2 * num_classes - 1` divided between two loss terms.
88-
gamma: Determines the relative weight of the sum of all $K$ dependently drawn
89-
classes against the marginally drawn one. Specifically, $p(y=k) := p_K$,
90-
$p(y=0) := p_0$, $p_0 = 1 - K p_K$, and finally $\gamma := K p_K / p_0$.
88+
gamma: Determines the relative weight of the sum of all $K$ dependently
89+
drawn classes against the marginally drawn one. Specifically,
90+
$p(y=k) :=p_K$, $p(y=0) := p_0$, $p_0 = 1 - K p_K$, and finally
91+
$\gamma := K p_K / p_0$.
9192
training_batch_size: Training batch size.
9293
learning_rate: Learning rate for Adam optimizer.
9394
validation_fraction: The fraction of data to use for validation.
@@ -125,26 +126,31 @@ def train(
125126
def _loss(
126127
self, theta: Tensor, x: Tensor, num_atoms: int, gamma: float
127128
) -> torch.Tensor:
128-
r"""Return cross-entropy loss (via ''multi-class sigmoid'' activation) for 1-out-of-`K + 1` classification.
129+
r"""Return cross-entropy loss (via ''multi-class sigmoid'' activation) for
130+
1-out-of-`K + 1` classification.
129131
130-
At optimum, this loss function returns the exact likelihood-to-evidence ratio in the first round.
131-
Details of loss computation are described in Contrastive Neural Ratio Estimation[1]. The paper
132-
does not discuss the sequential case.
132+
At optimum, this loss function returns the exact likelihood-to-evidence ratio
133+
in the first round.
134+
Details of loss computation are described in Contrastive Neural Ratio
135+
Estimation[1]. The paper does not discuss the sequential case.
133136
134137
[1] _Contrastive Neural Ratio Estimation_, Benajmin Kurt Miller, et. al.,
135138
NeurIPS 2022, https://arxiv.org/abs/2210.06170
136139
"""
137140

138141
# Reminder: K = num_classes
139-
# The algorithm is written with K, so we convert back to K format rather than reasoning in num_atoms.
142+
# The algorithm is written with K, so we convert back to K format rather than
143+
# reasoning in num_atoms.
140144
num_classes = num_atoms - 1
141145
assert num_classes >= 1, f"num_classes = {num_classes} must be greater than 1."
142146

143147
assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
144148
batch_size = theta.shape[0]
145149

146-
# We append an contrastive theta to the marginal case because we will remove the jointly drawn
147-
# sample in the logits_marginal[:, 0] position. That makes the remaining sample marginally drawn.
150+
# We append an contrastive theta to the marginal case because we will remove
151+
# the jointly drawn
152+
# sample in the logits_marginal[:, 0] position. That makes the remaining sample
153+
# marginally drawn.
148154
# We have a batch of `batch_size` datapoints.
149155
logits_marginal = self._classifier_logits(theta, x, num_classes + 1).reshape(
150156
batch_size, num_classes + 1
@@ -191,10 +197,12 @@ def _loss(
191197
def _get_prior_probs_marginal_and_joint(
192198
num_classes: int, gamma: float
193199
) -> Tuple[float, float]:
194-
"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$, `p_joint := `$p_K$.
200+
"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$,
201+
`p_joint := `$p_K$.
195202
196-
We let the joint (dependently drawn) class to be equally likely across K options.
197-
The marginal class is therefore restricted to get the remaining probability.
203+
We let the joint (dependently drawn) class to be equally likely across K
204+
options. The marginal class is therefore restricted to get the remaining
205+
probability.
198206
"""
199207
assert num_classes >= 1
200208
p_joint = gamma / (1 + gamma * num_classes)

tests/linearGaussian_snre_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,24 @@ def simulator(theta):
195195

196196
if method_str == "sre":
197197
inference = SNRE_B(**kwargs)
198+
train_kwargs = {}
198199
elif method_str == "aalr":
199200
inference = AALR(**kwargs)
201+
train_kwargs = {}
200202
elif method_str == "bnre":
201-
inference = BNRE(regularization_strength=20, **kwargs)
203+
inference = BNRE(**kwargs)
204+
train_kwargs = {"regularization_strength": 20}
202205
elif method_str == "nrec":
203206
inference = SNRE_C(**kwargs)
207+
train_kwargs = {}
204208
else:
205209
raise ValueError(f"{method_str} is not an allowed option")
206210

207211
# Should use default `num_atoms=10` for SRE; `num_atoms=2` for AALR
208212
theta, x = simulate_for_sbi(
209213
simulator, prior, num_simulations, simulation_batch_size=50
210214
)
211-
ratio_estimator = inference.append_simulations(theta, x).train()
215+
ratio_estimator = inference.append_simulations(theta, x).train(**train_kwargs)
212216
potential_fn, theta_transform = ratio_estimator_based_potential(
213217
ratio_estimator=ratio_estimator, prior=prior, x_o=x_o
214218
)

tutorials/16_implemented_methods.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,10 @@
272272
"source": [
273273
"from sbi.inference import BNRE\n",
274274
"\n",
275-
"inference = BNRE(prior, regularization_strength=100.)\n",
275+
"inference = BNRE(prior)\n",
276276
"theta = prior.sample((num_sims,))\n",
277277
"x = simulator(theta)\n",
278-
"_ = inference.append_simulations(theta, x).train()\n",
278+
"_ = inference.append_simulations(theta, x).train(regularization_strength=100.)\n",
279279
"posterior = inference.build_posterior().set_default_x(x_o)"
280280
]
281281
},
@@ -420,7 +420,7 @@
420420
],
421421
"metadata": {
422422
"kernelspec": {
423-
"display_name": "Python 3.9.7 (conda)",
423+
"display_name": "Python 3 (ipykernel)",
424424
"language": "python",
425425
"name": "python3"
426426
},
@@ -434,7 +434,7 @@
434434
"name": "python",
435435
"nbconvert_exporter": "python",
436436
"pygments_lexer": "ipython3",
437-
"version": "3.9.7"
437+
"version": "3.8.12"
438438
},
439439
"vscode": {
440440
"interpreter": {

0 commit comments

Comments
 (0)