-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathResNetModel.py
36 lines (25 loc) · 929 Bytes
/
ResNetModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from resnet import resnet18, resnet34, resnet50
class ResNetModel(nn.Module):
def __init__(self, num_classes=None):
super(ResNetModel, self).__init__()
self.base = resnet34(pretrained=True)
planes = 512
if num_classes is not None:
self.fc = nn.Linear(planes, num_classes)
init.xavier_uniform(self.fc.weight)
init.constant(self.fc.bias, 0.1)
def forward(self, x):
# shape [N, C, H, W]
feat = self.base(x)
global_feat = F.avg_pool2d(feat, feat.size()[2:])
# shape [N, C]
global_feat = global_feat.view(global_feat.size(0), -1)
if hasattr(self, 'fc'):
logits = self.fc(global_feat)
return global_feat, logits
# return global_feat, local_feat
return global_feat