1
+ # Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
2
+ import os
3
+ import argparse
4
+ import functools
5
+ import torch
6
+ import torch .nn as nn
7
+ import torch .nn .functional as F
8
+ import torch .optim as optim
9
+ from torchvision import datasets , transforms
10
+
11
+
12
+ from torch .optim .lr_scheduler import StepLR
13
+
14
+ import torch .distributed as dist
15
+ import torch .multiprocessing as mp
16
+ from torch .nn .parallel import DistributedDataParallel as DDP
17
+ from torch .utils .data .distributed import DistributedSampler
18
+ from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
19
+ from torch .distributed .fsdp .fully_sharded_data_parallel import (
20
+ CPUOffload ,
21
+ BackwardPrefetch ,
22
+ )
23
+ from torch .distributed .fsdp .wrap import (
24
+ size_based_auto_wrap_policy ,
25
+ enable_wrap ,
26
+ wrap ,
27
+ )
28
+
29
+ def setup (rank , world_size ):
30
+ os .environ ['MASTER_ADDR' ] = 'localhost'
31
+ os .environ ['MASTER_PORT' ] = '12355'
32
+
33
+ # initialize the process group
34
+ dist .init_process_group ("nccl" , rank = rank , world_size = world_size )
35
+
36
+ def cleanup ():
37
+ dist .destroy_process_group ()
38
+
39
+ class Net (nn .Module ):
40
+ def __init__ (self ):
41
+ super (Net , self ).__init__ ()
42
+ self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
43
+ self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
44
+ self .dropout1 = nn .Dropout (0.25 )
45
+ self .dropout2 = nn .Dropout (0.5 )
46
+ self .fc1 = nn .Linear (9216 , 128 )
47
+ self .fc2 = nn .Linear (128 , 10 )
48
+
49
+ def forward (self , x ):
50
+
51
+ x = self .conv1 (x )
52
+ x = F .relu (x )
53
+ x = self .conv2 (x )
54
+ x = F .relu (x )
55
+ x = F .max_pool2d (x , 2 )
56
+ x = self .dropout1 (x )
57
+ x = torch .flatten (x , 1 )
58
+ x = self .fc1 (x )
59
+ x = F .relu (x )
60
+ x = self .dropout2 (x )
61
+ x = self .fc2 (x )
62
+ output = F .log_softmax (x , dim = 1 )
63
+ return output
64
+
65
+ def train (args , model , rank , world_size , train_loader , optimizer , epoch , sampler = None ):
66
+ model .train ()
67
+ ddp_loss = torch .zeros (2 ).to (rank )
68
+ if sampler :
69
+ sampler .set_epoch (epoch )
70
+ for batch_idx , (data , target ) in enumerate (train_loader ):
71
+ data , target = data .to (rank ), target .to (rank )
72
+ optimizer .zero_grad ()
73
+ output = model (data )
74
+ loss = F .nll_loss (output , target , reduction = 'sum' )
75
+ loss .backward ()
76
+ optimizer .step ()
77
+ ddp_loss [0 ] += loss .item ()
78
+ ddp_loss [1 ] += len (data )
79
+
80
+ dist .all_reduce (ddp_loss , op = dist .ReduceOp .SUM )
81
+ if rank == 0 :
82
+ print ('Train Epoch: {} \t Loss: {:.6f}' .format (epoch , ddp_loss [0 ] / ddp_loss [1 ]))
83
+
84
+ def test (model , rank , world_size , test_loader ):
85
+ model .eval ()
86
+ correct = 0
87
+ ddp_loss = torch .zeros (3 ).to (rank )
88
+ with torch .no_grad ():
89
+ for data , target in test_loader :
90
+ data , target = data .to (rank ), target .to (rank )
91
+ output = model (data )
92
+ ddp_loss [0 ] += F .nll_loss (output , target , reduction = 'sum' ).item () # sum up batch loss
93
+ pred = output .argmax (dim = 1 , keepdim = True ) # get the index of the max log-probability
94
+ ddp_loss [1 ] += pred .eq (target .view_as (pred )).sum ().item ()
95
+ ddp_loss [2 ] += len (data )
96
+
97
+ dist .all_reduce (ddp_loss , op = dist .ReduceOp .SUM )
98
+
99
+ if rank == 0 :
100
+ test_loss = ddp_loss [0 ] / ddp_loss [2 ]
101
+ print ('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n ' .format (
102
+ test_loss , int (ddp_loss [1 ]), int (ddp_loss [2 ]),
103
+ 100. * ddp_loss [1 ] / ddp_loss [2 ]))
104
+
105
+ def fsdp_main (rank , world_size , args ):
106
+ setup (rank , world_size )
107
+
108
+ transform = transforms .Compose ([
109
+ transforms .ToTensor (),
110
+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
111
+ ])
112
+
113
+ dataset1 = datasets .MNIST ('../data' , train = True , download = True ,
114
+ transform = transform )
115
+ dataset2 = datasets .MNIST ('../data' , train = False ,
116
+ transform = transform )
117
+
118
+ sampler1 = DistributedSampler (dataset1 , rank = rank , num_replicas = world_size , shuffle = True )
119
+ sampler2 = DistributedSampler (dataset2 , rank = rank , num_replicas = world_size )
120
+
121
+ train_kwargs = {'batch_size' : args .batch_size , 'sampler' : sampler1 }
122
+ test_kwargs = {'batch_size' : args .test_batch_size , 'sampler' : sampler2 }
123
+ cuda_kwargs = {'num_workers' : 2 ,
124
+ 'pin_memory' : True ,
125
+ 'shuffle' : False }
126
+ train_kwargs .update (cuda_kwargs )
127
+ test_kwargs .update (cuda_kwargs )
128
+
129
+ train_loader = torch .utils .data .DataLoader (dataset1 ,** train_kwargs )
130
+ test_loader = torch .utils .data .DataLoader (dataset2 , ** test_kwargs )
131
+ my_auto_wrap_policy = functools .partial (
132
+ size_based_auto_wrap_policy , min_num_params = 20000
133
+ )
134
+ torch .cuda .set_device (rank )
135
+
136
+
137
+ init_start_event = torch .cuda .Event (enable_timing = True )
138
+ init_end_event = torch .cuda .Event (enable_timing = True )
139
+
140
+ model = Net ().to (rank )
141
+
142
+ model = FSDP (model ,
143
+ fsdp_auto_wrap_policy = my_auto_wrap_policy ,
144
+ cpu_offload = CPUOffload (offload_params = True )
145
+ )
146
+
147
+ optimizer = optim .Adadelta (model .parameters (), lr = args .lr )
148
+
149
+ scheduler = StepLR (optimizer , step_size = 1 , gamma = args .gamma )
150
+ init_start_event .record ()
151
+ for epoch in range (1 , args .epochs + 1 ):
152
+ train (args , model , rank , world_size , train_loader , optimizer , epoch , sampler = sampler1 )
153
+ test (model , rank , world_size , test_loader )
154
+ scheduler .step ()
155
+
156
+ init_end_event .record ()
157
+
158
+ if rank == 0 :
159
+ print (f"CUDA event elapsed time: { init_start_event .elapsed_time (init_end_event ) / 1000 } sec" )
160
+ print (f"{ model } " )
161
+
162
+ if args .save_model :
163
+ # use a barrier to make sure training is done on all ranks
164
+ dist .barrier ()
165
+ states = model .state_dict ()
166
+ if rank == 0 :
167
+ torch .save (states , "mnist_cnn.pt" )
168
+
169
+ cleanup ()
170
+
171
+ if __name__ == '__main__' :
172
+ # Training settings
173
+ parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
174
+ parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' ,
175
+ help = 'input batch size for training (default: 64)' )
176
+ parser .add_argument ('--test-batch-size' , type = int , default = 1000 , metavar = 'N' ,
177
+ help = 'input batch size for testing (default: 1000)' )
178
+ parser .add_argument ('--epochs' , type = int , default = 10 , metavar = 'N' ,
179
+ help = 'number of epochs to train (default: 14)' )
180
+ parser .add_argument ('--lr' , type = float , default = 1.0 , metavar = 'LR' ,
181
+ help = 'learning rate (default: 1.0)' )
182
+ parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
183
+ help = 'Learning rate step gamma (default: 0.7)' )
184
+ parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
185
+ help = 'disables CUDA training' )
186
+ parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
187
+ help = 'random seed (default: 1)' )
188
+ parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
189
+ help = 'For Saving the current Model' )
190
+ args = parser .parse_args ()
191
+
192
+ torch .manual_seed (args .seed )
193
+
194
+ WORLD_SIZE = torch .cuda .device_count ()
195
+ mp .spawn (fsdp_main ,
196
+ args = (WORLD_SIZE , args ),
197
+ nprocs = WORLD_SIZE ,
198
+ join = True )
0 commit comments