Skip to content

Commit 0276c7b

Browse files
committed
modify to support visualize NL_MAP and based on pytorch 1.2.0
1 parent 3e7d168 commit 0276c7b

34 files changed

+865
-14
lines changed

README.md

+42-14
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,68 @@
33

44
## Statement
55
- You can find different kinds of non-local block in **lib/**.
6-
- The code is tested on MNIST dataset. You can select the type of non-local block
7-
in **lib/network.py**.
8-
- If there is something wrong in my code, please contact me, thanks!
96

7+
- You can **visualize** the Non_local Attention Map by following the **Running Steps** shown below.
8+
9+
- The code is tested on MNIST dataset. You can select the type of non-local block in **lib/network.py**.
10+
11+
- If there is something wrong in my code, please contact me, thanks!
1012

1113
## Environment
1214
- python 3.7.3
1315
- pytorch 1.2.0
16+
- opencv 3.4.2
17+
18+
## Visualization
19+
1. In the **first** Non-local Layer.
20+
![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_1/37.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_1/44.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_1/46.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_1/110.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_1/161.png)
21+
22+
2. In the **second** Non-local Layer.
23+
![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_2/1.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_2/8.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_2/10.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_2/18.png)![](/home/phor/Code/project/Non-Local_pytorch/nl_map_vis/nl_map_2/38.png)
24+
25+
26+
## Running Steps
27+
1. Select the type of non-local block in **lib/network.py**.
28+
```
29+
from lib.non_local_concatenation import NONLocalBlock2D
30+
from lib.non_local_gaussian import NONLocalBlock2D
31+
from lib.non_local_embedded_gaussian import NONLocalBlock2D
32+
from lib.non_local_dot_product import NONLocalBlock2D
33+
2. Run **demo_MNIST_train.py** with one GPU or multi GPU to train the Network. Then the weights will be save in **weights/**.
34+
```
35+
CUDA_VISIBLE_DEVICES=0,1 python demo_MNIST.py
36+
37+
3. Run **nl_map_save.py** to save NL_MAP of one test sample in **nl_map_vis**.
38+
```
39+
CUDA_VISIBLE_DEVICES=0,1 python nl_map_save.py
40+
41+
4. Come into **nl_map_vis/** and run **nl_map_vis.py** to visualize the NL_MAP. (tips: if the Non-local type you select is **non_local_concatenation** or **non_local_dot_product** (without Softmax operation), you may need to normalize NL_MAP in the visualize code)
42+
```
43+
python nl_map_save.py
1444
1545
1646
## Update Records
1747
1. Figure out how to implement the **concatenation** type, and add the code to **lib/**.
48+
1849
2. Fix the bug in **lib/non_local.py** (old version) when using multi-gpu. Someone shares the
1950
reason with me, and you can find it in [here](https://github.com/pytorch/pytorch/issues/8637).
51+
2052
3. Fix the error of 3D pooling in **lib/non_local.py** (old version). Appreciate
2153
[**protein27**](https://github.com/AlexHex7/Non-local_pytorch/issues/17) for pointing it out.
54+
2255
4. For convenience, I split the **lib/non_local.py** into four python files, and move the
2356
old versions (**lib/non_loca.py** and **lib/non_local_simple_version.py**) into
2457
**lib/backup/**.
58+
2559
5. Modify the code to support pytorch 0.4.1, and move the code supporting pytorch 0.3.1 \
2660
to **Non-Local_pytorch_0.3.1/**.
61+
2762
6. Test the code with pytorch 1.1.0 and it works.
28-
7. Move the code supporting pytorch 0.4.1 and 1.1.0 to **Non-Local_pytorch_0.4.1_to_1.1.0/**.
2963
30-
## Running Steps
31-
1. Select the type of non-local block in **lib/network.py**.
32-
```
33-
from lib.non_local_concatenation import NONLocalBlock2D
34-
from lib.non_local_gaussian import NONLocalBlock2D
35-
from lib.non_local_embedded_gaussian import NONLocalBlock2D
36-
from lib.non_local_dot_product import NONLocalBlock2D
37-
2. Run **demo_MNIST.py** with one GPU or multi GPU.
38-
```
39-
CUDA_VISIBLE_DEVICES=0,1 python demo_MNIST.py
64+
7. Move the code supporting pytorch 0.4.1 and 1.1.0 to **Non-Local_pytorch_0.4.1_to_1.1.0/** (In fact, I think it can also support pytorch 1.2.0).
65+
66+
8. In order to visualize NL_MAP, some code have been slightly modified. The code **nl_map_save.py** is added to save NL_MAP (two Non-local Layer) of one test sample. The code **Non-local_pytorch/nl_map_vis.py** is added to visualize NL_MAP. Besieds, the code is support pytorch 1.2.0.
67+
4068
4169
## Todo
4270
- Experiments on Charades dataset.

demo_MNIST_train.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch.utils.data as Data
3+
import torchvision
4+
from lib.network import Network
5+
from torch import nn
6+
import time
7+
8+
9+
train_data = torchvision.datasets.MNIST(root='./mnist', train=True,
10+
transform=torchvision.transforms.ToTensor(),
11+
download=True)
12+
test_data = torchvision.datasets.MNIST(root='./mnist/',
13+
transform=torchvision.transforms.ToTensor(),
14+
train=False)
15+
16+
train_loader = Data.DataLoader(dataset=train_data, batch_size=128, shuffle=True)
17+
test_loader = Data.DataLoader(dataset=test_data, batch_size=128, shuffle=False)
18+
19+
train_batch_num = len(train_loader)
20+
test_batch_num = len(test_loader)
21+
22+
net = Network()
23+
if torch.cuda.is_available():
24+
net = nn.DataParallel(net)
25+
net.cuda()
26+
27+
opt = torch.optim.Adam(net.parameters(), lr=0.001)
28+
loss_func = nn.CrossEntropyLoss()
29+
30+
for epoch_index in range(10):
31+
st = time.time()
32+
33+
torch.set_grad_enabled(True)
34+
net.train()
35+
for train_batch_index, (img_batch, label_batch) in enumerate(train_loader):
36+
if torch.cuda.is_available():
37+
img_batch = img_batch.cuda()
38+
label_batch = label_batch.cuda()
39+
40+
predict = net(img_batch)
41+
loss = loss_func(predict, label_batch)
42+
43+
net.zero_grad()
44+
loss.backward()
45+
opt.step()
46+
47+
print('(LR:%f) Time of a epoch:%.4fs' % (opt.param_groups[0]['lr'], time.time()-st))
48+
49+
torch.set_grad_enabled(False)
50+
net.eval()
51+
total_loss = []
52+
total_acc = 0
53+
total_sample = 0
54+
55+
for test_batch_index, (img_batch, label_batch) in enumerate(test_loader):
56+
if torch.cuda.is_available():
57+
img_batch = img_batch.cuda()
58+
label_batch = label_batch.cuda()
59+
60+
predict = net(img_batch)
61+
loss = loss_func(predict, label_batch)
62+
63+
predict = predict.argmax(dim=1)
64+
acc = (predict == label_batch).sum()
65+
66+
total_loss.append(loss)
67+
total_acc += acc
68+
total_sample += img_batch.size(0)
69+
70+
net.train()
71+
72+
mean_acc = total_acc.item() * 1.0 / total_sample
73+
mean_loss = sum(total_loss) / total_loss.__len__()
74+
75+
print('[Test] epoch[%d/%d] acc:%.4f%% loss:%.4f\n'
76+
% (epoch_index, 10, mean_acc * 100, mean_loss.item()))
77+
78+
# weight_path = 'weights/net.pth'
79+
# print('Save Net weights to', weight_path)
80+
# net.cpu()
81+
# torch.save(net.state_dict(), weight_path)

lib/network.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from torch import nn
2+
# from lib.non_local_concatenation import NONLocalBlock2D
3+
# from lib.non_local_gaussian import NONLocalBlock2D
4+
from lib.non_local_embedded_gaussian import NONLocalBlock2D
5+
# from lib.non_local_dot_product import NONLocalBlock2D
6+
7+
8+
class Network(nn.Module):
9+
def __init__(self):
10+
super(Network, self).__init__()
11+
12+
self.conv_1 = nn.Sequential(
13+
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
14+
nn.BatchNorm2d(32),
15+
nn.ReLU(),
16+
nn.MaxPool2d(2),
17+
)
18+
19+
self.nl_1 = NONLocalBlock2D(in_channels=32)
20+
self.conv_2 = nn.Sequential(
21+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
22+
nn.BatchNorm2d(64),
23+
nn.ReLU(),
24+
nn.MaxPool2d(2),
25+
)
26+
27+
self.nl_2 = NONLocalBlock2D(in_channels=64)
28+
self.conv_3 = nn.Sequential(
29+
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
30+
nn.BatchNorm2d(128),
31+
nn.ReLU(),
32+
nn.MaxPool2d(2),
33+
)
34+
35+
self.fc = nn.Sequential(
36+
nn.Linear(in_features=128*3*3, out_features=256),
37+
nn.ReLU(),
38+
nn.Dropout(0.5),
39+
40+
nn.Linear(in_features=256, out_features=10)
41+
)
42+
43+
def forward(self, x):
44+
batch_size = x.size(0)
45+
46+
feature_1 = self.conv_1(x)
47+
nl_feature_1 = self.nl_1(feature_1)
48+
49+
feature_2 = self.conv_2(nl_feature_1)
50+
nl_feature_2 = self.nl_2(feature_2)
51+
52+
output = self.conv_3(nl_feature_2).view(batch_size, -1)
53+
output = self.fc(output)
54+
55+
return output
56+
57+
def forward_with_nl_map(self, x):
58+
batch_size = x.size(0)
59+
60+
feature_1 = self.conv_1(x)
61+
nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True)
62+
63+
feature_2 = self.conv_2(nl_feature_1)
64+
nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True)
65+
66+
output = self.conv_3(nl_feature_2).view(batch_size, -1)
67+
output = self.fc(output)
68+
69+
return output, [nl_map_1, nl_map_2]
70+
71+
72+
if __name__ == '__main__':
73+
import torch
74+
75+
img = torch.randn(3, 1, 28, 28)
76+
net = Network()
77+
out = net(img)
78+
print(out.size())
79+

lib/non_local_concatenation.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
5+
6+
class _NonLocalBlockND(nn.Module):
7+
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
8+
super(_NonLocalBlockND, self).__init__()
9+
10+
assert dimension in [1, 2, 3]
11+
12+
self.dimension = dimension
13+
self.sub_sample = sub_sample
14+
15+
self.in_channels = in_channels
16+
self.inter_channels = inter_channels
17+
18+
if self.inter_channels is None:
19+
self.inter_channels = in_channels // 2
20+
if self.inter_channels == 0:
21+
self.inter_channels = 1
22+
23+
if dimension == 3:
24+
conv_nd = nn.Conv3d
25+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
26+
bn = nn.BatchNorm3d
27+
elif dimension == 2:
28+
conv_nd = nn.Conv2d
29+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
30+
bn = nn.BatchNorm2d
31+
else:
32+
conv_nd = nn.Conv1d
33+
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
34+
bn = nn.BatchNorm1d
35+
36+
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
37+
kernel_size=1, stride=1, padding=0)
38+
39+
if bn_layer:
40+
self.W = nn.Sequential(
41+
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
42+
kernel_size=1, stride=1, padding=0),
43+
bn(self.in_channels)
44+
)
45+
nn.init.constant_(self.W[1].weight, 0)
46+
nn.init.constant_(self.W[1].bias, 0)
47+
else:
48+
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49+
kernel_size=1, stride=1, padding=0)
50+
nn.init.constant_(self.W.weight, 0)
51+
nn.init.constant_(self.W.bias, 0)
52+
53+
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
54+
kernel_size=1, stride=1, padding=0)
55+
56+
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57+
kernel_size=1, stride=1, padding=0)
58+
59+
self.concat_project = nn.Sequential(
60+
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
61+
nn.ReLU()
62+
)
63+
64+
if sub_sample:
65+
self.g = nn.Sequential(self.g, max_pool_layer)
66+
self.phi = nn.Sequential(self.phi, max_pool_layer)
67+
68+
def forward(self, x, return_nl_map=False):
69+
'''
70+
:param x: (b, c, t, h, w)
71+
:param return_nl_map: if True return z, nl_map, else only return z.
72+
:return:
73+
'''
74+
75+
batch_size = x.size(0)
76+
77+
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
78+
g_x = g_x.permute(0, 2, 1)
79+
80+
# (b, c, N, 1)
81+
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
82+
# (b, c, 1, N)
83+
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
84+
85+
h = theta_x.size(2)
86+
w = phi_x.size(3)
87+
theta_x = theta_x.repeat(1, 1, 1, w)
88+
phi_x = phi_x.repeat(1, 1, h, 1)
89+
90+
concat_feature = torch.cat([theta_x, phi_x], dim=1)
91+
f = self.concat_project(concat_feature)
92+
b, _, h, w = f.size()
93+
f = f.view(b, h, w)
94+
95+
N = f.size(-1)
96+
f_div_C = f / N
97+
98+
y = torch.matmul(f_div_C, g_x)
99+
y = y.permute(0, 2, 1).contiguous()
100+
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
101+
W_y = self.W(y)
102+
z = W_y + x
103+
104+
if return_nl_map:
105+
return z, f_div_C
106+
return z
107+
108+
109+
class NONLocalBlock1D(_NonLocalBlockND):
110+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
111+
super(NONLocalBlock1D, self).__init__(in_channels,
112+
inter_channels=inter_channels,
113+
dimension=1, sub_sample=sub_sample,
114+
bn_layer=bn_layer)
115+
116+
117+
class NONLocalBlock2D(_NonLocalBlockND):
118+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
119+
super(NONLocalBlock2D, self).__init__(in_channels,
120+
inter_channels=inter_channels,
121+
dimension=2, sub_sample=sub_sample,
122+
bn_layer=bn_layer)
123+
124+
125+
class NONLocalBlock3D(_NonLocalBlockND):
126+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,):
127+
super(NONLocalBlock3D, self).__init__(in_channels,
128+
inter_channels=inter_channels,
129+
dimension=3, sub_sample=sub_sample,
130+
bn_layer=bn_layer)
131+
132+
133+
if __name__ == '__main__':
134+
import torch
135+
136+
for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
137+
img = torch.zeros(2, 3, 20)
138+
net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
139+
out = net(img)
140+
print(out.size())
141+
142+
img = torch.zeros(2, 3, 20, 20)
143+
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
144+
out = net(img)
145+
print(out.size())
146+
147+
img = torch.randn(2, 3, 8, 20, 20)
148+
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
149+
out = net(img)
150+
print(out.size())

0 commit comments

Comments
 (0)