diff --git a/.gitignore b/.gitignore index b6e4761..cd97309 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,4 @@ dmypy.json # Pyre type checker .pyre/ +.vscode \ No newline at end of file diff --git a/README.md b/README.md index 3a28977..0b92810 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # 3D-Torchvision 3D torchvision with ImageNet Pretrained +![](./assets/logo.png) + ## it can be used for: 1. Video Embedding 2. Action Recognition @@ -13,20 +15,36 @@ cd 3D-Torchvision python setup.py install ``` +## Update v.0.2.0 +1. add more type for resnet3d model: ```resnext50_32x4d, resnext101_32x8d, wide_resnet50_2 wide_resnet101_2'``` +2. support for densenet3d, squeezenet3d and mobilenetv2_3d + ## Model Ready ```python # 1. AlexNet3D: from torchvision_3d.models import AlexNet3D: -model = AlexNet3D() +model = AlexNet3D(pretrained=True) # 2. VGG3D from torchvision_3d.models import VGG3D -model = VGG3D(type='vgg11') #type can be vgg11, vgg16, vgg19, vgg11_bn, vgg16_bn, vgg19_bn +model = VGG3D(type='vgg11', pretrained=True) #type can be vgg11, vgg16, vgg19, vgg11_bn, vgg16_bn, vgg19_bn #3. ResNet3D from torchvision_3d.models import ResNet3D -model = ResNet3D(type='resnet50') #type can be resnet18, resnet34, resnet50, resnet101, resnet152 +model = ResNet3D(type='resnet50', pretrained=True) #type can be resnet18, resnet34, resnet50, resnet101, resnet152, , resnext50_32x4d, resnext101_32x8d, wide_resnet50_2 wide_resnet101_2' + +#4. DenseNet3D +from torchvision_3d.models import DenseNet3D +model = DenseNet3D(type='densenet121', pretrained=True) #type can be densenet121, densenet161, densenet169, densenet201 + +#5. MobileNetV2_3D +from torchvision_3d.models import MobileNetV2_3D +model = MobileNetV2_3D(pretrained=True) + +#6 SqueezeNet3D +from torchvision_3d.models import SqueezeNet3D +model = SqueezeNet3D(type= 'squeezenet1_0', pretrained=True) #type can be squeezenet1_0, squeezenet1_1 ``` \ No newline at end of file diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000..a853590 Binary files /dev/null and b/assets/logo.png differ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..9d5f797 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +# Inside of setup.cfg +[metadata] +description-file = README.md \ No newline at end of file diff --git a/setup.py b/setup.py index 092771a..ce7ffe7 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='torchvision_3d', - version='0.1.0', + version='0.2.0', description='3D CNN for PyTorch with imagenet pretrained models', author='Ruby Abdullah', author_email='rubyabdullah14@gmail.com', diff --git a/torchvision_3d/models/__init__.py b/torchvision_3d/models/__init__.py index 9df073e..2d20150 100644 --- a/torchvision_3d/models/__init__.py +++ b/torchvision_3d/models/__init__.py @@ -1,3 +1,6 @@ from torchvision_3d.models.alexnet import * from torchvision_3d.models.vgg import * -from torchvision_3d.models.resnet import * \ No newline at end of file +from torchvision_3d.models.resnet import * +from torchvision_3d.models.densenet import * +from torchvision_3d.models.squeezenet import * +from torchvision_3d.models.mobilenetv2 import * \ No newline at end of file diff --git a/torchvision_3d/models/densenet.py b/torchvision_3d/models/densenet.py new file mode 100644 index 0000000..21b22cc --- /dev/null +++ b/torchvision_3d/models/densenet.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch import Tensor + +class _DenseLayer3D(nn.Sequential): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False, model=None): + super(_DenseLayer3D, self).__init__() + self.add_module('norm1', nn.BatchNorm3d(num_input_features, affine=True, eps=1e-05, momentum=0.1, track_running_stats=True)) + self.add_module('relu1', nn.ReLU(inplace=True)) + conv1 = nn.Conv3d(num_input_features, bn_size * + growth_rate, kernel_size=1, stride=1, bias=False) + conv1.weight.data = torch.stack([model.conv1.weight.data] , dim=2) if model is not None else conv1.weight.data + self.add_module('conv1', conv1) + self.add_module('norm2', nn.BatchNorm3d(bn_size * + growth_rate, affine=True, eps=1e-05, momentum=0.1, track_running_stats=True)) + self.add_module('relu2', nn.ReLU(inplace=True)) + conv2 = nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False) + conv2.weight.data = torch.stack([model.conv2.weight.data] , dim=2) if model is not None else conv2.weight.data + self.add_module('conv2', conv2) + + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + self.bn_size = bn_size + + def bn_function(self, inputs): + concated_features = torch.cat(inputs, 1) + bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, input) -> bool: + for tensor in input: + if tensor.requires_grad: + return True + return False + + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, input): + def closure(*inputs): + return self.bn_function(inputs) + + return cp.checkpoint(closure, *input) + + def forward(self, input): # noqa: F811 + if isinstance(input, Tensor): + prev_features = [input] + else: + prev_features = input + + if self.memory_efficient and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bn_function(prev_features) + + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, + training=self.training) + return new_features + +class _DenseBlock3D(nn.Module): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False, model=None): + super(_DenseBlock3D, self).__init__() + if model is not None: + for i, mode_c in zip(range(num_layers), model.children()): + layer = _DenseLayer3D( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + memory_efficient=memory_efficient, + model=mode_c + ) + self.add_module('denselayer%d' % (i + 1), layer) + else: + for i in range(num_layers): + layer = _DenseLayer3D( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + memory_efficient=memory_efficient + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.named_children(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + +class _Transition(nn.Sequential): + def __init__(self, num_input_features: int, num_output_features: int, model=None) -> None: + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm3d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + conv = nn.Conv3d(num_input_features, num_output_features, + kernel_size=1, stride=1, bias=False) + conv.weight.data = torch.stack([model.weight.data] , dim=2) if model is not None else conv.weight.data + self.add_module('conv', conv) + + + self.add_module('pool', nn.AvgPool3d(kernel_size=(1,2,2), stride=(1,2,2))) + +class DenseNet3D(nn.Module): + def __init__(self, type, pretrained=False): + super().__init__() + + if type == 'densenet121': + from torchvision.models import densenet121 + model_instance = densenet121(pretrained=pretrained) + self.features = model_instance.features + elif type == 'densenet161': + from torchvision.models import densenet161 + model_instance = densenet161(pretrained=pretrained) + self.features = model_instance.features + elif type == 'densenet169': + from torchvision.models import densenet169 + model_instance = densenet169(pretrained=pretrained) + self.features = model_instance.features + elif type == 'densenet201': + from torchvision.models import densenet201 + model_instance = densenet201(pretrained=pretrained) + self.features = model_instance.features + else: + raise NotImplementedError('type only support for densenet121, densenet161, densenet169, densenet201') + + self.pretrained = pretrained + self.features = self.init_features() + + def init_features(self): + features = [] + for model in self.features.children(): + if isinstance(model, nn.Conv2d): + model_temp = nn.Conv3d(in_channels=model.in_channels, out_channels=model.out_channels, kernel_size=(1,*model.kernel_size), stride=(1,*model.stride), padding=(0,*model.padding), bias=False) + model_temp.weight.data = torch.stack([model.weight.data] , dim=2) if self.pretrained else model_temp.weight.data + features.append(model_temp) + elif isinstance(model, nn.MaxPool2d): + model_temp = nn.MaxPool3d(kernel_size=[1,model.kernel_size, model.kernel_size], stride=[1,model.stride, model.stride], padding=[0,model.padding, model.padding]) + features.append(model_temp) + elif isinstance(model, nn.ReLU): + features.append(model) + elif isinstance(model, nn.BatchNorm2d): + model_temp = nn.BatchNorm3d(num_features=model.num_features) + features.append(model_temp) + elif model._get_name() == '_DenseBlock': + num_layers = len(model) + num_input_features = model.denselayer1.conv1.in_channels + growth_rate = model.denselayer1.conv2.out_channels + bn_size = model.denselayer1.conv2.in_channels//growth_rate + drop_rate = model.denselayer1.drop_rate + memory_efficient = model.denselayer1.memory_efficient + if self.pretrained: + layer = _DenseBlock3D(num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient, model=model) + else: + layer = _DenseBlock3D(num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient) + features.append(layer) + elif model._get_name() == '_Transition': + num_input_features = model.conv.in_channels + num_output_features = model.conv.out_channels + layer = _Transition(num_input_features, num_output_features) + features.append(layer) + + return nn.Sequential(*features) + + def forward(self, x): + for model in self.features.children(): + x = model(x) + return x + + +if __name__ == '__main__': + inputs = torch.randn(1, 3, 1, 224, 224) + model = DenseNet3D('densenet121', pretrained=True) + outputs = model(inputs) + print(outputs.shape) \ No newline at end of file diff --git a/torchvision_3d/models/mobilenetv2.py b/torchvision_3d/models/mobilenetv2.py new file mode 100644 index 0000000..4fd934b --- /dev/null +++ b/torchvision_3d/models/mobilenetv2.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn + +class ConvNormActivation3D(torch.nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding = None, + groups: int = 1, + norm_layer = torch.nn.BatchNorm3d, + activation_layer= torch.nn.ReLU, + dilation: int = 1, + inplace: bool = True, + model = None + ) -> None: + if padding is None: + padding = (kernel_size - 1) // 2 * dilation + layers = [torch.nn.Conv3d(in_channels, out_channels, (1,kernel_size, kernel_size), (1,stride,stride), (0, padding,padding), + dilation=dilation, groups=groups, bias=norm_layer is None)] + if model is not None: + layers[0].weight.data = torch.stack([model[0].weight.data] , dim=2) + if norm_layer is not None: + layers.append(norm_layer(out_channels)) + if activation_layer is not None: + layers.append(activation_layer(inplace=inplace)) + super().__init__(*layers) + self.out_channels = out_channels + +class InvertedResidual3D(nn.Module): + def __init__( + self, + inp: int, + oup: int, + stride: int, + expand_ratio: int, + norm_layer = None, + model = None + ) -> None: + super(InvertedResidual3D, self).__init__() + self.stride = stride + assert stride in [1, 2] + + if norm_layer is None: + norm_layer = nn.BatchNorm3d + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + if model is None: + layers.append(ConvNormActivation3D(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, + activation_layer=nn.ReLU6)) + else: + layers.append(ConvNormActivation3D(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, + activation_layer=nn.ReLU6, model=model[0])) + if model is None: + layers.extend([ + # dw + ConvNormActivation3D(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, + activation_layer=nn.ReLU6), + # pw-linear + nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ]) + else: + pw_linear = nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False) + pw_linear.weight.data = torch.stack([model.conv[2].weight.data] , dim=2) + layers.extend([ + # dw + ConvNormActivation3D(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, + activation_layer=nn.ReLU6, model=model[1]), + # pw-linear + pw_linear, + norm_layer(oup), + ]) + self.conv = nn.Sequential(*layers) + self.out_channels = oup + self._is_cn = stride > 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + +class MobileNetV2_3D(nn.Module): + def __init__(self, pretrained=False): + super().__init__() + from torchvision.models import mobilenet_v2 + + self.pretrained = pretrained + + model_instance = mobilenet_v2(pretrained=pretrained) + self.features = model_instance.features + self.features = self.init_features() + + def init_features(self): + features = [] + for model in self.features.children(): + if model._get_name() == 'ConvNormActivation': + in_channels = model[0].in_channels + out_channels = model[0].out_channels + kernel_size = model[0].kernel_size[0] + stride = model[0].stride[0] + padding = model[0].padding[0] if model[0].padding is not None else None + dilation = model[0].dilation[0] + if self.pretrained: + layer = ConvNormActivation3D(in_channels=in_channels, out_channels=out_channels, dilation=dilation, activation_layer=nn.ReLU6, kernel_size=kernel_size, stride=stride, padding=padding, model=model) + else: + layer = ConvNormActivation3D(in_channels=in_channels, out_channels=out_channels, dilation=dilation, activation_layer=nn.ReLU6, kernel_size=kernel_size, stride=stride, padding=padding) + features.append(layer) + elif model._get_name() == 'InvertedResidual': + inp = model.conv[0][0].in_channels + oup = model.conv[-2].out_channels + stride = model.stride + expand_ratio = model.conv[-2].in_channels / inp + if self.pretrained: + layer = InvertedResidual3D(inp, oup, stride, expand_ratio) + else: + layer = InvertedResidual3D(inp, oup, stride, expand_ratio) + features.append(layer) + + return nn.Sequential(*features) + + def forward(self, x): + for model in self.features.children(): + x = model(x) + return x + +if __name__ == '__main__': + inputs = torch.randn(1, 3, 1, 224, 224) + model = MobileNetV2_3D(pretrained=True) + outputs = model(inputs) + print(outputs.shape) \ No newline at end of file diff --git a/torchvision_3d/models/resnet.py b/torchvision_3d/models/resnet.py index f41d502..8624a76 100644 --- a/torchvision_3d/models/resnet.py +++ b/torchvision_3d/models/resnet.py @@ -22,7 +22,7 @@ def __init__(self, inplanes, planes, stride=1, base_width=64, groups = 1, downsa self.conv1.weight.data = torch.stack([model_instance.conv1.weight.data], dim=2) self.conv2.weight.data = torch.stack([model_instance.conv2.weight.data], dim=2) self.conv3.weight.data = torch.stack([model_instance.conv3.weight.data], dim=2) - + if self.downsample is not None: self.downsample = nn.Sequential( nn.Conv3d(inplanes, planes * 4, kernel_size=[1,1,1], stride=[1,*model_instance.downsample[0].stride], padding=[0,*model_instance.downsample[0].padding], bias=False), @@ -94,9 +94,12 @@ def forward(self, x): class ResNet3D(nn.Module): def __init__(self, type, pretrained=False): super(ResNet3D, self).__init__() - assert type in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], 'type only support for resnet18, resnet34, resnet50, resnet101, resnet152' + assert type in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'], 'type only support for resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2' self.type = type self.features = None + + self.groups = 1 + self.base_width = 64 if type == 'resnet18': from torchvision.models.resnet import resnet18 @@ -113,6 +116,24 @@ def __init__(self, type, pretrained=False): elif type == 'resnet152': from torchvision.models.resnet import resnet152 self.features = self.init_layer(resnet152(pretrained=pretrained)) + elif type == 'resnext50_32x4d': + from torchvision.models.resnet import resnext50_32x4d + self.features = self.init_layer(resnext50_32x4d(pretrained=pretrained)) + self.groups = 32 + self.base_width = 4 + elif type == 'resnext101_32x8d': + from torchvision.models.resnet import resnext101_32x8d + self.features = self.init_layer(resnext101_32x8d(pretrained=pretrained)) + self.groups = 32 + self.base_width = 8 + elif type == 'wide_resnet50_2': + from torchvision.models.resnet import wide_resnet50_2 + self.features = self.init_layer(wide_resnet50_2(pretrained=pretrained)) + self.base_width = 64 * 2 + elif type == 'wide_resnet101_2': + from torchvision.models.resnet import wide_resnet101_2 + self.features = self.init_layer(wide_resnet101_2(pretrained=pretrained)) + self.base_width = 64 * 2 self.features = self.init_features() @@ -145,7 +166,7 @@ def init_features(self): child_temp = BasicBlock3D(inplanes=child.conv1.in_channels, planes=child.conv1.out_channels, stride=child.conv1.stride, downsample=child.downsample, model_instance=child) features_child.append(child_temp) elif child._get_name() == 'Bottleneck': - child_temp = Bottleneck3D(inplanes=child.conv1.in_channels, planes=child.conv1.out_channels, stride=child.conv2.stride, downsample=child.downsample, model_instance=child) + child_temp = Bottleneck3D(inplanes=child.conv1.in_channels, planes=child.conv3.out_channels//4, stride=child.conv2.stride, downsample=child.downsample, groups=self.groups, base_width=self.base_width, model_instance=child) features_child.append(child_temp) features.append(nn.Sequential(*features_child)) @@ -159,7 +180,5 @@ def forward(self, x): if __name__ == '__main__': from torchvision.models.resnet import resnet50 sample = torch.randn(1,3,1,224,224) - model = ResNet3D('resnet50') - model_18 = resnet50(pretrained=False) - print(model_18) + model = ResNet3D('resnet101') print(model(sample).shape) \ No newline at end of file diff --git a/torchvision_3d/models/squeezenet.py b/torchvision_3d/models/squeezenet.py new file mode 100644 index 0000000..8452e2c --- /dev/null +++ b/torchvision_3d/models/squeezenet.py @@ -0,0 +1,88 @@ +from turtle import forward +import torch +import torch.nn as nn + +class Fire3D(nn.Module): + def __init__( + self, + inplanes: int, + squeeze_planes: int, + expand1x1_planes: int, + expand3x3_planes: int, + model = None + ) -> None: + super(Fire3D, self).__init__() + self.inplanes = inplanes + self.squeeze = nn.Conv3d(inplanes, squeeze_planes, kernel_size=1) + self.squeeze.weight.data = torch.stack([model.squeeze.weight.data] , dim=2) if model is not None else self.squeeze.weight.data + self.squeeze_activation = nn.ReLU(inplace=True) + self.expand1x1 = nn.Conv3d(squeeze_planes, expand1x1_planes, kernel_size=1) + self.expand1x1.weight.data = torch.stack([model.expand1x1.weight.data] , dim=2) if model is not None else self.expand1x1.weight.data + self.expand1x1_activation = nn.ReLU(inplace=True) + self.expand3x3 = nn.Conv3d(squeeze_planes, expand3x3_planes, kernel_size=(1,3,3), padding=(0,1,1)) + self.expand3x3.weight.data = torch.stack([model.expand3x3.weight.data] , dim=2) if model is not None else self.expand3x3.weight.data + self.expand3x3_activation = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.squeeze_activation(self.squeeze(x)) + return torch.cat([ + self.expand1x1_activation(self.expand1x1(x)), + self.expand3x3_activation(self.expand3x3(x)) + ], 1) + +class SqueezeNet3D(nn.Module): + def __init__(self, type, pretrained=False): + super().__init__() + + if type not in ['squeezenet1_0', 'squeezenet1_1']: + raise NotImplementedError('type only support for squeezenet1_0, squeezenet1_1') + + if type == 'squeezenet1_0': + from torchvision.models import squeezenet1_0 + model_instance = squeezenet1_0(pretrained=pretrained) + self.features = model_instance.features + elif type == 'squeezenet1_1': + from torchvision.models import squeezenet1_1 + model_instance = squeezenet1_1(pretrained=pretrained) + self.features = model_instance.features + + self.pretrained = pretrained + self.features = self.init_features() + + def init_features(self): + features = [] + for model in self.features.children(): + if isinstance(model, nn.Conv2d): + model_temp = nn.Conv3d(in_channels=model.in_channels, out_channels=model.out_channels, kernel_size=(1,*model.kernel_size), stride=(1,*model.stride), padding=(0,*model.padding), bias=False) + model_temp.weight.data = torch.stack([model.weight.data] , dim=2) if self.pretrained else model_temp.weight.data + features.append(model_temp) + elif isinstance(model, nn.MaxPool2d): + model_temp = nn.MaxPool3d(kernel_size=[1,model.kernel_size, model.kernel_size], stride=[1,model.stride, model.stride], padding=[0,model.padding, model.padding]) + features.append(model_temp) + elif isinstance(model, nn.ReLU): + features.append(model) + elif isinstance(model, nn.BatchNorm2d): + model_temp = nn.BatchNorm3d(num_features=model.num_features) + features.append(model_temp) + elif model._get_name() == 'Fire': + inplanes = model.inplanes + squeeze_planes = model.squeeze.out_channels + expand1x1_planes = model.expand1x1.out_channels + expand3x3_planes = model.expand3x3.out_channels + if self.pretrained: + features.append(Fire3D(inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes, model)) + else: + features.append(Fire3D(inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes)) + + return nn.Sequential(*features) + + def forward(self, x): + x = self.features(x) + + return x + +if __name__ == '__main__': + inputs = torch.randn(1, 3, 1, 224, 224) + model = SqueezeNet3D('squeezenet1_0', pretrained=True) + outputs = model(inputs) + print(outputs.shape)