Skip to content

Commit f4f7fdf

Browse files
committed
Update 01-convolutional-nn
1 parent e8c25b1 commit f4f7fdf

File tree

1 file changed

+204
-0
lines changed
  • tutorials/02-intermediate/01-convolutional-nn

1 file changed

+204
-0
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
5+
import torchvision
6+
import torchvision.transforms as transforms
7+
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
11+
# ========================================= #
12+
# Load and normalize the data #
13+
# ========================================= #
14+
15+
transform = transforms.Compose([
16+
transforms.ToTensor(),
17+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
18+
])
19+
20+
batch_size = 4
21+
train_data = torchvision.datasets.CIFAR10(
22+
root='data',
23+
train=True,
24+
download=True,
25+
transform=transform
26+
)
27+
28+
test_data = torchvision.datasets.CIFAR10(
29+
root='data',
30+
train=False,
31+
download=True,
32+
transform=transform
33+
)
34+
35+
train_loader = torch.utils.data.DataLoader(
36+
dataset=train_data,
37+
batch_size=batch_size,
38+
shuffle=True
39+
)
40+
41+
test_loader = torch.utils.data.DataLoader(
42+
dataset=test_data,
43+
batch_size=batch_size,
44+
shuffle=False
45+
)
46+
47+
classes = ('plane', 'car', 'bird', 'cat',
48+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
49+
50+
51+
# ========================================= #
52+
# Visualize Training Images #
53+
# ========================================= #
54+
55+
def imshow(img):
56+
img = img / 2 + 0.5 # unnormalize
57+
npimg = img.numpy()
58+
plt.imshow(np.transpose(npimg, (1, 2, 0)))
59+
plt.show()
60+
61+
62+
# Get some random training images
63+
images, labels = next(iter(train_loader))
64+
imshow(torchvision.utils.make_grid(images))
65+
66+
# Print labels
67+
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))
68+
69+
70+
# ========================================= #
71+
# Define Convolutional Neural Network #
72+
# ========================================= #
73+
74+
class Net(nn.Module):
75+
def __init__(self):
76+
super(Net, self).__init__()
77+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
78+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
79+
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
80+
self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
81+
self.fc2 = nn.Linear(in_features=120, out_features=84)
82+
self.fc3 = nn.Linear(in_features=84, out_features=10)
83+
84+
def forward(self, x):
85+
x = self.conv1(x)
86+
x = torch.relu(x)
87+
x = self.pool(x)
88+
89+
x = self.conv2(x)
90+
x = torch.relu(x)
91+
x = self.pool(x)
92+
93+
x = x.view(-1, 16 * 5 * 5)
94+
95+
x = self.fc1(x)
96+
x = torch.relu(x)
97+
98+
x = self.fc2(x)
99+
x = torch.relu(x)
100+
101+
x = self.fc3(x)
102+
103+
return x
104+
105+
106+
net = Net()
107+
108+
# ========================================= #
109+
# Define a Loss function and Optimizer #
110+
# ========================================= #
111+
112+
criterion = nn.CrossEntropyLoss()
113+
optimizer = optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
114+
115+
# ========================================= #
116+
# Train the network #
117+
# ========================================= #
118+
119+
for epoch in range(2): # loop over the dataset multiple times
120+
121+
running_loss = 0.0
122+
for i, data in enumerate(train_loader, 0):
123+
# get the inputs; data is a list of [inputs, labels]
124+
inputs, labels = data
125+
126+
# zero the parameter gradients
127+
optimizer.zero_grad()
128+
129+
# forward + backward + optimize
130+
outputs = net(inputs)
131+
loss = criterion(outputs, labels)
132+
loss.backward()
133+
optimizer.step()
134+
135+
# print statistics
136+
running_loss += loss.item()
137+
if i % 2000 == 1999: # print every 2000 mini-batches
138+
print('[%d, %5d] loss: %.3f' %
139+
(epoch + 1, i + 1, running_loss / 2000))
140+
running_loss = 0.0
141+
142+
print('Finished Training')
143+
144+
# Save the trained model
145+
146+
PATH = './cifar_net.pth'
147+
torch.save(net.state_dict(), PATH)
148+
149+
# ========================================= #
150+
# Test the network #
151+
# ========================================= #
152+
153+
# Show test images
154+
images, labels = next(iter(test_loader))
155+
imshow(torchvision.utils.make_grid(images))
156+
print('Ground Truth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
157+
158+
# Load the model
159+
net = Net()
160+
net.load_state_dict(torch.load(PATH))
161+
162+
outputs = net(images)
163+
_, predicted = torch.max(outputs, 1)
164+
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
165+
166+
# How network performs on whole test data
167+
correct = 0
168+
total = 0
169+
# For testing we don't need calculate gradients
170+
with torch.no_grad():
171+
for data in test_loader:
172+
images, labels = data
173+
outputs = net(images)
174+
_, predicted = torch.max(outputs.data, 1)
175+
total += labels.size(0)
176+
correct += (predicted == labels).sum().item()
177+
178+
print(f'Accuracy of the network on test data: {100 * correct / total}')
179+
180+
# ========================================= #
181+
# Class-based Accuracy #
182+
# ========================================= #
183+
184+
# prepare to count predictions for each class
185+
correct_pred = {classname: 0 for classname in classes}
186+
total_pred = {classname: 0 for classname in classes}
187+
188+
# again no gradients needed
189+
with torch.no_grad():
190+
for data in test_loader:
191+
images, labels = data
192+
outputs = net(images)
193+
_, predictions = torch.max(outputs, 1)
194+
# collect the correct predictions for each class
195+
for label, prediction in zip(labels, predictions):
196+
if label == prediction:
197+
correct_pred[classes[label]] += 1
198+
total_pred[classes[label]] += 1
199+
200+
# print accuracy for each class
201+
for classname, correct_count in correct_pred.items():
202+
accuracy = 100 * float(correct_count) / total_pred[classname]
203+
print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
204+
accuracy))

0 commit comments

Comments
 (0)