1
+ from src .facepose .mp_utils import LMKExtractor
2
+ from src .facepose .draw_utils import FaceMeshVisualizer
3
+ from src .facepose .motion_utils import motion_sync
4
+ from src .facematting .u2net_matting import U2NET
5
+ from src .decalib .utils import util
6
+ from src .decalib .utils .tensor_cropper import transform_points
7
+ from src .decalib .deca import DECA
8
+ from src .decalib .utils .config import cfg as deca_cfg
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import cv2
14
+ import os
15
+ import argparse
16
+
17
+
18
+ class FaceMatting :
19
+ def __init__ (self ) -> None :
20
+ self .net = U2NET (3 ,1 ).cuda ()
21
+ self .net .load_state_dict (torch .load ("./src/facematting/u2net_human_seg.pth" ))
22
+
23
+ def portrait_matting (self , rgb_image ):
24
+ rgb_image = cv2 .resize (rgb_image , (320 , 320 ))[None ] / 255
25
+ rgb_image [:,:,0 ] = (rgb_image [:,:,0 ] - 0.485 ) / 0.229
26
+ rgb_image [:,:,1 ] = (rgb_image [:,:,1 ] - 0.456 ) / 0.224
27
+ rgb_image [:,:,2 ] = (rgb_image [:,:,2 ] - 0.406 ) / 0.225
28
+ rgb_image_th = torch .tensor (rgb_image , dtype = torch .float32 ).cuda ().permute (0 , 3 , 1 , 2 )
29
+ with torch .no_grad ():
30
+ d1 ,d2 ,d3 ,d4 ,d5 ,d6 ,d7 = self .net (rgb_image_th )
31
+ # normalization
32
+ pred = d1 [:,0 ,:,:]
33
+ ma = torch .max (pred )
34
+ mi = torch .min (pred )
35
+ alpha = (pred - mi )/ (ma - mi )
36
+ alpha = alpha .detach ().cpu ().numpy ()[0 ]
37
+ alpha [alpha > 0.5 ] = 255
38
+ alpha [alpha <= 0.5 ] = 0
39
+ alpha = np .dstack ([alpha , alpha , alpha ])
40
+ alpha = cv2 .resize (alpha , (512 , 512 ))
41
+ alpha = cv2 .dilate (alpha , np .ones ([7 , 7 ]))
42
+ return alpha
43
+
44
+
45
+ class FaceImageRender :
46
+ def __init__ (self ) -> None :
47
+ # Init DECA
48
+ self .deca = DECA (config = deca_cfg )
49
+ f_mask = np .load ('./src/decalib/data/FLAME_masks_face-id.pkl' , allow_pickle = True , encoding = 'latin1' )
50
+ v_mask = np .load ('./src/decalib/data/FLAME_masks.pkl' , allow_pickle = True , encoding = 'latin1' )
51
+ self .mask = {
52
+ 'v_mask' :v_mask ['face' ].tolist (),
53
+ 'f_mask' :f_mask ['face' ].tolist ()
54
+ }
55
+
56
+ def image_to_3dcoeff (self , rgb_image ):
57
+ with torch .no_grad ():
58
+ codedict , detected_flag = self .deca .img_to_3dcoeff (rgb_image )
59
+ return codedict
60
+
61
+ def render_shape (self , shape , exp , pose , cam , light , tform , h , w ):
62
+ with torch .no_grad ():
63
+ # all parameters are from codedict
64
+ verts , landmarks2d , landmarks3d = self .deca .flame (shape_params = shape , expression_params = exp , pose_params = pose )
65
+
66
+ ## projection
67
+ trans_verts = util .batch_orth_proj (verts , cam ); trans_verts [:,:,1 :] = - trans_verts [:,:,1 :]
68
+
69
+ points_scale = [self .deca .image_size , self .deca .image_size ]
70
+ trans_verts = transform_points (trans_verts , tform , points_scale , [h , w ])
71
+
72
+ shape_images , _ , grid , alpha_images , albedo_images = self .deca .render .render_shape (verts , trans_verts , h = h , w = w , lights = light , images = None , return_grid = True , mask = self .mask )
73
+ shape_images = shape_images .permute (0 , 2 , 3 , 1 ).clamp (0 , 1 ).detach ().cpu ().numpy ()[0 ] * 255
74
+ albedo_images = albedo_images .permute (0 , 2 , 3 , 1 ).clamp (0 , 1 ).detach ().cpu ().numpy ()[0 ] * 255
75
+ return shape_images , albedo_images
76
+
77
+ def render_shape_with_light (self , codedict , target_light = None ):
78
+ if target_light is None :
79
+ target_light = codedict ["light" ]
80
+ shape , exp , pose = codedict ["shape" ], codedict ["exp" ], codedict ["pose" ]
81
+ cam , tform , h , w = codedict ["cam" ], codedict ["tform" ], codedict ["height" ], codedict ["width" ]
82
+ shape_image , albedo_image = self .render_shape (shape , exp , pose , cam , target_light , tform , h , w )
83
+ return shape_image
84
+
85
+ def render_motion_single (self , image ):
86
+ codedict = self .image_to_3dcoeff (image )
87
+ shading = self .render_shape_with_light (codedict )
88
+ return shading
89
+
90
+ def render_motion_single_with_light (self , image , target_light_image ):
91
+ codedict = self .image_to_3dcoeff (image )
92
+ target_light = self .image_to_3dcoeff (target_light_image )["light" ]
93
+ shading = self .render_shape_with_light (codedict , target_light = target_light )
94
+ return shading
95
+
96
+ def render_motion_sync (self , ref_image , driver_frames , target_light_image ):
97
+ ref_code_dict = self .image_to_3dcoeff (ref_image )
98
+ target_light = self .image_to_3dcoeff (target_light_image )["light" ]
99
+
100
+ shading_frames = []
101
+ for drv_frm in tqdm (driver_frames ):
102
+ codedict = self .image_to_3dcoeff (drv_frm )
103
+ shape , exp , pose = ref_code_dict ["shape" ], ref_code_dict ["exp" ], codedict ["pose" ]
104
+ cam , tform , h , w = ref_code_dict ["cam" ], ref_code_dict ["tform" ], ref_code_dict ["height" ], ref_code_dict ["width" ]
105
+ shape_image , albedo_image = self .render_shape (shape , exp , pose , cam , target_light , tform , h , w )
106
+ shading_frames .append (shape_image )
107
+ return shading_frames
108
+
109
+ def render_motion_sync_relative (self , ref_image , driver_frames , target_light_image ):
110
+ ref_codedict = self .image_to_3dcoeff (ref_image )
111
+ target_light = self .image_to_3dcoeff (target_light_image )["light" ]
112
+
113
+ drv_codedict_list = []
114
+ shading_frames = []
115
+ for drv_frm in tqdm (driver_frames ):
116
+ drv_codedict = self .image_to_3dcoeff (drv_frm )
117
+ drv_codedict_list .append (drv_codedict )
118
+
119
+ # best_dist = 10000
120
+ # best_pose = None
121
+ # for idx, drv_codedict in enumerate(drv_codedict_list):
122
+ # dist = torch.mean(torch.abs(ref_codedict["pose"] - drv_codedict["pose"]))
123
+ # if dist < best_dist:
124
+ # best_dist = dist
125
+ # best_pose = drv_codedict["pose"]
126
+ best_pose = drv_codedict_list [0 ]["pose" ]
127
+ best_exp = drv_codedict_list [0 ]["exp" ]
128
+ for drv_codedict in drv_codedict_list :
129
+ relative_pose = drv_codedict ["pose" ] - best_pose + ref_codedict ["pose" ]
130
+ relative_exp = drv_codedict ["exp" ] - best_exp + ref_codedict ["exp" ]
131
+ shape , exp , pose = ref_codedict ["shape" ], relative_exp , relative_pose
132
+ cam , tform , h , w = ref_codedict ["cam" ], ref_codedict ["tform" ], ref_codedict ["height" ], ref_codedict ["width" ]
133
+ shape_image , albedo_image = self .render_shape (shape , exp , pose , cam , target_light , tform , h , w )
134
+ shading_frames .append (shape_image )
135
+ return shading_frames
136
+
137
+ def render_motion_sync (self , ref_image , driver_frames , target_light_image ):
138
+ ref_codedict = self .image_to_3dcoeff (ref_image )
139
+ target_light = self .image_to_3dcoeff (target_light_image )["light" ]
140
+
141
+ drv_codedict_list = []
142
+ shading_frames = []
143
+ for drv_frm in tqdm (driver_frames ):
144
+ drv_codedict = self .image_to_3dcoeff (drv_frm )
145
+ drv_codedict_list .append (drv_codedict )
146
+
147
+ for drv_codedict in drv_codedict_list :
148
+ shape , exp , pose = ref_codedict ["shape" ], drv_codedict ["exp" ], drv_codedict ["pose" ]
149
+ cam , tform , h , w = ref_codedict ["cam" ], ref_codedict ["tform" ], ref_codedict ["height" ], ref_codedict ["width" ]
150
+ shape_image , albedo_image = self .render_shape (shape , exp , pose , cam , target_light , tform , h , w )
151
+ shading_frames .append (shape_image )
152
+ return shading_frames
153
+
154
+ class FaceKPDetector :
155
+ def __init__ (self ) -> None :
156
+ self .vis = FaceMeshVisualizer (draw_iris = False , draw_mouse = True , draw_eye = True , draw_nose = True , draw_eyebrow = True , draw_pupil = True )
157
+ self .lmk_extractor = LMKExtractor ()
158
+
159
+ def motion_sync (self , ref_image , driver_frames ):
160
+ ref_image = cv2 .cvtColor (ref_image , cv2 .COLOR_RGB2BGR )
161
+ ref_frame = cv2 .resize (ref_image , (512 , 512 ))
162
+ ref_det = self .lmk_extractor (ref_frame )
163
+
164
+ sequence_driver_det = []
165
+ try :
166
+ for frame in tqdm (driver_frames ):
167
+ frame = cv2 .cvtColor (frame , cv2 .COLOR_RGB2BGR )
168
+ frame = cv2 .resize (frame , (512 , 512 ))
169
+ result = self .lmk_extractor (frame )
170
+ assert result is not None , "bad video, face not detected"
171
+ sequence_driver_det .append (result )
172
+ except :
173
+ print ("face detection failed" )
174
+ exit ()
175
+
176
+ sequence_det_ms = motion_sync (sequence_driver_det , ref_det )
177
+ pose_frames = [self .vis .draw_landmarks ((512 , 512 ), i , normed = False ) for i in sequence_det_ms ]
178
+ return pose_frames
179
+
180
+ def motion_self (self , driver_frames ):
181
+ pose_frames = []
182
+ try :
183
+ for frame in tqdm (driver_frames ):
184
+ frame = cv2 .cvtColor (frame , cv2 .COLOR_RGB2BGR )
185
+ frame = cv2 .resize (frame , (512 , 512 ))
186
+ frame_det = self .lmk_extractor (frame )
187
+ kpmap = self .vis .draw_landmarks ((512 , 512 ), frame_det ["lmks" ], normed = True )
188
+ pose_frames .append (kpmap )
189
+ except :
190
+ print ("face detection failed" )
191
+ exit ()
192
+
193
+ return pose_frames
194
+
195
+ def single_kp (self , image ):
196
+ frame_det = self .lmk_extractor (image )
197
+ kpmap = self .vis .draw_landmarks ((512 , 512 ), frame_det ["lmks" ], normed = True )
198
+ return kpmap
199
+
200
+ class InferVideo :
201
+ def __init__ (self ) -> None :
202
+ self .vis = FaceMeshVisualizer (draw_iris = False , draw_mouse = True , draw_eye = True , draw_nose = True , draw_eyebrow = True , draw_pupil = True )
203
+ self .lmk_extractor = LMKExtractor ()
204
+
205
+ self .fm = FaceMatting ()
206
+
207
+ self .fir = FaceImageRender ()
208
+
209
+ self .fkpd = FaceKPDetector ()
210
+
211
+ def inference (self , source_path , light_path , video_path , save_path , motion_align = "relative" ):
212
+ tmp_path = "resources/target/"
213
+
214
+ if os .path .exists (tmp_path ):
215
+ os .system (f"rm -r { tmp_path } " )
216
+
217
+ os .mkdir (tmp_path )
218
+ os .system (f"ffmpeg -i { video_path } { tmp_path } /%5d.png" )
219
+
220
+ # motion sync
221
+ source_image = np .array (Image .open (source_path ).resize ([512 , 512 ]))[..., :3 ]
222
+ target_lighting = np .array (Image .open (light_path ).resize ([512 , 512 ]))[..., :3 ]
223
+
224
+ driver_frames = [np .array (Image .open (os .path .join (tmp_path , str (i ).zfill (5 )+ ".png" )).resize ([512 , 512 ])) for i in range (1 , 1 + len (os .listdir (tmp_path )))]
225
+
226
+ aligned_kpmaps = self .fkpd .motion_self (driver_frames )
227
+
228
+ alpha = self .fm .portrait_matting (source_image )
229
+
230
+ if motion_align == "relative" :
231
+ aligned_shading = self .fir .render_motion_sync_relative (source_image , driver_frames , target_lighting )
232
+ else :
233
+ aligned_shading = self .fir .render_motion_sync (source_image , driver_frames , target_lighting )
234
+
235
+ for idx , (drv_frame , kpmap , shading ) in tqdm (enumerate (zip (driver_frames , aligned_kpmaps , aligned_shading ))):
236
+ img = np .concatenate ([source_image , alpha , drv_frame , kpmap , shading ], axis = 1 )
237
+ Image .fromarray (np .uint8 (img )).save (f"{ tmp_path } /{ str (idx + 1 ).zfill (5 )} .png" )
238
+
239
+ source_kp = self .fkpd .single_kp (source_image )
240
+ source_shading = self .fir .render_motion_single_with_light (source_image , source_image )
241
+
242
+ img = np .concatenate ([source_image , alpha , source_image , source_kp , source_shading ], axis = 1 )
243
+ Image .fromarray (np .uint8 (img )).save (f"{ tmp_path } /{ str (0 ).zfill (5 )} .png" )
244
+ os .system (f"ffmpeg -r 20 -i { tmp_path } /%05d.png -pix_fmt yuv420p -c:v libx264 { save_path } -y" )
245
+
246
+
247
+ if __name__ == "__main__" :
248
+ iv = InferVideo ()
249
+
250
+ parser = argparse .ArgumentParser ()
251
+ parser .add_argument ("--video_path" , type = str , default = "resources/WDA_DebbieDingell1_000.mp4" , help = "driving video path" )
252
+ parser .add_argument ("--source_path" , type = str , default = "resources/reference.png" , help = "reference image path" )
253
+ parser .add_argument ("--light_path" , type = str , default = "resources/target_lighting1.png" , help = "target lighting image " )
254
+ parser .add_argument ("--save_path" , type = str , default = "resources/shading.mp4" , help = "shading hints" )
255
+ parser .add_argument ("--motion_align" , type = str , default = "relative" , help = "motion alignment mode" )
256
+ args = parser .parse_args ()
257
+
258
+ iv .inference (source_path = args .source_path , light_path = args .light_path , video_path = args .video_path , save_path = args .save_path , motion_align = args .motion_align )
259
+
260
+
0 commit comments