Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate use_batchnorm in favor of generalized use_norm parameter #1095

Merged
merged 22 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 91 additions & 23 deletions segmentation_models_pytorch/base/modules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Union

import torch
import torch.nn as nn

Expand All @@ -7,43 +9,109 @@
InPlaceABN = None


def get_norm_layer(
use_norm: Union[bool, str, Dict[str, Any]], out_channels: int
) -> nn.Module:
supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm")

# Step 1. Convert tot dict representation

## Check boolean
if use_norm is True:
norm_params = {"type": "batchnorm"}

Check warning on line 21 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L21

Added line #L21 was not covered by tests
elif use_norm is False:
norm_params = {"type": "identity"}

Check warning on line 23 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L23

Added line #L23 was not covered by tests

## Check string
elif isinstance(use_norm, str):
norm_str = use_norm.lower()
if norm_str == "inplace":
norm_params = {

Check warning on line 29 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L29

Added line #L29 was not covered by tests
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
elif norm_str in supported_norms:
norm_params = {"type": norm_str}
else:
raise ValueError(

Check warning on line 37 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L37

Added line #L37 was not covered by tests
f"Unrecognized normalization type string provided: {use_norm}. Should be in "
f"{supported_norms}"
)

## Check dict
elif isinstance(use_norm, dict):
norm_params = use_norm

else:
raise ValueError(

Check warning on line 47 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L47

Added line #L47 was not covered by tests
f"Invalid type for use_norm should either be a bool (batchnorm/identity), "
f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}"
)

# Step 2. Check if the dict is valid
if "type" not in norm_params:
raise ValueError(

Check warning on line 54 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L54

Added line #L54 was not covered by tests
f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'."
)
if norm_params["type"] not in supported_norms:
raise ValueError(

Check warning on line 58 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L58

Added line #L58 was not covered by tests
f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}"
)
if norm_params["type"] == "inplace" and InPlaceABN is None:
raise RuntimeError(

Check warning on line 62 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L62

Added line #L62 was not covered by tests
"In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n"
" $ pip install -U wheel setuptools\n"
" $ pip install inplace_abn --no-build-isolation\n"
"Also see: https://github.com/mapillary/inplace_abn"
)

# Step 3. Initialize the norm layer
norm_type = norm_params["type"]
norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"}

if norm_type == "inplace":
norm = InPlaceABN(out_channels, **norm_kwargs)

Check warning on line 74 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L74

Added line #L74 was not covered by tests
elif norm_type == "batchnorm":
norm = nn.BatchNorm2d(out_channels, **norm_kwargs)
elif norm_type == "identity":
norm = nn.Identity()
elif norm_type == "layernorm":
norm = nn.LayerNorm(out_channels, **norm_kwargs)
elif norm_type == "instancenorm":
norm = nn.InstanceNorm2d(out_channels, **norm_kwargs)
else:
raise ValueError(f"Unrecognized normalization type: {norm_type}")

Check warning on line 84 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L84

Added line #L84 was not covered by tests

return norm


class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: int = 0,
stride: int = 1,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
norm = get_norm_layer(use_norm, out_channels)

is_identity = isinstance(norm, nn.Identity)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
bias=is_identity,
)
relu = nn.ReLU(inplace=True)

if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()

elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)

else:
bn = nn.Identity()
is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN)
activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True)

super(Conv2dReLU, self).__init__(conv, bn, relu)
super(Conv2dReLU, self).__init__(conv, norm, activation)


class SCSEModule(nn.Module):
Expand Down
49 changes: 28 additions & 21 deletions segmentation_models_pytorch/decoders/linknet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,48 @@
import torch
import torch.nn as nn

from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from segmentation_models_pytorch.base import modules


class TransposeX2(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
super().__init__()
layers = [
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=4, stride=2, padding=1
),
nn.ReLU(inplace=True),
]

if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels))

super().__init__(*layers)
conv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=4, stride=2, padding=1
)
norm = modules.get_norm_layer(use_norm, out_channels)
activation = nn.ReLU(inplace=True)
super().__init__(conv, norm, activation)


class DecoderBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
super().__init__()

self.block = nn.Sequential(
modules.Conv2dReLU(
in_channels,
in_channels // 4,
kernel_size=1,
use_batchnorm=use_batchnorm,
),
TransposeX2(
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
use_norm=use_norm,
),
TransposeX2(in_channels // 4, in_channels // 4, use_norm=use_norm),
modules.Conv2dReLU(
in_channels // 4,
out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
)

Expand All @@ -58,7 +61,7 @@ def __init__(
encoder_channels: List[int],
prefinal_channels: int = 32,
n_blocks: int = 5,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
super().__init__()

Expand All @@ -71,7 +74,11 @@ def __init__(

self.blocks = nn.ModuleList(
[
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
DecoderBlock(
channels[i],
channels[i + 1],
use_norm=use_norm,
)
for i in range(n_blocks)
]
)
Expand Down
37 changes: 30 additions & 7 deletions segmentation_models_pytorch/decoders/linknet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional, Union
import warnings
from typing import Any, Dict, Optional, Union, Callable

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand Down Expand Up @@ -29,9 +30,22 @@
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
Available options are **True, False, "inplace"**
decoder_use_norm: Specifies normalization between Conv2D and activation.
Accepts the following types:
- **True**: Defaults to `"batchnorm"`.
- **False**: No normalization (`nn.Identity`).
- **str**: Specifies normalization type using default parameters. Available values:
`"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
- **dict**: Fully customizable normalization settings. Structure:
```python
{"type": <norm_type>, **kwargs}
```
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

**Example**:
```python
decoder_use_norm={"type": "layernorm", "eps": 1e-2}
```
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Expand Down Expand Up @@ -60,10 +74,10 @@
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
activation: Optional[Union[str, Callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
Expand All @@ -74,6 +88,15 @@
"Encoder `{}` is not supported for Linknet".format(encoder_name)
)

decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
if decoder_use_batchnorm is not None:
warnings.warn(

Check warning on line 93 in segmentation_models_pytorch/decoders/linknet/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/linknet/model.py#L93

Added line #L93 was not covered by tests
"The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
DeprecationWarning,
stacklevel=2,
)
decoder_use_norm = decoder_use_batchnorm

Check warning on line 98 in segmentation_models_pytorch/decoders/linknet/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/linknet/model.py#L98

Added line #L98 was not covered by tests

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
Expand All @@ -86,7 +109,7 @@
encoder_channels=self.encoder.out_channels,
n_blocks=encoder_depth,
prefinal_channels=32,
use_batchnorm=decoder_use_batchnorm,
use_norm=decoder_use_norm,
)

self.segmentation_head = SegmentationHead(
Expand Down
27 changes: 15 additions & 12 deletions segmentation_models_pytorch/decoders/manet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Optional

from segmentation_models_pytorch.base import modules as md


Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
reduction: int = 16,
):
# MFABBlock is just a modified version of SE-blocks, one for skip, one for input
Expand All @@ -60,10 +60,13 @@ def __init__(
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
md.Conv2dReLU(
in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
in_channels,
skip_channels,
kernel_size=1,
use_norm=use_norm,
),
)
reduced_channels = max(1, skip_channels // reduction)
Expand All @@ -87,14 +90,14 @@ def __init__(
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -119,22 +122,22 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -155,7 +158,7 @@ def __init__(
decoder_channels: List[int],
n_blocks: int = 5,
reduction: int = 16,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
pab_channels: int = 64,
):
super().__init__()
Expand All @@ -182,7 +185,7 @@ def __init__(
self.center = PABBlock(head_channels, pab_channels=pab_channels)

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
kwargs = dict(use_norm=use_norm) # no attention type here
blocks = [
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
if skip_ch > 0
Expand Down
Loading
Loading