Skip to content

Commit 78820ee

Browse files
committed
Merge branch 'master' of github.com:CSCfi/pytorch-ddp-examples
2 parents 1f01d2d + bbaead3 commit 78820ee

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

mnist_mp.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from datetime import datetime
2+
from torch.utils.data import DataLoader
3+
from torchvision.datasets import MNIST
4+
import multiprocessing as mp
5+
import torch
6+
import torch.nn as nn
7+
import torchvision.transforms as transforms
8+
9+
10+
class ConvNet(nn.Module):
11+
def __init__(self, num_classes=10):
12+
super(ConvNet, self).__init__()
13+
self.layer1 = nn.Sequential(
14+
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
15+
nn.BatchNorm2d(16),
16+
nn.ReLU(),
17+
nn.MaxPool2d(kernel_size=2, stride=2))
18+
self.layer2 = nn.Sequential(
19+
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
20+
nn.BatchNorm2d(32),
21+
nn.ReLU(),
22+
nn.MaxPool2d(kernel_size=2, stride=2))
23+
self.fc = nn.Linear(7*7*32, num_classes)
24+
25+
def forward(self, x):
26+
out = self.layer1(x)
27+
out = self.layer2(out)
28+
out = out.reshape(out.size(0), -1)
29+
out = self.fc(out)
30+
return out
31+
32+
33+
def train(batch_size):
34+
num_epochs = 100
35+
36+
torch.manual_seed(0)
37+
verbose = True
38+
39+
model = ConvNet().cuda()
40+
41+
criterion = nn.CrossEntropyLoss().cuda()
42+
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
43+
44+
train_dataset = MNIST(root='./data', train=True,
45+
transform=transforms.ToTensor(), download=True)
46+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
47+
shuffle=False, num_workers=0, pin_memory=True)
48+
49+
start = datetime.now()
50+
for epoch in range(num_epochs):
51+
tot_loss = 0
52+
for i, (images, labels) in enumerate(train_loader):
53+
images = images.cuda(non_blocking=True)
54+
labels = labels.cuda(non_blocking=True)
55+
56+
outputs = model(images)
57+
loss = criterion(outputs, labels)
58+
59+
optimizer.zero_grad()
60+
loss.backward()
61+
optimizer.step()
62+
63+
tot_loss += loss.item()
64+
65+
if verbose:
66+
print('Epoch [{}/{}], batch_size={} average loss: {:.4f}'.format(
67+
epoch + 1,
68+
num_epochs,
69+
batch_size,
70+
tot_loss / (i+1)))
71+
if verbose:
72+
print("Training completed in: " + str(datetime.now() - start))
73+
74+
75+
if __name__ == '__main__':
76+
bs_list = [16, 32, 64, 128]
77+
num_processes = 4
78+
79+
with mp.Pool(processes=num_processes) as pool:
80+
pool.map(train, bs_list)

0 commit comments

Comments
 (0)