Skip to content

Commit 10e8119

Browse files
added fsdp mnist example
1 parent 26de419 commit 10e8119

20 files changed

+205
-0
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

Diff for: distributed/fsdp/fsdp-mnist/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## FSDP MNIST
2+
3+
To run a simple MNIST example with FSDP:
4+
5+
```bash
6+
python fsdp_mnist.py
7+
```

Diff for: distributed/fsdp/fsdp-mnist/fsdp_mnist.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
2+
import os
3+
import argparse
4+
import functools
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch.optim as optim
9+
from torchvision import datasets, transforms
10+
11+
12+
from torch.optim.lr_scheduler import StepLR
13+
14+
import torch.distributed as dist
15+
import torch.multiprocessing as mp
16+
from torch.nn.parallel import DistributedDataParallel as DDP
17+
from torch.utils.data.distributed import DistributedSampler
18+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
20+
CPUOffload,
21+
BackwardPrefetch,
22+
)
23+
from torch.distributed.fsdp.wrap import (
24+
size_based_auto_wrap_policy,
25+
enable_wrap,
26+
wrap,
27+
)
28+
29+
def setup(rank, world_size):
30+
os.environ['MASTER_ADDR'] = 'localhost'
31+
os.environ['MASTER_PORT'] = '12355'
32+
33+
# initialize the process group
34+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
35+
36+
def cleanup():
37+
dist.destroy_process_group()
38+
39+
class Net(nn.Module):
40+
def __init__(self):
41+
super(Net, self).__init__()
42+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
43+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
44+
self.dropout1 = nn.Dropout(0.25)
45+
self.dropout2 = nn.Dropout(0.5)
46+
self.fc1 = nn.Linear(9216, 128)
47+
self.fc2 = nn.Linear(128, 10)
48+
49+
def forward(self, x):
50+
51+
x = self.conv1(x)
52+
x = F.relu(x)
53+
x = self.conv2(x)
54+
x = F.relu(x)
55+
x = F.max_pool2d(x, 2)
56+
x = self.dropout1(x)
57+
x = torch.flatten(x, 1)
58+
x = self.fc1(x)
59+
x = F.relu(x)
60+
x = self.dropout2(x)
61+
x = self.fc2(x)
62+
output = F.log_softmax(x, dim=1)
63+
return output
64+
65+
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
66+
model.train()
67+
ddp_loss = torch.zeros(2).to(rank)
68+
if sampler:
69+
sampler.set_epoch(epoch)
70+
for batch_idx, (data, target) in enumerate(train_loader):
71+
data, target = data.to(rank), target.to(rank)
72+
optimizer.zero_grad()
73+
output = model(data)
74+
loss = F.nll_loss(output, target, reduction='sum')
75+
loss.backward()
76+
optimizer.step()
77+
ddp_loss[0] += loss.item()
78+
ddp_loss[1] += len(data)
79+
80+
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
81+
if rank == 0:
82+
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
83+
84+
def test(model, rank, world_size, test_loader):
85+
model.eval()
86+
correct = 0
87+
ddp_loss = torch.zeros(3).to(rank)
88+
with torch.no_grad():
89+
for data, target in test_loader:
90+
data, target = data.to(rank), target.to(rank)
91+
output = model(data)
92+
ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
93+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
94+
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
95+
ddp_loss[2] += len(data)
96+
97+
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
98+
99+
if rank == 0:
100+
test_loss = ddp_loss[0] / ddp_loss[2]
101+
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
102+
test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
103+
100. * ddp_loss[1] / ddp_loss[2]))
104+
105+
def fsdp_main(rank, world_size, args):
106+
setup(rank, world_size)
107+
108+
transform=transforms.Compose([
109+
transforms.ToTensor(),
110+
transforms.Normalize((0.1307,), (0.3081,))
111+
])
112+
113+
dataset1 = datasets.MNIST('../data', train=True, download=True,
114+
transform=transform)
115+
dataset2 = datasets.MNIST('../data', train=False,
116+
transform=transform)
117+
118+
sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
119+
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
120+
121+
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
122+
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
123+
cuda_kwargs = {'num_workers': 2,
124+
'pin_memory': True,
125+
'shuffle': False}
126+
train_kwargs.update(cuda_kwargs)
127+
test_kwargs.update(cuda_kwargs)
128+
129+
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
130+
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
131+
my_auto_wrap_policy = functools.partial(
132+
size_based_auto_wrap_policy, min_num_params=20000
133+
)
134+
torch.cuda.set_device(rank)
135+
136+
137+
init_start_event = torch.cuda.Event(enable_timing=True)
138+
init_end_event = torch.cuda.Event(enable_timing=True)
139+
140+
model = Net().to(rank)
141+
142+
model = FSDP(model,
143+
fsdp_auto_wrap_policy=my_auto_wrap_policy,
144+
cpu_offload=CPUOffload(offload_params=True)
145+
)
146+
147+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
148+
149+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
150+
init_start_event.record()
151+
for epoch in range(1, args.epochs + 1):
152+
train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
153+
test(model, rank, world_size, test_loader)
154+
scheduler.step()
155+
156+
init_end_event.record()
157+
158+
if rank == 0:
159+
print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
160+
print(f"{model}")
161+
162+
if args.save_model:
163+
# use a barrier to make sure training is done on all ranks
164+
dist.barrier()
165+
states = model.state_dict()
166+
if rank == 0:
167+
torch.save(states, "mnist_cnn.pt")
168+
169+
cleanup()
170+
171+
if __name__ == '__main__':
172+
# Training settings
173+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
174+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
175+
help='input batch size for training (default: 64)')
176+
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
177+
help='input batch size for testing (default: 1000)')
178+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
179+
help='number of epochs to train (default: 14)')
180+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
181+
help='learning rate (default: 1.0)')
182+
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
183+
help='Learning rate step gamma (default: 0.7)')
184+
parser.add_argument('--no-cuda', action='store_true', default=False,
185+
help='disables CUDA training')
186+
parser.add_argument('--seed', type=int, default=1, metavar='S',
187+
help='random seed (default: 1)')
188+
parser.add_argument('--save-model', action='store_true', default=False,
189+
help='For Saving the current Model')
190+
args = parser.parse_args()
191+
192+
torch.manual_seed(args.seed)
193+
194+
WORLD_SIZE = torch.cuda.device_count()
195+
mp.spawn(fsdp_main,
196+
args=(WORLD_SIZE, args),
197+
nprocs=WORLD_SIZE,
198+
join=True)

0 commit comments

Comments
 (0)