From e26adcd50062db9c5296bfe63a3ce5f0365bc377 Mon Sep 17 00:00:00 2001 From: GuillaumeErhard <25333848+GuillaumeErhard@users.noreply.github.com> Date: Thu, 20 Mar 2025 00:52:24 +0100 Subject: [PATCH 01/21] Deprecate use_batchnorm in favor of generalized use_norm parameter --- segmentation_models_pytorch/base/modules.py | 83 ++++++++++++++++--- .../decoders/linknet/decoder.py | 32 +++++-- .../decoders/linknet/model.py | 26 +++++- .../decoders/manet/decoder.py | 28 +++++-- .../decoders/manet/model.py | 26 +++++- .../decoders/unet/decoder.py | 27 ++++-- .../decoders/unet/model.py | 26 +++++- .../decoders/unetplusplus/decoder.py | 31 +++++-- .../decoders/unetplusplus/model.py | 26 +++++- 9 files changed, 257 insertions(+), 48 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index cbd643b6..ed3805be 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -1,3 +1,5 @@ +import warnings + import torch import torch.nn as nn @@ -16,11 +18,53 @@ def __init__( padding=0, stride=1, use_batchnorm=True, + use_norm="batchnorm", ): - if use_batchnorm == "inplace" and InPlaceABN is None: + if use_batchnorm is not None: + warnings.warn( + "The usage of use_batchnorm is deprecated. Please modify your code for use_norm", + DeprecationWarning, + ) + if use_batchnorm is True: + use_norm = {"type": "batchnorm"} + elif use_batchnorm is False: + use_norm = {"type": "identity"} + elif use_batchnorm == "inplace": + use_norm = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + else: + raise ValueError("Unrecognized value for use_batchnorm") + + if isinstance(use_norm, str): + norm_str = use_norm.lower() + if norm_str == "inplace": + use_norm = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + elif norm_str in ( + "batchnorm", + "identity", + "layernorm", + "groupnorm", + "instancenorm", + ): + use_norm = {"type": norm_str} + else: + raise ValueError("Unrecognized normalization type string provided") + elif isinstance(use_norm, bool): + use_norm = {"type": "batchnorm" if use_norm else "identity"} + elif not isinstance(use_norm, dict): + raise ValueError("use_norm must be a dictionary, boolean, or string") + + if use_norm["type"] == "inplace" and InPlaceABN is None: raise RuntimeError( - "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " - + "To install see: https://github.com/mapillary/inplace_abn" + "In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " + "To install see: https://github.com/mapillary/inplace_abn" ) conv = nn.Conv2d( @@ -29,21 +73,30 @@ def __init__( kernel_size, stride=stride, padding=padding, - bias=not (use_batchnorm), + bias=use_norm["type"] != "inplace", ) relu = nn.ReLU(inplace=True) - if use_batchnorm == "inplace": - bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) - relu = nn.Identity() - - elif use_batchnorm and use_batchnorm != "inplace": - bn = nn.BatchNorm2d(out_channels) + norm_type = use_norm["type"] + extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"} + if norm_type == "inplace": + norm = InPlaceABN(out_channels, **extra_kwargs) + relu = nn.Identity() + elif norm_type == "batchnorm": + norm = nn.BatchNorm2d(out_channels, **extra_kwargs) + elif norm_type == "identity": + norm = nn.Identity() + elif norm_type == "layernorm": + norm = nn.LayerNorm(out_channels, **extra_kwargs) + elif norm_type == "groupnorm": + norm = nn.GroupNorm(out_channels, **extra_kwargs) + elif norm_type == "instancenorm": + norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) else: - bn = nn.Identity() + raise ValueError(f"Unrecognized normalization type: {norm_type}") - super(Conv2dReLU, self).__init__(conv, bn, relu) + super(Conv2dReLU, self).__init__(conv, norm, relu) class SCSEModule(nn.Module): @@ -127,3 +180,9 @@ def __init__(self, name, **params): def forward(self, x): return self.attention(x) + + +if __name__ == "__main__": + print(Conv2dReLU(3, 12, 4)) + print(Conv2dReLU(3, 12, 4, use_norm={"type": "batchnorm"})) + print(Conv2dReLU(3, 12, 4, use_norm={"type": "layernorm", "eps": 1e-3})) diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index 8dfd8434..128ec2dc 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -1,12 +1,18 @@ import torch import torch.nn as nn -from typing import List, Optional +from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import modules class TransposeX2(nn.Sequential): - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, + ): super().__init__() layers = [ nn.ConvTranspose2d( @@ -15,14 +21,20 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr nn.ReLU(inplace=True), ] - if use_batchnorm: + if use_batchnorm or use_norm: layers.insert(1, nn.BatchNorm2d(out_channels)) super().__init__(*layers) class DecoderBlock(nn.Module): - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, + ): super().__init__() self.block = nn.Sequential( @@ -31,6 +43,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr in_channels // 4, kernel_size=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ), TransposeX2( in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm @@ -40,6 +53,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr out_channels, kernel_size=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ), ) @@ -58,7 +72,8 @@ def __init__( encoder_channels: List[int], prefinal_channels: int = 32, n_blocks: int = 5, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, ): super().__init__() @@ -71,7 +86,12 @@ def __init__( self.blocks = nn.ModuleList( [ - DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) + DecoderBlock( + channels[i], + channels[i + 1], + use_batchnorm=use_batchnorm, + use_norm=use_norm, + ) for i in range(n_blocks) ] ) diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 356468ed..0dca7e56 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from segmentation_models_pytorch.base import ( ClassificationHead, @@ -29,9 +29,27 @@ class Linknet(SegmentationModel): Default is 5 encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. Available options are **True, False, "inplace"** + + **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "groupnorm", "num_groups": 8} + ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. @@ -60,7 +78,8 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, + decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, @@ -87,6 +106,7 @@ def __init__( n_blocks=encoder_depth, prefinal_channels=32, use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py index 61b1fe57..49891fe3 100644 --- a/segmentation_models_pytorch/decoders/manet/decoder.py +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -1,9 +1,9 @@ +from typing import Any, Dict, List, Optional, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import List, Optional - from segmentation_models_pytorch.base import modules as md @@ -49,7 +49,8 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, reduction: int = 16, ): # MFABBlock is just a modified version of SE-blocks, one for skip, one for input @@ -61,9 +62,14 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ), md.Conv2dReLU( - in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm + in_channels, + skip_channels, + kernel_size=1, + use_batchnorm=use_batchnorm, + use_norm=use_norm, ), ) reduced_channels = max(1, skip_channels // reduction) @@ -88,6 +94,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( out_channels, @@ -95,6 +102,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) def forward( @@ -119,7 +127,8 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, ): super().__init__() self.conv1 = md.Conv2dReLU( @@ -128,6 +137,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( out_channels, @@ -135,6 +145,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) def forward( @@ -155,7 +166,8 @@ def __init__( decoder_channels: List[int], n_blocks: int = 5, reduction: int = 16, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, pab_channels: int = 64, ): super().__init__() @@ -182,7 +194,9 @@ def __init__( self.center = PABBlock(head_channels, pab_channels=pab_channels) # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here + kwargs = dict( + use_batchnorm=use_batchnorm, use_norm=use_norm + ) # no attention type here blocks = [ MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 6ed59207..5405f9d7 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import ( ClassificationHead, @@ -29,9 +29,27 @@ class MAnet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. Available options are **True, False, "inplace"** + + **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "groupnorm", "num_groups": 8} + ``` decoder_pab_channels: A number of channels for PAB module in decoder. Default is 64. in_channels: A number of input channels for the model, default is 3 (RGB images) @@ -63,7 +81,8 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, + decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any], None] = True, decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, in_channels: int = 3, @@ -87,6 +106,7 @@ def __init__( decoder_channels=decoder_channels, n_blocks=encoder_depth, use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, pab_channels=decoder_pab_channels, ) diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index 0e4f35fd..b5fce6df 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Optional, Sequence, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Sequence, List from segmentation_models_pytorch.base import modules as md @@ -14,7 +15,8 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, interpolation_mode: str = "nearest", ): @@ -26,6 +28,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention1 = md.Attention( attention_type, in_channels=in_channels + skip_channels @@ -36,6 +39,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) @@ -63,13 +67,20 @@ def forward( class UnetCenterBlock(nn.Sequential): """Center block of the Unet decoder. Applied to the last feature map of the encoder.""" - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, + ): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) conv2 = md.Conv2dReLU( out_channels, @@ -77,6 +88,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) super().__init__(conv1, conv2) @@ -93,7 +105,8 @@ def __init__( encoder_channels: Sequence[int], decoder_channels: Sequence[int], n_blocks: int = 5, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, add_center_block: bool = False, interpolation_mode: str = "nearest", @@ -120,7 +133,10 @@ def __init__( if add_center_block: self.center = UnetCenterBlock( - head_channels, head_channels, use_batchnorm=use_batchnorm + head_channels, + head_channels, + use_batchnorm=use_batchnorm, + use_norm=use_norm, ) else: self.center = nn.Identity() @@ -135,6 +151,7 @@ def __init__( block_skip_channels, block_out_channels, use_batchnorm=use_batchnorm, + use_norm=use_norm, attention_type=attention_type, interpolation_mode=interpolation_mode, ) diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 4b30527d..1778515e 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, Callable, Sequence +from typing import Any, Dict, Optional, Union, Callable, Sequence from segmentation_models_pytorch.base import ( ClassificationHead, @@ -39,9 +39,27 @@ class Unet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. Available options are **True, False, "inplace"** + + **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "groupnorm", "num_groups": 8} + ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are @@ -95,7 +113,8 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, + decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, decoder_interpolation_mode: str = "nearest", @@ -121,6 +140,7 @@ def __init__( decoder_channels=decoder_channels, n_blocks=encoder_depth, use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, add_center_block=add_center_block, attention_type=decoder_attention_type, interpolation_mode=decoder_interpolation_mode, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py index 3282849f..34f31627 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from typing import Optional, List +from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import modules as md @@ -13,7 +13,8 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, ): super().__init__() @@ -23,6 +24,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention1 = md.Attention( attention_type, in_channels=in_channels + skip_channels @@ -33,6 +35,7 @@ def __init__( kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) @@ -50,13 +53,20 @@ def forward( class CenterBlock(nn.Sequential): - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, + ): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) conv2 = md.Conv2dReLU( out_channels, @@ -64,6 +74,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr kernel_size=3, padding=1, use_batchnorm=use_batchnorm, + use_norm=use_norm, ) super().__init__(conv1, conv2) @@ -74,7 +85,8 @@ def __init__( encoder_channels: List[int], decoder_channels: List[int], n_blocks: int = 5, - use_batchnorm: bool = True, + use_batchnorm: Union[bool, str, None] = True, + use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, center: bool = False, ): @@ -97,13 +109,20 @@ def __init__( self.out_channels = decoder_channels if center: self.center = CenterBlock( - head_channels, head_channels, use_batchnorm=use_batchnorm + head_channels, + head_channels, + use_batchnorm=use_batchnorm, + use_norm=use_norm, ) else: self.center = nn.Identity() # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) + kwargs = dict( + use_batchnorm=use_batchnorm, + use_norm=use_norm, + attention_type=attention_type, + ) blocks = {} for layer_idx in range(len(self.in_channels) - 1): diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 5c3d3a91..e51bfe65 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import ( ClassificationHead, @@ -28,9 +28,27 @@ class UnetPlusPlus(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. Available options are **True, False, "inplace"** + + **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "groupnorm", "num_groups": 8} + ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). in_channels: A number of input channels for the model, default is 3 (RGB images) @@ -64,7 +82,8 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, + decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, @@ -93,6 +112,7 @@ def __init__( decoder_channels=decoder_channels, n_blocks=encoder_depth, use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, center=True if encoder_name.startswith("vgg") else False, attention_type=decoder_attention_type, ) From 1b16b254b26142696224d21ba05eb08e6195f4cc Mon Sep 17 00:00:00 2001 From: GuillaumeErhard <25333848+GuillaumeErhard@users.noreply.github.com> Date: Fri, 21 Mar 2025 01:09:49 +0100 Subject: [PATCH 02/21] First fix following review --- segmentation_models_pytorch/base/modules.py | 139 +++++++++--------- .../decoders/linknet/decoder.py | 10 +- .../decoders/linknet/model.py | 3 +- .../decoders/manet/decoder.py | 11 +- .../decoders/manet/model.py | 4 +- .../decoders/unet/decoder.py | 9 -- .../decoders/unet/model.py | 3 +- .../decoders/unetplusplus/decoder.py | 9 -- .../decoders/unetplusplus/model.py | 3 +- tests/base/test_modules.py | 48 ++++++ tests/encoders/test_batchnorm_deprecation.py | 1 + 11 files changed, 134 insertions(+), 106 deletions(-) create mode 100644 tests/base/test_modules.py create mode 100644 tests/encoders/test_batchnorm_deprecation.py diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index ed3805be..a3e18186 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, Tuple, Union import warnings import torch @@ -8,6 +9,76 @@ except ImportError: InPlaceABN = None +def handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of use_batchnorm is deprecated. Please modify your code for use_norm", + DeprecationWarning, + ) + if decoder_use_batchnorm is True: + decoder_use_norm = {"type": "batchnorm"} + elif decoder_use_batchnorm is False: + decoder_use_norm = {"type": "identity"} + elif decoder_use_batchnorm == "inplace": + decoder_use_norm = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + else: + raise ValueError("Unrecognized value for use_batchnorm") + + return decoder_use_norm + + +def normalize_use_norm(use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(use_norm, str): + norm_str = use_norm.lower() + if norm_str == "inplace": + use_norm = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + elif norm_str in ( + "batchnorm", + "identity", + "layernorm", + "groupnorm", + "instancenorm", + ): + use_norm = {"type": norm_str} + else: + raise ValueError("Unrecognized normalization type string provided") + elif isinstance(use_norm, bool): + use_norm = {"type": "batchnorm" if use_norm else "identity"} + elif not isinstance(use_norm, dict): + raise ValueError("use_norm must be a dictionary, boolean, or string") + + return use_norm + +def get_norm_layer(use_norm: Dict[str, Any], relu: nn.Module, out_channels: int) -> Tuple[nn.Module, nn.Module]: + norm_type = use_norm["type"] + extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"} + + if norm_type == "inplace": + norm = InPlaceABN(out_channels, **extra_kwargs) + relu = nn.Identity() + elif norm_type == "batchnorm": + norm = nn.BatchNorm2d(out_channels, **extra_kwargs) + elif norm_type == "identity": + norm = nn.Identity() + elif norm_type == "layernorm": + norm = nn.LayerNorm(out_channels, **extra_kwargs) + elif norm_type == "groupnorm": + norm = nn.GroupNorm(out_channels, **extra_kwargs) + elif norm_type == "instancenorm": + norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) + else: + raise ValueError(f"Unrecognized normalization type: {norm_type}") + + return norm, relu + class Conv2dReLU(nn.Sequential): def __init__( @@ -17,50 +88,9 @@ def __init__( kernel_size, padding=0, stride=1, - use_batchnorm=True, use_norm="batchnorm", ): - if use_batchnorm is not None: - warnings.warn( - "The usage of use_batchnorm is deprecated. Please modify your code for use_norm", - DeprecationWarning, - ) - if use_batchnorm is True: - use_norm = {"type": "batchnorm"} - elif use_batchnorm is False: - use_norm = {"type": "identity"} - elif use_batchnorm == "inplace": - use_norm = { - "type": "inplace", - "activation": "leaky_relu", - "activation_param": 0.0, - } - else: - raise ValueError("Unrecognized value for use_batchnorm") - - if isinstance(use_norm, str): - norm_str = use_norm.lower() - if norm_str == "inplace": - use_norm = { - "type": "inplace", - "activation": "leaky_relu", - "activation_param": 0.0, - } - elif norm_str in ( - "batchnorm", - "identity", - "layernorm", - "groupnorm", - "instancenorm", - ): - use_norm = {"type": norm_str} - else: - raise ValueError("Unrecognized normalization type string provided") - elif isinstance(use_norm, bool): - use_norm = {"type": "batchnorm" if use_norm else "identity"} - elif not isinstance(use_norm, dict): - raise ValueError("use_norm must be a dictionary, boolean, or string") - + use_norm = normalize_use_norm(use_norm) if use_norm["type"] == "inplace" and InPlaceABN is None: raise RuntimeError( "In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " @@ -77,24 +107,7 @@ def __init__( ) relu = nn.ReLU(inplace=True) - norm_type = use_norm["type"] - extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"} - - if norm_type == "inplace": - norm = InPlaceABN(out_channels, **extra_kwargs) - relu = nn.Identity() - elif norm_type == "batchnorm": - norm = nn.BatchNorm2d(out_channels, **extra_kwargs) - elif norm_type == "identity": - norm = nn.Identity() - elif norm_type == "layernorm": - norm = nn.LayerNorm(out_channels, **extra_kwargs) - elif norm_type == "groupnorm": - norm = nn.GroupNorm(out_channels, **extra_kwargs) - elif norm_type == "instancenorm": - norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) - else: - raise ValueError(f"Unrecognized normalization type: {norm_type}") + norm, relu = get_norm_layer(use_norm, relu, out_channels) super(Conv2dReLU, self).__init__(conv, norm, relu) @@ -180,9 +193,3 @@ def __init__(self, name, **params): def forward(self, x): return self.attention(x) - - -if __name__ == "__main__": - print(Conv2dReLU(3, 12, 4)) - print(Conv2dReLU(3, 12, 4, use_norm={"type": "batchnorm"})) - print(Conv2dReLU(3, 12, 4, use_norm={"type": "layernorm", "eps": 1e-3})) diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index 128ec2dc..b7833210 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -10,7 +10,6 @@ def __init__( self, in_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, ): super().__init__() @@ -21,7 +20,7 @@ def __init__( nn.ReLU(inplace=True), ] - if use_batchnorm or use_norm: + if use_norm: layers.insert(1, nn.BatchNorm2d(out_channels)) super().__init__(*layers) @@ -32,7 +31,6 @@ def __init__( self, in_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, ): super().__init__() @@ -42,17 +40,15 @@ def __init__( in_channels, in_channels // 4, kernel_size=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ), TransposeX2( - in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm + in_channels // 4, in_channels // 4, use_norm=use_norm ), modules.Conv2dReLU( in_channels // 4, out_channels, kernel_size=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ), ) @@ -72,7 +68,6 @@ def __init__( encoder_channels: List[int], prefinal_channels: int = 32, n_blocks: int = 5, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, ): super().__init__() @@ -89,7 +84,6 @@ def __init__( DecoderBlock( channels[i], channels[i + 1], - use_batchnorm=use_batchnorm, use_norm=use_norm, ) for i in range(n_blocks) diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 0dca7e56..0a78564e 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -7,6 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading +from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation from .decoder import LinknetDecoder @@ -101,11 +102,11 @@ def __init__( **kwargs, ) + decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) self.decoder = LinknetDecoder( encoder_channels=self.encoder.out_channels, n_blocks=encoder_depth, prefinal_channels=32, - use_batchnorm=decoder_use_batchnorm, use_norm=decoder_use_norm, ) diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py index 49891fe3..b18011bc 100644 --- a/segmentation_models_pytorch/decoders/manet/decoder.py +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -49,7 +49,6 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, reduction: int = 16, ): @@ -61,14 +60,12 @@ def __init__( in_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ), md.Conv2dReLU( in_channels, skip_channels, kernel_size=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ), ) @@ -93,7 +90,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( @@ -101,7 +97,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) @@ -127,7 +122,6 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, ): super().__init__() @@ -136,7 +130,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( @@ -144,7 +137,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) @@ -166,7 +158,6 @@ def __init__( decoder_channels: List[int], n_blocks: int = 5, reduction: int = 16, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, pab_channels: int = 64, ): @@ -195,7 +186,7 @@ def __init__( # combine decoder keyword arguments kwargs = dict( - use_batchnorm=use_batchnorm, use_norm=use_norm + use_norm=use_norm ) # no attention type here blocks = [ MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 5405f9d7..f1675098 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -7,6 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading +from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation from .decoder import MAnetDecoder @@ -101,11 +102,12 @@ def __init__( **kwargs, ) + decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) + self.decoder = MAnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, use_norm=decoder_use_norm, pab_channels=decoder_pab_channels, ) diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index b5fce6df..adf98d66 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -15,7 +15,6 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, interpolation_mode: str = "nearest", @@ -27,7 +26,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) self.attention1 = md.Attention( @@ -38,7 +36,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) @@ -71,7 +68,6 @@ def __init__( self, in_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, ): conv1 = md.Conv2dReLU( @@ -79,7 +75,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) conv2 = md.Conv2dReLU( @@ -87,7 +82,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) super().__init__(conv1, conv2) @@ -105,7 +99,6 @@ def __init__( encoder_channels: Sequence[int], decoder_channels: Sequence[int], n_blocks: int = 5, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, add_center_block: bool = False, @@ -135,7 +128,6 @@ def __init__( self.center = UnetCenterBlock( head_channels, head_channels, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) else: @@ -150,7 +142,6 @@ def __init__( block_in_channels, block_skip_channels, block_out_channels, - use_batchnorm=use_batchnorm, use_norm=use_norm, attention_type=attention_type, interpolation_mode=interpolation_mode, diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 1778515e..133a95cf 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -7,6 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading +from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation from .decoder import UnetDecoder @@ -135,11 +136,11 @@ def __init__( ) add_center_block = encoder_name.startswith("vgg") + decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, use_norm=decoder_use_norm, add_center_block=add_center_block, attention_type=decoder_attention_type, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py index 34f31627..b1970104 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -13,7 +13,6 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, ): @@ -23,7 +22,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) self.attention1 = md.Attention( @@ -34,7 +32,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) @@ -57,7 +54,6 @@ def __init__( self, in_channels: int, out_channels: int, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, ): conv1 = md.Conv2dReLU( @@ -65,7 +61,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) conv2 = md.Conv2dReLU( @@ -73,7 +68,6 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) super().__init__(conv1, conv2) @@ -85,7 +79,6 @@ def __init__( encoder_channels: List[int], decoder_channels: List[int], n_blocks: int = 5, - use_batchnorm: Union[bool, str, None] = True, use_norm: Union[bool, str, Dict[str, Any]] = True, attention_type: Optional[str] = None, center: bool = False, @@ -111,7 +104,6 @@ def __init__( self.center = CenterBlock( head_channels, head_channels, - use_batchnorm=use_batchnorm, use_norm=use_norm, ) else: @@ -119,7 +111,6 @@ def __init__( # combine decoder keyword arguments kwargs = dict( - use_batchnorm=use_batchnorm, use_norm=use_norm, attention_type=attention_type, ) diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index e51bfe65..17dcc18b 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -7,6 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading +from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation from .decoder import UnetPlusPlusDecoder @@ -107,11 +108,11 @@ def __init__( **kwargs, ) + decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) self.decoder = UnetPlusPlusDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, use_norm=decoder_use_norm, center=True if encoder_name.startswith("vgg") else False, attention_type=decoder_attention_type, diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py new file mode 100644 index 00000000..702d7cda --- /dev/null +++ b/tests/base/test_modules.py @@ -0,0 +1,48 @@ +from segmentation_models_pytorch.base.modules import Conv2dReLU + + +def test_conv2drelu_batchnorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="batchnorm") + + expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' + '\n (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)' + '\n (2): ReLU(inplace=True)\n)') + assert repr(module) == expected + +def test_conv2drelu_batchnorm_with_keywords(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False}) + + expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' + '\n (1): BatchNorm2d(16, eps=1e-05, momentum=0.0001, affine=False, track_running_stats=True)' + '\n (2): ReLU(inplace=True)\n)') + assert repr(module) == expected + + +def test_conv2drelu_identity(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="identity") + expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' + '\n (1): Identity()' + '\n (2): ReLU(inplace=True)\n)') + assert repr(module) == expected + + +def test_conv2drelu_layernorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="layernorm") + expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' + '\n (1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)' + '\n (2): ReLU(inplace=True)\n)') + assert repr(module) == expected + +def test_conv2drelu_groupnorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="groupnorm") + expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' + '\n (1): Identity()' + '\n (2): ReLU(inplace=True)\n)') + assert repr(module) == expected + +def test_conv2drelu_instancenorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="instancenorm") + expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' + '\n (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)' + '\n (2): ReLU(inplace=True)\n)') + assert repr(module) == expected diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py new file mode 100644 index 00000000..65ec1518 --- /dev/null +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -0,0 +1 @@ +# TODO: Simple equal before / after \ No newline at end of file From 10d496a1bf423587913014057ed7ed36f7276f84 Mon Sep 17 00:00:00 2001 From: GuillaumeErhard <25333848+GuillaumeErhard@users.noreply.github.com> Date: Tue, 25 Mar 2025 00:46:31 +0100 Subject: [PATCH 03/21] Add test to verify before / after functionality Fix issue in Linknet Handle norm and relu on itsd own Align pspnet Add use_norm in upernet Remove groupnorm possibility Update doc --- segmentation_models_pytorch/base/modules.py | 70 +++++++++---------- .../decoders/linknet/decoder.py | 10 ++- .../decoders/linknet/model.py | 10 +-- .../decoders/manet/model.py | 12 ++-- .../decoders/pspnet/decoder.py | 19 ++--- .../decoders/pspnet/model.py | 30 ++++++-- .../decoders/unet/model.py | 10 +-- .../decoders/unetplusplus/model.py | 10 +-- .../decoders/upernet/decoder.py | 31 ++++---- .../decoders/upernet/model.py | 20 +++++- tests/base/test_modules.py | 6 -- tests/encoders/test_batchnorm_deprecation.py | 29 +++++++- tests/utils.py | 7 ++ 13 files changed, 169 insertions(+), 95 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index a3e18186..38719005 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -9,7 +9,33 @@ except ImportError: InPlaceABN = None -def handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: +def normalize_use_norm(decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(decoder_use_norm, str): + norm_str = decoder_use_norm.lower() + if norm_str == "inplace": + decoder_use_norm = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + elif norm_str in ( + "batchnorm", + "identity", + "layernorm", + "groupnorm", + "instancenorm", + ): + decoder_use_norm = {"type": norm_str} + else: + raise ValueError("Unrecognized normalization type string provided") + elif isinstance(decoder_use_norm, bool): + decoder_use_norm = {"type": "batchnorm" if decoder_use_norm else "identity"} + elif not isinstance(decoder_use_norm, dict): + raise ValueError("use_norm must be a dictionary, boolean, or string") + + return decoder_use_norm + +def normalize_decoder_norm(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: if decoder_use_batchnorm is not None: warnings.warn( "The usage of use_batchnorm is deprecated. Please modify your code for use_norm", @@ -28,57 +54,28 @@ def handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm: Union[bool, else: raise ValueError("Unrecognized value for use_batchnorm") + decoder_use_norm = normalize_use_norm(decoder_use_norm) return decoder_use_norm -def normalize_use_norm(use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: - if isinstance(use_norm, str): - norm_str = use_norm.lower() - if norm_str == "inplace": - use_norm = { - "type": "inplace", - "activation": "leaky_relu", - "activation_param": 0.0, - } - elif norm_str in ( - "batchnorm", - "identity", - "layernorm", - "groupnorm", - "instancenorm", - ): - use_norm = {"type": norm_str} - else: - raise ValueError("Unrecognized normalization type string provided") - elif isinstance(use_norm, bool): - use_norm = {"type": "batchnorm" if use_norm else "identity"} - elif not isinstance(use_norm, dict): - raise ValueError("use_norm must be a dictionary, boolean, or string") - - return use_norm - -def get_norm_layer(use_norm: Dict[str, Any], relu: nn.Module, out_channels: int) -> Tuple[nn.Module, nn.Module]: +def get_norm_layer(use_norm: Dict[str, Any], out_channels: int) -> nn.Module: norm_type = use_norm["type"] extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"} if norm_type == "inplace": norm = InPlaceABN(out_channels, **extra_kwargs) - relu = nn.Identity() elif norm_type == "batchnorm": norm = nn.BatchNorm2d(out_channels, **extra_kwargs) elif norm_type == "identity": norm = nn.Identity() elif norm_type == "layernorm": norm = nn.LayerNorm(out_channels, **extra_kwargs) - elif norm_type == "groupnorm": - norm = nn.GroupNorm(out_channels, **extra_kwargs) elif norm_type == "instancenorm": norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) else: raise ValueError(f"Unrecognized normalization type: {norm_type}") - return norm, relu - + return norm class Conv2dReLU(nn.Sequential): def __init__( @@ -105,9 +102,12 @@ def __init__( padding=padding, bias=use_norm["type"] != "inplace", ) - relu = nn.ReLU(inplace=True) + norm = get_norm_layer(use_norm, out_channels) - norm, relu = get_norm_layer(use_norm, relu, out_channels) + if use_norm["type"] == "inplace": + relu = nn.Identity() + else: + relu = nn.ReLU(inplace=True) super(Conv2dReLU, self).__init__(conv, norm, relu) diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index b7833210..4e671c37 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -10,7 +10,7 @@ def __init__( self, in_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() layers = [ @@ -20,8 +20,12 @@ def __init__( nn.ReLU(inplace=True), ] - if use_norm: - layers.insert(1, nn.BatchNorm2d(out_channels)) + if use_norm != "identity": + if isinstance(use_norm, dict): + if use_norm.get("type") != "identity": + layers.insert(1, modules.get_norm_layer(use_norm, out_channels)) + else: + layers.insert(1, modules.get_norm_layer(use_norm, out_channels)) super().__init__(*layers) diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 0a78564e..f2103cfb 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -7,7 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation +from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import LinknetDecoder @@ -40,7 +40,7 @@ class Linknet(SegmentationModel): - **True**: Defaults to `"batchnorm"`. - **False**: No normalization (`nn.Identity`). - **str**: Specifies normalization type using default parameters. Available values: - `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. - **dict**: Fully customizable normalization settings. Structure: ```python {"type": , **kwargs} @@ -49,7 +49,7 @@ class Linknet(SegmentationModel): **Example**: ```python - use_norm={"type": "groupnorm", "num_groups": 8} + use_norm={"type": "layernorm", "eps": 1e-2} ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -80,7 +80,7 @@ def __init__( encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: Union[bool, str, None] = True, - decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, @@ -102,7 +102,7 @@ def __init__( **kwargs, ) - decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) + decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) self.decoder = LinknetDecoder( encoder_channels=self.encoder.out_channels, n_blocks=encoder_depth, diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index f1675098..ded053a4 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -7,7 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation +from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import MAnetDecoder @@ -35,12 +35,12 @@ class MAnet(SegmentationModel): Available options are **True, False, "inplace"** **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. - decoder_use_norm: Specifies normalization between Conv2D and activation. + decoder_use_norm: Specifies normalization between Conv2D and activation. Accepts the following types: - **True**: Defaults to `"batchnorm"`. - **False**: No normalization (`nn.Identity`). - **str**: Specifies normalization type using default parameters. Available values: - `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. - **dict**: Fully customizable normalization settings. Structure: ```python {"type": , **kwargs} @@ -49,7 +49,7 @@ class MAnet(SegmentationModel): **Example**: ```python - use_norm={"type": "groupnorm", "num_groups": 8} + use_norm={"type": "layernorm", "eps": 1e-2} ``` decoder_pab_channels: A number of channels for PAB module in decoder. Default is 64. @@ -83,7 +83,7 @@ def __init__( encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: Union[bool, str, None] = True, - decoder_use_norm: Union[bool, str, Dict[str, Any], None] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, in_channels: int = 3, @@ -102,7 +102,7 @@ def __init__( **kwargs, ) - decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) + decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) self.decoder = MAnetDecoder( encoder_channels=self.encoder.out_channels, diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py index 42ac42d0..547c563d 100644 --- a/segmentation_models_pytorch/decoders/pspnet/decoder.py +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import List, Tuple from segmentation_models_pytorch.base import modules @@ -12,17 +13,17 @@ def __init__( in_channels: int, out_channels: int, pool_size: int, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() if pool_size == 1: - use_batchnorm = False # PyTorch does not support BatchNorm for 1x1 shape + use_norm = "identity" # PyTorch does not support BatchNorm for 1x1 shape self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), modules.Conv2dReLU( - in_channels, out_channels, (1, 1), use_batchnorm=use_batchnorm + in_channels, out_channels, (1, 1), use_norm=use_norm ), ) @@ -38,7 +39,7 @@ def __init__( self, in_channels: int, sizes: Tuple[int, ...] = (1, 2, 3, 6), - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -48,7 +49,7 @@ def __init__( in_channels, in_channels // len(sizes), size, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) for size in sizes ] @@ -64,7 +65,7 @@ class PSPDecoder(nn.Module): def __init__( self, encoder_channels: List[int], - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", out_channels: int = 512, dropout: float = 0.2, ): @@ -73,14 +74,14 @@ def __init__( self.psp = PSPModule( in_channels=encoder_channels[-1], sizes=(1, 2, 3, 6), - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv = modules.Conv2dReLU( in_channels=encoder_channels[-1] * 2, out_channels=out_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.dropout = nn.Dropout2d(p=dropout) diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 8b99b3da..9e79bf1b 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from segmentation_models_pytorch.base import ( ClassificationHead, @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.modules import normalize_decoder_norm from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import PSPDecoder @@ -28,9 +29,27 @@ class PSPNet(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) psp_out_channels: A number of filters in Spatial Pyramid - psp_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + psp_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. Available options are **True, False, "inplace"** + + **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "layernorm", "eps": 1e-2} + ``` psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -62,7 +81,8 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, - psp_use_batchnorm: bool = True, + psp_use_batchnorm: Union[bool, str, None] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any], None] = "batchnorm", psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, @@ -80,10 +100,10 @@ def __init__( weights=encoder_weights, **kwargs, ) - + decoder_use_norm = normalize_decoder_norm(psp_use_batchnorm, decoder_use_norm) self.decoder = PSPDecoder( encoder_channels=self.encoder.out_channels, - use_batchnorm=psp_use_batchnorm, + use_norm=decoder_use_norm, out_channels=psp_out_channels, dropout=psp_dropout, ) diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 133a95cf..69fb1ac9 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -7,7 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation +from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import UnetDecoder @@ -50,7 +50,7 @@ class Unet(SegmentationModel): - **True**: Defaults to `"batchnorm"`. - **False**: No normalization (`nn.Identity`). - **str**: Specifies normalization type using default parameters. Available values: - `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. - **dict**: Fully customizable normalization settings. Structure: ```python {"type": , **kwargs} @@ -59,7 +59,7 @@ class Unet(SegmentationModel): **Example**: ```python - use_norm={"type": "groupnorm", "num_groups": 8} + use_norm={"type": "layernorm", "eps": 1e-2} ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). @@ -115,7 +115,7 @@ def __init__( encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: Union[bool, str, None] = True, - decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, decoder_interpolation_mode: str = "nearest", @@ -136,7 +136,7 @@ def __init__( ) add_center_block = encoder_name.startswith("vgg") - decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) + decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 17dcc18b..ea1df1c8 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -7,7 +7,7 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation +from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import UnetPlusPlusDecoder @@ -39,7 +39,7 @@ class UnetPlusPlus(SegmentationModel): - **True**: Defaults to `"batchnorm"`. - **False**: No normalization (`nn.Identity`). - **str**: Specifies normalization type using default parameters. Available values: - `"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`. + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. - **dict**: Fully customizable normalization settings. Structure: ```python {"type": , **kwargs} @@ -48,7 +48,7 @@ class UnetPlusPlus(SegmentationModel): **Example**: ```python - use_norm={"type": "groupnorm", "num_groups": 8} + use_norm={"type": "layernorm", "eps": 1e-2} ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). @@ -84,7 +84,7 @@ def __init__( encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: Union[bool, str, None] = True, - decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, @@ -108,7 +108,7 @@ def __init__( **kwargs, ) - decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm) + decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) self.decoder = UnetPlusPlusDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 99c74fb1..97a794b4 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,10 +10,10 @@ class PSPModule(nn.Module): def __init__( self, - in_channels, - out_channels, - sizes=(1, 2, 3, 6), - use_batchnorm=True, + in_channels: int, + out_channels: int, + sizes: Tuple[int, ...] = (1, 2, 3, 6), + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() self.blocks = nn.ModuleList( @@ -22,7 +24,7 @@ def __init__( in_channels, in_channels // len(sizes), kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ), ) for size in sizes @@ -32,7 +34,7 @@ def __init__( in_channels=in_channels * 2, out_channels=out_channels, kernel_size=1, - use_batchnorm=True, + use_norm="batchnorm", ) def forward(self, x): @@ -48,14 +50,14 @@ def forward(self, x): class FPNBlock(nn.Module): - def __init__(self, skip_channels, pyramid_channels, use_batchnorm=True): + def __init__(self, skip_channels: int, pyramid_channels: int, use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm"): super().__init__() self.skip_conv = ( md.Conv2dReLU( skip_channels, pyramid_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) if skip_channels != 0 else nn.Identity() @@ -73,10 +75,11 @@ def forward(self, x, skip): class UPerNetDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth=5, - pyramid_channels=256, - segmentation_channels=64, + encoder_channels: Tuple[int, ...], + encoder_depth: int = 5, + pyramid_channels: int = 256, + segmentation_channels: int = 64, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -94,7 +97,7 @@ def __init__( in_channels=encoder_channels[0], out_channels=pyramid_channels, sizes=(1, 2, 3, 6), - use_batchnorm=True, + use_norm=use_norm, ) # FPN Module @@ -107,7 +110,7 @@ def __init__( out_channels=segmentation_channels, kernel_size=3, padding=1, - use_batchnorm=True, + use_norm=use_norm, ) def forward(self, features): diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 7ffeee5b..caae60c2 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from segmentation_models_pytorch.base import ( ClassificationHead, @@ -25,6 +25,22 @@ class UPerNet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "layernorm", "eps": 1e-2} + ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. @@ -58,6 +74,7 @@ def __init__( encoder_weights: Optional[str] = "imagenet", decoder_pyramid_channels: int = 256, decoder_segmentation_channels: int = 64, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, @@ -79,6 +96,7 @@ def __init__( encoder_depth=encoder_depth, pyramid_channels=decoder_pyramid_channels, segmentation_channels=decoder_segmentation_channels, + use_norm=decoder_use_norm, ) self.segmentation_head = SegmentationHead( diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py index 702d7cda..b64834c8 100644 --- a/tests/base/test_modules.py +++ b/tests/base/test_modules.py @@ -33,12 +33,6 @@ def test_conv2drelu_layernorm(): '\n (2): ReLU(inplace=True)\n)') assert repr(module) == expected -def test_conv2drelu_groupnorm(): - module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="groupnorm") - expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' - '\n (1): Identity()' - '\n (2): ReLU(inplace=True)\n)') - assert repr(module) == expected def test_conv2drelu_instancenorm(): module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="instancenorm") diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py index 65ec1518..e0d05ed9 100644 --- a/tests/encoders/test_batchnorm_deprecation.py +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -1 +1,28 @@ -# TODO: Simple equal before / after \ No newline at end of file +import pytest + +import torch + +from segmentation_models_pytorch import create_model +from tests.utils import check_two_models_strictly_equal + + +@pytest.mark.parametrize("model_name", ["unet", "unetplusplus", "linknet", "manet"]) +@pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) +def test_seg_models_before_after_use_norm(model_name, decoder_option): + torch.manual_seed(42) + model_decoder_batchnorm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=decoder_option) + torch.manual_seed(42) + model_decoder_norm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=None, decoder_use_norm=decoder_option) + + check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm) + + + +@pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) +def test_pspnet_before_after_use_norm(decoder_option): + torch.manual_seed(42) + model_decoder_batchnorm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=decoder_option) + torch.manual_seed(42) + model_decoder_norm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=None, decoder_use_norm=decoder_option) + + check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm) diff --git a/tests/utils.py b/tests/utils.py index 6e201f1d..ff7e9792 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -58,3 +58,10 @@ def check_run_test_on_diff_or_main(filepath_patterns: List[str]): return True return False + + +def check_two_models_strictly_equal(model_a, model_b): + for (k1, v1), (k2, v2) in zip(model_a.state_dict().items(), + model_b.state_dict().items()): + assert k1 == k2, f"Key mismatch: {k1} != {k2}" + assert v1.shape == v2.shape, f"Shape mismatch in {k1}: {v1.shape} != {v2.shape}" From 467057a92928bb283ccde29bb173dcf170ccb2e1 Mon Sep 17 00:00:00 2001 From: GuillaumeErhard <25333848+GuillaumeErhard@users.noreply.github.com> Date: Wed, 26 Mar 2025 00:07:30 +0100 Subject: [PATCH 04/21] Set use_batchnorm default to None so that default use_norm Make warning visible by changing filter and add a test for it Fix test before after so that the value is looked and not the shape of tensor --- segmentation_models_pytorch/base/modules.py | 1 + .../decoders/linknet/model.py | 2 +- .../decoders/manet/model.py | 2 +- .../decoders/pspnet/model.py | 2 +- .../decoders/unet/model.py | 2 +- .../decoders/unetplusplus/model.py | 2 +- tests/encoders/test_batchnorm_deprecation.py | 16 ++++++++++++++-- tests/utils.py | 4 ++-- 8 files changed, 22 insertions(+), 9 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index 38719005..1c334308 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -40,6 +40,7 @@ def normalize_decoder_norm(decoder_use_batchnorm: Union[bool, str, None], decode warnings.warn( "The usage of use_batchnorm is deprecated. Please modify your code for use_norm", DeprecationWarning, + stacklevel=2 ) if decoder_use_batchnorm is True: decoder_use_norm = {"type": "batchnorm"} diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index f2103cfb..941cd561 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -79,7 +79,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_batchnorm: Union[bool, str, None] = None, decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index ded053a4..91802320 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -82,7 +82,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_batchnorm: Union[bool, str, None] = None, decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 9e79bf1b..4e4b1980 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -81,7 +81,7 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, - psp_use_batchnorm: Union[bool, str, None] = True, + psp_use_batchnorm: Union[bool, str, None] = None, decoder_use_norm: Union[bool, str, Dict[str, Any], None] = "batchnorm", psp_dropout: float = 0.2, in_channels: int = 3, diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 69fb1ac9..955150ac 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -114,7 +114,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_batchnorm: Union[bool, str, None] = None, decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index ea1df1c8..36a57300 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -83,7 +83,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = True, + decoder_use_batchnorm: Union[bool, str, None] = None, decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py index e0d05ed9..caf87bcb 100644 --- a/tests/encoders/test_batchnorm_deprecation.py +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -10,7 +10,13 @@ @pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) def test_seg_models_before_after_use_norm(model_name, decoder_option): torch.manual_seed(42) - model_decoder_batchnorm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=decoder_option) + with pytest.warns(DeprecationWarning): + model_decoder_batchnorm = create_model( + model_name, + "mobilenet_v2", + None, + decoder_use_batchnorm=decoder_option + ) torch.manual_seed(42) model_decoder_norm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=None, decoder_use_norm=decoder_option) @@ -21,7 +27,13 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option): @pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) def test_pspnet_before_after_use_norm(decoder_option): torch.manual_seed(42) - model_decoder_batchnorm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=decoder_option) + with pytest.warns(DeprecationWarning): + model_decoder_batchnorm = create_model( + "pspnet", + "mobilenet_v2", + None, + psp_use_batchnorm=decoder_option + ) torch.manual_seed(42) model_decoder_norm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=None, decoder_use_norm=decoder_option) diff --git a/tests/utils.py b/tests/utils.py index ff7e9792..fa9efc2c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -60,8 +60,8 @@ def check_run_test_on_diff_or_main(filepath_patterns: List[str]): return False -def check_two_models_strictly_equal(model_a, model_b): +def check_two_models_strictly_equal(model_a: torch.nn.Module, model_b: torch.nn.Module) -> None: for (k1, v1), (k2, v2) in zip(model_a.state_dict().items(), model_b.state_dict().items()): assert k1 == k2, f"Key mismatch: {k1} != {k2}" - assert v1.shape == v2.shape, f"Shape mismatch in {k1}: {v1.shape} != {v2.shape}" + assert (v1 == v2).all(), f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" From 1ae11c31df6a1c275d94ed322a4c8dc8af782d8c Mon Sep 17 00:00:00 2001 From: GuillaumeErhard <25333848+GuillaumeErhard@users.noreply.github.com> Date: Thu, 27 Mar 2025 01:11:47 +0100 Subject: [PATCH 05/21] Changes following review --- segmentation_models_pytorch/base/modules.py | 84 ++++++++----------- .../decoders/linknet/model.py | 20 +++-- .../decoders/manet/model.py | 19 +++-- .../decoders/pspnet/model.py | 21 +++-- .../decoders/unet/model.py | 21 +++-- .../decoders/unetplusplus/model.py | 18 ++-- tests/base/test_modules.py | 51 ++++++----- tests/encoders/test_batchnorm_deprecation.py | 4 +- tests/utils.py | 5 +- 9 files changed, 124 insertions(+), 119 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index 1c334308..9bd342e3 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -1,5 +1,4 @@ -from typing import Any, Dict, Tuple, Union -import warnings +from typing import Any, Dict, Union import torch import torch.nn as nn @@ -9,11 +8,18 @@ except ImportError: InPlaceABN = None -def normalize_use_norm(decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: - if isinstance(decoder_use_norm, str): - norm_str = decoder_use_norm.lower() +def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int) -> nn.Module: + supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") + if use_norm is True: + norm_params = {"type": "batchnorm"} + elif use_norm is False: + norm_params = {"type": "identity"} + elif use_norm == "inplace": + norm_params = {"type": "inplace", "activation": "leaky_relu", "activation_param": 0.0} + elif isinstance(use_norm, str): + norm_str = use_norm.lower() if norm_str == "inplace": - decoder_use_norm = { + norm_params = { "type": "inplace", "activation": "leaky_relu", "activation_param": 0.0, @@ -22,48 +28,31 @@ def normalize_use_norm(decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Di "batchnorm", "identity", "layernorm", - "groupnorm", "instancenorm", ): - decoder_use_norm = {"type": norm_str} + norm_params = {"type": norm_str} else: - raise ValueError("Unrecognized normalization type string provided") - elif isinstance(decoder_use_norm, bool): - decoder_use_norm = {"type": "batchnorm" if decoder_use_norm else "identity"} - elif not isinstance(decoder_use_norm, dict): - raise ValueError("use_norm must be a dictionary, boolean, or string") - - return decoder_use_norm - -def normalize_decoder_norm(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: - if decoder_use_batchnorm is not None: - warnings.warn( - "The usage of use_batchnorm is deprecated. Please modify your code for use_norm", - DeprecationWarning, - stacklevel=2 - ) - if decoder_use_batchnorm is True: - decoder_use_norm = {"type": "batchnorm"} - elif decoder_use_batchnorm is False: - decoder_use_norm = {"type": "identity"} - elif decoder_use_batchnorm == "inplace": - decoder_use_norm = { - "type": "inplace", - "activation": "leaky_relu", - "activation_param": 0.0, - } - else: - raise ValueError("Unrecognized value for use_batchnorm") + raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}") + elif isinstance(use_norm, dict): + norm_params = use_norm + else: + raise ValueError("use_norm must be a dictionary, boolean, or string. Please refer to the documentation.") - decoder_use_norm = normalize_use_norm(decoder_use_norm) - return decoder_use_norm + if not "type" in norm_params: + raise ValueError(f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'.") + if norm_params["type"] not in supported_norms: + raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}") -def get_norm_layer(use_norm: Dict[str, Any], out_channels: int) -> nn.Module: - norm_type = use_norm["type"] - extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"} + norm_type = norm_params["type"] + extra_kwargs = {k: v for k, v in norm_params.items() if k != "type"} - if norm_type == "inplace": + if norm_type == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " + "To install see: https://github.com/mapillary/inplace_abn" + ) + elif norm_type == "inplace": norm = InPlaceABN(out_channels, **extra_kwargs) elif norm_type == "batchnorm": norm = nn.BatchNorm2d(out_channels, **extra_kwargs) @@ -88,24 +77,17 @@ def __init__( stride=1, use_norm="batchnorm", ): - use_norm = normalize_use_norm(use_norm) - if use_norm["type"] == "inplace" and InPlaceABN is None: - raise RuntimeError( - "In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " - "To install see: https://github.com/mapillary/inplace_abn" - ) - + norm = get_norm_layer(use_norm, out_channels) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - bias=use_norm["type"] != "inplace", + bias=norm._get_name() != "BatchNorm2d", ) - norm = get_norm_layer(use_norm, out_channels) - if use_norm["type"] == "inplace": + if norm._get_name() == "Inplace": relu = nn.Identity() else: relu = nn.ReLU(inplace=True) diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 941cd561..39ede764 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, Optional, Union from segmentation_models_pytorch.base import ( @@ -7,7 +8,6 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import LinknetDecoder @@ -30,11 +30,6 @@ class Linknet(SegmentationModel): Default is 5 encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) - decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** - - **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. decoder_use_norm: Specifies normalization between Conv2D and activation. Accepts the following types: - **True**: Defaults to `"batchnorm"`. @@ -79,8 +74,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = None, - decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, @@ -102,7 +96,15 @@ def __init__( **kwargs, ) - decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", + DeprecationWarning, + stacklevel=2 + ) + decoder_use_norm = decoder_use_batchnorm + self.decoder = LinknetDecoder( encoder_channels=self.encoder.out_channels, n_blocks=encoder_depth, diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 91802320..7b8e90e5 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import ( @@ -7,7 +8,6 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import MAnetDecoder @@ -30,11 +30,6 @@ class MAnet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** - - **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. decoder_use_norm: Specifies normalization between Conv2D and activation. Accepts the following types: - **True**: Defaults to `"batchnorm"`. @@ -82,8 +77,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = None, - decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, in_channels: int = 3, @@ -102,7 +96,14 @@ def __init__( **kwargs, ) - decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", + DeprecationWarning, + stacklevel=2 + ) + decoder_use_norm = decoder_use_batchnorm self.decoder = MAnetDecoder( encoder_channels=self.encoder.out_channels, diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 4e4b1980..16ba8afa 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, Optional, Union from segmentation_models_pytorch.base import ( @@ -6,7 +7,6 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder -from segmentation_models_pytorch.base.modules import normalize_decoder_norm from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import PSPDecoder @@ -29,11 +29,6 @@ class PSPNet(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) psp_out_channels: A number of filters in Spatial Pyramid - psp_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** - - **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. decoder_use_norm: Specifies normalization between Conv2D and activation. Accepts the following types: - **True**: Defaults to `"batchnorm"`. @@ -81,8 +76,7 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, - psp_use_batchnorm: Union[bool, str, None] = None, - decoder_use_norm: Union[bool, str, Dict[str, Any], None] = "batchnorm", + decoder_use_norm: Union[bool, str, Dict[str, Any], None] = True, psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, @@ -100,7 +94,16 @@ def __init__( weights=encoder_weights, **kwargs, ) - decoder_use_norm = normalize_decoder_norm(psp_use_batchnorm, decoder_use_norm) + + psp_use_batchnorm = kwargs.pop("psp_use_batchnorm", None) + if psp_use_batchnorm is not None: + warnings.warn( + "The usage of psp_use_batchnorm is deprecated. Please modify your code for use_norm", + DeprecationWarning, + stacklevel=2 + ) + decoder_use_norm = psp_use_batchnorm + self.decoder = PSPDecoder( encoder_channels=self.encoder.out_channels, use_norm=decoder_use_norm, diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 955150ac..dd295b34 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, Optional, Union, Callable, Sequence from segmentation_models_pytorch.base import ( @@ -7,7 +8,6 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import UnetDecoder @@ -40,11 +40,6 @@ class Unet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** - - **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. decoder_use_norm: Specifies normalization between Conv2D and activation. Accepts the following types: - **True**: Defaults to `"batchnorm"`. @@ -114,8 +109,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = None, - decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, decoder_interpolation_mode: str = "nearest", @@ -136,7 +130,16 @@ def __init__( ) add_center_block = encoder_name.startswith("vgg") - decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) + + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", + DeprecationWarning, + stacklevel=2 + ) + decoder_use_norm = decoder_use_batchnorm + self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 36a57300..66983255 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import ( @@ -7,7 +8,6 @@ ) from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base.hub_mixin import supports_config_loading -from segmentation_models_pytorch.base.modules import normalize_decoder_norm from .decoder import UnetPlusPlusDecoder @@ -29,11 +29,6 @@ class UnetPlusPlus(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** - - **Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. decoder_use_norm: Specifies normalization between Conv2D and activation. Accepts the following types: - **True**: Defaults to `"batchnorm"`. @@ -83,7 +78,6 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: Union[bool, str, None] = None, decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, @@ -108,7 +102,15 @@ def __init__( **kwargs, ) - decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", + DeprecationWarning, + stacklevel=2 + ) + decoder_use_norm = decoder_use_batchnorm + self.decoder = UnetPlusPlusDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py index b64834c8..11ebd8c9 100644 --- a/tests/base/test_modules.py +++ b/tests/base/test_modules.py @@ -1,42 +1,51 @@ +from torch import nn + +from inplace_abn import InPlaceABN from segmentation_models_pytorch.base.modules import Conv2dReLU def test_conv2drelu_batchnorm(): module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="batchnorm") - expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' - '\n (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)' - '\n (2): ReLU(inplace=True)\n)') - assert repr(module) == expected + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.BatchNorm2d) + assert isinstance(module[2], nn.ReLU) def test_conv2drelu_batchnorm_with_keywords(): module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False}) - expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' - '\n (1): BatchNorm2d(16, eps=1e-05, momentum=0.0001, affine=False, track_running_stats=True)' - '\n (2): ReLU(inplace=True)\n)') - assert repr(module) == expected - + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.BatchNorm2d) + assert module[1].momentum == 1e-4 and module[1].affine == False + assert isinstance(module[2], nn.ReLU) def test_conv2drelu_identity(): module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="identity") - expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' - '\n (1): Identity()' - '\n (2): ReLU(inplace=True)\n)') - assert repr(module) == expected + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.Identity) + assert isinstance(module[2], nn.ReLU) def test_conv2drelu_layernorm(): module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="layernorm") - expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' - '\n (1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)' - '\n (2): ReLU(inplace=True)\n)') - assert repr(module) == expected + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.LayerNorm) + assert isinstance(module[2], nn.ReLU) def test_conv2drelu_instancenorm(): module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="instancenorm") - expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))' - '\n (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)' - '\n (2): ReLU(inplace=True)\n)') - assert repr(module) == expected + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.InstanceNorm2d) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_inplace(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="inplace") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], InPlaceABN) + assert isinstance(module[2], nn.ReLU) \ No newline at end of file diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py index caf87bcb..271c865e 100644 --- a/tests/encoders/test_batchnorm_deprecation.py +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -20,7 +20,7 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option): torch.manual_seed(42) model_decoder_norm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=None, decoder_use_norm=decoder_option) - check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm) + check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)) @@ -37,4 +37,4 @@ def test_pspnet_before_after_use_norm(decoder_option): torch.manual_seed(42) model_decoder_norm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=None, decoder_use_norm=decoder_option) - check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm) + check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)) diff --git a/tests/utils.py b/tests/utils.py index fa9efc2c..0b827942 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -60,8 +60,11 @@ def check_run_test_on_diff_or_main(filepath_patterns: List[str]): return False -def check_two_models_strictly_equal(model_a: torch.nn.Module, model_b: torch.nn.Module) -> None: +def check_two_models_strictly_equal(model_a: torch.nn.Module, model_b: torch.nn.Module, input_data: torch.Tensor) -> None: for (k1, v1), (k2, v2) in zip(model_a.state_dict().items(), model_b.state_dict().items()): assert k1 == k2, f"Key mismatch: {k1} != {k2}" assert (v1 == v2).all(), f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" + + with torch.inference_mode(): + assert (model_a(input_data) == model_b(input_data)).all() From be229515e756950fd1cedc99b716db764e57a227 Mon Sep 17 00:00:00 2001 From: GuillaumeErhard <25333848+GuillaumeErhard@users.noreply.github.com> Date: Sun, 30 Mar 2025 18:01:05 +0200 Subject: [PATCH 06/21] Revert default value to batchnorm Fix typo in decoder_use_norm doc Add description to invalid type error --- segmentation_models_pytorch/base/modules.py | 11 ++++++++--- .../decoders/linknet/decoder.py | 4 ++-- segmentation_models_pytorch/decoders/linknet/model.py | 4 ++-- segmentation_models_pytorch/decoders/manet/decoder.py | 6 +++--- segmentation_models_pytorch/decoders/manet/model.py | 4 ++-- segmentation_models_pytorch/decoders/pspnet/model.py | 4 ++-- segmentation_models_pytorch/decoders/unet/decoder.py | 6 +++--- segmentation_models_pytorch/decoders/unet/model.py | 4 ++-- .../decoders/unetplusplus/decoder.py | 6 +++--- .../decoders/unetplusplus/model.py | 2 +- 10 files changed, 28 insertions(+), 23 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index 9bd342e3..e723cae7 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -8,6 +8,7 @@ except ImportError: InPlaceABN = None + def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int) -> nn.Module: supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") if use_norm is True: @@ -32,12 +33,15 @@ def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int ): norm_params = {"type": norm_str} else: - raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}") + raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in " + f"{supported_norms}") elif isinstance(use_norm, dict): norm_params = use_norm else: - raise ValueError("use_norm must be a dictionary, boolean, or string. Please refer to the documentation.") - + raise ValueError( + f"Invalid type for use_norm should either be a bool (batchnorm/identity), " + f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}" + ) if not "type" in norm_params: raise ValueError(f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'.") @@ -67,6 +71,7 @@ def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int return norm + class Conv2dReLU(nn.Sequential): def __init__( self, diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index 4e671c37..d46e4ec1 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -35,7 +35,7 @@ def __init__( self, in_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -72,7 +72,7 @@ def __init__( encoder_channels: List[int], prefinal_channels: int = 32, n_blocks: int = 5, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 39ede764..ad559bd6 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -44,7 +44,7 @@ class Linknet(SegmentationModel): **Example**: ```python - use_norm={"type": "layernorm", "eps": 1e-2} + decoder_use_norm={"type": "layernorm", "eps": 1e-2} ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -74,7 +74,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, callable]] = None, diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py index b18011bc..07bd5384 100644 --- a/segmentation_models_pytorch/decoders/manet/decoder.py +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -49,7 +49,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", reduction: int = 16, ): # MFABBlock is just a modified version of SE-blocks, one for skip, one for input @@ -122,7 +122,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() self.conv1 = md.Conv2dReLU( @@ -158,7 +158,7 @@ def __init__( decoder_channels: List[int], n_blocks: int = 5, reduction: int = 16, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", pab_channels: int = 64, ): super().__init__() diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 7b8e90e5..c3691123 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -44,7 +44,7 @@ class MAnet(SegmentationModel): **Example**: ```python - use_norm={"type": "layernorm", "eps": 1e-2} + decoder_use_norm={"type": "layernorm", "eps": 1e-2} ``` decoder_pab_channels: A number of channels for PAB module in decoder. Default is 64. @@ -77,7 +77,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, in_channels: int = 3, diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 16ba8afa..dde6a0f1 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -43,7 +43,7 @@ class PSPNet(SegmentationModel): **Example**: ```python - use_norm={"type": "layernorm", "eps": 1e-2} + decoder_use_norm={"type": "layernorm", "eps": 1e-2} ``` psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid in_channels: A number of input channels for the model, default is 3 (RGB images) @@ -76,7 +76,7 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, - decoder_use_norm: Union[bool, str, Dict[str, Any], None] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any], None] = "batchnorm", psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index adf98d66..cfeb267e 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -15,7 +15,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, interpolation_mode: str = "nearest", ): @@ -68,7 +68,7 @@ def __init__( self, in_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): conv1 = md.Conv2dReLU( in_channels, @@ -99,7 +99,7 @@ def __init__( encoder_channels: Sequence[int], decoder_channels: Sequence[int], n_blocks: int = 5, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, add_center_block: bool = False, interpolation_mode: str = "nearest", diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index dd295b34..0bc99d1f 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -54,7 +54,7 @@ class Unet(SegmentationModel): **Example**: ```python - use_norm={"type": "layernorm", "eps": 1e-2} + decoder_use_norm={"type": "layernorm", "eps": 1e-2} ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). @@ -109,7 +109,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, decoder_interpolation_mode: str = "nearest", diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py index b1970104..42d7b338 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -13,7 +13,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, ): super().__init__() @@ -54,7 +54,7 @@ def __init__( self, in_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): conv1 = md.Conv2dReLU( in_channels, @@ -79,7 +79,7 @@ def __init__( encoder_channels: List[int], decoder_channels: List[int], n_blocks: int = 5, - use_norm: Union[bool, str, Dict[str, Any]] = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, center: bool = False, ): diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 66983255..29872661 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -43,7 +43,7 @@ class UnetPlusPlus(SegmentationModel): **Example**: ```python - use_norm={"type": "layernorm", "eps": 1e-2} + decoder_use_norm={"type": "layernorm", "eps": 1e-2} ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). From 7c883613c9cc8a3af0c1599cd5804291818282d4 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:12:53 +0000 Subject: [PATCH 07/21] Minor style fixes --- segmentation_models_pytorch/base/modules.py | 86 ++++++++++++--------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index e723cae7..5febe53c 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -9,14 +9,20 @@ InPlaceABN = None -def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int) -> nn.Module: +def get_norm_layer( + use_norm: Union[bool, str, Dict[str, Any]], out_channels: int +) -> nn.Module: supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") + + # Step 1. Convert tot dict representation + + ## Check boolean if use_norm is True: norm_params = {"type": "batchnorm"} elif use_norm is False: norm_params = {"type": "identity"} - elif use_norm == "inplace": - norm_params = {"type": "inplace", "activation": "leaky_relu", "activation_param": 0.0} + + ## Check string elif isinstance(use_norm, str): norm_str = use_norm.lower() if norm_str == "inplace": @@ -25,47 +31,53 @@ def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int "activation": "leaky_relu", "activation_param": 0.0, } - elif norm_str in ( - "batchnorm", - "identity", - "layernorm", - "instancenorm", - ): + elif norm_str in supported_norms: norm_params = {"type": norm_str} else: - raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in " - f"{supported_norms}") + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in " + f"{supported_norms}" + ) + + ## Check dict elif isinstance(use_norm, dict): norm_params = use_norm + else: raise ValueError( f"Invalid type for use_norm should either be a bool (batchnorm/identity), " f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}" ) - if not "type" in norm_params: - raise ValueError(f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'.") + # Step 2. Check if the dict is valid + if "type" not in norm_params: + raise ValueError( + f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'." + ) if norm_params["type"] not in supported_norms: - raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}") - - norm_type = norm_params["type"] - extra_kwargs = {k: v for k, v in norm_params.items() if k != "type"} - - if norm_type == "inplace" and InPlaceABN is None: + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}" + ) + if norm_params["type"] == "inplace" and InPlaceABN is None: raise RuntimeError( - "In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " + "In order to use `use_norm='inplace'` the inplace_abn package must be installed. " "To install see: https://github.com/mapillary/inplace_abn" ) - elif norm_type == "inplace": - norm = InPlaceABN(out_channels, **extra_kwargs) + + # Step 3. Initialize the norm layer + norm_type = norm_params["type"] + norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"} + + if norm_type == "inplace": + norm = InPlaceABN(out_channels, **norm_kwargs) elif norm_type == "batchnorm": - norm = nn.BatchNorm2d(out_channels, **extra_kwargs) + norm = nn.BatchNorm2d(out_channels, **norm_kwargs) elif norm_type == "identity": norm = nn.Identity() elif norm_type == "layernorm": - norm = nn.LayerNorm(out_channels, **extra_kwargs) + norm = nn.LayerNorm(out_channels, **norm_kwargs) elif norm_type == "instancenorm": - norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) + norm = nn.InstanceNorm2d(out_channels, **norm_kwargs) else: raise ValueError(f"Unrecognized normalization type: {norm_type}") @@ -75,29 +87,29 @@ def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int class Conv2dReLU(nn.Sequential): def __init__( self, - in_channels, - out_channels, - kernel_size, - padding=0, - stride=1, - use_norm="batchnorm", + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + stride: int = 1, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): norm = get_norm_layer(use_norm, out_channels) + + is_batchnorm = isinstance(norm, nn.BatchNorm2d) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - bias=norm._get_name() != "BatchNorm2d", + bias=is_batchnorm, ) - if norm._get_name() == "Inplace": - relu = nn.Identity() - else: - relu = nn.ReLU(inplace=True) + is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN) + activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True) - super(Conv2dReLU, self).__init__(conv, norm, relu) + super(Conv2dReLU, self).__init__(conv, norm, activation) class SCSEModule(nn.Module): From 1255ee08b40bde25853728f80444cdeaca4c1860 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:15:40 +0000 Subject: [PATCH 08/21] Fixup --- .../decoders/linknet/decoder.py | 4 +- .../decoders/linknet/model.py | 2 +- .../decoders/manet/decoder.py | 4 +- .../decoders/manet/model.py | 2 +- .../decoders/pspnet/decoder.py | 4 +- .../decoders/pspnet/model.py | 2 +- .../decoders/unet/model.py | 2 +- .../decoders/unetplusplus/model.py | 2 +- .../decoders/upernet/decoder.py | 7 +++- tests/base/test_modules.py | 14 +++++-- tests/encoders/test_batchnorm_deprecation.py | 37 ++++++++++++------- tests/utils.py | 9 +++-- 12 files changed, 54 insertions(+), 35 deletions(-) diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index d46e4ec1..dc15ac65 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -46,9 +46,7 @@ def __init__( kernel_size=1, use_norm=use_norm, ), - TransposeX2( - in_channels // 4, in_channels // 4, use_norm=use_norm - ), + TransposeX2(in_channels // 4, in_channels // 4, use_norm=use_norm), modules.Conv2dReLU( in_channels // 4, out_channels, diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index ad559bd6..894444f7 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -101,7 +101,7 @@ def __init__( warnings.warn( "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) decoder_use_norm = decoder_use_batchnorm diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py index 07bd5384..ae2498c7 100644 --- a/segmentation_models_pytorch/decoders/manet/decoder.py +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -185,9 +185,7 @@ def __init__( self.center = PABBlock(head_channels, pab_channels=pab_channels) # combine decoder keyword arguments - kwargs = dict( - use_norm=use_norm - ) # no attention type here + kwargs = dict(use_norm=use_norm) # no attention type here blocks = [ MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index c3691123..2beaaf25 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -101,7 +101,7 @@ def __init__( warnings.warn( "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) decoder_use_norm = decoder_use_batchnorm diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py index 547c563d..ae0fda43 100644 --- a/segmentation_models_pytorch/decoders/pspnet/decoder.py +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -22,9 +22,7 @@ def __init__( self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), - modules.Conv2dReLU( - in_channels, out_channels, (1, 1), use_norm=use_norm - ), + modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_norm=use_norm), ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index dde6a0f1..44cbbc6e 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -100,7 +100,7 @@ def __init__( warnings.warn( "The usage of psp_use_batchnorm is deprecated. Please modify your code for use_norm", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) decoder_use_norm = psp_use_batchnorm diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 0bc99d1f..aa0b5c5e 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -136,7 +136,7 @@ def __init__( warnings.warn( "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) decoder_use_norm = decoder_use_batchnorm diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 29872661..ce5106b5 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -107,7 +107,7 @@ def __init__( warnings.warn( "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) decoder_use_norm = decoder_use_batchnorm diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 97a794b4..fa6f9b05 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -50,7 +50,12 @@ def forward(self, x): class FPNBlock(nn.Module): - def __init__(self, skip_channels: int, pyramid_channels: int, use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm"): + def __init__( + self, + skip_channels: int, + pyramid_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() self.skip_conv = ( md.Conv2dReLU( diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py index 11ebd8c9..13522484 100644 --- a/tests/base/test_modules.py +++ b/tests/base/test_modules.py @@ -11,14 +11,22 @@ def test_conv2drelu_batchnorm(): assert isinstance(module[1], nn.BatchNorm2d) assert isinstance(module[2], nn.ReLU) + def test_conv2drelu_batchnorm_with_keywords(): - module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False}) + module = Conv2dReLU( + 3, + 16, + kernel_size=3, + padding=1, + use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False}, + ) assert isinstance(module[0], nn.Conv2d) assert isinstance(module[1], nn.BatchNorm2d) - assert module[1].momentum == 1e-4 and module[1].affine == False + assert module[1].momentum == 1e-4 and module[1].affine is False assert isinstance(module[2], nn.ReLU) + def test_conv2drelu_identity(): module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="identity") @@ -48,4 +56,4 @@ def test_conv2drelu_inplace(): assert isinstance(module[0], nn.Conv2d) assert isinstance(module[1], InPlaceABN) - assert isinstance(module[2], nn.ReLU) \ No newline at end of file + assert isinstance(module[2], nn.ReLU) diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py index 271c865e..f7186988 100644 --- a/tests/encoders/test_batchnorm_deprecation.py +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -12,16 +12,20 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option): torch.manual_seed(42) with pytest.warns(DeprecationWarning): model_decoder_batchnorm = create_model( - model_name, - "mobilenet_v2", - None, - decoder_use_batchnorm=decoder_option + model_name, "mobilenet_v2", None, decoder_use_batchnorm=decoder_option ) torch.manual_seed(42) - model_decoder_norm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=None, decoder_use_norm=decoder_option) - - check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)) + model_decoder_norm = create_model( + model_name, + "mobilenet_v2", + None, + decoder_use_batchnorm=None, + decoder_use_norm=decoder_option, + ) + check_two_models_strictly_equal( + model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) + ) @pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) @@ -29,12 +33,17 @@ def test_pspnet_before_after_use_norm(decoder_option): torch.manual_seed(42) with pytest.warns(DeprecationWarning): model_decoder_batchnorm = create_model( - "pspnet", - "mobilenet_v2", - None, - psp_use_batchnorm=decoder_option + "pspnet", "mobilenet_v2", None, psp_use_batchnorm=decoder_option ) torch.manual_seed(42) - model_decoder_norm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=None, decoder_use_norm=decoder_option) - - check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)) + model_decoder_norm = create_model( + "pspnet", + "mobilenet_v2", + None, + psp_use_batchnorm=None, + decoder_use_norm=decoder_option, + ) + + check_two_models_strictly_equal( + model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) + ) diff --git a/tests/utils.py b/tests/utils.py index 0b827942..bd6376a3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -60,9 +60,12 @@ def check_run_test_on_diff_or_main(filepath_patterns: List[str]): return False -def check_two_models_strictly_equal(model_a: torch.nn.Module, model_b: torch.nn.Module, input_data: torch.Tensor) -> None: - for (k1, v1), (k2, v2) in zip(model_a.state_dict().items(), - model_b.state_dict().items()): +def check_two_models_strictly_equal( + model_a: torch.nn.Module, model_b: torch.nn.Module, input_data: torch.Tensor +) -> None: + for (k1, v1), (k2, v2) in zip( + model_a.state_dict().items(), model_b.state_dict().items() + ): assert k1 == k2, f"Key mismatch: {k1} != {k2}" assert (v1 == v2).all(), f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" From b0d41136a7d67ba9f9526d130d2fca0621e955b8 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:18:48 +0000 Subject: [PATCH 09/21] Fix bias term --- segmentation_models_pytorch/base/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index 5febe53c..e6440033 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -96,14 +96,14 @@ def __init__( ): norm = get_norm_layer(use_norm, out_channels) - is_batchnorm = isinstance(norm, nn.BatchNorm2d) + is_identity = isinstance(norm, nn.Identity) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - bias=is_batchnorm, + bias=is_identity, ) is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN) From 799f8f4590f35512f192643c5bdb13c1f8717b58 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:48:20 +0000 Subject: [PATCH 10/21] Refine error message --- segmentation_models_pytorch/base/modules.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index e6440033..15cfdb12 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -60,8 +60,10 @@ def get_norm_layer( ) if norm_params["type"] == "inplace" and InPlaceABN is None: raise RuntimeError( - "In order to use `use_norm='inplace'` the inplace_abn package must be installed. " - "To install see: https://github.com/mapillary/inplace_abn" + "In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n" + " $ pip install -U wheel setuptools\n" + " $ pip install inplace_abn --no-build-isolation\n" + "Also see: https://github.com/mapillary/inplace_abn" ) # Step 3. Initialize the norm layer From cb10389ac4726b648b4202a74522f4d2faa182b5 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:49:12 +0000 Subject: [PATCH 11/21] Move deprecation on top + some typehint fixes --- .../decoders/linknet/model.py | 20 ++++++++--------- .../decoders/manet/model.py | 22 +++++++++---------- .../decoders/pspnet/model.py | 20 ++++++++--------- .../decoders/unetplusplus/model.py | 22 +++++++++---------- .../decoders/upernet/model.py | 4 ++-- 5 files changed, 44 insertions(+), 44 deletions(-) diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 894444f7..7b1819e5 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -77,7 +77,7 @@ def __init__( decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -88,14 +88,6 @@ def __init__( "Encoder `{}` is not supported for Linknet".format(encoder_name) ) - self.encoder = get_encoder( - encoder_name, - in_channels=in_channels, - depth=encoder_depth, - weights=encoder_weights, - **kwargs, - ) - decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) if decoder_use_batchnorm is not None: warnings.warn( @@ -105,6 +97,14 @@ def __init__( ) decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + **kwargs, + ) + self.decoder = LinknetDecoder( encoder_channels=self.encoder.out_channels, n_blocks=encoder_depth, diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 2beaaf25..d020e7be 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union, Sequence, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -78,24 +78,16 @@ def __init__( encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): super().__init__() - self.encoder = get_encoder( - encoder_name, - in_channels=in_channels, - depth=encoder_depth, - weights=encoder_weights, - **kwargs, - ) - decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) if decoder_use_batchnorm is not None: warnings.warn( @@ -105,6 +97,14 @@ def __init__( ) decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + **kwargs, + ) + self.decoder = MAnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 44cbbc6e..5a354edf 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -80,21 +80,13 @@ def __init__( psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, upsampling: int = 8, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): super().__init__() - self.encoder = get_encoder( - encoder_name, - in_channels=in_channels, - depth=encoder_depth, - weights=encoder_weights, - **kwargs, - ) - psp_use_batchnorm = kwargs.pop("psp_use_batchnorm", None) if psp_use_batchnorm is not None: warnings.warn( @@ -104,6 +96,14 @@ def __init__( ) decoder_use_norm = psp_use_batchnorm + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + **kwargs, + ) + self.decoder = PSPDecoder( encoder_channels=self.encoder.out_channels, use_norm=decoder_use_norm, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index ce5106b5..257629e7 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Sequence, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -79,11 +79,11 @@ def __init__( encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -94,14 +94,6 @@ def __init__( "UnetPlusPlus is not support encoder_name={}".format(encoder_name) ) - self.encoder = get_encoder( - encoder_name, - in_channels=in_channels, - depth=encoder_depth, - weights=encoder_weights, - **kwargs, - ) - decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) if decoder_use_batchnorm is not None: warnings.warn( @@ -111,6 +103,14 @@ def __init__( ) decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + **kwargs, + ) + self.decoder = UnetPlusPlusDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index caae60c2..6ad5afd5 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -77,7 +77,7 @@ def __init__( decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): From 4b6792f90d9297addbb9a307a8707031f239089b Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:49:38 +0000 Subject: [PATCH 12/21] Redesign ConvBnRelu block --- .../decoders/linknet/decoder.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index dc15ac65..95c7f9f6 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -13,21 +13,12 @@ def __init__( use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() - layers = [ - nn.ConvTranspose2d( - in_channels, out_channels, kernel_size=4, stride=2, padding=1 - ), - nn.ReLU(inplace=True), - ] - - if use_norm != "identity": - if isinstance(use_norm, dict): - if use_norm.get("type") != "identity": - layers.insert(1, modules.get_norm_layer(use_norm, out_channels)) - else: - layers.insert(1, modules.get_norm_layer(use_norm, out_channels)) - - super().__init__(*layers) + conv = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=4, stride=2, padding=1 + ) + norm = modules.get_norm_layer(use_norm, out_channels) + activation = nn.ReLU(inplace=True) + super().__init__(conv, norm, activation) class DecoderBlock(nn.Module): From 22ea56992419e3ee225a9db55b4eb563e7a97936 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:49:57 +0000 Subject: [PATCH 13/21] Fix kernel_size type --- segmentation_models_pytorch/decoders/pspnet/decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py index ae0fda43..80ad289c 100644 --- a/segmentation_models_pytorch/decoders/pspnet/decoder.py +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -22,7 +22,9 @@ def __init__( self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), - modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_norm=use_norm), + modules.Conv2dReLU( + in_channels, out_channels, kernel_size=1, use_norm=use_norm + ), ) def forward(self, x: torch.Tensor) -> torch.Tensor: From ce59ffaeceffdcc5580f3d784f001f07dd226b05 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:50:10 +0000 Subject: [PATCH 14/21] Better type hints --- .../decoders/unetplusplus/decoder.py | 6 +++--- segmentation_models_pytorch/decoders/upernet/decoder.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py index 42d7b338..e09327ac 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Sequence from segmentation_models_pytorch.base import modules as md @@ -76,8 +76,8 @@ def __init__( class UnetPlusPlusDecoder(nn.Module): def __init__( self, - encoder_channels: List[int], - decoder_channels: List[int], + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], n_blocks: int = 5, use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index fa6f9b05..810778f3 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Union, Sequence import torch import torch.nn as nn @@ -12,7 +12,7 @@ def __init__( self, in_channels: int, out_channels: int, - sizes: Tuple[int, ...] = (1, 2, 3, 6), + sizes: Sequence[int] = (1, 2, 3, 6), use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -80,7 +80,7 @@ def forward(self, x, skip): class UPerNetDecoder(nn.Module): def __init__( self, - encoder_channels: Tuple[int, ...], + encoder_channels: Sequence[int], encoder_depth: int = 5, pyramid_channels: int = 256, segmentation_channels: int = 64, From 05d6d7aec553e496eb5030f8999a53f8b113d599 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:50:24 +0000 Subject: [PATCH 15/21] Fix InplaceABN test --- tests/base/test_modules.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py index 13522484..5afa8e4f 100644 --- a/tests/base/test_modules.py +++ b/tests/base/test_modules.py @@ -1,6 +1,5 @@ +import pytest from torch import nn - -from inplace_abn import InPlaceABN from segmentation_models_pytorch.base.modules import Conv2dReLU @@ -52,8 +51,14 @@ def test_conv2drelu_instancenorm(): def test_conv2drelu_inplace(): + try: + from inplace_abn import InPlaceABN + except ImportError: + pytest.skip("InPlaceABN is not installed") + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="inplace") + assert len(module) == 3 assert isinstance(module[0], nn.Conv2d) assert isinstance(module[1], InPlaceABN) - assert isinstance(module[2], nn.ReLU) + assert isinstance(module[2], nn.Identity) From 2856bc51483bcab6444d807eaff6f1de260725f2 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 11:55:01 +0000 Subject: [PATCH 16/21] Fix segformer --- segmentation_models_pytorch/decoders/segformer/decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/decoders/segformer/decoder.py b/segmentation_models_pytorch/decoders/segformer/decoder.py index cd160a4c..2bfadfff 100644 --- a/segmentation_models_pytorch/decoders/segformer/decoder.py +++ b/segmentation_models_pytorch/decoders/segformer/decoder.py @@ -50,7 +50,7 @@ def __init__( in_channels=(len(encoder_channels) - 1) * segmentation_channels, out_channels=segmentation_channels, kernel_size=1, - use_batchnorm=True, + use_norm="batchnorm", ) def forward(self, features: List[torch.Tensor]) -> torch.Tensor: From 846e112d5294b4e02bafd584ca1abe8b1cbdc2b9 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 12:17:01 +0000 Subject: [PATCH 17/21] Minor fix --- segmentation_models_pytorch/decoders/pspnet/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 5a354edf..4b2d19f0 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -76,7 +76,7 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, - decoder_use_norm: Union[bool, str, Dict[str, Any], None] = "batchnorm", + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, @@ -90,7 +90,7 @@ def __init__( psp_use_batchnorm = kwargs.pop("psp_use_batchnorm", None) if psp_use_batchnorm is not None: warnings.warn( - "The usage of psp_use_batchnorm is deprecated. Please modify your code for use_norm", + "The usage of psp_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", DeprecationWarning, stacklevel=2, ) From e8852c9105ba75da6e03547a137e85275e257271 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 12:17:24 +0000 Subject: [PATCH 18/21] Fix deprecation tests --- tests/encoders/test_batchnorm_deprecation.py | 31 ++++++++++++-------- tests/utils.py | 11 +++++-- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py index f7186988..ff53563f 100644 --- a/tests/encoders/test_batchnorm_deprecation.py +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -2,7 +2,7 @@ import torch -from segmentation_models_pytorch import create_model +import segmentation_models_pytorch as smp from tests.utils import check_two_models_strictly_equal @@ -11,18 +11,21 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option): torch.manual_seed(42) with pytest.warns(DeprecationWarning): - model_decoder_batchnorm = create_model( - model_name, "mobilenet_v2", None, decoder_use_batchnorm=decoder_option + model_decoder_batchnorm = smp.create_model( + model_name, + "mobilenet_v2", + encoder_weights=None, + decoder_use_batchnorm=decoder_option, ) - torch.manual_seed(42) - model_decoder_norm = create_model( + model_decoder_norm = smp.create_model( model_name, "mobilenet_v2", - None, - decoder_use_batchnorm=None, + encoder_weights=None, decoder_use_norm=decoder_option, ) + model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) + check_two_models_strictly_equal( model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) ) @@ -32,17 +35,19 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option): def test_pspnet_before_after_use_norm(decoder_option): torch.manual_seed(42) with pytest.warns(DeprecationWarning): - model_decoder_batchnorm = create_model( - "pspnet", "mobilenet_v2", None, psp_use_batchnorm=decoder_option + model_decoder_batchnorm = smp.create_model( + "pspnet", + "mobilenet_v2", + encoder_weights=None, + psp_use_batchnorm=decoder_option, ) - torch.manual_seed(42) - model_decoder_norm = create_model( + model_decoder_norm = smp.create_model( "pspnet", "mobilenet_v2", - None, - psp_use_batchnorm=None, + encoder_weights=None, decoder_use_norm=decoder_option, ) + model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) check_two_models_strictly_equal( model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) diff --git a/tests/utils.py b/tests/utils.py index bd6376a3..1e97b40b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -67,7 +67,14 @@ def check_two_models_strictly_equal( model_a.state_dict().items(), model_b.state_dict().items() ): assert k1 == k2, f"Key mismatch: {k1} != {k2}" - assert (v1 == v2).all(), f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" + torch.testing.assert_close( + v1, v2, msg=f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" + ) + model_a.eval() + model_b.eval() with torch.inference_mode(): - assert (model_a(input_data) == model_b(input_data)).all() + output_a = model_a(input_data) + output_b = model_b(input_data) + + torch.testing.assert_close(output_a, output_b) From c8c114a6d63d0f41f7d3051ab8499bb229c8fc01 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 12:24:15 +0000 Subject: [PATCH 19/21] Bump tolerance a bit --- tests/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/base.py b/tests/models/base.py index f7492986..b96e76e8 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -282,4 +282,4 @@ def test_torch_script(self): eager_output = model(sample) self.assertEqual(scripted_output.shape, eager_output.shape) - torch.testing.assert_close(scripted_output, eager_output) + torch.testing.assert_close(scripted_output, eager_output, rtol=1e-3, atol=1e-3) From 1f422a1e3d654835eef88163e8af5dbf5839ab25 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 12:24:40 +0000 Subject: [PATCH 20/21] Move validation on top (important) --- .../decoders/unet/model.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index aa0b5c5e..22d7db11 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -121,6 +121,15 @@ def __init__( ): super().__init__() + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -131,15 +140,6 @@ def __init__( add_center_block = encoder_name.startswith("vgg") - decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) - if decoder_use_batchnorm is not None: - warnings.warn( - "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", - DeprecationWarning, - stacklevel=2, - ) - decoder_use_norm = decoder_use_batchnorm - self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, From b53525e7318a353983ea962535c2fd69f173d81c Mon Sep 17 00:00:00 2001 From: qubvel Date: Sat, 5 Apr 2025 12:24:54 +0000 Subject: [PATCH 21/21] Fix deprecation message --- segmentation_models_pytorch/decoders/linknet/model.py | 2 +- segmentation_models_pytorch/decoders/manet/model.py | 2 +- segmentation_models_pytorch/decoders/unetplusplus/model.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 7b1819e5..be0d01b2 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -91,7 +91,7 @@ def __init__( decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) if decoder_use_batchnorm is not None: warnings.warn( - "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", DeprecationWarning, stacklevel=2, ) diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index d020e7be..a478b5c5 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -91,7 +91,7 @@ def __init__( decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) if decoder_use_batchnorm is not None: warnings.warn( - "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", DeprecationWarning, stacklevel=2, ) diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 257629e7..be0f8f83 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -97,7 +97,7 @@ def __init__( decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) if decoder_use_batchnorm is not None: warnings.warn( - "The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", DeprecationWarning, stacklevel=2, )