diff --git a/classification/models/imagenet/resnet_sge.py b/classification/models/imagenet/resnet_sge.py index 548cc67..93cac1d 100644 --- a/classification/models/imagenet/resnet_sge.py +++ b/classification/models/imagenet/resnet_sge.py @@ -11,23 +11,37 @@ def __init__(self, groups = 64): self.weight = Parameter(torch.zeros(1, groups, 1, 1)) self.bias = Parameter(torch.ones(1, groups, 1, 1)) self.sig = nn.Sigmoid() + self.gn = nn.GroupNorm(1, 1) + # By GroupNorm def forward(self, x): # (b, c, h, w) b, c, h, w = x.size() x = x.view(b * self.groups, -1, h, w) xn = x * self.avg_pool(x) xn = xn.sum(dim=1, keepdim=True) - t = xn.view(b * self.groups, -1) - t = t - t.mean(dim=1, keepdim=True) - std = t.std(dim=1, keepdim=True) + 1e-5 - t = t / std - t = t.view(b, self.groups, h, w) - t = t * self.weight + self.bias - t = t.view(b * self.groups, 1, h, w) + + xn = xn.view(b * self.groups, -1, h, w) + t = self.gn.forward(x_pool) x = x * self.sig(t) x = x.view(b, c, h, w) return x + # def forward(self, x): # (b, c, h, w) + # b, c, h, w = x.size() + # x = x.view(b * self.groups, -1, h, w) + # xn = x * self.avg_pool(x) + # xn = xn.sum(dim=1, keepdim=True) + # t = xn.view(b * self.groups, -1) + # t = t - t.mean(dim=1, keepdim=True) + # std = t.std(dim=1, keepdim=True) + 1e-5 + # t = t / std + # t = t.view(b, self.groups, h, w) + # t = t * self.weight + self.bias + # t = t.view(b * self.groups, 1, h, w) + # x = x * self.sig(t) + # x = x.view(b, c, h, w) + # return x + def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,