diff --git a/imagenet/main.py b/imagenet/main.py index cc32d50733..1644ef0577 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -270,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args): num_workers=args.workers, pin_memory=True, sampler=val_sampler) if args.evaluate: - validate(val_loader, model, criterion, args) + validate(val_loader, model, criterion, device, args) return for epoch in range(args.start_epoch, args.epochs): @@ -281,7 +281,7 @@ def main_worker(gpu, ngpus_per_node, args): train(train_loader, model, criterion, optimizer, epoch, device, args) # evaluate on validation set - acc1 = validate(val_loader, model, criterion, args) + acc1 = validate(val_loader, model, criterion, device, args) scheduler.step() @@ -347,21 +347,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): progress.display(i + 1) -def validate(val_loader, model, criterion, args): +def validate(val_loader, model, criterion, device, args): def run_validate(loader, base_progress=0): with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(loader): i = base_progress + i - if args.gpu is not None and torch.cuda.is_available(): - images = images.cuda(args.gpu, non_blocking=True) - if torch.backends.mps.is_available(): - images = images.to('mps') - target = target.to('mps') - if torch.cuda.is_available(): - target = target.cuda(args.gpu, non_blocking=True) - + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) # compute output output = model(images) loss = criterion(output, target)