Skip to content

Commit 3e7d168

Browse files
committed
package code of pytorch 0.4.1 and 1.1.0
1 parent 0b122c1 commit 3e7d168

File tree

7 files changed

+8
-7
lines changed

7 files changed

+8
-7
lines changed

demo_MNIST.py renamed to Non-Local_pytorch_0.4.1_to_1.1.0/demo_MNIST.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import time
77

88

9-
def calc_acc(x, y):
10-
x = torch.max(x, dim=-1)[1]
11-
accuracy = sum(x == y) / x.size(0)
12-
return accuracy
9+
# def calc_acc(x, y):
10+
# x = torch.max(x, dim=-1)[1]
11+
# accuracy = sum(x == y) / x.size(0)
12+
# return accuracy
1313

1414

1515
train_data = torchvision.datasets.MNIST(root='./mnist', train=True,
@@ -44,7 +44,7 @@ def calc_acc(x, y):
4444
label_batch = label_batch.cuda()
4545

4646
predict = net(img_batch)
47-
acc = calc_acc(predict.cpu().data, label_batch.cpu().data)
47+
# acc = calc_acc(predict.cpu().data, label_batch.cpu().data)
4848
loss = loss_func(predict, label_batch)
4949

5050
net.zero_grad()
File renamed without changes.

README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ in **lib/network.py**.
99

1010

1111
## Environment
12-
- python 3.6 or python 3.7.3
13-
- pytorch 0.4.1 or pytorch 1.1.0
12+
- python 3.7.3
13+
- pytorch 1.2.0
1414

1515

1616
## Update Records
@@ -25,6 +25,7 @@ old versions (**lib/non_loca.py** and **lib/non_local_simple_version.py**) into
2525
5. Modify the code to support pytorch 0.4.1, and move the code supporting pytorch 0.3.1 \
2626
to **Non-Local_pytorch_0.3.1/**.
2727
6. Test the code with pytorch 1.1.0 and it works.
28+
7. Move the code supporting pytorch 0.4.1 and 1.1.0 to **Non-Local_pytorch_0.4.1_to_1.1.0/**.
2829

2930
## Running Steps
3031
1. Select the type of non-local block in **lib/network.py**.

0 commit comments

Comments
 (0)