|
327 | 327 | "opts.device = device\n", |
328 | 328 | "encoder = pSp(opts)\n", |
329 | 329 | "encoder.eval()\n", |
330 | | - "encoder.cuda()\n", |
| 330 | + "encoder = encoder.to(device)\n", |
331 | 331 | "\n", |
332 | 332 | "# load extrinsic style code\n", |
333 | 333 | "exstyles = np.load(os.path.join(MODEL_DIR, style_type, MODEL_PATHS[style_type+'-S'][\"name\"]), allow_pickle='TRUE').item()\n", |
|
462 | 462 | "outputs": [], |
463 | 463 | "source": [ |
464 | 464 | "if if_align_face:\n", |
465 | | - " I = transform(run_alignment(image_path)).unsqueeze(dim=0).cuda()\n", |
| 465 | + " I = transform(run_alignment(image_path)).unsqueeze(dim=0).to(device)\n", |
466 | 466 | "else:\n", |
467 | | - " I = F.adaptive_avg_pool2d(load_image(image_path).cuda(), 256)" |
| 467 | + " I = F.adaptive_avg_pool2d(load_image(image_path).to(device), 256)" |
468 | 468 | ] |
469 | 469 | }, |
470 | 470 | { |
|
612 | 612 | " z_plus_latent=True, return_z_plus_latent=True, resize=False) \n", |
613 | 613 | " img_rec = torch.clamp(img_rec.detach(), -1, 1)\n", |
614 | 614 | " \n", |
615 | | - " latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).cuda()\n", |
| 615 | + " latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).to(device)\n", |
616 | 616 | " # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer\n", |
617 | 617 | " latent[1,7:18] = instyle[0,7:18]\n", |
618 | 618 | " exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape)\n", |
|
824 | 824 | ], |
825 | 825 | "source": [ |
826 | 826 | "with torch.no_grad():\n", |
827 | | - " latent = torch.tensor(exstyles[stylename]).repeat(6,1,1).cuda()\n", |
828 | | - " latent2 = torch.tensor(exstyles[stylename2]).repeat(6,1,1).cuda()\n", |
829 | | - " fuse_weight = torch.arange(6).reshape(6,1,1).cuda() / 5.0\n", |
| 827 | + " latent = torch.tensor(exstyles[stylename]).repeat(6,1,1).to(device)\n", |
| 828 | + " latent2 = torch.tensor(exstyles[stylename2]).repeat(6,1,1).to(device)\n", |
| 829 | + " fuse_weight = torch.arange(6).reshape(6,1,1).to(device) / 5.0\n", |
830 | 830 | " fuse_latent = latent * fuse_weight + latent2 * (1-fuse_weight)\n", |
831 | 831 | " exstyle = generator.generator.style(fuse_latent.reshape(fuse_latent.shape[0]*fuse_latent.shape[1], fuse_latent.shape[2])).reshape(fuse_latent.shape)\n", |
832 | 832 | " \n", |
|
871 | 871 | "batch = 6 # sample 6 style codes\n", |
872 | 872 | "\n", |
873 | 873 | "with torch.no_grad():\n", |
874 | | - " instyle = torch.randn(6, 512).cuda()\n", |
| 874 | + " instyle = torch.randn(6, 512).to(device)\n", |
875 | 875 | " # sample structure codes\n", |
876 | | - " res_in = icpts.icp.netT(torch.randn(batch, 128).cuda()).reshape(-1,7,512)\n", |
| 876 | + " res_in = icpts.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,7,512)\n", |
877 | 877 | " # sample color codes\n", |
878 | | - " ada_in = icptc.icp.netT(torch.randn(batch, 128).cuda()).reshape(-1,11,512)\n", |
| 878 | + " ada_in = icptc.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,11,512)\n", |
879 | 879 | "\n", |
880 | 880 | " # concatenate two codes to form the complete extrinsic style code\n", |
881 | 881 | " latent = torch.cat((res_in, ada_in), dim=1)\n", |
|
0 commit comments