@@ -39,17 +39,20 @@ def parse(self):
39
39
args = parser .parse ()
40
40
print ('*' * 98 )
41
41
42
- generator = DualStyleGAN (1024 , 512 , 8 , 2 , res_index = 6 ). to ( device )
42
+ generator = DualStyleGAN (1024 , 512 , 8 , 2 , res_index = 6 )
43
43
generator .eval ()
44
44
icptc = ICPTrainer (np .empty ([0 ,512 * 11 ]), 128 )
45
45
icpts = ICPTrainer (np .empty ([0 ,512 * 7 ]), 128 )
46
46
47
- ckpt = torch .load (os .path .join (args .model_path , args .style , args .model_name ))
47
+ ckpt = torch .load (os .path .join (args .model_path , args .style , args .model_name ), map_location = lambda storage , loc : storage )
48
48
generator .load_state_dict (ckpt ["g_ema" ])
49
+ generator = generator .to (device )
49
50
50
- ckpt = torch .load (os .path .join (args .model_path , args .style , args .sampler_name ))
51
+ ckpt = torch .load (os .path .join (args .model_path , args .style , args .sampler_name ), map_location = lambda storage , loc : storage )
51
52
icptc .icp .netT .load_state_dict (ckpt ['color' ])
52
53
icpts .icp .netT .load_state_dict (ckpt ['structure' ])
54
+ icptc .icp .netT = icptc .icp .netT .to (device )
55
+ icpts .icp .netT = icpts .icp .netT .to (device )
53
56
54
57
print ('Load models successfully!' )
55
58
@@ -81,4 +84,4 @@ def parse(self):
81
84
for i in range (args .batch ):
82
85
save_image (img_gen [i ].cpu (), os .path .join (args .output_path , args .name + '_%02d.jpg' % (i )))
83
86
84
- print ('Save images successfully!' )
87
+ print ('Save images successfully!' )
0 commit comments