diff --git a/test_classification.py b/test_classification.py index 79ac09723..5e120aed5 100644 --- a/test_classification.py +++ b/test_classification.py @@ -38,11 +38,13 @@ def test(model, loader, num_class=40, vote_num=1): class_acc = np.zeros((num_class, 3)) for j, (points, target) in tqdm(enumerate(loader), total=len(loader)): + vote_pool = torch.zeros(target.size()[0], num_class) + if not args.use_cpu: - points, target = points.cuda(), target.cuda() + points, target, vote_pool = points.cuda(), target.cuda(), vote_pool.cuda() points = points.transpose(2, 1) - vote_pool = torch.zeros(target.size()[0], num_class).cuda() + for _ in range(vote_num): pred, _ = classifier(points) @@ -102,7 +104,9 @@ def log_string(str): if not args.use_cpu: classifier = classifier.cuda() - checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') + torch_load_map_location = torch.device('cpu') if args.use_cpu else torch.device('cuda') + checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth', map_location=torch_load_map_location) + classifier.load_state_dict(checkpoint['model_state_dict']) with torch.no_grad():