Skip to content

Commit b93acba

Browse files
support w+ psp encoder
1 parent 1309f94 commit b93acba

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

style_transfer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,13 @@ def run_alignment(args):
117117

118118
stylename = list(exstyles.keys())[args.style_id]
119119
latent = torch.tensor(exstyles[stylename]).to(device)
120-
if args.preserve_color:
120+
if args.preserve_color and not args.wplus:
121121
latent[:,7:18] = instyle[:,7:18]
122122
# extrinsic styte code
123123
exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape)
124-
124+
if args.preserve_color and args.wplus:
125+
exstyle[:,7:18] = instyle[:,7:18]
126+
125127
# load style image if it exists
126128
S = None
127129
if os.path.exists(os.path.join(args.data_path, args.style, 'images/train', stylename)):

0 commit comments

Comments
 (0)