Skip to content

Commit 014eb22

Browse files
committed
video: code formatting
1 parent 63f0c9e commit 014eb22

File tree

1 file changed

+150
-87
lines changed

1 file changed

+150
-87
lines changed

dlclivegui/video.py

+150-87
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,32 @@
1515
from tqdm import tqdm
1616

1717

18-
def create_labeled_video(video_file,
19-
ts_file,
20-
dlc_file,
21-
out_dir=None,
22-
save_images=False,
23-
cut=(0, np.Inf),
24-
crop=None,
25-
cmap='bmy',
26-
radius=3,
27-
lik_thresh=0.5,
28-
write_ts=False,
29-
write_scale=2,
30-
display=False,
31-
progress=True,
32-
label=True):
18+
def create_labeled_video(
19+
data_dir,
20+
out_dir=None,
21+
dlc_online=True,
22+
save_images=False,
23+
cut=(0, np.Inf),
24+
crop=None,
25+
cmap="bmy",
26+
radius=3,
27+
lik_thresh=0.5,
28+
write_ts=False,
29+
write_scale=2,
30+
write_pos="bottom-left",
31+
write_ts_offset=0,
32+
display=False,
33+
progress=True,
34+
label=True,
35+
):
3336
""" Create a labeled video from DeepLabCut-live-GUI recording
3437
3538
Parameters
3639
----------
37-
video_file : str
38-
path to video file
39-
ts_file : str
40-
path to timestamps file
41-
dlc_file : str
42-
path to DeepLabCut file
40+
data_dir : str
41+
path to data directory
42+
dlc_online : bool, optional
43+
flag indicating dlc keypoints from online tracking, using DeepLabCut-live-GUI, or offline tracking, using :func:`dlclive.benchmark_videos`
4344
out_file : str, optional
4445
path for output file. If None, output file will be "'video_file'_LABELED.avi". by default None. If NOn
4546
save_images : bool, optional
@@ -63,50 +64,86 @@ def create_labeled_video(video_file,
6364
if frames cannot be read from the video file
6465
"""
6566

67+
base_dir = os.path.basename(data_dir)
68+
video_file = os.path.normpath(f"{data_dir}/{base_dir}_VIDEO.avi")
69+
ts_file = os.path.normpath(f"{data_dir}/{base_dir}_TS.npy")
70+
dlc_file = (
71+
os.path.normpath(f"{data_dir}/{base_dir}_DLC.hdf5")
72+
if dlc_online
73+
else os.path.normpath(f"{data_dir}/{base_dir}_VIDEO_DLCLIVE_POSES.h5")
74+
)
75+
6676
cap = cv2.VideoCapture(video_file)
6777
cam_frame_times = np.load(ts_file)
6878
n_frames = cam_frame_times.size
6979

70-
7180
lab = "LABELED" if label else "UNLABELED"
7281
if out_dir:
73-
out_file = f"{out_dir}/{os.path.splitext(os.path.basename(video_file))[0]}_{lab}.avi"
74-
out_times_file = f"{out_dir}/{os.path.splitext(os.path.basename(ts_file))[0]}_{lab}.npy"
82+
out_file = (
83+
f"{out_dir}/{os.path.splitext(os.path.basename(video_file))[0]}_{lab}.avi"
84+
)
85+
out_times_file = (
86+
f"{out_dir}/{os.path.splitext(os.path.basename(ts_file))[0]}_{lab}.npy"
87+
)
7588
else:
7689
out_file = f"{os.path.splitext(video_file)[0]}_{lab}.avi"
7790
out_times_file = f"{os.path.splitext(ts_file)[0]}_{lab}.npy"
7891

7992
os.makedirs(os.path.normpath(os.path.dirname(out_file)), exist_ok=True)
80-
93+
8194
if save_images:
8295
im_dir = os.path.splitext(out_file)[0]
8396
os.makedirs(im_dir, exist_ok=True)
8497

85-
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
98+
im_size = (
99+
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
100+
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
101+
)
86102
if crop is not None:
87-
crop = np.max(np.vstack((crop, [0, im_size[1], 0, im_size[0]])), axis=0)
88-
im_size = (crop[3]-crop[2], crop[1]-crop[0])
103+
crop[0] = crop[0] if crop[0] > 0 else 0
104+
crop[1] = crop[1] if crop[1] > 0 else im_size[1]
105+
crop[2] = crop[2] if crop[2] > 0 else 0
106+
crop[3] = crop[3] if crop[3] > 0 else im_size[0]
107+
im_size = (crop[3] - crop[2], crop[1] - crop[0])
89108

90-
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
109+
fourcc = cv2.VideoWriter_fourcc(*"DIVX")
91110
fps = cap.get(cv2.CAP_PROP_FPS)
92111
vwriter = cv2.VideoWriter(out_file, fourcc, fps, im_size)
93112
label_times = []
94-
113+
95114
if write_ts:
96115
ts_font = cv2.FONT_HERSHEY_PLAIN
97-
ts_w = 0 if crop is None else crop[0]
98-
ts_h = im_size[1] if crop is None else crop[1]
116+
117+
if "left" in write_pos:
118+
ts_w = 0
119+
else:
120+
ts_w = (
121+
im_size[0] if crop is None else (crop[3] - crop[2]) - (55 * write_scale)
122+
)
123+
124+
if "bottom" in write_pos:
125+
ts_h = im_size[1] if crop is None else (crop[1] - crop[0])
126+
else:
127+
ts_h = 0 if crop is None else crop[0] + (12 * write_scale)
128+
99129
ts_coord = (ts_w, ts_h)
100130
ts_color = (255, 255, 255)
101131
ts_size = 2
102132

103133
poses = pd.read_hdf(dlc_file)
104-
pose_times = poses['pose_time']
105-
poses = poses.melt(id_vars=['frame_time', 'pose_time'])
106-
bodyparts = poses['bodyparts'].unique()
134+
if dlc_online:
135+
pose_times = poses["pose_time"]
136+
else:
137+
poses["frame_time"] = cam_frame_times
138+
poses["pose_time"] = cam_frame_times
139+
poses = poses.melt(id_vars=["frame_time", "pose_time"])
140+
bodyparts = poses["bodyparts"].unique()
107141

108142
all_colors = getattr(cc, cmap)
109-
colors = [ImageColor.getcolor(c, "RGB")[::-1] for c in all_colors[::int(len(all_colors)/bodyparts.size)]]
143+
colors = [
144+
ImageColor.getcolor(c, "RGB")[::-1]
145+
for c in all_colors[:: int(len(all_colors) / bodyparts.size)]
146+
]
110147

111148
ind = 0
112149
vid_time = 0
@@ -116,47 +153,70 @@ def create_labeled_video(video_file,
116153
vid_time = cur_time - cam_frame_times[0]
117154
ret, frame = cap.read()
118155
ind += 1
119-
156+
120157
if not ret:
121-
raise Exception(f"Could not read frame = {ind+1} at time = {cur_time-cam_frame_times[0]}.")
122-
123-
124-
frame_times_sub = cam_frame_times[(cam_frame_times-cam_frame_times[0] > cut[0]) & (cam_frame_times-cam_frame_times[0] < cut[1])]
125-
iterator = tqdm(range(ind, ind+frame_times_sub.size)) if progress else range(ind, ind+frame_times_sub.size)
158+
raise Exception(
159+
f"Could not read frame = {ind+1} at time = {cur_time-cam_frame_times[0]}."
160+
)
161+
162+
frame_times_sub = cam_frame_times[
163+
(cam_frame_times - cam_frame_times[0] > cut[0])
164+
& (cam_frame_times - cam_frame_times[0] < cut[1])
165+
]
166+
iterator = (
167+
tqdm(range(ind, ind + frame_times_sub.size))
168+
if progress
169+
else range(ind, ind + frame_times_sub.size)
170+
)
126171
this_pose = np.zeros((bodyparts.size, 3))
127172

128173
for i in iterator:
129174

130175
cur_time = cam_frame_times[i]
131176
vid_time = cur_time - cam_frame_times[0]
132177
ret, frame = cap.read()
133-
178+
134179
if not ret:
135-
raise Exception(f"Could not read frame = {i+1} at time = {cur_time-cam_frame_times[0]}.")
180+
raise Exception(
181+
f"Could not read frame = {i+1} at time = {cur_time-cam_frame_times[0]}."
182+
)
136183

137-
poses_before_index = np.where(pose_times < cur_time)[0]
138-
if poses_before_index.size > 0:
139-
cur_pose_time = pose_times[poses_before_index[-1]]
140-
this_pose = poses[poses['pose_time']==cur_pose_time]
184+
if dlc_online:
185+
poses_before_index = np.where(pose_times < cur_time)[0]
186+
if poses_before_index.size > 0:
187+
cur_pose_time = pose_times[poses_before_index[-1]]
188+
this_pose = poses[poses["pose_time"] == cur_pose_time]
189+
else:
190+
this_pose = poses[poses["frame_time"] == cur_time]
141191

142192
if label:
143193
for j in range(bodyparts.size):
144-
this_bp = this_pose[this_pose['bodyparts'] == bodyparts[j]]['value'].values
194+
this_bp = this_pose[this_pose["bodyparts"] == bodyparts[j]][
195+
"value"
196+
].values
145197
if this_bp[2] > lik_thresh:
146198
x = int(this_bp[0])
147199
y = int(this_bp[1])
148200
frame = cv2.circle(frame, (x, y), radius, colors[j], thickness=-1)
149-
201+
150202
if crop is not None:
151-
frame = frame[crop[0]:crop[1], crop[2]:crop[3]]
203+
frame = frame[crop[0] : crop[1], crop[2] : crop[3]]
152204

153205
if write_ts:
154-
frame = cv2.putText(frame, f"{vid_time:0.3f}", ts_coord, ts_font, write_scale, ts_color, ts_size)
206+
frame = cv2.putText(
207+
frame,
208+
f"{(vid_time-write_ts_offset):0.3f}",
209+
ts_coord,
210+
ts_font,
211+
write_scale,
212+
ts_color,
213+
ts_size,
214+
)
155215

156216
if display:
157-
cv2.imshow('DLC Live Labeled Video', frame)
217+
cv2.imshow("DLC Live Labeled Video", frame)
158218
cv2.waitKey(1)
159-
219+
160220
vwriter.write(frame)
161221
label_times.append(cur_time)
162222
if save_images:
@@ -165,7 +225,7 @@ def create_labeled_video(video_file,
165225

166226
if display:
167227
cv2.destroyAllWindows()
168-
228+
169229
vwriter.release()
170230
np.save(out_times_file, label_times)
171231

@@ -176,37 +236,40 @@ def main():
176236
import os
177237

178238
parser = argparse.ArgumentParser()
179-
parser.add_argument('file', type=str)
180-
parser.add_argument('-o', '--out-dir', type=str, default=None)
181-
parser.add_argument('-s', '--save-images', action='store_true')
182-
parser.add_argument('-u', '--cut', nargs='+', type=float, default=[0, np.Inf])
183-
parser.add_argument('-c', '--crop', nargs='+', type=int, default=None)
184-
parser.add_argument('-m', '--cmap', type=str, default='bmy')
185-
parser.add_argument('-r', '--radius', type=int, default=3)
186-
parser.add_argument('-l', '--lik-thresh', type=float, default=0.5)
187-
parser.add_argument('-w', '--write-ts', action='store_true')
188-
parser.add_argument('--write-scale', type=int, default=2)
189-
parser.add_argument('-d', '--display', action='store_true')
190-
parser.add_argument('--no-progress', action='store_false')
191-
parser.add_argument('--no-label', action='store_false')
239+
parser.add_argument("dir", type=str)
240+
parser.add_argument("-o", "--out-dir", type=str, default=None)
241+
parser.add_argument("--dlc-offline", action="store_true")
242+
parser.add_argument("-s", "--save-images", action="store_true")
243+
parser.add_argument("-u", "--cut", nargs="+", type=float, default=[0, np.Inf])
244+
parser.add_argument("-c", "--crop", nargs="+", type=int, default=None)
245+
parser.add_argument("-m", "--cmap", type=str, default="bmy")
246+
parser.add_argument("-r", "--radius", type=int, default=3)
247+
parser.add_argument("-l", "--lik-thresh", type=float, default=0.5)
248+
parser.add_argument("-w", "--write-ts", action="store_true")
249+
parser.add_argument("--write-scale", type=int, default=2)
250+
parser.add_argument("--write-pos", type=str, default="bottom-left")
251+
parser.add_argument("--write-ts-offset", type=float, default=0.0)
252+
parser.add_argument("-d", "--display", action="store_true")
253+
parser.add_argument("--no-progress", action="store_false")
254+
parser.add_argument("--no-label", action="store_false")
192255
args = parser.parse_args()
193256

194-
vid_file = os.path.normpath(f"{args.file}_VIDEO.avi")
195-
ts_file = os.path.normpath(f"{args.file}_TS.npy")
196-
dlc_file = os.path.normpath(f"{args.file}_DLC.hdf5")
197-
198-
create_labeled_video(vid_file,
199-
ts_file,
200-
dlc_file,
201-
out_dir=args.out_dir,
202-
save_images=args.save_images,
203-
cut=tuple(args.cut),
204-
crop=args.crop,
205-
cmap=args.cmap,
206-
radius=args.radius,
207-
lik_thresh=args.lik_thresh,
208-
write_ts=args.write_ts,
209-
write_scale=args.write_scale,
210-
display=args.display,
211-
progress=args.no_progress,
212-
label=args.no_label)
257+
create_labeled_video(
258+
args.dir,
259+
out_dir=args.out_dir,
260+
dlc_online=(not args.dlc_offline),
261+
save_images=args.save_images,
262+
cut=tuple(args.cut),
263+
crop=args.crop,
264+
cmap=args.cmap,
265+
radius=args.radius,
266+
lik_thresh=args.lik_thresh,
267+
write_ts=args.write_ts,
268+
write_scale=args.write_scale,
269+
write_pos=args.write_pos,
270+
write_ts_offset=args.write_ts_offset,
271+
display=args.display,
272+
progress=args.no_progress,
273+
label=args.no_label,
274+
)
275+

0 commit comments

Comments
 (0)