|
| 1 | +"""run bash scripts/download_models.sh first to prepare the weights file""" |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +from argparse import Namespace |
| 5 | +from src.utils.preprocess import CropAndExtract |
| 6 | +from src.test_audio2coeff import Audio2Coeff |
| 7 | +from src.facerender.animate import AnimateFromCoeff |
| 8 | +from src.generate_batch import get_data |
| 9 | +from src.generate_facerender_batch import get_facerender_data |
| 10 | +from cog import BasePredictor, Input, Path |
| 11 | + |
| 12 | +checkpoints = "checkpoints" |
| 13 | + |
| 14 | + |
| 15 | +class Predictor(BasePredictor): |
| 16 | + def setup(self): |
| 17 | + """Load the model into memory to make running multiple predictions efficient""" |
| 18 | + device = "cuda" |
| 19 | + |
| 20 | + path_of_lm_croper = os.path.join( |
| 21 | + checkpoints, "shape_predictor_68_face_landmarks.dat" |
| 22 | + ) |
| 23 | + path_of_net_recon_model = os.path.join(checkpoints, "epoch_20.pth") |
| 24 | + dir_of_BFM_fitting = os.path.join(checkpoints, "BFM_Fitting") |
| 25 | + wav2lip_checkpoint = os.path.join(checkpoints, "wav2lip.pth") |
| 26 | + |
| 27 | + audio2pose_checkpoint = os.path.join(checkpoints, "auido2pose_00140-model.pth") |
| 28 | + audio2pose_yaml_path = os.path.join("src", "config", "auido2pose.yaml") |
| 29 | + |
| 30 | + audio2exp_checkpoint = os.path.join(checkpoints, "auido2exp_00300-model.pth") |
| 31 | + audio2exp_yaml_path = os.path.join("src", "config", "auido2exp.yaml") |
| 32 | + |
| 33 | + free_view_checkpoint = os.path.join( |
| 34 | + checkpoints, "facevid2vid_00189-model.pth.tar" |
| 35 | + ) |
| 36 | + |
| 37 | + # init model |
| 38 | + self.preprocess_model = CropAndExtract( |
| 39 | + path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device |
| 40 | + ) |
| 41 | + |
| 42 | + self.audio_to_coeff = Audio2Coeff( |
| 43 | + audio2pose_checkpoint, |
| 44 | + audio2pose_yaml_path, |
| 45 | + audio2exp_checkpoint, |
| 46 | + audio2exp_yaml_path, |
| 47 | + wav2lip_checkpoint, |
| 48 | + device, |
| 49 | + ) |
| 50 | + |
| 51 | + self.animate_from_coeff = { |
| 52 | + "full": AnimateFromCoeff( |
| 53 | + free_view_checkpoint, |
| 54 | + os.path.join(checkpoints, "mapping_00109-model.pth.tar"), |
| 55 | + os.path.join("src", "config", "facerender_still.yaml"), |
| 56 | + device, |
| 57 | + ), |
| 58 | + "others": AnimateFromCoeff( |
| 59 | + free_view_checkpoint, |
| 60 | + os.path.join(checkpoints, "mapping_00229-model.pth.tar"), |
| 61 | + os.path.join("src", "config", "facerender.yaml"), |
| 62 | + device, |
| 63 | + ), |
| 64 | + } |
| 65 | + |
| 66 | + def predict( |
| 67 | + self, |
| 68 | + source_image: Path = Input( |
| 69 | + description="Upload the source image, it can be video.mp4 or picture.png", |
| 70 | + ), |
| 71 | + driven_audio: Path = Input( |
| 72 | + description="Upload the driven audio, accepts .wav and .mp4 file", |
| 73 | + ), |
| 74 | + enhancer: str = Input( |
| 75 | + description="Choose a face enhancer", |
| 76 | + choices=["gfpgan", "RestoreFormer"], |
| 77 | + default="gfpgan", |
| 78 | + ), |
| 79 | + preprocess: str = Input( |
| 80 | + description="how to preprocess the images", |
| 81 | + choices=["crop", "resize", "full"], |
| 82 | + default="full", |
| 83 | + ), |
| 84 | + ref_eyeblink: Path = Input( |
| 85 | + description="path to reference video providing eye blinking", |
| 86 | + default=None, |
| 87 | + ), |
| 88 | + ref_pose: Path = Input( |
| 89 | + description="path to reference video providing pose", |
| 90 | + default=None, |
| 91 | + ), |
| 92 | + still: bool = Input( |
| 93 | + description="can crop back to the original videos for the full body aniamtion when preprocess is full", |
| 94 | + default=True, |
| 95 | + ), |
| 96 | + ) -> Path: |
| 97 | + """Run a single prediction on the model""" |
| 98 | + |
| 99 | + animate_from_coeff = ( |
| 100 | + self.animate_from_coeff["full"] |
| 101 | + if preprocess == "full" |
| 102 | + else self.animate_from_coeff["others"] |
| 103 | + ) |
| 104 | + |
| 105 | + args = load_default() |
| 106 | + args.pic_path = str(source_image) |
| 107 | + args.audio_path = str(driven_audio) |
| 108 | + device = "cuda" |
| 109 | + args.still = still |
| 110 | + args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink) |
| 111 | + args.ref_pose = None if ref_pose is None else str(ref_pose) |
| 112 | + |
| 113 | + # crop image and extract 3dmm from image |
| 114 | + results_dir = "results" |
| 115 | + if os.path.exists(results_dir): |
| 116 | + shutil.rmtree(results_dir) |
| 117 | + os.makedirs(results_dir) |
| 118 | + first_frame_dir = os.path.join(results_dir, "first_frame_dir") |
| 119 | + os.makedirs(first_frame_dir) |
| 120 | + |
| 121 | + print("3DMM Extraction for source image") |
| 122 | + first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( |
| 123 | + args.pic_path, first_frame_dir, preprocess, source_image_flag=True |
| 124 | + ) |
| 125 | + if first_coeff_path is None: |
| 126 | + print("Can't get the coeffs of the input") |
| 127 | + return |
| 128 | + |
| 129 | + if ref_eyeblink is not None: |
| 130 | + ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[ |
| 131 | + 0 |
| 132 | + ] |
| 133 | + ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname) |
| 134 | + os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) |
| 135 | + print("3DMM Extraction for the reference video providing eye blinking") |
| 136 | + ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate( |
| 137 | + ref_eyeblink, ref_eyeblink_frame_dir |
| 138 | + ) |
| 139 | + else: |
| 140 | + ref_eyeblink_coeff_path = None |
| 141 | + |
| 142 | + if ref_pose is not None: |
| 143 | + if ref_pose == ref_eyeblink: |
| 144 | + ref_pose_coeff_path = ref_eyeblink_coeff_path |
| 145 | + else: |
| 146 | + ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] |
| 147 | + ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname) |
| 148 | + os.makedirs(ref_pose_frame_dir, exist_ok=True) |
| 149 | + print("3DMM Extraction for the reference video providing pose") |
| 150 | + ref_pose_coeff_path, _, _ = self.preprocess_model.generate( |
| 151 | + ref_pose, ref_pose_frame_dir |
| 152 | + ) |
| 153 | + else: |
| 154 | + ref_pose_coeff_path = None |
| 155 | + |
| 156 | + # audio2ceoff |
| 157 | + batch = get_data( |
| 158 | + first_coeff_path, |
| 159 | + args.audio_path, |
| 160 | + device, |
| 161 | + ref_eyeblink_coeff_path, |
| 162 | + still=still, |
| 163 | + ) |
| 164 | + coeff_path = self.audio_to_coeff.generate( |
| 165 | + batch, results_dir, args.pose_style, ref_pose_coeff_path |
| 166 | + ) |
| 167 | + # coeff2video |
| 168 | + print("coeff2video") |
| 169 | + data = get_facerender_data( |
| 170 | + coeff_path, |
| 171 | + crop_pic_path, |
| 172 | + first_coeff_path, |
| 173 | + args.audio_path, |
| 174 | + args.batch_size, |
| 175 | + args.input_yaw, |
| 176 | + args.input_pitch, |
| 177 | + args.input_roll, |
| 178 | + expression_scale=args.expression_scale, |
| 179 | + still_mode=still, |
| 180 | + preprocess=preprocess, |
| 181 | + ) |
| 182 | + animate_from_coeff.generate( |
| 183 | + data, results_dir, args.pic_path, crop_info, |
| 184 | + enhancer=enhancer, background_enhancer=args.background_enhancer, |
| 185 | + preprocess=preprocess) |
| 186 | + |
| 187 | + output = "/tmp/out.mp4" |
| 188 | + mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0]) |
| 189 | + shutil.copy(mp4_path, output) |
| 190 | + |
| 191 | + return Path(output) |
| 192 | + |
| 193 | + |
| 194 | +def load_default(): |
| 195 | + return Namespace( |
| 196 | + pose_style=0, |
| 197 | + batch_size=2, |
| 198 | + expression_scale=1.0, |
| 199 | + input_yaw=None, |
| 200 | + input_pitch=None, |
| 201 | + input_roll=None, |
| 202 | + background_enhancer=None, |
| 203 | + face3dvis=False, |
| 204 | + net_recon="resnet50", |
| 205 | + init_path=None, |
| 206 | + use_last_fc=False, |
| 207 | + bfm_folder="./checkpoints/BFM_Fitting/", |
| 208 | + bfm_model="BFM_model_front.mat", |
| 209 | + focal=1015.0, |
| 210 | + center=112.0, |
| 211 | + camera_d=10.0, |
| 212 | + z_near=5.0, |
| 213 | + z_far=15.0, |
| 214 | + ) |
0 commit comments