Skip to content

AutoEncoder using the EfficientNet #257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
182 changes: 169 additions & 13 deletions efficientnet_pytorch/model.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -40,50 +40,56 @@ class MBConvBlock(nn.Module):
block_args (namedtuple): BlockArgs, defined in utils.py.
global_params (namedtuple): GlobalParam, defined in utils.py.
image_size (tuple or list): [image_height, image_width].
decoder_mode (bool): Reverse the block (deconvolution) if true.

References:
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
"""

def __init__(self, block_args, global_params, image_size=None):
def __init__(self, block_args, global_params, image_size=None, decoder_mode=False, decoder_output_image_size=None):
super().__init__()
self._block_args = block_args
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
self.decoder_mode = decoder_mode

# Expansion phase (Inverted Bottleneck)
inp = self._block_args.input_filters # number of input channels
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
if self._block_args.expand_ratio != 1:
Conv2d = get_same_padding_conv2d(image_size=image_size)
Conv2d = get_same_padding_conv2d(image_size=image_size, transposed=self.decoder_mode)
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size

# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Conv2d = get_same_padding_conv2d(image_size=image_size)
if self.decoder_mode:
# assert decoder_output_image_size
image_size = decoder_output_image_size
Conv2d = get_same_padding_conv2d(image_size=image_size, transposed=self.decoder_mode)
self._depthwise_conv = Conv2d(
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
kernel_size=k, stride=s, bias=False)
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
image_size = calculate_output_image_size(image_size, s)
if not self.decoder_mode:
image_size = calculate_output_image_size(image_size, s)

# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
Conv2d = get_same_padding_conv2d(image_size=(1, 1), transposed=self.decoder_mode)
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)

# Pointwise convolution phase
final_oup = self._block_args.output_filters
Conv2d = get_same_padding_conv2d(image_size=image_size)
Conv2d = get_same_padding_conv2d(image_size=image_size, transposed=self.decoder_mode)
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
self._swish = MemoryEfficientSwish()
Expand Down Expand Up @@ -152,9 +158,7 @@ class EfficientNet(nn.Module):
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)

Example:


import torch
>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
Expand All @@ -170,8 +174,8 @@ def __init__(self, blocks_args=None, global_params=None):
self._blocks_args = blocks_args

# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
self._bn_mom = bn_mom = 1 - self._global_params.batch_norm_momentum
self._bn_eps = bn_eps = self._global_params.batch_norm_epsilon

# Get stem static or dynamic convolution depending on image size
image_size = global_params.image_size
Expand All @@ -186,6 +190,7 @@ def __init__(self, blocks_args=None, global_params=None):

# Build blocks
self._blocks = nn.ModuleList([])
self._blocks_image_size = [image_size]
for block_args in self._blocks_args:

# Update block input and output filters based on depth multiplier.
Expand All @@ -198,6 +203,7 @@ def __init__(self, blocks_args=None, global_params=None):
# The first block needs to take care of stride and filter size increase.
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
image_size = calculate_output_image_size(image_size, block_args.stride)
self._blocks_image_size.append(image_size)
if block_args.num_repeat > 1: # modify block_args to keep same output size
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1):
Expand All @@ -217,6 +223,10 @@ def __init__(self, blocks_args=None, global_params=None):
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()

self._image_size = image_size
self._last_block_args = block_args
self._last_out_channels = out_channels

def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export).

Expand All @@ -239,6 +249,8 @@ def extract_endpoints(self, inputs):
Dictionary of last intermediate features
with reduction levels i in [1, 2, 3, 4, 5].
Example:


>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
Expand Down Expand Up @@ -284,7 +296,6 @@ def extract_features(self, inputs):
"""
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))

# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
Expand All @@ -294,7 +305,6 @@ def extract_features(self, inputs):

# Head
x = self._swish(self._bn1(self._conv_head(x)))

return x

def forward(self, inputs):
Expand All @@ -309,6 +319,7 @@ def forward(self, inputs):
"""
# Convolution layers
x = self.extract_features(inputs)

# Pooling and final linear layer
x = self._avg_pooling(x)
if self._global_params.include_top:
Expand Down Expand Up @@ -413,3 +424,148 @@ def _change_in_channels(self, in_channels):
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
out_channels = round_filters(32, self._global_params)
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)

class EfficientNetAutoEncoder(EfficientNet):
"""EfficientNet AutoEncoder model.
Most easily loaded with the .from_name or .from_pretrained methods.

Args:
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
global_params (namedtuple): A set of GlobalParams shared between blocks.

References:
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)

Example:


>>> import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNetAutoEncoder.from_pretrained('efficientnet-b0')
>>> model.eval()
>>> ae_output, latent_fc_output = model(inputs)
"""

def __init__(self, blocks_args=None, global_params=None):
super().__init__(blocks_args=blocks_args, global_params=global_params)
bn_mom = self._bn_mom
bn_eps = self._bn_eps
image_size = self._image_size
block_args = self._last_block_args

Conv2d = get_same_padding_conv2d(image_size=image_size)
self._feature_downsample = Conv2d(self._last_out_channels, 8, kernel_size=1, bias=False)
self._downsample_bn = nn.BatchNorm2d(num_features=8, momentum=bn_mom, eps=bn_eps)
self._feature_upsample = Conv2d(8, self._last_out_channels, kernel_size=1, bias=False)
self._upsample_bn = nn.BatchNorm2d(num_features=self._last_out_channels, momentum=bn_mom, eps=bn_eps)
self.feature_size = 8 * image_size[0]**2

# EfficientNet Decoder
# use dynamic image size for decoder
TransposedConv2d = get_same_padding_conv2d(image_size=image_size, transposed=True)

# Stem
# self._decoder_conv_stem symmetry to self._conv_head
in_channels = round_filters(1280, self._global_params)
out_channels = block_args.output_filters # output of final block
self._decoder_conv_stem = TransposedConv2d(in_channels, out_channels, kernel_size=1, bias=False)
image_size = calculate_output_image_size(image_size, 1, transposed=True)
self._decoder_bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
# image_size = calculate_output_image_size(image_size, 2)

# Build blocks
self._decoder_blocks = nn.ModuleList([])
assert len(self._blocks_image_size) == len(self._blocks_args) + 1
self._blocks_image_size = list(reversed(self._blocks_image_size))
for i, block_args in enumerate(reversed(self._blocks_args)):
image_size = self._blocks_image_size[i]
# Update block input and output filters based on depth multiplier.
# input/output are flip here to support deconvolution
block_args = block_args._replace(
input_filters=round_filters(block_args.output_filters, self._global_params),
output_filters=round_filters(block_args.input_filters, self._global_params),
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
)
# The first block needs to take care of stride and filter size increase.
self._decoder_blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size,
decoder_mode=True, decoder_output_image_size=self._blocks_image_size[i+1]))
image_size = self._blocks_image_size[i+1]
if block_args.num_repeat > 1: # modify block_args to keep same output size
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1):
self._decoder_blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size,
decoder_mode=True, decoder_output_image_size=image_size))
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1

# Head
in_channels = round_filters(32, self._global_params) # number of output channels
out_channels = 3 # rgb
TransposedConv2d = get_same_padding_conv2d(image_size=global_params.image_size, transposed=True)
self._decoder_conv_head = TransposedConv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._decoder_bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

def extract_features(self, inputs):
"""use convolution layer to extract feature,
with additional down-sample layer to get 1280 hidden feature.

Args:
inputs (tensor): Input tensor.

Returns:
Output of the final convolution
layer in the efficientnet model.
"""
x = super().extract_features(inputs)
x = self._swish(self._downsample_bn(self._feature_downsample(x)))
return x


def decode_features(self, inputs):
"""decoder portion of this autoencoder.

Args:
inputs (tensor): Input tensor to the decoder,
usually from self.extract_features

Returns:
Output of the final convolution
layer in the efficientnet model.
"""
# upsample
x = self._swish(self._upsample_bn(self._feature_upsample(inputs)))
# Stem
x = self._swish(self._decoder_bn0(self._decoder_conv_stem(x)))
# Blocks
for idx, block in enumerate(self._decoder_blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
# scale drop connect_rate
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)

# Head
x = self._swish(self._decoder_bn1(self._decoder_conv_head(x)))
return x


def forward(self, inputs):
"""EfficientNet AutoEncoder's forward function.
Calls extract_features to extract features,
then calls decode features to generates original inputs.

Args:
inputs (tensor): Input tensor.

Returns:
(AE output tensor, latent representation tensor)
"""
# Convolution layers
x = self.extract_features(inputs)

# Pooling and final linear layer
latent_rep = x.flatten(start_dim=1)

# Deconvolution - decoder
x = self.decode_features(x)
return x, latent_rep
Loading