-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate use_batchnorm in favor of generalized use_norm parameter #1095
Changes from 6 commits
e26adcd
d65001b
1b16b25
10d496a
467057a
1ae11c3
be22951
7c88361
1255ee0
b0d4113
799f8f4
cb10389
4b6792f
22ea569
ce59ffa
05d6d7a
2856bc5
846e112
e8852c9
c8c114a
1f422a1
b53525e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from typing import Any, Dict, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
@@ -6,6 +8,64 @@ | |
except ImportError: | ||
InPlaceABN = None | ||
|
||
def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int) -> nn.Module: | ||
supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") | ||
if use_norm is True: | ||
norm_params = {"type": "batchnorm"} | ||
elif use_norm is False: | ||
norm_params = {"type": "identity"} | ||
elif use_norm == "inplace": | ||
norm_params = {"type": "inplace", "activation": "leaky_relu", "activation_param": 0.0} | ||
elif isinstance(use_norm, str): | ||
norm_str = use_norm.lower() | ||
if norm_str == "inplace": | ||
norm_params = { | ||
"type": "inplace", | ||
"activation": "leaky_relu", | ||
"activation_param": 0.0, | ||
} | ||
elif norm_str in ( | ||
"batchnorm", | ||
"identity", | ||
"layernorm", | ||
"instancenorm", | ||
): | ||
norm_params = {"type": norm_str} | ||
else: | ||
raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}") | ||
elif isinstance(use_norm, dict): | ||
norm_params = use_norm | ||
else: | ||
raise ValueError("use_norm must be a dictionary, boolean, or string. Please refer to the documentation.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's have a more descriptive error here, I mean specify what kind of string and dict structure it should be. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made a proposition |
||
|
||
|
||
if not "type" in norm_params: | ||
raise ValueError(f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'.") | ||
if norm_params["type"] not in supported_norms: | ||
raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}") | ||
|
||
norm_type = norm_params["type"] | ||
extra_kwargs = {k: v for k, v in norm_params.items() if k != "type"} | ||
|
||
if norm_type == "inplace" and InPlaceABN is None: | ||
raise RuntimeError( | ||
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " | ||
"To install see: https://github.com/mapillary/inplace_abn" | ||
) | ||
elif norm_type == "inplace": | ||
norm = InPlaceABN(out_channels, **extra_kwargs) | ||
elif norm_type == "batchnorm": | ||
norm = nn.BatchNorm2d(out_channels, **extra_kwargs) | ||
elif norm_type == "identity": | ||
norm = nn.Identity() | ||
elif norm_type == "layernorm": | ||
norm = nn.LayerNorm(out_channels, **extra_kwargs) | ||
elif norm_type == "instancenorm": | ||
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) | ||
else: | ||
raise ValueError(f"Unrecognized normalization type: {norm_type}") | ||
|
||
return norm | ||
|
||
class Conv2dReLU(nn.Sequential): | ||
def __init__( | ||
|
@@ -15,35 +75,24 @@ def __init__( | |
kernel_size, | ||
padding=0, | ||
stride=1, | ||
use_batchnorm=True, | ||
use_norm="batchnorm", | ||
): | ||
if use_batchnorm == "inplace" and InPlaceABN is None: | ||
raise RuntimeError( | ||
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " | ||
+ "To install see: https://github.com/mapillary/inplace_abn" | ||
) | ||
|
||
norm = get_norm_layer(use_norm, out_channels) | ||
conv = nn.Conv2d( | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
bias=not (use_batchnorm), | ||
bias=norm._get_name() != "BatchNorm2d", | ||
) | ||
relu = nn.ReLU(inplace=True) | ||
|
||
if use_batchnorm == "inplace": | ||
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) | ||
if norm._get_name() == "Inplace": | ||
relu = nn.Identity() | ||
|
||
elif use_batchnorm and use_batchnorm != "inplace": | ||
bn = nn.BatchNorm2d(out_channels) | ||
|
||
else: | ||
bn = nn.Identity() | ||
relu = nn.ReLU(inplace=True) | ||
|
||
super(Conv2dReLU, self).__init__(conv, bn, relu) | ||
super(Conv2dReLU, self).__init__(conv, norm, relu) | ||
|
||
|
||
class SCSEModule(nn.Module): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from typing import Any, Optional, Union | ||
import warnings | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from segmentation_models_pytorch.base import ( | ||
ClassificationHead, | ||
|
@@ -29,9 +30,22 @@ class Linknet(SegmentationModel): | |
Default is 5 | ||
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and | ||
other pretrained weights (see table with available weights for each encoder_name) | ||
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers | ||
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. | ||
Available options are **True, False, "inplace"** | ||
decoder_use_norm: Specifies normalization between Conv2D and activation. | ||
Accepts the following types: | ||
- **True**: Defaults to `"batchnorm"`. | ||
- **False**: No normalization (`nn.Identity`). | ||
- **str**: Specifies normalization type using default parameters. Available values: | ||
`"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. | ||
- **dict**: Fully customizable normalization settings. Structure: | ||
```python | ||
{"type": <norm_type>, **kwargs} | ||
``` | ||
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. | ||
|
||
**Example**: | ||
```python | ||
use_norm={"type": "layernorm", "eps": 1e-2} | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the detailed docstring, really appretiate it! |
||
in_channels: A number of input channels for the model, default is 3 (RGB images) | ||
classes: A number of classes for output mask (or you can think as a number of channels of output mask) | ||
activation: An activation function to apply after the final convolution layer. | ||
|
@@ -60,7 +74,7 @@ def __init__( | |
encoder_name: str = "resnet34", | ||
encoder_depth: int = 5, | ||
encoder_weights: Optional[str] = "imagenet", | ||
decoder_use_batchnorm: bool = True, | ||
decoder_use_norm: Union[bool, str, Dict[str, Any]] = True, | ||
in_channels: int = 3, | ||
classes: int = 1, | ||
activation: Optional[Union[str, callable]] = None, | ||
|
@@ -82,11 +96,20 @@ def __init__( | |
**kwargs, | ||
) | ||
|
||
decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) | ||
if decoder_use_batchnorm is not None: | ||
warnings.warn( | ||
"The usage of decoder_use_batchnorm is deprecated. Please modify your code for use_norm", | ||
DeprecationWarning, | ||
stacklevel=2 | ||
) | ||
decoder_use_norm = decoder_use_batchnorm | ||
|
||
self.decoder = LinknetDecoder( | ||
encoder_channels=self.encoder.out_channels, | ||
n_blocks=encoder_depth, | ||
prefinal_channels=32, | ||
use_batchnorm=decoder_use_batchnorm, | ||
use_norm=decoder_use_norm, | ||
) | ||
|
||
self.segmentation_head = SegmentationHead( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!