|
312 | 312 | ")\n",
|
313 | 313 | "\n",
|
314 | 314 | "# load DualStyleGAN\n",
|
315 |
| - "generator = DualStyleGAN(1024, 512, 8, 2, res_index=6).cuda()\n", |
| 315 | + "generator = DualStyleGAN(1024, 512, 8, 2, res_index=6)\n", |
316 | 316 | "generator.eval()\n",
|
317 |
| - "ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'generator.pt'))\n", |
| 317 | + "ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'generator.pt'), map_location=lambda storage, loc: storage)\n", |
318 | 318 | "generator.load_state_dict(ckpt[\"g_ema\"])\n",
|
| 319 | + "generator = generator.to(device)\n", |
319 | 320 | "\n",
|
320 | 321 | "# load encoder\n",
|
321 | 322 | "model_path = os.path.join(MODEL_DIR, 'encoder.pt')\n",
|
322 | 323 | "ckpt = torch.load(model_path, map_location='cpu')\n",
|
323 | 324 | "opts = ckpt['opts']\n",
|
324 | 325 | "opts['checkpoint_path'] = model_path\n",
|
325 | 326 | "opts = Namespace(**opts)\n",
|
| 327 | + "opts.device = device\n", |
326 | 328 | "encoder = pSp(opts)\n",
|
327 | 329 | "encoder.eval()\n",
|
328 | 330 | "encoder.cuda()\n",
|
|
333 | 335 | "# load sampler network\n",
|
334 | 336 | "icptc = ICPTrainer(np.empty([0,512*11]), 128)\n",
|
335 | 337 | "icpts = ICPTrainer(np.empty([0,512*7]), 128)\n",
|
336 |
| - "ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'sampler.pt'))\n", |
| 338 | + "ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'sampler.pt'), map_location=lambda storage, loc: storage)\n", |
337 | 339 | "icptc.icp.netT.load_state_dict(ckpt['color'])\n",
|
338 | 340 | "icpts.icp.netT.load_state_dict(ckpt['structure'])\n",
|
| 341 | + "icptc.icp.netT = icptc.icp.netT.to(device)\n", |
| 342 | + "icpts.icp.netT = icpts.icp.netT.to(device)\n", |
339 | 343 | "\n",
|
340 | 344 | "print('Model successfully loaded!')"
|
341 | 345 | ]
|
|
0 commit comments