-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtracker_video.py
187 lines (145 loc) · 6.51 KB
/
tracker_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
###########################################
## FLAME Video Tracker. #
## -------------------------------------- #
## Author: Peizhi Yan #
## Update: 11/12/2024 #
###########################################
#
# Copyright (C) Peizhi Yan. 2024
#
import os
import numpy as np
import cv2
from tqdm import tqdm
import pickle
import torch
import random
from utils.video_utils import video_to_images
from tracker_base import Tracker
def sample_frames(frames, N):
# If N is greater than the number of frames, return the entire list
if N >= len(frames):
return frames
# Otherwise, return a random sample of N frames
return random.sample(frames, N)
def track_video_legacy(tracker_cfg):
# load video frames to images
frames = video_to_images(video_path = tracker_cfg['video_path'],
original_fps = tracker_cfg['original_fps'],
subsample_fps = tracker_cfg['subsample_fps'])
video_path = tracker_cfg['video_path']
save_path = tracker_cfg['save_path']
video_base_name = os.path.basename(video_path)
video_name = video_base_name.split('.')[0] # remove the name extension
result_save_path = os.path.join(save_path, video_name)
if not os.path.exists(result_save_path):
os.makedirs(result_save_path) # create the output path if not exists
###########################
## Setup Flame Tracker #
###########################
tracker = Tracker(tracker_cfg)
###########################
## Estimate Global Shape #
###########################
print('Estimating global shape code')
MAX_SAMPLE_SIZE = 100
frames_subset = sample_frames(frames, MAX_SAMPLE_SIZE) # we take a subset of frames to estimate the global shape code
with torch.no_grad():
mean_shape_code = torch.zeros([1,100], dtype=torch.float32).to(tracker.device)
counter = 0
for i in tqdm(range(len(frames_subset))):
img = frames_subset[i]
deca_dict = tracker.run_deca(img) # run DECA reconstruction
if deca_dict is not None:
mean_shape_code += deca_dict['shape']
counter += 1
if counter == 0:
mean_shape_code = None
else:
mean_shape_code /= counter # compute the average shape code
#######################
# process all frames #
#######################
print(f'Processing video: {video_path}')
prev_ret_dict = None
for fid in tqdm(range(len(frames))):
# # Skip processed files (optional)
# if os.path.exists(os.path.join(result_save_path, f'{fid}_compare.jpg')):
# prev_ret_dict = None
# continue
# fit on the current frame
ret_dict = tracker.run(img=frames[fid], realign=True, prev_ret_dict=prev_ret_dict, shape_code=mean_shape_code)
prev_ret_dict = ret_dict
if ret_dict is None:
continue
# save
save_file_path = os.path.join(result_save_path, f'{fid}.npy')
with open(save_file_path, 'wb') as f:
pickle.dump(ret_dict, f)
# check result: reconstruct from saved parameters and save the visualization results
with torch.no_grad():
with open(save_file_path, 'rb') as f:
loaded_params = pickle.load(f)
img = ret_dict['img']
result_img = np.zeros([256, 2*256, 3], dtype=np.uint8)
# GT image
gt_img = cv2.resize(np.asarray(img), (256,256))
gt_img = np.clip(np.array(gt_img, dtype=np.uint8), 0, 255) # uint8
gt_img = cv2.cvtColor(gt_img, cv2.COLOR_RGB2BGR)
result_img[:,:256,:] = gt_img
# rendered with texture but canonical camera pose
rendered = np.clip(cv2.resize(loaded_params['img_rendered'], (256,256)), 0, 255)
rendered = cv2.cvtColor(rendered, cv2.COLOR_RGB2BGR)
result_img[:,256:256*2,:] = rendered
cv2.imwrite(os.path.join(result_save_path, f'{fid}_compare.jpg'), result_img)
def track_video(tracker_cfg):
# load video frames to images
frames = video_to_images(video_path = tracker_cfg['video_path'],
original_fps = tracker_cfg['original_fps'],
subsample_fps = tracker_cfg['subsample_fps'])
video_path = tracker_cfg['video_path']
save_path = tracker_cfg['save_path']
video_base_name = os.path.basename(video_path)
video_name = video_base_name.split('.')[0] # remove the name extension
result_save_path = os.path.join(save_path, video_name)
if not os.path.exists(result_save_path):
os.makedirs(result_save_path) # create the output path if not exists
###########################
## Setup Flame Tracker #
###########################
tracker = Tracker(tracker_cfg)
# frames = frames[:30] # for debugging only
#######################
# process all frames #
#######################
print(f'Processing video: {video_path}')
ret_dict_all = tracker.run_all_images(imgs=frames, realign=True)
#################
# save results #
#################
print(f'Saving results: {video_path}')
NUM_OF_RESULTS = len(ret_dict_all['shape'])
for fid in tqdm(range(NUM_OF_RESULTS)):
ret_dict = {}
for key in ret_dict_all.keys():
ret_dict[key] = ret_dict_all[key][fid]
# save
save_file_path = os.path.join(result_save_path, f'{fid}.npy')
with open(save_file_path, 'wb') as f:
pickle.dump(ret_dict, f)
# check result: reconstruct from saved parameters and save the visualization results
with torch.no_grad():
with open(save_file_path, 'rb') as f:
loaded_params = pickle.load(f)
img = ret_dict['img']
result_img = np.zeros([256, 2*256, 3], dtype=np.uint8)
# GT image
gt_img = cv2.resize(np.asarray(img), (256,256))
gt_img = np.clip(np.array(gt_img, dtype=np.uint8), 0, 255) # uint8
gt_img = cv2.cvtColor(gt_img, cv2.COLOR_RGB2BGR)
result_img[:,:256,:] = gt_img
# rendered with texture but canonical camera pose
rendered = np.clip(cv2.resize(loaded_params['img_rendered'], (256,256)), 0, 255)
rendered = cv2.cvtColor(rendered, cv2.COLOR_RGB2BGR)
result_img[:,256:256*2,:] = rendered
cv2.imwrite(os.path.join(result_save_path, f'{fid}_compare.jpg'), result_img)