Skip to content

Commit bfc8207

Browse files
使用 Colaboratory 创建
1 parent 7a936ef commit bfc8207

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

Diff for: notebooks/inference_playground.ipynb

+7-3
Original file line numberDiff line numberDiff line change
@@ -312,17 +312,19 @@
312312
")\n",
313313
"\n",
314314
"# load DualStyleGAN\n",
315-
"generator = DualStyleGAN(1024, 512, 8, 2, res_index=6).cuda()\n",
315+
"generator = DualStyleGAN(1024, 512, 8, 2, res_index=6)\n",
316316
"generator.eval()\n",
317-
"ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'generator.pt'))\n",
317+
"ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'generator.pt'), map_location=lambda storage, loc: storage)\n",
318318
"generator.load_state_dict(ckpt[\"g_ema\"])\n",
319+
"generator = generator.to(device)\n",
319320
"\n",
320321
"# load encoder\n",
321322
"model_path = os.path.join(MODEL_DIR, 'encoder.pt')\n",
322323
"ckpt = torch.load(model_path, map_location='cpu')\n",
323324
"opts = ckpt['opts']\n",
324325
"opts['checkpoint_path'] = model_path\n",
325326
"opts = Namespace(**opts)\n",
327+
"opts.device = device\n",
326328
"encoder = pSp(opts)\n",
327329
"encoder.eval()\n",
328330
"encoder.cuda()\n",
@@ -333,9 +335,11 @@
333335
"# load sampler network\n",
334336
"icptc = ICPTrainer(np.empty([0,512*11]), 128)\n",
335337
"icpts = ICPTrainer(np.empty([0,512*7]), 128)\n",
336-
"ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'sampler.pt'))\n",
338+
"ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'sampler.pt'), map_location=lambda storage, loc: storage)\n",
337339
"icptc.icp.netT.load_state_dict(ckpt['color'])\n",
338340
"icpts.icp.netT.load_state_dict(ckpt['structure'])\n",
341+
"icptc.icp.netT = icptc.icp.netT.to(device)\n",
342+
"icpts.icp.netT = icpts.icp.netT.to(device)\n",
339343
"\n",
340344
"print('Model successfully loaded!')"
341345
]

0 commit comments

Comments
 (0)