Skip to content

Commit 44505fd

Browse files
🚨🚨🚨 Deprecate use_batchnorm in favor of generalized use_norm parameter (#1095)
* Deprecate use_batchnorm in favor of generalized use_norm parameter * First fix following review * Add test to verify before / after functionality Fix issue in Linknet Handle norm and relu on itsd own Align pspnet Add use_norm in upernet Remove groupnorm possibility Update doc * Set use_batchnorm default to None so that default use_norm Make warning visible by changing filter and add a test for it Fix test before after so that the value is looked and not the shape of tensor * Changes following review * Revert default value to batchnorm Fix typo in decoder_use_norm doc Add description to invalid type error * Minor style fixes * Fixup * Fix bias term * Refine error message * Move deprecation on top + some typehint fixes * Redesign ConvBnRelu block * Fix kernel_size type * Better type hints * Fix InplaceABN test * Fix segformer * Minor fix * Fix deprecation tests * Bump tolerance a bit * Move validation on top (important) * Fix deprecation message --------- Co-authored-by: qubvel <[email protected]>
1 parent 930a163 commit 44505fd

File tree

18 files changed

+518
-141
lines changed

18 files changed

+518
-141
lines changed

‎segmentation_models_pytorch/base/modules.py

+91-23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Dict, Union
2+
13
import torch
24
import torch.nn as nn
35

@@ -7,43 +9,109 @@
79
InPlaceABN = None
810

911

12+
def get_norm_layer(
13+
use_norm: Union[bool, str, Dict[str, Any]], out_channels: int
14+
) -> nn.Module:
15+
supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm")
16+
17+
# Step 1. Convert tot dict representation
18+
19+
## Check boolean
20+
if use_norm is True:
21+
norm_params = {"type": "batchnorm"}
22+
elif use_norm is False:
23+
norm_params = {"type": "identity"}
24+
25+
## Check string
26+
elif isinstance(use_norm, str):
27+
norm_str = use_norm.lower()
28+
if norm_str == "inplace":
29+
norm_params = {
30+
"type": "inplace",
31+
"activation": "leaky_relu",
32+
"activation_param": 0.0,
33+
}
34+
elif norm_str in supported_norms:
35+
norm_params = {"type": norm_str}
36+
else:
37+
raise ValueError(
38+
f"Unrecognized normalization type string provided: {use_norm}. Should be in "
39+
f"{supported_norms}"
40+
)
41+
42+
## Check dict
43+
elif isinstance(use_norm, dict):
44+
norm_params = use_norm
45+
46+
else:
47+
raise ValueError(
48+
f"Invalid type for use_norm should either be a bool (batchnorm/identity), "
49+
f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}"
50+
)
51+
52+
# Step 2. Check if the dict is valid
53+
if "type" not in norm_params:
54+
raise ValueError(
55+
f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'."
56+
)
57+
if norm_params["type"] not in supported_norms:
58+
raise ValueError(
59+
f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}"
60+
)
61+
if norm_params["type"] == "inplace" and InPlaceABN is None:
62+
raise RuntimeError(
63+
"In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n"
64+
" $ pip install -U wheel setuptools\n"
65+
" $ pip install inplace_abn --no-build-isolation\n"
66+
"Also see: https://github.com/mapillary/inplace_abn"
67+
)
68+
69+
# Step 3. Initialize the norm layer
70+
norm_type = norm_params["type"]
71+
norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"}
72+
73+
if norm_type == "inplace":
74+
norm = InPlaceABN(out_channels, **norm_kwargs)
75+
elif norm_type == "batchnorm":
76+
norm = nn.BatchNorm2d(out_channels, **norm_kwargs)
77+
elif norm_type == "identity":
78+
norm = nn.Identity()
79+
elif norm_type == "layernorm":
80+
norm = nn.LayerNorm(out_channels, **norm_kwargs)
81+
elif norm_type == "instancenorm":
82+
norm = nn.InstanceNorm2d(out_channels, **norm_kwargs)
83+
else:
84+
raise ValueError(f"Unrecognized normalization type: {norm_type}")
85+
86+
return norm
87+
88+
1089
class Conv2dReLU(nn.Sequential):
1190
def __init__(
1291
self,
13-
in_channels,
14-
out_channels,
15-
kernel_size,
16-
padding=0,
17-
stride=1,
18-
use_batchnorm=True,
92+
in_channels: int,
93+
out_channels: int,
94+
kernel_size: int,
95+
padding: int = 0,
96+
stride: int = 1,
97+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
1998
):
20-
if use_batchnorm == "inplace" and InPlaceABN is None:
21-
raise RuntimeError(
22-
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
23-
+ "To install see: https://github.com/mapillary/inplace_abn"
24-
)
99+
norm = get_norm_layer(use_norm, out_channels)
25100

101+
is_identity = isinstance(norm, nn.Identity)
26102
conv = nn.Conv2d(
27103
in_channels,
28104
out_channels,
29105
kernel_size,
30106
stride=stride,
31107
padding=padding,
32-
bias=not (use_batchnorm),
108+
bias=is_identity,
33109
)
34-
relu = nn.ReLU(inplace=True)
35-
36-
if use_batchnorm == "inplace":
37-
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
38-
relu = nn.Identity()
39110

40-
elif use_batchnorm and use_batchnorm != "inplace":
41-
bn = nn.BatchNorm2d(out_channels)
42-
43-
else:
44-
bn = nn.Identity()
111+
is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN)
112+
activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True)
45113

46-
super(Conv2dReLU, self).__init__(conv, bn, relu)
114+
super(Conv2dReLU, self).__init__(conv, norm, activation)
47115

48116

49117
class SCSEModule(nn.Module):

‎segmentation_models_pytorch/decoders/linknet/decoder.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,48 @@
11
import torch
22
import torch.nn as nn
33

4-
from typing import List, Optional
4+
from typing import Any, Dict, List, Optional, Union
55
from segmentation_models_pytorch.base import modules
66

77

88
class TransposeX2(nn.Sequential):
9-
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
9+
def __init__(
10+
self,
11+
in_channels: int,
12+
out_channels: int,
13+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
14+
):
1015
super().__init__()
11-
layers = [
12-
nn.ConvTranspose2d(
13-
in_channels, out_channels, kernel_size=4, stride=2, padding=1
14-
),
15-
nn.ReLU(inplace=True),
16-
]
17-
18-
if use_batchnorm:
19-
layers.insert(1, nn.BatchNorm2d(out_channels))
20-
21-
super().__init__(*layers)
16+
conv = nn.ConvTranspose2d(
17+
in_channels, out_channels, kernel_size=4, stride=2, padding=1
18+
)
19+
norm = modules.get_norm_layer(use_norm, out_channels)
20+
activation = nn.ReLU(inplace=True)
21+
super().__init__(conv, norm, activation)
2222

2323

2424
class DecoderBlock(nn.Module):
25-
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
25+
def __init__(
26+
self,
27+
in_channels: int,
28+
out_channels: int,
29+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
30+
):
2631
super().__init__()
2732

2833
self.block = nn.Sequential(
2934
modules.Conv2dReLU(
3035
in_channels,
3136
in_channels // 4,
3237
kernel_size=1,
33-
use_batchnorm=use_batchnorm,
34-
),
35-
TransposeX2(
36-
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
38+
use_norm=use_norm,
3739
),
40+
TransposeX2(in_channels // 4, in_channels // 4, use_norm=use_norm),
3841
modules.Conv2dReLU(
3942
in_channels // 4,
4043
out_channels,
4144
kernel_size=1,
42-
use_batchnorm=use_batchnorm,
45+
use_norm=use_norm,
4346
),
4447
)
4548

@@ -58,7 +61,7 @@ def __init__(
5861
encoder_channels: List[int],
5962
prefinal_channels: int = 32,
6063
n_blocks: int = 5,
61-
use_batchnorm: bool = True,
64+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
6265
):
6366
super().__init__()
6467

@@ -71,7 +74,11 @@ def __init__(
7174

7275
self.blocks = nn.ModuleList(
7376
[
74-
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
77+
DecoderBlock(
78+
channels[i],
79+
channels[i + 1],
80+
use_norm=use_norm,
81+
)
7582
for i in range(n_blocks)
7683
]
7784
)

‎segmentation_models_pytorch/decoders/linknet/model.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Optional, Union
1+
import warnings
2+
from typing import Any, Dict, Optional, Union, Callable
23

34
from segmentation_models_pytorch.base import (
45
ClassificationHead,
@@ -29,9 +30,22 @@ class Linknet(SegmentationModel):
2930
Default is 5
3031
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
3132
other pretrained weights (see table with available weights for each encoder_name)
32-
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
33-
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
34-
Available options are **True, False, "inplace"**
33+
decoder_use_norm: Specifies normalization between Conv2D and activation.
34+
Accepts the following types:
35+
- **True**: Defaults to `"batchnorm"`.
36+
- **False**: No normalization (`nn.Identity`).
37+
- **str**: Specifies normalization type using default parameters. Available values:
38+
`"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
39+
- **dict**: Fully customizable normalization settings. Structure:
40+
```python
41+
{"type": <norm_type>, **kwargs}
42+
```
43+
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
44+
45+
**Example**:
46+
```python
47+
decoder_use_norm={"type": "layernorm", "eps": 1e-2}
48+
```
3549
in_channels: A number of input channels for the model, default is 3 (RGB images)
3650
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
3751
activation: An activation function to apply after the final convolution layer.
@@ -60,10 +74,10 @@ def __init__(
6074
encoder_name: str = "resnet34",
6175
encoder_depth: int = 5,
6276
encoder_weights: Optional[str] = "imagenet",
63-
decoder_use_batchnorm: bool = True,
77+
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
6478
in_channels: int = 3,
6579
classes: int = 1,
66-
activation: Optional[Union[str, callable]] = None,
80+
activation: Optional[Union[str, Callable]] = None,
6781
aux_params: Optional[dict] = None,
6882
**kwargs: dict[str, Any],
6983
):
@@ -74,6 +88,15 @@ def __init__(
7488
"Encoder `{}` is not supported for Linknet".format(encoder_name)
7589
)
7690

91+
decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
92+
if decoder_use_batchnorm is not None:
93+
warnings.warn(
94+
"The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
95+
DeprecationWarning,
96+
stacklevel=2,
97+
)
98+
decoder_use_norm = decoder_use_batchnorm
99+
77100
self.encoder = get_encoder(
78101
encoder_name,
79102
in_channels=in_channels,
@@ -86,7 +109,7 @@ def __init__(
86109
encoder_channels=self.encoder.out_channels,
87110
n_blocks=encoder_depth,
88111
prefinal_channels=32,
89-
use_batchnorm=decoder_use_batchnorm,
112+
use_norm=decoder_use_norm,
90113
)
91114

92115
self.segmentation_head = SegmentationHead(

‎segmentation_models_pytorch/decoders/manet/decoder.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
13
import torch
24
import torch.nn as nn
35
import torch.nn.functional as F
46

5-
from typing import List, Optional
6-
77
from segmentation_models_pytorch.base import modules as md
88

99

@@ -49,7 +49,7 @@ def __init__(
4949
in_channels: int,
5050
skip_channels: int,
5151
out_channels: int,
52-
use_batchnorm: bool = True,
52+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
5353
reduction: int = 16,
5454
):
5555
# MFABBlock is just a modified version of SE-blocks, one for skip, one for input
@@ -60,10 +60,13 @@ def __init__(
6060
in_channels,
6161
kernel_size=3,
6262
padding=1,
63-
use_batchnorm=use_batchnorm,
63+
use_norm=use_norm,
6464
),
6565
md.Conv2dReLU(
66-
in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
66+
in_channels,
67+
skip_channels,
68+
kernel_size=1,
69+
use_norm=use_norm,
6770
),
6871
)
6972
reduced_channels = max(1, skip_channels // reduction)
@@ -87,14 +90,14 @@ def __init__(
8790
out_channels,
8891
kernel_size=3,
8992
padding=1,
90-
use_batchnorm=use_batchnorm,
93+
use_norm=use_norm,
9194
)
9295
self.conv2 = md.Conv2dReLU(
9396
out_channels,
9497
out_channels,
9598
kernel_size=3,
9699
padding=1,
97-
use_batchnorm=use_batchnorm,
100+
use_norm=use_norm,
98101
)
99102

100103
def forward(
@@ -119,22 +122,22 @@ def __init__(
119122
in_channels: int,
120123
skip_channels: int,
121124
out_channels: int,
122-
use_batchnorm: bool = True,
125+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
123126
):
124127
super().__init__()
125128
self.conv1 = md.Conv2dReLU(
126129
in_channels + skip_channels,
127130
out_channels,
128131
kernel_size=3,
129132
padding=1,
130-
use_batchnorm=use_batchnorm,
133+
use_norm=use_norm,
131134
)
132135
self.conv2 = md.Conv2dReLU(
133136
out_channels,
134137
out_channels,
135138
kernel_size=3,
136139
padding=1,
137-
use_batchnorm=use_batchnorm,
140+
use_norm=use_norm,
138141
)
139142

140143
def forward(
@@ -155,7 +158,7 @@ def __init__(
155158
decoder_channels: List[int],
156159
n_blocks: int = 5,
157160
reduction: int = 16,
158-
use_batchnorm: bool = True,
161+
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
159162
pab_channels: int = 64,
160163
):
161164
super().__init__()
@@ -182,7 +185,7 @@ def __init__(
182185
self.center = PABBlock(head_channels, pab_channels=pab_channels)
183186

184187
# combine decoder keyword arguments
185-
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
188+
kwargs = dict(use_norm=use_norm) # no attention type here
186189
blocks = [
187190
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
188191
if skip_ch > 0

0 commit comments

Comments
 (0)