Skip to content

Commit 07b3bd8

Browse files
authoredMar 2, 2025
Add files via upload
1 parent fb693e4 commit 07b3bd8

File tree

2 files changed

+445
-0
lines changed

2 files changed

+445
-0
lines changed
 

‎inference.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import os
18+
import cv2
19+
import argparse
20+
import numpy as np
21+
from PIL import Image
22+
import torch
23+
import torch.utils.checkpoint
24+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
25+
26+
from diffusers import AutoencoderKL, EulerDiscreteScheduler
27+
28+
from src.modules.head_net import HeadNet
29+
from src.modules.light_net import LightNet
30+
from src.modules.ref_net import RefNet
31+
from src.modules.unet import UNetSpatioTemporalConditionModel
32+
from src.pipelines.pipeline_relightalbepa_composer import RelightablepaPipeline
33+
34+
35+
pretrained_model_name_or_path = "../../stable-video-diffusion-img2vid-xt"
36+
37+
# Load scheduler, tokenizer and models.
38+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
39+
pretrained_model_name_or_path, subfolder="scheduler")
40+
feature_extractor = CLIPImageProcessor.from_pretrained(
41+
pretrained_model_name_or_path, subfolder="feature_extractor"
42+
)
43+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
44+
pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16"
45+
)
46+
vae = AutoencoderKL.from_pretrained(
47+
pretrained_model_name_or_path, subfolder="sd-vae-ft-mse")
48+
unet = UNetSpatioTemporalConditionModel.from_config(
49+
pretrained_model_name_or_path,
50+
subfolder="unet",
51+
low_cpu_mem_usage=True,
52+
)
53+
head_embedder = HeadNet(noise_latent_channels=320)
54+
light_embedder = LightNet(noise_latent_channels=320)
55+
ref_embedder = RefNet(noise_latent_channels=320)
56+
57+
# Freeze vae and image_encoder
58+
vae.requires_grad_(False)
59+
image_encoder.requires_grad_(False)
60+
unet.requires_grad_(False)
61+
head_embedder.requires_grad_(False)
62+
light_embedder.requires_grad_(False)
63+
ref_embedder.requires_grad_(False)
64+
65+
unet.load_state_dict(torch.load("outputs/checkpoint-29000/unet.pth"))
66+
head_embedder.load_state_dict(torch.load("outputs/checkpoint-29000/head_embedder.pth"))
67+
light_embedder.load_state_dict(torch.load("outputs/checkpoint-29000/light_embedder.pth"))
68+
ref_embedder.load_state_dict(torch.load("outputs/checkpoint-29000/app_embedder.pth"))
69+
70+
weight_dtype = torch.float16
71+
device = "cuda"
72+
73+
image_encoder.to(device, dtype=weight_dtype)
74+
vae.to(device, dtype=weight_dtype)
75+
unet.to(device, dtype=weight_dtype)
76+
head_embedder.to(device, dtype=weight_dtype)
77+
light_embedder.to(device, dtype=weight_dtype)
78+
ref_embedder.to(device, dtype=weight_dtype)
79+
80+
# The models need unwrapping because for compatibility in distributed training mode.
81+
pipeline = RelightablepaPipeline.from_pretrained(
82+
pretrained_model_name_or_path,
83+
unet=unet,
84+
image_encoder=image_encoder,
85+
vae=vae,
86+
head_embedder=head_embedder,
87+
light_embedder=light_embedder,
88+
ref_embedder=ref_embedder,
89+
torch_dtype=weight_dtype,
90+
)
91+
pipeline = pipeline.to(device)
92+
pipeline.set_progress_bar_config(disable=False)
93+
94+
def portrait_animation_and_relighting(video_path, save_path, guidance, inference_steps, driving_mode="relighting"):
95+
path = "resources/target/"
96+
path_tmp = "resources/tmp/"
97+
if not os.path.exists(path):
98+
os.system(f"mkdir {path}")
99+
else:
100+
os.system(f"rm -r {path}/*")
101+
102+
if not os.path.exists(path_tmp):
103+
os.system(f"mkdir {path_tmp}")
104+
else:
105+
os.system(f"rm -r {path_tmp}/*")
106+
107+
os.system(f"ffmpeg -i {video_path} {path}/%5d.png")
108+
109+
pixel_values = []
110+
pixel_head = []
111+
pixel_values_light = []
112+
img = np.array(Image.open(path + "00001.png"))
113+
# img = cv2.resize(img, (img.shape[1], img.shape[0]))
114+
pixel_ref_values = img[:, :512]
115+
pixel_ref_mask = img[:, 512:1024]
116+
pixel_ref_mask = cv2.resize(pixel_ref_mask, (64, 64))
117+
# pixel_ref_mask = np.ones_like(pixel_ref_mask) * 255
118+
119+
for i in range(1, len(os.listdir(path))+1):
120+
img = np.array(Image.open(f"{path}/{str(i).zfill(5)}.png"), dtype=np.uint8)
121+
# img = cv2.resize(img, (img.shape[1], img.shape[0]))
122+
pixel_values.append(img[:, 1024:1536][None])
123+
pixel_head.append(img[:, 1536:2048][None])
124+
pixel_values_light.append(img[:, 2048:2560][None])
125+
126+
pixel_values = torch.tensor(np.concatenate(pixel_values, axis=0)[None]).to(device, dtype=weight_dtype).permute(0, 1, 4, 2, 3) / 127.5 - 1.0
127+
pixel_head = torch.tensor(np.concatenate(pixel_head, axis=0)[None]).to(device, dtype=weight_dtype).permute(0, 1, 4, 2, 3) / 255.0
128+
pixel_values_light = torch.tensor(np.concatenate(pixel_values_light, axis=0)[None]).to(device, dtype=weight_dtype).permute(0, 1, 4, 2, 3) / 255.0
129+
130+
pixel_ref_values = torch.tensor(pixel_ref_values[None, None]).repeat(1, pixel_values.size(1), 1, 1, 1).to(device, dtype=weight_dtype).permute(0, 1, 4, 2, 3) / 127.5 - 1.0
131+
pixel_ref_mask = torch.tensor(pixel_ref_mask[None, None]).repeat(1, pixel_values.size(1), 1, 1, 1).to(device, dtype=weight_dtype).permute(0, 1, 4, 2, 3)[:, :, 0:1] / 255.0
132+
133+
num_frames = pixel_values.size(1)
134+
pixel_pil = [Image.fromarray(np.uint8((pixel_values.permute(0, 1, 3, 4, 2).cpu().numpy()[0, i] + 1) * 127.5)) for i in range(num_frames)]
135+
heads_pil = [Image.fromarray(np.uint8((pixel_head.permute(0, 1, 3, 4, 2).cpu().numpy()[0, i]) * 255)) for i in range(num_frames)]
136+
lights_drv_pil = [Image.fromarray(np.uint8((pixel_values_light.permute(0, 1, 3, 4, 2).cpu().numpy()[0, i]) * 255)) for i in range(num_frames)]
137+
reference_pil = [Image.fromarray(np.uint8((pixel_ref_values.permute(0, 1, 3, 4, 2).cpu().numpy()[0, 0] + 1) * 127.5))]
138+
139+
if driving_mode == "relighting":
140+
model_args = [{"image_head": None, "image_light": pixel_values_light, "image_ref": pixel_ref_values}, # cond
141+
{"image_head": None, "image_light": None, "image_ref": pixel_ref_values}] # uncond
142+
elif driving_mode == "landmark":
143+
model_args = [{"image_head": pixel_head, "image_light": None, "image_ref": pixel_ref_values}, # cond
144+
{"image_head": None, "image_light": None, "image_ref": None}] # uncond
145+
else:
146+
model_args = [{"image_head": None, "image_light": pixel_values_light, "image_ref": pixel_ref_values}, # cond
147+
{"image_head": None, "image_light": None, "image_ref": None}] # uncond
148+
149+
frames = pipeline(
150+
reference_pil, model_args=model_args, image_mask=pixel_ref_mask,
151+
num_frames=pixel_head.size(1),
152+
tile_size=16, tile_overlap=6,
153+
height=512, width=512, fps=7,
154+
noise_aug_strength=0.02, num_inference_steps=inference_steps,
155+
generator=None, min_guidance_scale=guidance,
156+
max_guidance_scale=guidance, decode_chunk_size=8, output_type="pt", device="cuda"
157+
).frames.cpu()
158+
video_frames = (frames.permute(0, 1, 3, 4, 2) * 255.0).to(torch.uint8).numpy()[0]
159+
160+
final = []
161+
for i in range(pixel_head.size(1)):
162+
img = video_frames[i]
163+
head = np.array(heads_pil[i])
164+
light = np.array(lights_drv_pil[i])
165+
tar = np.array(pixel_pil[i])
166+
ref = np.array(reference_pil[0])
167+
# final.append(np.concatenate([ref, head, light, img, tar], axis=1))
168+
Image.fromarray(np.uint8(np.concatenate([ref, light, img, tar], axis=1))).save(f"{path_tmp}/{str(i).zfill(5)}.png")
169+
170+
os.system(f"ffmpeg -r 20 -i {path_tmp}/%05d.png -pix_fmt yuv420p -c:v libx264 {save_path} -y")
171+
# torchvision.io.write_video(save_path, final, fps=20, video_codec='h264', options={'crf': '10'})
172+
173+
174+
if __name__ == "__main__":
175+
parser = argparse.ArgumentParser()
176+
parser.add_argument("--video_path", type=str, default="resources/shading.mp4", help="reference and shading")
177+
parser.add_argument("--save_path", type=str, default="result.mp4", help="result save path")
178+
parser.add_argument("--guidance", type=float, default=4.5, help="lighting intensity")
179+
parser.add_argument("--inference_steps", type=int, default=25, help="diffusion reverse sampling steps")
180+
181+
args = parser.parse_args()
182+
183+
portrait_animation_and_relighting(video_path=args.video_path, save_path=args.save_path, guidance=args.guidance, inference_steps=args.inference_steps, driving_mode="relighting")
184+
185+

‎preprocess.py

+260
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
from src.facepose.mp_utils import LMKExtractor
2+
from src.facepose.draw_utils import FaceMeshVisualizer
3+
from src.facepose.motion_utils import motion_sync
4+
from src.facematting.u2net_matting import U2NET
5+
from src.decalib.utils import util
6+
from src.decalib.utils.tensor_cropper import transform_points
7+
from src.decalib.deca import DECA
8+
from src.decalib.utils.config import cfg as deca_cfg
9+
from PIL import Image
10+
from tqdm import tqdm
11+
import numpy as np
12+
import torch
13+
import cv2
14+
import os
15+
import argparse
16+
17+
18+
class FaceMatting:
19+
def __init__(self) -> None:
20+
self.net = U2NET(3,1).cuda()
21+
self.net.load_state_dict(torch.load("./src/facematting/u2net_human_seg.pth"))
22+
23+
def portrait_matting(self, rgb_image):
24+
rgb_image = cv2.resize(rgb_image, (320, 320))[None] / 255
25+
rgb_image[:,:,0] = (rgb_image[:,:,0] - 0.485) / 0.229
26+
rgb_image[:,:,1] = (rgb_image[:,:,1] - 0.456) / 0.224
27+
rgb_image[:,:,2] = (rgb_image[:,:,2] - 0.406) / 0.225
28+
rgb_image_th = torch.tensor(rgb_image, dtype=torch.float32).cuda().permute(0, 3, 1, 2)
29+
with torch.no_grad():
30+
d1,d2,d3,d4,d5,d6,d7 = self.net(rgb_image_th)
31+
# normalization
32+
pred = d1[:,0,:,:]
33+
ma = torch.max(pred)
34+
mi = torch.min(pred)
35+
alpha = (pred-mi)/(ma-mi)
36+
alpha = alpha.detach().cpu().numpy()[0]
37+
alpha[alpha > 0.5] = 255
38+
alpha[alpha <=0.5] = 0
39+
alpha = np.dstack([alpha, alpha, alpha])
40+
alpha = cv2.resize(alpha, (512, 512))
41+
alpha = cv2.dilate(alpha, np.ones([7, 7]))
42+
return alpha
43+
44+
45+
class FaceImageRender:
46+
def __init__(self) -> None:
47+
# Init DECA
48+
self.deca = DECA(config=deca_cfg)
49+
f_mask = np.load('./src/decalib/data/FLAME_masks_face-id.pkl', allow_pickle=True, encoding='latin1')
50+
v_mask = np.load('./src/decalib/data/FLAME_masks.pkl', allow_pickle=True, encoding='latin1')
51+
self.mask={
52+
'v_mask':v_mask['face'].tolist(),
53+
'f_mask':f_mask['face'].tolist()
54+
}
55+
56+
def image_to_3dcoeff(self, rgb_image):
57+
with torch.no_grad():
58+
codedict, detected_flag = self.deca.img_to_3dcoeff(rgb_image)
59+
return codedict
60+
61+
def render_shape(self, shape, exp, pose, cam, light, tform, h, w):
62+
with torch.no_grad():
63+
# all parameters are from codedict
64+
verts, landmarks2d, landmarks3d = self.deca.flame(shape_params=shape, expression_params=exp, pose_params=pose)
65+
66+
## projection
67+
trans_verts = util.batch_orth_proj(verts, cam); trans_verts[:,:,1:] = -trans_verts[:,:,1:]
68+
69+
points_scale = [self.deca.image_size, self.deca.image_size]
70+
trans_verts = transform_points(trans_verts, tform, points_scale, [h, w])
71+
72+
shape_images, _, grid, alpha_images, albedo_images =self.deca.render.render_shape(verts, trans_verts, h=h, w=w, lights=light, images=None, return_grid=True, mask=self.mask)
73+
shape_images = shape_images.permute(0, 2, 3, 1).clamp(0, 1).detach().cpu().numpy()[0] * 255
74+
albedo_images = albedo_images.permute(0, 2, 3, 1).clamp(0, 1).detach().cpu().numpy()[0] * 255
75+
return shape_images, albedo_images
76+
77+
def render_shape_with_light(self, codedict, target_light=None):
78+
if target_light is None:
79+
target_light = codedict["light"]
80+
shape, exp, pose = codedict["shape"], codedict["exp"], codedict["pose"]
81+
cam, tform, h, w = codedict["cam"], codedict["tform"], codedict["height"], codedict["width"]
82+
shape_image, albedo_image = self.render_shape(shape, exp, pose, cam, target_light, tform, h, w)
83+
return shape_image
84+
85+
def render_motion_single(self, image):
86+
codedict = self.image_to_3dcoeff(image)
87+
shading = self.render_shape_with_light(codedict)
88+
return shading
89+
90+
def render_motion_single_with_light(self, image, target_light_image):
91+
codedict = self.image_to_3dcoeff(image)
92+
target_light = self.image_to_3dcoeff(target_light_image)["light"]
93+
shading = self.render_shape_with_light(codedict, target_light=target_light)
94+
return shading
95+
96+
def render_motion_sync(self, ref_image, driver_frames, target_light_image):
97+
ref_code_dict = self.image_to_3dcoeff(ref_image)
98+
target_light = self.image_to_3dcoeff(target_light_image)["light"]
99+
100+
shading_frames = []
101+
for drv_frm in tqdm(driver_frames):
102+
codedict = self.image_to_3dcoeff(drv_frm)
103+
shape, exp, pose = ref_code_dict["shape"], ref_code_dict["exp"], codedict["pose"]
104+
cam, tform, h, w = ref_code_dict["cam"], ref_code_dict["tform"], ref_code_dict["height"], ref_code_dict["width"]
105+
shape_image, albedo_image = self.render_shape(shape, exp, pose, cam, target_light, tform, h, w)
106+
shading_frames.append(shape_image)
107+
return shading_frames
108+
109+
def render_motion_sync_relative(self, ref_image, driver_frames, target_light_image):
110+
ref_codedict = self.image_to_3dcoeff(ref_image)
111+
target_light = self.image_to_3dcoeff(target_light_image)["light"]
112+
113+
drv_codedict_list = []
114+
shading_frames = []
115+
for drv_frm in tqdm(driver_frames):
116+
drv_codedict = self.image_to_3dcoeff(drv_frm)
117+
drv_codedict_list.append(drv_codedict)
118+
119+
# best_dist = 10000
120+
# best_pose = None
121+
# for idx, drv_codedict in enumerate(drv_codedict_list):
122+
# dist = torch.mean(torch.abs(ref_codedict["pose"] - drv_codedict["pose"]))
123+
# if dist < best_dist:
124+
# best_dist = dist
125+
# best_pose = drv_codedict["pose"]
126+
best_pose = drv_codedict_list[0]["pose"]
127+
best_exp = drv_codedict_list[0]["exp"]
128+
for drv_codedict in drv_codedict_list:
129+
relative_pose = drv_codedict["pose"] - best_pose + ref_codedict["pose"]
130+
relative_exp = drv_codedict["exp"] - best_exp + ref_codedict["exp"]
131+
shape, exp, pose = ref_codedict["shape"], relative_exp, relative_pose
132+
cam, tform, h, w = ref_codedict["cam"], ref_codedict["tform"], ref_codedict["height"], ref_codedict["width"]
133+
shape_image, albedo_image = self.render_shape(shape, exp, pose, cam, target_light, tform, h, w)
134+
shading_frames.append(shape_image)
135+
return shading_frames
136+
137+
def render_motion_sync(self, ref_image, driver_frames, target_light_image):
138+
ref_codedict = self.image_to_3dcoeff(ref_image)
139+
target_light = self.image_to_3dcoeff(target_light_image)["light"]
140+
141+
drv_codedict_list = []
142+
shading_frames = []
143+
for drv_frm in tqdm(driver_frames):
144+
drv_codedict = self.image_to_3dcoeff(drv_frm)
145+
drv_codedict_list.append(drv_codedict)
146+
147+
for drv_codedict in drv_codedict_list:
148+
shape, exp, pose = ref_codedict["shape"], drv_codedict["exp"], drv_codedict["pose"]
149+
cam, tform, h, w = ref_codedict["cam"], ref_codedict["tform"], ref_codedict["height"], ref_codedict["width"]
150+
shape_image, albedo_image = self.render_shape(shape, exp, pose, cam, target_light, tform, h, w)
151+
shading_frames.append(shape_image)
152+
return shading_frames
153+
154+
class FaceKPDetector:
155+
def __init__(self) -> None:
156+
self.vis = FaceMeshVisualizer(draw_iris=False, draw_mouse=True, draw_eye=True, draw_nose=True, draw_eyebrow=True, draw_pupil=True)
157+
self.lmk_extractor = LMKExtractor()
158+
159+
def motion_sync(self, ref_image, driver_frames):
160+
ref_image = cv2.cvtColor(ref_image, cv2.COLOR_RGB2BGR)
161+
ref_frame =cv2.resize(ref_image, (512, 512))
162+
ref_det = self.lmk_extractor(ref_frame)
163+
164+
sequence_driver_det = []
165+
try:
166+
for frame in tqdm(driver_frames):
167+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
168+
frame =cv2.resize(frame, (512, 512))
169+
result = self.lmk_extractor(frame)
170+
assert result is not None, "bad video, face not detected"
171+
sequence_driver_det.append(result)
172+
except:
173+
print("face detection failed")
174+
exit()
175+
176+
sequence_det_ms = motion_sync(sequence_driver_det, ref_det)
177+
pose_frames = [self.vis.draw_landmarks((512, 512), i, normed=False) for i in sequence_det_ms]
178+
return pose_frames
179+
180+
def motion_self(self, driver_frames):
181+
pose_frames = []
182+
try:
183+
for frame in tqdm(driver_frames):
184+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
185+
frame =cv2.resize(frame, (512, 512))
186+
frame_det = self.lmk_extractor(frame)
187+
kpmap = self.vis.draw_landmarks((512, 512), frame_det["lmks"], normed=True)
188+
pose_frames.append(kpmap)
189+
except:
190+
print("face detection failed")
191+
exit()
192+
193+
return pose_frames
194+
195+
def single_kp(self, image):
196+
frame_det = self.lmk_extractor(image)
197+
kpmap = self.vis.draw_landmarks((512, 512), frame_det["lmks"], normed=True)
198+
return kpmap
199+
200+
class InferVideo:
201+
def __init__(self) -> None:
202+
self.vis = FaceMeshVisualizer(draw_iris=False, draw_mouse=True, draw_eye=True, draw_nose=True, draw_eyebrow=True, draw_pupil=True)
203+
self.lmk_extractor = LMKExtractor()
204+
205+
self.fm = FaceMatting()
206+
207+
self.fir = FaceImageRender()
208+
209+
self.fkpd = FaceKPDetector()
210+
211+
def inference(self, source_path, light_path, video_path, save_path, motion_align="relative"):
212+
tmp_path = "resources/target/"
213+
214+
if os.path.exists(tmp_path):
215+
os.system(f"rm -r {tmp_path}")
216+
217+
os.mkdir(tmp_path)
218+
os.system(f"ffmpeg -i {video_path} {tmp_path}/%5d.png")
219+
220+
# motion sync
221+
source_image = np.array(Image.open(source_path).resize([512, 512]))[..., :3]
222+
target_lighting = np.array(Image.open(light_path).resize([512, 512]))[..., :3]
223+
224+
driver_frames = [np.array(Image.open(os.path.join(tmp_path, str(i).zfill(5)+".png")).resize([512, 512])) for i in range(1, 1 + len(os.listdir(tmp_path)))]
225+
226+
aligned_kpmaps = self.fkpd.motion_self(driver_frames)
227+
228+
alpha = self.fm.portrait_matting(source_image)
229+
230+
if motion_align == "relative":
231+
aligned_shading = self.fir.render_motion_sync_relative(source_image, driver_frames, target_lighting)
232+
else:
233+
aligned_shading = self.fir.render_motion_sync(source_image, driver_frames, target_lighting)
234+
235+
for idx, (drv_frame, kpmap, shading) in tqdm(enumerate(zip(driver_frames, aligned_kpmaps, aligned_shading))):
236+
img = np.concatenate([source_image, alpha, drv_frame, kpmap, shading], axis=1)
237+
Image.fromarray(np.uint8(img)).save(f"{tmp_path}/{str(idx + 1).zfill(5)}.png")
238+
239+
source_kp = self.fkpd.single_kp(source_image)
240+
source_shading = self.fir.render_motion_single_with_light(source_image, source_image)
241+
242+
img = np.concatenate([source_image, alpha, source_image, source_kp, source_shading], axis=1)
243+
Image.fromarray(np.uint8(img)).save(f"{tmp_path}/{str(0).zfill(5)}.png")
244+
os.system(f"ffmpeg -r 20 -i {tmp_path}/%05d.png -pix_fmt yuv420p -c:v libx264 {save_path} -y")
245+
246+
247+
if __name__ == "__main__":
248+
iv = InferVideo()
249+
250+
parser = argparse.ArgumentParser()
251+
parser.add_argument("--video_path", type=str, default="resources/WDA_DebbieDingell1_000.mp4", help="driving video path")
252+
parser.add_argument("--source_path", type=str, default="resources/reference.png", help="reference image path")
253+
parser.add_argument("--light_path", type=str, default="resources/target_lighting1.png", help="target lighting image ")
254+
parser.add_argument("--save_path", type=str, default="resources/shading.mp4", help="shading hints")
255+
parser.add_argument("--motion_align", type=str, default="relative", help="motion alignment mode")
256+
args = parser.parse_args()
257+
258+
iv.inference(source_path=args.source_path, light_path=args.light_path, video_path=args.video_path, save_path=args.save_path, motion_align=args.motion_align)
259+
260+

0 commit comments

Comments
 (0)
Please sign in to comment.