diff --git a/graphium/nn/architectures/__init__.py b/graphium/nn/architectures/__init__.py index 1025d900a..a02a5cc8c 100644 --- a/graphium/nn/architectures/__init__.py +++ b/graphium/nn/architectures/__init__.py @@ -3,3 +3,4 @@ from .global_architectures import TaskHeads from .global_architectures import GraphOutputNN from .pyg_architectures import FeedForwardPyg +from .global_architectures import EnsembleFeedForwardNN diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 6570ca492..1b6f44dd9 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -166,16 +166,18 @@ def __init__( self.last_layer_is_readout = last_layer_is_readout self._readout_cache = None + self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim] + self._parse_layers(layer_type=layer_type, residual_type=residual_type) + self._create_layers() + self._check_bad_arguments() + + def _parse_layers(self, layer_type, residual_type): # Parse the layer and residuals from graphium.utils.spaces import LAYERS_DICT, RESIDUALS_DICT self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT) self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) - self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim] - self._create_layers() - self._check_bad_arguments() - def _check_bad_arguments(self): r""" Raise comprehensive errors if the arguments seem wrong @@ -403,6 +405,352 @@ def __repr__(self): return class_str + layer_str +class EnsembleFeedForwardNN(FeedForwardNN): + def __init__( + self, + in_dim: int, + out_dim: int, + hidden_dims: Union[List[int], int], + num_ensemble: int, + reduction: Union[str, Callable], + subset_in_dim: Union[float, int] = 1.0, + depth: Optional[int] = None, + activation: Union[str, Callable] = "relu", + last_activation: Union[str, Callable] = "none", + dropout: float = 0.0, + last_dropout: float = 0.0, + normalization: Union[str, Callable] = "none", + first_normalization: Union[str, Callable] = "none", + last_normalization: Union[str, Callable] = "none", + residual_type: str = "none", + residual_skip_steps: int = 1, + name: str = "LNN", + layer_type: Union[str, nn.Module] = "ens-fc", + layer_kwargs: Optional[Dict] = None, + last_layer_is_readout: bool = False, + ): + r""" + An ensemble of flexible neural network architecture, with variable hidden dimensions, + support for multiple layer types, and support for different residual + connections. + + Parameters: + + in_dim: + Input feature dimensions of the layer + + out_dim: + Output feature dimensions of the layer + + hidden_dims: + Either an integer specifying all the hidden dimensions, + or a list of dimensions in the hidden layers. + Be careful, the "simple" residual type only supports + hidden dimensions of the same value. + + num_ensemble: + Number of MLPs that run in parallel. + + reduction: + Reduction to use at the end of the MLP. Choices: + + - "none" or `None`: No reduction + - "mean": Mean reduction + - "sum": Sum reduction + - "max": Max reduction + - "min": Min reduction + - "median": Median reduction + - `Callable`: Any callable function. Must take `dim` as a keyword argument. + + subset_in_dim: + If float, ratio of the subset of the ensemble to use. Must be between 0 and 1. + If int, number of elements to subset from in_dim. + If `None`, the subset_in_dim is set to `1.0`. + A different subset is used for each ensemble. + Only valid if the input shape is `[B, Din]`. + + depth: + If `hidden_dims` is an integer, `depth` is 1 + the number of + hidden layers to use. + If `hidden_dims` is a list, then + `depth` must be `None` or equal to `len(hidden_dims) + 1` + + activation: + activation function to use in the hidden layers. + + last_activation: + activation function to use in the last layer. + + dropout: + The ratio of units to dropout. Must be between 0 and 1 + + last_dropout: + The ratio of units to dropout for the last_layer. Must be between 0 and 1 + + normalization: + Normalization to use. Choices: + + - "none" or `None`: No normalization + - "batch_norm": Batch normalization + - "layer_norm": Layer normalization + - `Callable`: Any callable function + + first_normalization: + Whether to use batch normalization **before** the first layer + + last_normalization: + Whether to use batch normalization in the last layer + + residual_type: + - "none": No residual connection + - "simple": Residual connection similar to the ResNet architecture. + See class `ResidualConnectionSimple` + - "weighted": Residual connection similar to the Resnet architecture, + but with weights applied before the summation. See class `ResidualConnectionWeighted` + - "concat": Residual connection where the residual is concatenated instead + of being added. + - "densenet": Residual connection where the residual of all previous layers + are concatenated. This leads to a strong increase in the number of parameters + if there are multiple hidden layers. + + residual_skip_steps: + The number of steps to skip between each residual connection. + If `1`, all the layers are connected. If `2`, half of the + layers are connected. + + name: + Name attributed to the current network, for display and printing + purposes. + + layer_type: + The type of layers to use in the network. + Either "ens-fc" as the `EnsembleFCLayer`, or a class representing the `nn.Module` + to use. + + layer_kwargs: + The arguments to be used in the initialization of the layer provided by `layer_type` + + last_layer_is_readout: Whether the last layer should be treated as a readout layer. + Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup + + """ + + # Parse the ensemble arguments + if layer_kwargs is None: + layer_kwargs = {} + layer_kwargs["num_ensemble"] = self._parse_num_ensemble(num_ensemble, layer_kwargs) + + # Parse the sample input dimension + self.subset_in_dim, self.subset_idx = self._parse_subset_in_dim(in_dim, subset_in_dim, num_ensemble) + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + hidden_dims=hidden_dims, + depth=depth, + activation=activation, + last_activation=last_activation, + dropout=dropout, + last_dropout=last_dropout, + normalization=normalization, + first_normalization=first_normalization, + last_normalization=last_normalization, + residual_type=residual_type, + residual_skip_steps=residual_skip_steps, + name=name, + layer_type=layer_type, + layer_kwargs=layer_kwargs, + last_layer_is_readout=last_layer_is_readout, + ) + + # Parse the reduction + self.reduction = reduction + self.reduction_fn = self._parse_reduction(reduction) + + def _create_layers(self): + self.full_dims[0] = self.subset_in_dim + super()._create_layers() + + def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int: + r""" + Parse the num_ensemble argument. + """ + num_ensemble_out = num_ensemble + + # Get the num_ensemble from the layer_kwargs if it exists + num_ensemble_2 = None + if layer_kwargs is None: + layer_kwargs = {} + else: + num_ensemble_2 = layer_kwargs.get("num_ensemble", None) + + if num_ensemble is None: + num_ensemble_out = num_ensemble_2 + + # Check that the num_ensemble is consistent + if num_ensemble_2 is not None: + assert ( + num_ensemble_2 == num_ensemble + ), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}" + + # Check that `num_ensemble_out` is not None + assert ( + num_ensemble_out is not None + ), f"num_ensemble={num_ensemble} and num_ensemble_2={num_ensemble_2}" + + return num_ensemble_out + + def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]: + r""" + Parse the reduction argument. + """ + + if isinstance(reduction, str): + reduction = reduction.lower() + if reduction is None or reduction == "none": + return None + elif reduction == "mean": + return torch.mean + elif reduction == "sum": + return torch.sum + elif reduction == "max": + + def max_vals(x, dim): + return torch.max(x, dim=dim).values + + return max_vals + elif reduction == "min": + + def min_vals(x, dim): + return torch.min(x, dim=dim).values + + return min_vals + elif reduction == "median": + + def median_vals(x, dim): + return torch.median(x, dim=dim).values + + return median_vals + elif callable(reduction): + return reduction + else: + raise ValueError(f"Unknown reduction {reduction}") + + def _parse_subset_in_dim( + self, in_dim: int, subset_in_dim: Union[float, int], num_ensemble: int + ) -> Tuple[float, int]: + r""" + Parse the subset_in_dim argument and the subset_in_dim. + + The subset_in_dim is the ratio of the hidden features to use by each MLP of the ensemble. + The subset_in_dim is the number of input features to use by each MLP of the ensemble. + + Parameters: + + in_dim: The number of input features, before subsampling + + subset_in_dim: + Ratio of the subset of features to use by each MLP of the ensemble. + Must be between 0 and 1. A different subset is used for each ensemble. + Only valid if the input shape is `[B, Din]`. + + If None, the subset_in_dim is set to 1.0. + + num_ensemble: + Number of MLPs that run in parallel. + + Returns: + + subset_in_dim: The ratio of the subset of features to use by each MLP of the ensemble. + subset_idx: The indices of the features to use by each MLP of the ensemble. + """ + + # Parse the subset_in_dim, make sure value is between 0 and 1 + subset_idx = None + if subset_in_dim is None: + return 1.0, None + if isinstance(subset_in_dim, int): + assert ( + subset_in_dim > 0 and subset_in_dim <= in_dim + ), f"subset_in_dim={subset_in_dim}, in_dim={in_dim}" + elif isinstance(subset_in_dim, float): + assert subset_in_dim > 0.0 and subset_in_dim <= 1.0, f"subset_in_dim={subset_in_dim}" + + # Convert to integer value + subset_in_dim = int(in_dim * subset_in_dim) + if subset_in_dim == 0: + subset_in_dim = 1 + + # Create the subset_idx, which is a list of indices to use for each ensemble + if subset_in_dim != in_dim: + subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]) + + return subset_in_dim, subset_idx + + def _parse_layers(self, layer_type, residual_type): + # Parse the layer and residuals + from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT + + self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, ENSEMBLE_LAYERS_DICT) + self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Subset the hidden dimension for each MLP, + forward the ensemble MLP on the input features, + then reduce the output if specified. + + Parameters: + + h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`: + + Input feature tensor, before the MLP. + `Din` is the number of input features, `B` is the batch size, and `L` is the number of ensembles. + + Returns: + + `torch.Tensor[..., L, B, Dout]` or `torch.Tensor[..., B, Dout]`: + + Output feature tensor, after the MLP. + `Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles. + `L` is removed if a reduction is specified. + """ + # Subset the input features for each MLP in the ensemble + if self.subset_idx is not None: + if len(h.shape) != 2: + assert ( + h.shape[-3] == 1 + ), f"Expected shape to be [B, Din] or [..., 1, B, Din] when using `subset_in_dim`, got {h.shape}." + h = h[..., self.subset_idx].transpose(-2, -3) + + # Run the standard forward pass + h = super().forward(h) + + # Reduce the output if specified + if self.reduction_fn is not None: + h = self.reduction_fn(h, dim=-3) + + return h + + def get_init_kwargs(self) -> Dict[str, Any]: + """ + Get a dictionary that can be used to instanciate a new object with identical parameters. + """ + kw = super().get_init_kwargs() + kw["num_ensemble"] = self.num_ensemble + kw["reduction"] = self.reduction + return kw + + def __repr__(self): + r""" + Controls how the class is printed + """ + class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n , num_ensemble={self.num_ensemble}, reduction={self.reduction}\n " + layer_str = f"[{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]" + + return class_str + layer_str + + class FeedForwardGraph(FeedForwardNN): def __init__( self, diff --git a/graphium/nn/base_layers.py b/graphium/nn/base_layers.py index 8a8b29f1a..1d402b760 100644 --- a/graphium/nn/base_layers.py +++ b/graphium/nn/base_layers.py @@ -236,44 +236,7 @@ class MuReadoutGraphium(MuReadout): def __init__(self, in_features, *args, **kwargs): super().__init__(in_features, *args, **kwargs) - self.base_width = in_features - - @property - def absolute_width(self): - return float(self.in_features) - - @property - def base_width(self): - return self._base_width - - @base_width.setter - def base_width(self, val): - if val is None: - return - assert isinstance( - val, (int, torch.int, torch.long) - ), f"`base_width` must be None, int or long, provided {val} of type {type(val)}" - self._base_width = val - - def width_mult(self): - return self.absolute_width / self.base_width - - -class MuReadoutGraphium(MuReadout): - """ - PopTorch-compatible replacement for `mup.MuReadout` - - Not quite a drop-in replacement for `mup.MuReadout` - you need to specify - `base_width`. - - Set `base_width` to width of base model passed to `mup.set_base_shapes` - to get same results on IPU and CPU. Should still "work" with any other - value, but won't give the same results as CPU - """ - - def __init__(self, in_features, *args, **kwargs): - super().__init__(in_features, *args, **kwargs) - self.base_width = in_features + self._base_width = in_features @property def absolute_width(self): @@ -479,7 +442,7 @@ def __init__( in_dim: int, hidden_dims: Union[Iterable[int], int], out_dim: int, - depth: int, + depth: Optional[int] = None, activation: Union[str, Callable] = "relu", last_activation: Union[str, Callable] = "none", dropout: float = 0.0, @@ -490,6 +453,8 @@ def __init__( last_layer_is_readout: bool = False, droppath_rate: float = 0.0, constant_droppath_rate: bool = True, + fc_layer: FCLayer = FCLayer, + fc_layer_kwargs: Optional[dict] = None, ): r""" Simple multi-layer perceptron, built of a series of FCLayers @@ -538,12 +503,17 @@ def __init__( If `True`, drop rates will remain constant accross layers. Otherwise, drop rates will vary stochastically. See `DropPath.get_stochastic_drop_rate` + fc_layer: + The fully connected layer to use. Must inherit from `FCLayer`. + fc_layer_kwargs: + Keyword arguments to pass to the fully connected layer. """ super().__init__() self.in_dim = in_dim self.out_dim = out_dim + self.fc_layer_kwargs = deepcopy(fc_layer_kwargs) or {} # Parse the hidden dimensions and depth if isinstance(hidden_dims, int): @@ -560,12 +530,12 @@ def __init__( all_dims = [in_dim] + self.hidden_dims + [out_dim] fully_connected = [] - if depth == 0: + if self.depth == 0: self.fully_connected = None return else: - for ii in range(depth): - if ii < (depth - 1): + for ii in range(self.depth): + if ii < (self.depth - 1): # Define the parameters for all intermediate layers this_activation = activation this_normalization = normalization @@ -581,11 +551,11 @@ def __init__( if constant_droppath_rate: this_drop_rate = droppath_rate else: - this_drop_rate = DropPath.get_stochastic_drop_rate(droppath_rate, ii, depth) + this_drop_rate = DropPath.get_stochastic_drop_rate(droppath_rate, ii, self.depth) # Add a fully-connected layer fully_connected.append( - FCLayer( + fc_layer( all_dims[ii], all_dims[ii + 1], activation=this_activation, @@ -593,6 +563,7 @@ def __init__( dropout=this_dropout, is_readout_layer=is_readout_layer, droppath_rate=this_drop_rate, + **self.fc_layer_kwargs, ) ) diff --git a/graphium/nn/ensemble_layers.py b/graphium/nn/ensemble_layers.py new file mode 100644 index 000000000..41c9a4018 --- /dev/null +++ b/graphium/nn/ensemble_layers.py @@ -0,0 +1,445 @@ +from typing import Union, Callable, Optional, Type, Tuple, Iterable +from copy import deepcopy +from loguru import logger + + +import torch +import torch.nn as nn +import mup.init as mupi +from mup import set_base_shapes + +from graphium.nn.base_layers import FCLayer, MLP + + +class EnsembleLinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + bias: bool = True, + init_fn: Optional[Callable] = None, + ): + r""" + Multiple linear layers that are applied in parallel with batched matrix multiplication with `torch.matmul`. + + Parameters: + in_dim: + Input dimension of the linear layers + out_dim: + Output dimension of the linear layers. + num_ensemble: + Number of linear layers in the ensemble. + + + """ + super(EnsembleLinear, self).__init__() + + # Initialize weight and bias as learnable parameters + self.weight = nn.Parameter(torch.Tensor(num_ensemble, out_dim, in_dim)) + if bias: + self.bias = nn.Parameter(torch.Tensor(num_ensemble, 1, out_dim)) + else: + self.register_parameter("bias", None) + + # Initialize parameters + self.init_fn = init_fn if init_fn is not None else mupi.xavier_uniform_ + self.reset_parameters() + + def reset_parameters(self): + """ + Reset the parameters of the linear layer using the `init_fn`. + """ + set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup + # Initialize weight using the provided initialization function + self.init_fn(self.weight) + + # Initialize bias if present + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the batched linear transformation on the input features. + + Parameters: + h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`: + Input feature tensor, before the batched linear transformation. + `Din` is the number of input features, `B` is the batch size, and `L` is the number of linear layers. + + Returns: + `torch.Tensor[..., L, B, Dout]`: + Output feature tensor, after the batched linear transformation. + `Dout` is the number of output features, , `B` is the batch size, and `L` is the number of linear layers. + """ + + # Perform the linear transformation using torch.matmul + h = torch.matmul(self.weight, h.transpose(-1, -2)).transpose(-1, -2) + + # Add bias if present + if self.bias is not None: + h += self.bias + + return h + + +class EnsembleFCLayer(FCLayer): + def __init__( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + activation: Union[str, Callable] = "relu", + dropout: float = 0.0, + normalization: Union[str, Callable] = "none", + bias: bool = True, + init_fn: Optional[Callable] = None, + is_readout_layer: bool = False, + droppath_rate: float = 0.0, + ): + r""" + Multiple fully connected layers running in parallel. + This layer is centered around a `torch.nn.Linear` module. + The order in which transformations are applied is: + + - Dense Layer + - Activation + - Dropout (if applicable) + - Batch Normalization (if applicable) + + Parameters: + in_dim: + Input dimension of the layer (the `torch.nn.Linear`) + out_dim: + Output dimension of the layer. + num_ensemble: + Number of linear layers in the ensemble. + dropout: + The ratio of units to dropout. No dropout by default. + activation: + Activation function to use. + normalization: + Normalization to use. Choices: + + - "none" or `None`: No normalization + - "batch_norm": Batch normalization + - "layer_norm": Layer normalization + - `Callable`: Any callable function + bias: + Whether to enable bias in for the linear layer. + init_fn: + Initialization function to use for the weight of the layer. Default is + $$\mathcal{U}(-\sqrt{k}, \sqrt{k})$$ with $$k=\frac{1}{ \text{in_dim}}$$ + is_readout_layer: Whether the layer should be treated as a readout layer by replacing of `torch.nn.Linear` + by `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup + + droppath_rate: + stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382 + Attributes: + dropout (int): + The ratio of units to dropout. + normalization (None or Callable): + Normalization layer + linear (`torch.nn.Linear`): + The linear layer + activation (`torch.nn.Module`): + The activation layer + init_fn (Callable): + Initialization function used for the weight of the layer + in_dim (int): + Input dimension of the linear layer + out_dim (int): + Output dimension of the linear layer + """ + + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + activation=activation, + dropout=dropout, + normalization=normalization, + bias=bias, + init_fn=init_fn, + is_readout_layer=is_readout_layer, + droppath_rate=droppath_rate, + ) + + # Linear layer, or MuReadout layer + if not is_readout_layer: + self.linear = EnsembleLinear( + in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn + ) + else: + self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble=num_ensemble, bias=bias) + + self.reset_parameters() + + def reset_parameters(self, init_fn=None): + """ + Reset the parameters of the linear layer using the `init_fn`. + """ + set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup + self.linear.reset_parameters() + + def __repr__(self): + rep = super().__repr__() + rep = rep[:-1] + f", num_ensemble={self.linear.weight.shape[0]})" + return rep + + +class EnsembleMuReadoutGraphium(EnsembleLinear): + """ + This layer implements an ensemble version of μP with a 1/width multiplier and a + constant variance initialization for both weights and biases. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + bias: bool = True, + init_fn: Optional[Callable] = None, + readout_zero_init=False, + output_mult=1.0, + ): + self.in_dim = in_dim + self.output_mult = output_mult + self.readout_zero_init = readout_zero_init + self._base_width = in_dim + super().__init__( + in_dim=in_dim, + out_dim=out_dim, + num_ensemble=num_ensemble, + bias=bias, + init_fn=init_fn, + ) + + def reset_parameters(self) -> None: + if self.readout_zero_init: + self.weight.data[:] = 0 + if self.bias is not None: + self.bias.data[:] = 0 + else: + super().reset_parameters() + + def width_mult(self): + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" + ) + return self.weight.infshape.width_mult() + + def _rescale_parameters(self): + """Rescale parameters to convert SP initialization to μP initialization. + + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) + if self.bias is not None: + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 + self._has_rescaled_params = True + + def forward(self, x): + return super().forward(self.output_mult * x / self.width_mult()) + + @property + def absolute_width(self): + return float(self.in_dim) + + @property + def base_width(self): + return self._base_width + + @base_width.setter + def base_width(self, val): + if val is None: + return + assert isinstance( + val, (int, torch.int, torch.long) + ), f"`base_width` must be None, int or long, provided {val} of type {type(val)}" + self._base_width = val + + def width_mult(self): + return self.absolute_width / self.base_width + + +class EnsembleMLP(MLP): + def __init__( + self, + in_dim: int, + hidden_dims: Union[Iterable[int], int], + out_dim: int, + num_ensemble: int, + depth: Optional[int] = None, + reduction: Optional[Union[str, Callable]] = "none", + activation: Union[str, Callable] = "relu", + last_activation: Union[str, Callable] = "none", + dropout: float = 0.0, + last_dropout: float = 0.0, + normalization: Union[Type[None], str, Callable] = "none", + last_normalization: Union[Type[None], str, Callable] = "none", + first_normalization: Union[Type[None], str, Callable] = "none", + last_layer_is_readout: bool = False, + droppath_rate: float = 0.0, + constant_droppath_rate: bool = True, + ): + r""" + Simple multi-layer perceptron, built of a series of FCLayers + + Parameters: + in_dim: + Input dimension of the MLP + hidden_dims: + Either an integer specifying all the hidden dimensions, + or a list of dimensions in the hidden layers. + out_dim: + Output dimension of the MLP. + num_ensemble: + Number of MLPs that run in parallel. + depth: + If `hidden_dims` is an integer, `depth` is 1 + the number of + hidden layers to use. + If `hidden_dims` is a list, then + `depth` must be `None` or equal to `len(hidden_dims) + 1` + reduction: + Reduction to use at the end of the MLP. Choices: + + - "none" or `None`: No reduction + - "mean": Mean reduction + - "sum": Sum reduction + - "max": Max reduction + - "min": Min reduction + - "median": Median reduction + - `Callable`: Any callable function. Must take `dim` as a keyword argument. + activation: + Activation function to use in all the layers except the last. + if `layers==1`, this parameter is ignored + last_activation: + Activation function to use in the last layer. + dropout: + The ratio of units to dropout. Must be between 0 and 1 + normalization: + Normalization to use. Choices: + + - "none" or `None`: No normalization + - "batch_norm": Batch normalization + - "layer_norm": Layer normalization in the hidden layers. + - `Callable`: Any callable function + + if `layers==1`, this parameter is ignored + last_normalization: + Norrmalization to use **after the last layer**. Same options as `normalization`. + first_normalization: + Norrmalization to use in **before the first layer**. Same options as `normalization`. + last_dropout: + The ratio of units to dropout at the last layer. + last_layer_is_readout: Whether the last layer should be treated as a readout layer. + Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup + droppath_rate: + stochastic depth drop rate, between 0 and 1. + See https://arxiv.org/abs/1603.09382 + constant_droppath_rate: + If `True`, drop rates will remain constant accross layers. + Otherwise, drop rates will vary stochastically. + See `DropPath.get_stochastic_drop_rate` + """ + + super().__init__( + in_dim=in_dim, + hidden_dims=hidden_dims, + out_dim=out_dim, + depth=depth, + activation=activation, + last_activation=last_activation, + dropout=dropout, + last_dropout=last_dropout, + normalization=normalization, + last_normalization=last_normalization, + first_normalization=first_normalization, + last_layer_is_readout=last_layer_is_readout, + droppath_rate=droppath_rate, + constant_droppath_rate=constant_droppath_rate, + fc_layer=EnsembleFCLayer, + fc_layer_kwargs={"num_ensemble": num_ensemble}, + ) + + self.reduction_fn = self._parse_reduction(reduction) + + def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]: + r""" + Parse the reduction argument. + """ + + if isinstance(reduction, str): + reduction = reduction.lower() + if reduction is None or reduction == "none": + return None + elif reduction == "mean": + return torch.mean + elif reduction == "sum": + return torch.sum + elif reduction == "max": + + def max_vals(x, dim): + return torch.max(x, dim=dim).values + + return max_vals + elif reduction == "min": + + def min_vals(x, dim): + return torch.min(x, dim=dim).values + + return min_vals + elif reduction == "median": + + def median_vals(x, dim): + return torch.median(x, dim=dim).values + + return median_vals + elif callable(reduction): + return reduction + else: + raise ValueError(f"Unknown reduction {reduction}") + + def forward(self, h: torch.Tensor) -> torch.Tensor: + r""" + Apply the ensemble MLP on the input features, then reduce the output if specified. + + Parameters: + + h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`: + + Input feature tensor, before the MLP. + `Din` is the number of input features, `B` is the batch size, and `L` is the number of ensembles. + + Returns: + + `torch.Tensor[..., L, B, Dout]` or `torch.Tensor[..., B, Dout]`: + + Output feature tensor, after the MLP. + `Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles. + `L` is removed if a reduction is specified. + """ + h = super().forward(h) + if self.reduction_fn is not None: + h = self.reduction_fn(h, dim=-3) + return h + + def __repr__(self): + r""" + Controls how the class is printed + """ + rep = super().__repr__() + rep = rep[:-1] + f", num_ensemble={self.layers[0].linear.weight.shape[0]})" diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index d821223a4..e18cd2302 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -4,7 +4,8 @@ import torchmetrics.functional as TorchMetrics import graphium.nn.base_layers as BaseLayers -from graphium.nn.architectures import FeedForwardNN, FeedForwardPyg, TaskHeads +import graphium.nn.ensemble_layers as EnsembleLayers +import graphium.nn.architectures as Architectures import graphium.utils.custom_lr as CustomLR import graphium.data.datamodule as Datamodules import graphium.ipu.ipu_losses as IPULosses @@ -27,6 +28,10 @@ "fc": BaseLayers.FCLayer, } +ENSEMBLE_FC_LAYERS_DICT = { + "ens-fc": EnsembleLayers.EnsembleFCLayer, +} + PYG_LAYERS_DICT = { "pyg:gcn": PygLayers.GCNConvPyg, "pyg:gin": PygLayers.GINConvPyg, @@ -41,6 +46,7 @@ LAYERS_DICT = deepcopy(FC_LAYERS_DICT) LAYERS_DICT.update(deepcopy(PYG_LAYERS_DICT)) +ENSEMBLE_LAYERS_DICT = deepcopy(ENSEMBLE_FC_LAYERS_DICT) RESIDUALS_DICT = { "none": Residuals.ResidualConnectionNone, @@ -132,4 +138,9 @@ "dummy-pretrained-model": "tests/dummy-pretrained-model.ckpt", # dummy model used for testing purposes } -FINETUNING_HEADS_DICT = {"mlp": FeedForwardNN, "gnn": FeedForwardPyg, "task_head": TaskHeads} +FINETUNING_HEADS_DICT = { + "mlp": Architectures.FeedForwardNN, + "gnn": Architectures.FeedForwardPyg, + "task_head": Architectures.TaskHeads, + "ens-mlp": Architectures.EnsembleFeedForwardNN, +} diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py new file mode 100644 index 000000000..e43b14ac1 --- /dev/null +++ b/tests/test_ensemble_layers.py @@ -0,0 +1,584 @@ +""" +Unit tests for the different layers of graphium/nn/ensemble_layers +""" + +import numpy as np +import torch +from torch.nn import Linear +import unittest as ut + +from graphium.nn.base_layers import FCLayer, MLP, MuReadoutGraphium +from graphium.nn.ensemble_layers import ( + EnsembleLinear, + EnsembleFCLayer, + EnsembleMLP, + EnsembleMuReadoutGraphium, +) +from graphium.nn.architectures import FeedForwardNN, EnsembleFeedForwardNN + + +class test_Ensemble_Layers(ut.TestCase): + def check_ensemble_linear( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + use_mureadout=False, + ): + msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + if use_mureadout: + # Create EnsembleMuReadoutGraphium instance + ensemble_linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble) + # Create equivalent separate Linear layers with synchronized weights and biases + linear_layers = [MuReadoutGraphium(in_dim, out_dim) for _ in range(num_ensemble)] + else: + # Create EnsembleLinear instance + ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble) + # Create equivalent separate Linear layers with synchronized weights and biases + linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)] + + for i, linear_layer in enumerate(linear_layers): + linear_layer.weight.data = ensemble_linear.weight.data[i] + if ensemble_linear.bias is not None: + linear_layer.bias.data = ensemble_linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_linear(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, linear_layer in enumerate(linear_layers): + individual_output = linear_layer(input_tensor) + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_linear(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, linear_layer in enumerate(linear_layers): + if more_batch_dim: + individual_output = linear_layer(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = linear_layer(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + def test_ensemble_linear(self): + # more_batch_dim=0 + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + + # more_batch_dim=1 + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + + # more_batch_dim=7 + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) + self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + + def test_ensemble_mureadout_graphium(self): + # Test `use_mureadout` + # more_batch_dim=0 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0, use_mureadout=True + ) + + # more_batch_dim=1 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1, use_mureadout=True + ) + + # more_batch_dim=7 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7, use_mureadout=True + ) + + def check_ensemble_fclayer( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + is_readout_layer=False, + ): + msg = f"Testing EnsembleFCLayer with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFCLayer instance + ensemble_fclayer = EnsembleFCLayer(in_dim, out_dim, num_ensemble, is_readout_layer=is_readout_layer) + + # Create equivalent separate FCLayer layers with synchronized weights and biases + fc_layers = [FCLayer(in_dim, out_dim, is_readout_layer=is_readout_layer) for _ in range(num_ensemble)] + for i, fc_layer in enumerate(fc_layers): + fc_layer.linear.weight.data = ensemble_fclayer.linear.weight.data[i] + if ensemble_fclayer.bias is not None: + fc_layer.linear.bias.data = ensemble_fclayer.linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_fclayer(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, fc_layer in enumerate(fc_layers): + individual_output = fc_layer(input_tensor) + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_fclayer(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, fc_layer in enumerate(fc_layers): + if more_batch_dim: + individual_output = fc_layer(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = fc_layer(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + def test_ensemble_fclayer(self): + # more_batch_dim=0 + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + + # more_batch_dim=1 + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + + # more_batch_dim=7 + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) + self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + + # Test `is_readout_layer` + self.check_ensemble_fclayer( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True + ) + self.check_ensemble_fclayer( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True + ) + self.check_ensemble_fclayer( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True + ) + + def check_ensemble_mlp( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + ): + msg = f"Testing EnsembleMLP with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleMLP instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleMLP( + in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout + ) + + # Create equivalent separate MLP layers with synchronized weights and biases + mlps = [ + MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout) + for _ in range(num_ensemble) + ] + for i, mlp in enumerate(mlps): + for j, layer in enumerate(mlp.fully_connected): + layer.linear.weight.data = ensemble_mlp.fully_connected[j].linear.weight.data[i] + if layer.bias is not None: + layer.linear.bias.data = ensemble_mlp.fully_connected[j].linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, mlp in enumerate(mlps): + individual_output = mlp(input_tensor) + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, mlp in enumerate(mlps): + if more_batch_dim: + individual_output = mlp(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = mlp(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + def test_ensemble_mlp(self): + # more_batch_dim=0 + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0) + + # more_batch_dim=1 + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1) + + # more_batch_dim=7 + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) + self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + + # Test `last_layer_is_readout` + self.check_ensemble_mlp( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True + ) + self.check_ensemble_mlp( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True + ) + self.check_ensemble_mlp( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True + ) + + def check_ensemble_feedforwardnn( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + ): + msg = f"Testing EnsembleFeedForwardNN with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFeedForwardNN instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleFeedForwardNN( + in_dim, + out_dim, + hidden_dims, + num_ensemble, + reduction=None, + last_layer_is_readout=last_layer_is_readout, + ) + + # Create equivalent separate MLP layers with synchronized weights and biases + mlps = [ + FeedForwardNN(in_dim, out_dim, hidden_dims, last_layer_is_readout=last_layer_is_readout) + for _ in range(num_ensemble) + ] + for i, mlp in enumerate(mlps): + for j, layer in enumerate(mlp.layers): + layer.linear.weight.data = ensemble_mlp.layers[j].linear.weight.data[i] + if layer.bias is not None: + layer.linear.bias.data = ensemble_mlp.layers[j].linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + individual_outputs = [] + for i, mlp in enumerate(mlps): + individual_outputs.append(mlp(input_tensor)) + individual_outputs = torch.stack(individual_outputs).detach().numpy() + for i, mlp in enumerate(mlps): + ensemble_output_i = ensemble_output[i].detach().numpy() + np.testing.assert_allclose( + ensemble_output_i, individual_outputs[..., i, :, :], atol=1e-5, err_msg=msg + ) + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (num_ensemble, batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + for i, mlp in enumerate(mlps): + if more_batch_dim: + individual_output = mlp(input_tensor[:, i]) + ensemble_output_i = ensemble_output[:, i] + else: + individual_output = mlp(input_tensor[i]) + ensemble_output_i = ensemble_output[i] + individual_output = individual_output.detach().numpy() + ensemble_output_i = ensemble_output_i.detach().numpy() + np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg) + + def check_ensemble_feedforwardnn_mean( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + ): + msg = f"Testing EnsembleFeedForwardNN with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFeedForwardNN instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleFeedForwardNN( + in_dim, + out_dim, + hidden_dims, + num_ensemble, + reduction="mean", + last_layer_is_readout=last_layer_is_readout, + ) + + # Create equivalent separate MLP layers with synchronized weights and biases + mlps = [ + FeedForwardNN(in_dim, out_dim, hidden_dims, last_layer_is_readout=last_layer_is_readout) + for _ in range(num_ensemble) + ] + for i, mlp in enumerate(mlps): + for j, layer in enumerate(mlp.layers): + layer.linear.weight.data = ensemble_mlp.layers[j].linear.weight.data[i] + if layer.bias is not None: + layer.linear.bias.data = ensemble_mlp.layers[j].linear.bias.data[i].squeeze() + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (batch_size, out_dim), msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + individual_outputs = [] + for i, mlp in enumerate(mlps): + individual_outputs.append(mlp(input_tensor)) + individual_outputs = torch.stack(individual_outputs, dim=-3) + individual_outputs = individual_outputs.mean(dim=-3).detach().numpy() + np.testing.assert_allclose( + ensemble_output.detach().numpy(), individual_outputs, atol=1e-5, err_msg=msg + ) + + # Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension + if more_batch_dim: + out_shape = (more_batch_dim, batch_size, out_dim) + input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim) + else: + out_shape = (batch_size, out_dim) + input_tensor = torch.randn(num_ensemble, batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor).detach() + + # Check for the output shape + self.assertEqual(ensemble_output.shape, out_shape, msg=msg) + + # Make sure that the outputs of the individual layers are the same as the ensemble output + individual_outputs = [] + for i, mlp in enumerate(mlps): + if more_batch_dim: + individual_outputs.append(mlp(input_tensor[:, i])) + else: + individual_outputs.append(mlp(input_tensor[i])) + individual_output = torch.stack(individual_outputs, dim=-3).mean(dim=-3).detach().numpy() + np.testing.assert_allclose(ensemble_output, individual_output, atol=1e-5, err_msg=msg) + + def check_ensemble_feedforwardnn_simple( + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + last_layer_is_readout=False, + **kwargs, + ): + msg = f"Testing EnsembleFeedForwardNN with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" + + # Create EnsembleFeedForwardNN instance + hidden_dims = [17, 17, 17] + ensemble_mlp = EnsembleFeedForwardNN( + in_dim, + out_dim, + hidden_dims, + num_ensemble, + reduction=None, + last_layer_is_readout=last_layer_is_readout, + **kwargs, + ) + + # Test with a sample input + input_tensor = torch.randn(batch_size, in_dim) + ensemble_output = ensemble_mlp(input_tensor) + + # Check for the output shape + self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg) + + def test_ensemble_feedforwardnn(self): + # more_batch_dim=0 + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0 + ) + + # more_batch_dim=1 + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1 + ) + + # more_batch_dim=7 + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7 + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7 + ) + + # Test `last_layer_is_readout` + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True + ) + + # Test `reduction` + self.check_ensemble_feedforwardnn_mean( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn_mean( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True + ) + self.check_ensemble_feedforwardnn_mean( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True + ) + + # Test `subset_in_dim` + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=0.5 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=0.5 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=0.5 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=7 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=7 + ) + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=7 + ) + with self.assertRaises(AssertionError): + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, subset_in_dim=1.5 + ) + with self.assertRaises(AssertionError): + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, subset_in_dim=39 + ) + with self.assertRaises(AssertionError): + self.check_ensemble_feedforwardnn_simple( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, subset_in_dim=39 + ) + + +if __name__ == "__main__": + ut.main()