|
327 | 327 | "opts.device = device\n",
|
328 | 328 | "encoder = pSp(opts)\n",
|
329 | 329 | "encoder.eval()\n",
|
330 |
| - "encoder.cuda()\n", |
| 330 | + "encoder = encoder.to(device)\n", |
331 | 331 | "\n",
|
332 | 332 | "# load extrinsic style code\n",
|
333 | 333 | "exstyles = np.load(os.path.join(MODEL_DIR, style_type, MODEL_PATHS[style_type+'-S'][\"name\"]), allow_pickle='TRUE').item()\n",
|
|
462 | 462 | "outputs": [],
|
463 | 463 | "source": [
|
464 | 464 | "if if_align_face:\n",
|
465 |
| - " I = transform(run_alignment(image_path)).unsqueeze(dim=0).cuda()\n", |
| 465 | + " I = transform(run_alignment(image_path)).unsqueeze(dim=0).to(device)\n", |
466 | 466 | "else:\n",
|
467 |
| - " I = F.adaptive_avg_pool2d(load_image(image_path).cuda(), 256)" |
| 467 | + " I = F.adaptive_avg_pool2d(load_image(image_path).to(device), 256)" |
468 | 468 | ]
|
469 | 469 | },
|
470 | 470 | {
|
|
612 | 612 | " z_plus_latent=True, return_z_plus_latent=True, resize=False) \n",
|
613 | 613 | " img_rec = torch.clamp(img_rec.detach(), -1, 1)\n",
|
614 | 614 | " \n",
|
615 |
| - " latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).cuda()\n", |
| 615 | + " latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).to(device)\n", |
616 | 616 | " # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer\n",
|
617 | 617 | " latent[1,7:18] = instyle[0,7:18]\n",
|
618 | 618 | " exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape)\n",
|
|
824 | 824 | ],
|
825 | 825 | "source": [
|
826 | 826 | "with torch.no_grad():\n",
|
827 |
| - " latent = torch.tensor(exstyles[stylename]).repeat(6,1,1).cuda()\n", |
828 |
| - " latent2 = torch.tensor(exstyles[stylename2]).repeat(6,1,1).cuda()\n", |
829 |
| - " fuse_weight = torch.arange(6).reshape(6,1,1).cuda() / 5.0\n", |
| 827 | + " latent = torch.tensor(exstyles[stylename]).repeat(6,1,1).to(device)\n", |
| 828 | + " latent2 = torch.tensor(exstyles[stylename2]).repeat(6,1,1).to(device)\n", |
| 829 | + " fuse_weight = torch.arange(6).reshape(6,1,1).to(device) / 5.0\n", |
830 | 830 | " fuse_latent = latent * fuse_weight + latent2 * (1-fuse_weight)\n",
|
831 | 831 | " exstyle = generator.generator.style(fuse_latent.reshape(fuse_latent.shape[0]*fuse_latent.shape[1], fuse_latent.shape[2])).reshape(fuse_latent.shape)\n",
|
832 | 832 | " \n",
|
|
871 | 871 | "batch = 6 # sample 6 style codes\n",
|
872 | 872 | "\n",
|
873 | 873 | "with torch.no_grad():\n",
|
874 |
| - " instyle = torch.randn(6, 512).cuda()\n", |
| 874 | + " instyle = torch.randn(6, 512).to(device)\n", |
875 | 875 | " # sample structure codes\n",
|
876 |
| - " res_in = icpts.icp.netT(torch.randn(batch, 128).cuda()).reshape(-1,7,512)\n", |
| 876 | + " res_in = icpts.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,7,512)\n", |
877 | 877 | " # sample color codes\n",
|
878 |
| - " ada_in = icptc.icp.netT(torch.randn(batch, 128).cuda()).reshape(-1,11,512)\n", |
| 878 | + " ada_in = icptc.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,11,512)\n", |
879 | 879 | "\n",
|
880 | 880 | " # concatenate two codes to form the complete extrinsic style code\n",
|
881 | 881 | " latent = torch.cat((res_in, ada_in), dim=1)\n",
|
|
0 commit comments