@@ -104,10 +104,7 @@ def main():
104
104
105
105
args .distributed = args .world_size > 1 or args .multiprocessing_distributed
106
106
107
- if torch .cuda .is_available ():
108
- ngpus_per_node = torch .cuda .device_count ()
109
- else :
110
- ngpus_per_node = 1
107
+ ngpus_per_node = torch .cuda .device_count ()
111
108
if args .multiprocessing_distributed :
112
109
# Since we have ngpus_per_node processes per node, the total world_size
113
110
# needs to be adjusted accordingly
@@ -144,33 +141,29 @@ def main_worker(gpu, ngpus_per_node, args):
144
141
print ("=> creating model '{}'" .format (args .arch ))
145
142
model = models .__dict__ [args .arch ]()
146
143
147
- if not torch .cuda .is_available () and not torch . backends . mps . is_available () :
144
+ if not torch .cuda .is_available ():
148
145
print ('using CPU, this will be slow' )
149
146
elif args .distributed :
150
147
# For multiprocessing distributed, DistributedDataParallel constructor
151
148
# should always set the single device scope, otherwise,
152
149
# DistributedDataParallel will use all available devices.
153
- if torch .cuda .is_available ():
154
- if args .gpu is not None :
155
- torch .cuda .set_device (args .gpu )
156
- model .cuda (args .gpu )
157
- # When using a single GPU per process and per
158
- # DistributedDataParallel, we need to divide the batch size
159
- # ourselves based on the total number of GPUs of the current node.
160
- args .batch_size = int (args .batch_size / ngpus_per_node )
161
- args .workers = int ((args .workers + ngpus_per_node - 1 ) / ngpus_per_node )
162
- model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
163
- else :
164
- model .cuda ()
165
- # DistributedDataParallel will divide and allocate batch_size to all
166
- # available GPUs if device_ids are not set
167
- model = torch .nn .parallel .DistributedDataParallel (model )
168
- elif args .gpu is not None and torch .cuda .is_available ():
150
+ if args .gpu is not None :
151
+ torch .cuda .set_device (args .gpu )
152
+ model .cuda (args .gpu )
153
+ # When using a single GPU per process and per
154
+ # DistributedDataParallel, we need to divide the batch size
155
+ # ourselves based on the total number of GPUs of the current node.
156
+ args .batch_size = int (args .batch_size / ngpus_per_node )
157
+ args .workers = int ((args .workers + ngpus_per_node - 1 ) / ngpus_per_node )
158
+ model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
159
+ else :
160
+ model .cuda ()
161
+ # DistributedDataParallel will divide and allocate batch_size to all
162
+ # available GPUs if device_ids are not set
163
+ model = torch .nn .parallel .DistributedDataParallel (model )
164
+ elif args .gpu is not None :
169
165
torch .cuda .set_device (args .gpu )
170
166
model = model .cuda (args .gpu )
171
- elif torch .backends .mps .is_available ():
172
- device = torch .device ("mps" )
173
- model = model .to (device )
174
167
else :
175
168
# DataParallel will divide and allocate batch_size to all available GPUs
176
169
if args .arch .startswith ('alexnet' ) or args .arch .startswith ('vgg' ):
@@ -179,17 +172,8 @@ def main_worker(gpu, ngpus_per_node, args):
179
172
else :
180
173
model = torch .nn .DataParallel (model ).cuda ()
181
174
182
- if torch .cuda .is_available ():
183
- if args .gpu :
184
- device = torch .device ('cuda:{}' .format (args .gpu ))
185
- else :
186
- device = torch .device ("cuda" )
187
- elif torch .backends .mps .is_available ():
188
- device = torch .device ("mps" )
189
- else :
190
- device = torch .device ("cpu" )
191
175
# define loss function (criterion), optimizer, and learning rate scheduler
192
- criterion = nn .CrossEntropyLoss ().to ( device )
176
+ criterion = nn .CrossEntropyLoss ().cuda ( args . gpu )
193
177
194
178
optimizer = torch .optim .SGD (model .parameters (), args .lr ,
195
179
momentum = args .momentum ,
@@ -204,7 +188,7 @@ def main_worker(gpu, ngpus_per_node, args):
204
188
print ("=> loading checkpoint '{}'" .format (args .resume ))
205
189
if args .gpu is None :
206
190
checkpoint = torch .load (args .resume )
207
- elif torch . cuda . is_available () :
191
+ else :
208
192
# Map model to be loaded to specified single gpu.
209
193
loc = 'cuda:{}' .format (args .gpu )
210
194
checkpoint = torch .load (args .resume , map_location = loc )
@@ -318,13 +302,10 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
318
302
# measure data loading time
319
303
data_time .update (time .time () - end )
320
304
321
- if args .gpu is not None and torch . cuda . is_available () :
305
+ if args .gpu is not None :
322
306
images = images .cuda (args .gpu , non_blocking = True )
323
- elif not args . gpu and torch .cuda .is_available ():
307
+ if torch .cuda .is_available ():
324
308
target = target .cuda (args .gpu , non_blocking = True )
325
- elif torch .backends .mps .is_available ():
326
- images = images .to ('mps' )
327
- target = target .to ('mps' )
328
309
329
310
# compute output
330
311
output = model (images )
@@ -356,11 +337,8 @@ def run_validate(loader, base_progress=0):
356
337
end = time .time ()
357
338
for i , (images , target ) in enumerate (loader ):
358
339
i = base_progress + i
359
- if args .gpu is not None and torch . cuda . is_available () :
340
+ if args .gpu is not None :
360
341
images = images .cuda (args .gpu , non_blocking = True )
361
- if torch .backends .mps .is_available ():
362
- images = images .to ('mps' )
363
- target = target .to ('mps' )
364
342
if torch .cuda .is_available ():
365
343
target = target .cuda (args .gpu , non_blocking = True )
366
344
@@ -443,12 +421,7 @@ def update(self, val, n=1):
443
421
self .avg = self .sum / self .count
444
422
445
423
def all_reduce (self ):
446
- if torch .cuda .is_available ():
447
- device = torch .device ("cuda" )
448
- elif torch .backends .mps .is_available ():
449
- device = torch .device ("mps" )
450
- else :
451
- device = torch .device ("cpu" )
424
+ device = "cuda" if torch .cuda .is_available () else "cpu"
452
425
total = torch .tensor ([self .sum , self .count ], dtype = torch .float32 , device = device )
453
426
dist .all_reduce (total , dist .ReduceOp .SUM , async_op = False )
454
427
self .sum , self .count = total .tolist ()
0 commit comments