Skip to content

Commit 7a936ef

Browse files
Update generate.py
1 parent 190db4f commit 7a936ef

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

generate.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,20 @@ def parse(self):
3939
args = parser.parse()
4040
print('*'*98)
4141

42-
generator = DualStyleGAN(1024, 512, 8, 2, res_index=6).to(device)
42+
generator = DualStyleGAN(1024, 512, 8, 2, res_index=6)
4343
generator.eval()
4444
icptc = ICPTrainer(np.empty([0,512*11]), 128)
4545
icpts = ICPTrainer(np.empty([0,512*7]), 128)
4646

47-
ckpt = torch.load(os.path.join(args.model_path, args.style, args.model_name))
47+
ckpt = torch.load(os.path.join(args.model_path, args.style, args.model_name), map_location=lambda storage, loc: storage)
4848
generator.load_state_dict(ckpt["g_ema"])
49+
generator = generator.to(device)
4950

50-
ckpt = torch.load(os.path.join(args.model_path, args.style, args.sampler_name))
51+
ckpt = torch.load(os.path.join(args.model_path, args.style, args.sampler_name), map_location=lambda storage, loc: storage)
5152
icptc.icp.netT.load_state_dict(ckpt['color'])
5253
icpts.icp.netT.load_state_dict(ckpt['structure'])
54+
icptc.icp.netT = icptc.icp.netT.to(device)
55+
icpts.icp.netT = icpts.icp.netT.to(device)
5356

5457
print('Load models successfully!')
5558

@@ -81,4 +84,4 @@ def parse(self):
8184
for i in range(args.batch):
8285
save_image(img_gen[i].cpu(), os.path.join(args.output_path, args.name+'_%02d.jpg'%(i)))
8386

84-
print('Save images successfully!')
87+
print('Save images successfully!')

0 commit comments

Comments
 (0)