Skip to content

Commit 0044eb2

Browse files
使用 Colaboratory 创建
1 parent bfc8207 commit 0044eb2

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

Diff for: notebooks/inference_playground.ipynb

+10-10
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@
327327
"opts.device = device\n",
328328
"encoder = pSp(opts)\n",
329329
"encoder.eval()\n",
330-
"encoder.cuda()\n",
330+
"encoder = encoder.to(device)\n",
331331
"\n",
332332
"# load extrinsic style code\n",
333333
"exstyles = np.load(os.path.join(MODEL_DIR, style_type, MODEL_PATHS[style_type+'-S'][\"name\"]), allow_pickle='TRUE').item()\n",
@@ -462,9 +462,9 @@
462462
"outputs": [],
463463
"source": [
464464
"if if_align_face:\n",
465-
" I = transform(run_alignment(image_path)).unsqueeze(dim=0).cuda()\n",
465+
" I = transform(run_alignment(image_path)).unsqueeze(dim=0).to(device)\n",
466466
"else:\n",
467-
" I = F.adaptive_avg_pool2d(load_image(image_path).cuda(), 256)"
467+
" I = F.adaptive_avg_pool2d(load_image(image_path).to(device), 256)"
468468
]
469469
},
470470
{
@@ -612,7 +612,7 @@
612612
" z_plus_latent=True, return_z_plus_latent=True, resize=False) \n",
613613
" img_rec = torch.clamp(img_rec.detach(), -1, 1)\n",
614614
" \n",
615-
" latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).cuda()\n",
615+
" latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).to(device)\n",
616616
" # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer\n",
617617
" latent[1,7:18] = instyle[0,7:18]\n",
618618
" exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape)\n",
@@ -824,9 +824,9 @@
824824
],
825825
"source": [
826826
"with torch.no_grad():\n",
827-
" latent = torch.tensor(exstyles[stylename]).repeat(6,1,1).cuda()\n",
828-
" latent2 = torch.tensor(exstyles[stylename2]).repeat(6,1,1).cuda()\n",
829-
" fuse_weight = torch.arange(6).reshape(6,1,1).cuda() / 5.0\n",
827+
" latent = torch.tensor(exstyles[stylename]).repeat(6,1,1).to(device)\n",
828+
" latent2 = torch.tensor(exstyles[stylename2]).repeat(6,1,1).to(device)\n",
829+
" fuse_weight = torch.arange(6).reshape(6,1,1).to(device) / 5.0\n",
830830
" fuse_latent = latent * fuse_weight + latent2 * (1-fuse_weight)\n",
831831
" exstyle = generator.generator.style(fuse_latent.reshape(fuse_latent.shape[0]*fuse_latent.shape[1], fuse_latent.shape[2])).reshape(fuse_latent.shape)\n",
832832
" \n",
@@ -871,11 +871,11 @@
871871
"batch = 6 # sample 6 style codes\n",
872872
"\n",
873873
"with torch.no_grad():\n",
874-
" instyle = torch.randn(6, 512).cuda()\n",
874+
" instyle = torch.randn(6, 512).to(device)\n",
875875
" # sample structure codes\n",
876-
" res_in = icpts.icp.netT(torch.randn(batch, 128).cuda()).reshape(-1,7,512)\n",
876+
" res_in = icpts.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,7,512)\n",
877877
" # sample color codes\n",
878-
" ada_in = icptc.icp.netT(torch.randn(batch, 128).cuda()).reshape(-1,11,512)\n",
878+
" ada_in = icptc.icp.netT(torch.randn(batch, 128).to(device)).reshape(-1,11,512)\n",
879879
"\n",
880880
" # concatenate two codes to form the complete extrinsic style code\n",
881881
" latent = torch.cat((res_in, ada_in), dim=1)\n",

0 commit comments

Comments
 (0)