1- from typing import Callable , Optional , Union
1+ from typing import Callable , Dict , Optional , Union
22
33import torch
44from torch import Tensor , nn , ones
@@ -18,14 +18,14 @@ def __init__(
1818 logging_level : Union [int , str ] = "warning" ,
1919 summary_writer : Optional [TensorboardSummaryWriter ] = None ,
2020 show_progress_bars : bool = True ,
21- regularization_strength : float = 100.0 ,
2221 ):
2322
24- r"""Balanced neural ratio estimation (BNRE)[1]. BNRE is a variation of NRE aiming to
25- produce more conservative posterior approximations
23+ r"""Balanced neural ratio estimation (BNRE)[1]. BNRE is a variation of NRE
24+ aiming to produce more conservative posterior approximations
2625
2726 [1] Delaunoy, A., Hermans, J., Rozet, F., Wehenkel, A., & Louppe, G..
28- Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation.
27+ Towards Reliable Simulation-Based Inference with Balanced Neural Ratio
28+ Estimation.
2929 NeurIPS 2022. https://arxiv.org/abs/2208.13624
3030
3131 Args:
@@ -36,27 +36,78 @@ def __init__(
3636 a string, use a pre-configured network of the provided type (one of
3737 linear, mlp, resnet). Alternatively, a function that builds a custom
3838 neural network can be provided. The function will be called with the
39- first batch of simulations $(\theta, x)$, which can thus be used for shape
40- inference and potentially for z-scoring. It needs to return a PyTorch
41- `nn.Module` implementing the classifier.
39+ first batch of simulations $(\theta, x)$, which can thus be used for
40+ shape inference and potentially for z-scoring. It needs to return a
41+ PyTorch `nn.Module` implementing the classifier.
4242 device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
4343 logging_level: Minimum severity of messages to log. One of the strings
4444 INFO, WARNING, DEBUG, ERROR and CRITICAL.
4545 summary_writer: A tensorboard `SummaryWriter` to control, among others, log
4646 file location (default is `<current working directory>/logs`.)
4747 show_progress_bars: Whether to show a progressbar during simulation and
4848 sampling.
49- regularization_strength: The multiplicative coefficient applied to the
50- balancing regularizer ($\lambda$)
5149 """
5250
53- self .regularization_strength = regularization_strength
54- kwargs = del_entries (
55- locals (), entries = ("self" , "__class__" , "regularization_strength" )
56- )
51+ kwargs = del_entries (locals (), entries = ("self" , "__class__" ))
5752 super ().__init__ (** kwargs )
5853
59- def _loss (self , theta : Tensor , x : Tensor , num_atoms : int ) -> Tensor :
54+ def train (
55+ self ,
56+ regularization_strength : float = 100.0 ,
57+ training_batch_size : int = 50 ,
58+ learning_rate : float = 5e-4 ,
59+ validation_fraction : float = 0.1 ,
60+ stop_after_epochs : int = 20 ,
61+ max_num_epochs : int = 2 ** 31 - 1 ,
62+ clip_max_norm : Optional [float ] = 5.0 ,
63+ resume_training : bool = False ,
64+ discard_prior_samples : bool = False ,
65+ retrain_from_scratch : bool = False ,
66+ show_train_summary : bool = False ,
67+ dataloader_kwargs : Optional [Dict ] = None ,
68+ ) -> nn .Module :
69+ r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
70+ Args:
71+
72+ regularization_strength: The multiplicative coefficient applied to the
73+ balancing regularizer ($\lambda$).
74+ training_batch_size: Training batch size.
75+ learning_rate: Learning rate for Adam optimizer.
76+ validation_fraction: The fraction of data to use for validation.
77+ stop_after_epochs: The number of epochs to wait for improvement on the
78+ validation set before terminating training.
79+ max_num_epochs: Maximum number of epochs to run. If reached, we stop
80+ training even when the validation loss is still decreasing. Otherwise,
81+ we train until validation loss increases (see also `stop_after_epochs`).
82+ clip_max_norm: Value at which to clip the total gradient norm in order to
83+ prevent exploding gradients. Use None for no clipping.
84+ exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
85+ during training. Expect errors, silent or explicit, when `False`.
86+ resume_training: Can be used in case training time is limited, e.g. on a
87+ cluster. If `True`, the split between train and validation set, the
88+ optimizer, the number of epochs, and the best validation log-prob will
89+ be restored from the last time `.train()` was called.
90+ discard_prior_samples: Whether to discard samples simulated in round 1, i.e.
91+ from the prior. Training may be sped up by ignoring such less targeted
92+ samples.
93+ retrain_from_scratch: Whether to retrain the conditional density
94+ estimator for the posterior from scratch each round.
95+ show_train_summary: Whether to print the number of epochs and validation
96+ loss and leakage after the training.
97+ dataloader_kwargs: Additional or updated kwargs to be passed to the training
98+ and validation dataloaders (like, e.g., a collate_fn)
99+ Returns:
100+ Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
101+ """
102+ kwargs = del_entries (locals (), entries = ("self" , "__class__" ))
103+ kwargs ["loss_kwargs" ] = {
104+ "regularization_strength" : kwargs .pop ("regularization_strength" )
105+ }
106+ return super ().train (** kwargs )
107+
108+ def _loss (
109+ self , theta : Tensor , x : Tensor , num_atoms : int , regularization_strength : float
110+ ) -> Tensor :
60111 """Returns the binary cross-entropy loss for the trained classifier.
61112
62113 The classifier takes as input a $(\t heta,x)$ pair. It is trained to predict 1
@@ -87,4 +138,4 @@ def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
87138 .square ()
88139 )
89140
90- return bce + self . regularization_strength * regularizer
141+ return bce + regularization_strength * regularizer
0 commit comments