-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from rubythalib33/devel/v0.2.0
Devel/v0.2.0
- Loading branch information
Showing
10 changed files
with
464 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,4 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Inside of setup.cfg | ||
[metadata] | ||
description-file = README.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
setup( | ||
name='torchvision_3d', | ||
version='0.1.0', | ||
version='0.2.0', | ||
description='3D CNN for PyTorch with imagenet pretrained models', | ||
author='Ruby Abdullah', | ||
author_email='[email protected]', | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from torchvision_3d.models.alexnet import * | ||
from torchvision_3d.models.vgg import * | ||
from torchvision_3d.models.resnet import * | ||
from torchvision_3d.models.resnet import * | ||
from torchvision_3d.models.densenet import * | ||
from torchvision_3d.models.squeezenet import * | ||
from torchvision_3d.models.mobilenetv2 import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint as cp | ||
from torch import Tensor | ||
|
||
class _DenseLayer3D(nn.Sequential): | ||
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False, model=None): | ||
super(_DenseLayer3D, self).__init__() | ||
self.add_module('norm1', nn.BatchNorm3d(num_input_features, affine=True, eps=1e-05, momentum=0.1, track_running_stats=True)) | ||
self.add_module('relu1', nn.ReLU(inplace=True)) | ||
conv1 = nn.Conv3d(num_input_features, bn_size * | ||
growth_rate, kernel_size=1, stride=1, bias=False) | ||
conv1.weight.data = torch.stack([model.conv1.weight.data] , dim=2) if model is not None else conv1.weight.data | ||
self.add_module('conv1', conv1) | ||
self.add_module('norm2', nn.BatchNorm3d(bn_size * | ||
growth_rate, affine=True, eps=1e-05, momentum=0.1, track_running_stats=True)) | ||
self.add_module('relu2', nn.ReLU(inplace=True)) | ||
conv2 = nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False) | ||
conv2.weight.data = torch.stack([model.conv2.weight.data] , dim=2) if model is not None else conv2.weight.data | ||
self.add_module('conv2', conv2) | ||
|
||
self.drop_rate = float(drop_rate) | ||
self.memory_efficient = memory_efficient | ||
self.bn_size = bn_size | ||
|
||
def bn_function(self, inputs): | ||
concated_features = torch.cat(inputs, 1) | ||
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 | ||
return bottleneck_output | ||
|
||
# todo: rewrite when torchscript supports any | ||
def any_requires_grad(self, input) -> bool: | ||
for tensor in input: | ||
if tensor.requires_grad: | ||
return True | ||
return False | ||
|
||
@torch.jit.unused # noqa: T484 | ||
def call_checkpoint_bottleneck(self, input): | ||
def closure(*inputs): | ||
return self.bn_function(inputs) | ||
|
||
return cp.checkpoint(closure, *input) | ||
|
||
def forward(self, input): # noqa: F811 | ||
if isinstance(input, Tensor): | ||
prev_features = [input] | ||
else: | ||
prev_features = input | ||
|
||
if self.memory_efficient and self.any_requires_grad(prev_features): | ||
if torch.jit.is_scripting(): | ||
raise Exception("Memory Efficient not supported in JIT") | ||
|
||
bottleneck_output = self.call_checkpoint_bottleneck(prev_features) | ||
else: | ||
bottleneck_output = self.bn_function(prev_features) | ||
|
||
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) | ||
if self.drop_rate > 0: | ||
new_features = F.dropout(new_features, p=self.drop_rate, | ||
training=self.training) | ||
return new_features | ||
|
||
class _DenseBlock3D(nn.Module): | ||
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False, model=None): | ||
super(_DenseBlock3D, self).__init__() | ||
if model is not None: | ||
for i, mode_c in zip(range(num_layers), model.children()): | ||
layer = _DenseLayer3D( | ||
num_input_features + i * growth_rate, | ||
growth_rate=growth_rate, | ||
bn_size=bn_size, | ||
drop_rate=drop_rate, | ||
memory_efficient=memory_efficient, | ||
model=mode_c | ||
) | ||
self.add_module('denselayer%d' % (i + 1), layer) | ||
else: | ||
for i in range(num_layers): | ||
layer = _DenseLayer3D( | ||
num_input_features + i * growth_rate, | ||
growth_rate=growth_rate, | ||
bn_size=bn_size, | ||
drop_rate=drop_rate, | ||
memory_efficient=memory_efficient | ||
) | ||
self.add_module('denselayer%d' % (i + 1), layer) | ||
|
||
def forward(self, init_features): | ||
features = [init_features] | ||
for name, layer in self.named_children(): | ||
new_features = layer(features) | ||
features.append(new_features) | ||
return torch.cat(features, 1) | ||
|
||
class _Transition(nn.Sequential): | ||
def __init__(self, num_input_features: int, num_output_features: int, model=None) -> None: | ||
super(_Transition, self).__init__() | ||
self.add_module('norm', nn.BatchNorm3d(num_input_features)) | ||
self.add_module('relu', nn.ReLU(inplace=True)) | ||
conv = nn.Conv3d(num_input_features, num_output_features, | ||
kernel_size=1, stride=1, bias=False) | ||
conv.weight.data = torch.stack([model.weight.data] , dim=2) if model is not None else conv.weight.data | ||
self.add_module('conv', conv) | ||
|
||
|
||
self.add_module('pool', nn.AvgPool3d(kernel_size=(1,2,2), stride=(1,2,2))) | ||
|
||
class DenseNet3D(nn.Module): | ||
def __init__(self, type, pretrained=False): | ||
super().__init__() | ||
|
||
if type == 'densenet121': | ||
from torchvision.models import densenet121 | ||
model_instance = densenet121(pretrained=pretrained) | ||
self.features = model_instance.features | ||
elif type == 'densenet161': | ||
from torchvision.models import densenet161 | ||
model_instance = densenet161(pretrained=pretrained) | ||
self.features = model_instance.features | ||
elif type == 'densenet169': | ||
from torchvision.models import densenet169 | ||
model_instance = densenet169(pretrained=pretrained) | ||
self.features = model_instance.features | ||
elif type == 'densenet201': | ||
from torchvision.models import densenet201 | ||
model_instance = densenet201(pretrained=pretrained) | ||
self.features = model_instance.features | ||
else: | ||
raise NotImplementedError('type only support for densenet121, densenet161, densenet169, densenet201') | ||
|
||
self.pretrained = pretrained | ||
self.features = self.init_features() | ||
|
||
def init_features(self): | ||
features = [] | ||
for model in self.features.children(): | ||
if isinstance(model, nn.Conv2d): | ||
model_temp = nn.Conv3d(in_channels=model.in_channels, out_channels=model.out_channels, kernel_size=(1,*model.kernel_size), stride=(1,*model.stride), padding=(0,*model.padding), bias=False) | ||
model_temp.weight.data = torch.stack([model.weight.data] , dim=2) if self.pretrained else model_temp.weight.data | ||
features.append(model_temp) | ||
elif isinstance(model, nn.MaxPool2d): | ||
model_temp = nn.MaxPool3d(kernel_size=[1,model.kernel_size, model.kernel_size], stride=[1,model.stride, model.stride], padding=[0,model.padding, model.padding]) | ||
features.append(model_temp) | ||
elif isinstance(model, nn.ReLU): | ||
features.append(model) | ||
elif isinstance(model, nn.BatchNorm2d): | ||
model_temp = nn.BatchNorm3d(num_features=model.num_features) | ||
features.append(model_temp) | ||
elif model._get_name() == '_DenseBlock': | ||
num_layers = len(model) | ||
num_input_features = model.denselayer1.conv1.in_channels | ||
growth_rate = model.denselayer1.conv2.out_channels | ||
bn_size = model.denselayer1.conv2.in_channels//growth_rate | ||
drop_rate = model.denselayer1.drop_rate | ||
memory_efficient = model.denselayer1.memory_efficient | ||
if self.pretrained: | ||
layer = _DenseBlock3D(num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient, model=model) | ||
else: | ||
layer = _DenseBlock3D(num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient) | ||
features.append(layer) | ||
elif model._get_name() == '_Transition': | ||
num_input_features = model.conv.in_channels | ||
num_output_features = model.conv.out_channels | ||
layer = _Transition(num_input_features, num_output_features) | ||
features.append(layer) | ||
|
||
return nn.Sequential(*features) | ||
|
||
def forward(self, x): | ||
for model in self.features.children(): | ||
x = model(x) | ||
return x | ||
|
||
|
||
if __name__ == '__main__': | ||
inputs = torch.randn(1, 3, 1, 224, 224) | ||
model = DenseNet3D('densenet121', pretrained=True) | ||
outputs = model(inputs) | ||
print(outputs.shape) |
Oops, something went wrong.