Skip to content

Commit 1f01d2d

Browse files
committed
Add PyTorch profiler example
1 parent 0c4b58a commit 1f01d2d

File tree

2 files changed

+137
-0
lines changed

2 files changed

+137
-0
lines changed

mnist_ddp_profiler.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Based on multiprocessing example from
2+
# https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
3+
4+
from datetime import datetime
5+
import argparse
6+
import os
7+
import torch
8+
import torch.nn as nn
9+
import torch.distributed as dist
10+
import torchvision.transforms as transforms
11+
from torchvision.datasets import MNIST
12+
from torch.utils.data.distributed import DistributedSampler
13+
from torch.nn.parallel import DistributedDataParallel
14+
from torch.utils.data import DataLoader
15+
from torch.profiler import profile, record_function, ProfilerActivity
16+
17+
18+
class ConvNet(nn.Module):
19+
def __init__(self, num_classes=10):
20+
super(ConvNet, self).__init__()
21+
self.layer1 = nn.Sequential(
22+
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
23+
nn.BatchNorm2d(16),
24+
nn.ReLU(),
25+
nn.MaxPool2d(kernel_size=2, stride=2))
26+
self.layer2 = nn.Sequential(
27+
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
28+
nn.BatchNorm2d(32),
29+
nn.ReLU(),
30+
nn.MaxPool2d(kernel_size=2, stride=2))
31+
self.fc = nn.Linear(7*7*32, num_classes)
32+
33+
def forward(self, x):
34+
out = self.layer1(x)
35+
out = self.layer2(out)
36+
out = out.reshape(out.size(0), -1)
37+
out = self.fc(out)
38+
return out
39+
40+
41+
def train(num_epochs):
42+
dist.init_process_group(backend='nccl')
43+
44+
torch.manual_seed(0)
45+
local_rank = int(os.environ['LOCAL_RANK'])
46+
torch.cuda.set_device(local_rank)
47+
48+
verbose = dist.get_rank() == 0 # print only on global_rank==0
49+
50+
prof = profile(
51+
schedule=torch.profiler.schedule(
52+
skip_first=10,
53+
wait=5,
54+
warmup=1,
55+
active=3,
56+
repeat=1)
57+
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/profiler'),
58+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
59+
record_shapes=True, # record shapes of operator inputs
60+
profile_memory=True, # track tensor memory allocation/deallocation
61+
with_stack=True, # record source code information
62+
with_flops=True, # estimate FLOPS of operators
63+
)
64+
65+
model = ConvNet().cuda()
66+
batch_size = 100
67+
68+
criterion = nn.CrossEntropyLoss().cuda()
69+
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
70+
71+
model = DistributedDataParallel(model, device_ids=[local_rank])
72+
73+
train_dataset = MNIST(root='./data', train=True,
74+
transform=transforms.ToTensor(), download=True)
75+
train_sampler = DistributedSampler(train_dataset)
76+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
77+
shuffle=False, num_workers=0, pin_memory=True,
78+
sampler=train_sampler)
79+
80+
start = datetime.now()
81+
prof.start()
82+
for epoch in range(num_epochs):
83+
tot_loss = 0
84+
for i, (images, labels) in enumerate(train_loader):
85+
images = images.cuda(non_blocking=True)
86+
labels = labels.cuda(non_blocking=True)
87+
88+
outputs = model(images)
89+
loss = criterion(outputs, labels)
90+
91+
optimizer.zero_grad()
92+
loss.backward()
93+
optimizer.step()
94+
95+
prof.step()
96+
97+
tot_loss += loss.item()
98+
99+
if verbose:
100+
print('Epoch [{}/{}], average loss: {:.4f}'.format(
101+
epoch + 1,
102+
num_epochs,
103+
tot_loss / (i+1)))
104+
prof.stop()
105+
106+
if verbose:
107+
print("Training completed in: " + str(datetime.now() - start))
108+
109+
110+
def main():
111+
parser = argparse.ArgumentParser()
112+
parser.add_argument('--epochs', default=2, type=int, metavar='N',
113+
help='number of total epochs to run')
114+
args = parser.parse_args()
115+
116+
train(args.epochs)
117+
118+
119+
if __name__ == '__main__':
120+
main()

run-ddp-gpu1-profiler.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
#SBATCH --account=project_2001659
3+
#SBATCH --partition=gputest
4+
#SBATCH --ntasks=1
5+
#SBATCH --cpus-per-task=10
6+
#SBATCH --mem=64G
7+
#SBATCH --time=15
8+
#SBATCH --gres=gpu:v100:1
9+
10+
module purge
11+
module load pytorch
12+
13+
# Old way with torch.distributed.run
14+
# srun python3 -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100
15+
16+
# New way with torchrun
17+
srun torchrun --standalone --nnodes=1 --nproc_per_node=1 mnist_ddp_profiler.py --epochs=100

0 commit comments

Comments
 (0)