forked from facebookresearch/meshtalk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanimate_face.py
101 lines (93 loc) · 3.38 KB
/
animate_face.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import argparse
import numpy as np
import torch as th
from utils.renderer import Renderer
from utils.helpers import smooth_geom, load_mask, get_template_verts, load_audio, audio_chunking
from models.vertex_unet import VertexUnet
from models.context_model import ContextModel
from models.encoders import MultimodalEncoder
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir",
type=str,
default="pretrained_models",
help="directory containing the models to load")
parser.add_argument("--audio_file",
type=str,
default="assets/example_sentence.wav",
help="wave file to use for face animation"
)
parser.add_argument("--face_template",
type=str,
default="assets/face_template.obj",
help=".obj file containing neutral template mesh"
)
parser.add_argument("--output",
type=str,
default="video.mp4",
help="video output file"
)
args = parser.parse_args()
"""
load assets
"""
print("load assets...")
template_verts = get_template_verts(args.face_template)
audio = load_audio(args.audio_file)
mean = th.from_numpy(np.load("assets/face_mean.npy"))
stddev = th.from_numpy(np.load("assets/face_std.npy"))
forehead_mask = th.from_numpy(load_mask("assets/forehead_mask.txt", dtype=np.float32)).cuda()
neck_mask = th.from_numpy(load_mask("assets/neck_mask.txt", dtype=np.float32)).cuda()
renderer = Renderer("assets/face_template.obj")
"""
load models
"""
print("load models...")
geom_unet = VertexUnet(classes=128,
heads=16,
n_vertices=6172,
mean=mean,
stddev=stddev,
)
geom_unet.load(args.model_dir)
geom_unet.cuda().eval()
context_model = ContextModel(classes=128,
heads=16,
audio_dim=128
)
context_model.load(args.model_dir)
context_model.cuda().eval()
encoder = MultimodalEncoder(classes=128,
heads=16,
expression_dim=128,
audio_dim=128,
n_vertices=6172,
mean=mean,
stddev=stddev,
)
encoder.load(args.model_dir)
encoder.cuda().eval()
"""
generate and render sequence
"""
print("animate face mesh...")
# run template mesh and audio through networks
audio = audio_chunking(audio, frame_rate=30, chunk_size=16000)
with th.no_grad():
audio_enc = encoder.audio_encoder(audio.cuda().unsqueeze(0))["code"]
one_hot = context_model.sample(audio_enc, argmax=False)["one_hot"]
T = one_hot.shape[1]
geom = template_verts.cuda().view(1, 1, 6172, 3).expand(-1, T, -1, -1).contiguous()
result = geom_unet(geom, one_hot)["geom"].squeeze(0)
# smooth results
result = smooth_geom(result, forehead_mask)
result = smooth_geom(result, neck_mask)
# render sequence
print("render...")
renderer.to_video(result, args.audio_file, args.output)
print("done")