Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate use_batchnorm in favor of generalized use_norm parameter #1095

Merged
merged 22 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 66 additions & 17 deletions segmentation_models_pytorch/base/modules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Union

import torch
import torch.nn as nn

Expand All @@ -6,6 +8,64 @@
except ImportError:
InPlaceABN = None

def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int) -> nn.Module:
supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm")
if use_norm is True:
norm_params = {"type": "batchnorm"}
elif use_norm is False:
norm_params = {"type": "identity"}
elif use_norm == "inplace":
norm_params = {"type": "inplace", "activation": "leaky_relu", "activation_param": 0.0}
elif isinstance(use_norm, str):
norm_str = use_norm.lower()
if norm_str == "inplace":
norm_params = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
elif norm_str in (
"batchnorm",
"identity",
"layernorm",
"instancenorm",
):
norm_params = {"type": norm_str}
else:
raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

elif isinstance(use_norm, dict):
norm_params = use_norm
else:
raise ValueError("use_norm must be a dictionary, boolean, or string. Please refer to the documentation.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a more descriptive error here, I mean specify what kind of string and dict structure it should be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a proposition



if not "type" in norm_params:
raise ValueError(f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'.")
if norm_params["type"] not in supported_norms:
raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}")

norm_type = norm_params["type"]
extra_kwargs = {k: v for k, v in norm_params.items() if k != "type"}

if norm_type == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
"To install see: https://github.com/mapillary/inplace_abn"
)
elif norm_type == "inplace":
norm = InPlaceABN(out_channels, **extra_kwargs)
elif norm_type == "batchnorm":
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
elif norm_type == "identity":
norm = nn.Identity()
elif norm_type == "layernorm":
norm = nn.LayerNorm(out_channels, **extra_kwargs)
elif norm_type == "instancenorm":
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)
else:
raise ValueError(f"Unrecognized normalization type: {norm_type}")

return norm

class Conv2dReLU(nn.Sequential):
def __init__(
Expand All @@ -15,35 +75,24 @@ def __init__(
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
use_norm="batchnorm",
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)

norm = get_norm_layer(use_norm, out_channels)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
bias=norm._get_name() != "BatchNorm2d",
)
relu = nn.ReLU(inplace=True)

if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
if norm._get_name() == "Inplace":
relu = nn.Identity()

elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)

else:
bn = nn.Identity()
relu = nn.ReLU(inplace=True)

super(Conv2dReLU, self).__init__(conv, bn, relu)
super(Conv2dReLU, self).__init__(conv, norm, relu)


class SCSEModule(nn.Module):
Expand Down
38 changes: 28 additions & 10 deletions segmentation_models_pytorch/decoders/linknet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import torch
import torch.nn as nn

from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from segmentation_models_pytorch.base import modules


class TransposeX2(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
super().__init__()
layers = [
nn.ConvTranspose2d(
Expand All @@ -15,31 +20,40 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
nn.ReLU(inplace=True),
]

if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels))
if use_norm != "identity":
if isinstance(use_norm, dict):
if use_norm.get("type") != "identity":
layers.insert(1, modules.get_norm_layer(use_norm, out_channels))
else:
layers.insert(1, modules.get_norm_layer(use_norm, out_channels))

super().__init__(*layers)


class DecoderBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()

self.block = nn.Sequential(
modules.Conv2dReLU(
in_channels,
in_channels // 4,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
TransposeX2(
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
in_channels // 4, in_channels // 4, use_norm=use_norm
),
modules.Conv2dReLU(
in_channels // 4,
out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
)

Expand All @@ -58,7 +72,7 @@ def __init__(
encoder_channels: List[int],
prefinal_channels: int = 32,
n_blocks: int = 5,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()

Expand All @@ -71,7 +85,11 @@ def __init__(

self.blocks = nn.ModuleList(
[
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
DecoderBlock(
channels[i],
channels[i + 1],
use_norm=use_norm,
)
for i in range(n_blocks)
]
)
Expand Down
35 changes: 29 additions & 6 deletions segmentation_models_pytorch/decoders/linknet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional, Union
import warnings
from typing import Any, Dict, Optional, Union

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand Down Expand Up @@ -29,9 +30,22 @@ class Linknet(SegmentationModel):
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
Available options are **True, False, "inplace"**
decoder_use_norm: Specifies normalization between Conv2D and activation.
Accepts the following types:
- **True**: Defaults to `"batchnorm"`.
- **False**: No normalization (`nn.Identity`).
- **str**: Specifies normalization type using default parameters. Available values:
`"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
- **dict**: Fully customizable normalization settings. Structure:
```python
{"type": <norm_type>, **kwargs}
```
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

**Example**:
```python
use_norm={"type": "layernorm", "eps": 1e-2}
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed docstring, really appretiate it!

in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Expand Down Expand Up @@ -60,7 +74,7 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_use_norm: Union[bool, str, Dict[str, Any]] = True,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
Expand All @@ -82,11 +96,20 @@ def __init__(
**kwargs,
)

decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
if decoder_use_batchnorm is not None:
warnings.warn(
"The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm",
DeprecationWarning,
stacklevel=2
)
decoder_use_norm = decoder_use_batchnorm

self.decoder = LinknetDecoder(
encoder_channels=self.encoder.out_channels,
n_blocks=encoder_depth,
prefinal_channels=32,
use_batchnorm=decoder_use_batchnorm,
use_norm=decoder_use_norm,
)

self.segmentation_head = SegmentationHead(
Expand Down
29 changes: 17 additions & 12 deletions segmentation_models_pytorch/decoders/manet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Optional

from segmentation_models_pytorch.base import modules as md


Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
reduction: int = 16,
):
# MFABBlock is just a modified version of SE-blocks, one for skip, one for input
Expand All @@ -60,10 +60,13 @@ def __init__(
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
md.Conv2dReLU(
in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
in_channels,
skip_channels,
kernel_size=1,
use_norm=use_norm,
),
)
reduced_channels = max(1, skip_channels // reduction)
Expand All @@ -87,14 +90,14 @@ def __init__(
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -119,22 +122,22 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -155,7 +158,7 @@ def __init__(
decoder_channels: List[int],
n_blocks: int = 5,
reduction: int = 16,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
pab_channels: int = 64,
):
super().__init__()
Expand All @@ -182,7 +185,9 @@ def __init__(
self.center = PABBlock(head_channels, pab_channels=pab_channels)

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
kwargs = dict(
use_norm=use_norm
) # no attention type here
blocks = [
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
if skip_ch > 0
Expand Down
Loading