Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LLXXTTT committed Jun 26, 2020
1 parent 2d0c3a5 commit 543ff9f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def eval(args):
torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(args.model, args.dataset)
model_state = torch.load(args.checkpoint)['model_state']
model_state = torch.load(args.checkpoint,
map_location=device)['model_state']
model.load_state_dict(model_state)
model.to(device)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def train(args):
if os.path.isfile(args.resume):
print("Loading model and optimizer from checkpoint '{}'".format\
(args.resume))
checkpoint = torch.load(args.resume)
checkpoint = torch.load(args.resume, map_location=device)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
print("Loaded checkpoint '{}' (epoch{})".format(args.resume,
Expand Down

0 comments on commit 543ff9f

Please sign in to comment.