5
5
import os
6
6
import sys
7
7
from pathlib import Path
8
+ from textwrap import dedent
8
9
from typing import TYPE_CHECKING
9
10
11
+ from omegaconf import OmegaConf
12
+
13
+ # Don't import anything from torch or lightning_pose until needed.
14
+ # These imports are slow and delay CLI help text outputs.
15
+ # if TYPE_CHECKING allows use of imports for type annotations, without
16
+ # actually invoking the import at runtime.
10
17
if TYPE_CHECKING :
11
18
from lightning_pose .model import Model
12
19
@@ -45,6 +52,11 @@ def _build_parser():
45
52
"If not specified, defaults to "
46
53
"./outputs/{YYYY-MM-DD}/{HH:MM:SS}/" ,
47
54
)
55
+ train_parser .add_argument (
56
+ "--detector_model" ,
57
+ type = types .existing_model_dir ,
58
+ help = "If specified, uses cropped training data in the detector model's directory." ,
59
+ )
48
60
train_parser .add_argument (
49
61
"--overrides" ,
50
62
nargs = "*" ,
@@ -94,13 +106,86 @@ def _build_parser():
94
106
" uses the labels to compute pixel error.\n "
95
107
" saves outputs to `image_preds/<csv_file_name>`\n " ,
96
108
)
109
+ predict_parser .add_argument (
110
+ "--overrides" ,
111
+ nargs = "*" ,
112
+ metavar = "KEY=VALUE" ,
113
+ help = "overrides attributes of the config file. Uses hydra syntax:\n "
114
+ "https://hydra.cc/docs/advanced/override_grammar/basic/" ,
115
+ )
97
116
98
117
post_prediction_args = predict_parser .add_argument_group ("post-prediction" )
99
118
post_prediction_args .add_argument (
100
119
"--skip_viz" ,
101
120
action = "store_true" ,
102
121
help = "skip generating prediction-annotated images/videos" ,
103
122
)
123
+
124
+ # Crop command
125
+ crop_parser = subparsers .add_parser (
126
+ "crop" ,
127
+ description = dedent (
128
+ """\
129
+ Crops a video or labeled frames based on model predictions.
130
+ Requires model predictions to already have been generated using `litpose predict`.
131
+
132
+ Cropped videos are saved to:
133
+ <model_dir>/
134
+ └── video_preds/
135
+ ├── <video_filename>.csv (predictions)
136
+ ├── <video_filename>_bbox.csv (bbox)
137
+ └── remapped_<video_filename>.csv (TODO move to remap command)
138
+ └── cropped_videos/
139
+ └── cropped_<video_filename>.mp4 (cropped video)
140
+
141
+ Cropped images are saved to:
142
+ <model_dir>/
143
+ └── image_preds/
144
+ └── <csv_file_name>/
145
+ ├── predictions.csv
146
+ ├── bbox.csv (bbox)
147
+ └── cropped_<csv_file_name>.csv (cropped labels)
148
+ └── cropped_images/
149
+ └── a/b/c/<image_name>.png (cropped images)\
150
+ """
151
+ ),
152
+ usage = "litpose crop <model_dir> <input_path:video|csv>... --crop_ratio=CROP_RATIO --anchor_keypoints=x,y,z" ,
153
+ )
154
+ crop_parser .add_argument (
155
+ "model_dir" , type = types .existing_model_dir , help = "path to a model directory"
156
+ )
157
+
158
+ crop_parser .add_argument (
159
+ "input_path" , type = Path , nargs = "+" , help = "one or more files"
160
+ )
161
+ crop_parser .add_argument (
162
+ "--crop_ratio" ,
163
+ type = float ,
164
+ default = 2.0 ,
165
+ help = "Crop a bounding box this much larger than the animal. Default is 2." ,
166
+ )
167
+ crop_parser .add_argument (
168
+ "--anchor_keypoints" ,
169
+ type = str ,
170
+ default = "" , # Or a reasonable default like "0,0,0" if appropriate
171
+ help = "Comma-separated list of anchor keypoint names, defaults to all keypoints" ,
172
+ )
173
+
174
+ remap_parser = subparsers .add_parser (
175
+ "remap" ,
176
+ description = dedent (
177
+ """\
178
+ Remaps predictions from cropped to original coordinate space.
179
+ Requires model predictions to already have been generated using `litpose predict`.
180
+
181
+ Remapped predictions are saved as "remapped_{preds_file}" in the same folder as preds_file.
182
+ """
183
+ ),
184
+ usage = "litpose remap <preds_file> <bbox_file>" ,
185
+ )
186
+ remap_parser .add_argument ("preds_file" , type = Path , help = "path to a prediction file" )
187
+ remap_parser .add_argument ("bbox_file" , type = Path , help = "path to a bbox file" )
188
+
104
189
return parser
105
190
106
191
@@ -120,6 +205,84 @@ def main():
120
205
elif args .command == "predict" :
121
206
_predict (args )
122
207
208
+ elif args .command == "crop" :
209
+ _crop (args )
210
+
211
+ elif args .command == "remap" :
212
+ _remap_preds (args )
213
+
214
+
215
+ def _crop (args : argparse .Namespace ):
216
+ import lightning_pose .utils .cropzoom as cz
217
+ from lightning_pose .model import Model
218
+
219
+ model_dir = args .model_dir
220
+ model = Model .from_dir (model_dir )
221
+
222
+ # Make both cropped_images and cropped_videos dirs. Reason: After this, the user
223
+ # will train a pose model, and current code in io utils checks that both
224
+ # data_dir and videos_dir are present. if we just create one or the other,
225
+ # the check will fail.
226
+ model .cropped_data_dir ().mkdir (parents = True , exist_ok = True )
227
+ model .cropped_videos_dir ().mkdir (parents = True , exist_ok = True )
228
+
229
+ input_paths = [Path (p ) for p in args .input_path ]
230
+
231
+ detector_cfg = OmegaConf .create (
232
+ {
233
+ "crop_ratio" : args .crop_ratio ,
234
+ "anchor_keypoints" : args .anchor_keypoints .split ("," ) if args .anchor_keypoints else [],
235
+ }
236
+ )
237
+ assert detector_cfg .crop_ratio > 1
238
+
239
+ for input_path in input_paths :
240
+ if input_path .suffix == ".mp4" :
241
+ input_preds_file = model .video_preds_dir () / (input_path .stem + ".csv" )
242
+ output_bbox_file = model .video_preds_dir () / (
243
+ input_path .stem + "_bbox.csv"
244
+ )
245
+ output_file = model .cropped_videos_dir () / ("cropped_" + input_path .name )
246
+
247
+ cz .generate_cropped_video (
248
+ input_video_file = input_path ,
249
+ input_preds_file = input_preds_file ,
250
+ detector_cfg = detector_cfg ,
251
+ output_bbox_file = output_bbox_file ,
252
+ output_file = output_file ,
253
+ )
254
+ elif input_path .suffix == ".csv" :
255
+ preds_dir = model .image_preds_dir () / input_path .name
256
+ input_data_dir = Path (model .config .cfg .data .data_dir )
257
+ cropped_data_dir = model .cropped_data_dir ()
258
+
259
+ output_bbox_file = preds_dir / "bbox.csv"
260
+ output_csv_file_path = preds_dir / ("cropped_" + input_path .name )
261
+ input_preds_file = preds_dir / "predictions.csv"
262
+ cz .generate_cropped_labeled_frames (
263
+ input_data_dir = input_data_dir ,
264
+ input_csv_file = input_path ,
265
+ input_preds_file = input_preds_file ,
266
+ detector_cfg = detector_cfg ,
267
+ output_data_dir = cropped_data_dir ,
268
+ output_bbox_file = output_bbox_file ,
269
+ output_csv_file = output_csv_file_path ,
270
+ )
271
+ else :
272
+ raise NotImplementedError ("Only mp4 and csv files are supported." )
273
+
274
+
275
+ def _remap_preds (args : argparse .Namespace ):
276
+ import lightning_pose .utils .cropzoom as cz
277
+
278
+ output_file = args .preds_file .with_name ("remapped_" + args .preds_file .name )
279
+
280
+ cz .generate_cropped_csv_file (
281
+ input_csv_file = args .preds_file ,
282
+ input_bbox_file = args .bbox_file ,
283
+ output_csv_file = output_file ,
284
+ )
285
+
123
286
124
287
def _train (args : argparse .Namespace ):
125
288
import hydra
@@ -142,11 +305,32 @@ def _train(args: argparse.Namespace):
142
305
cfg = hydra .compose (config_name = args .config_file .stem , overrides = args .overrides )
143
306
144
307
# Delay this import because it's slow.
308
+ from lightning_pose .model import Model
145
309
from lightning_pose .train import train
146
310
147
311
# TODO: Move some aspects of directory mgmt to the train function.
148
312
output_dir .mkdir (parents = True , exist_ok = True )
149
313
# Maintain legacy hydra chdir until downstream no longer depends on it.
314
+
315
+ if args .detector_model :
316
+ # create detector model object before chdir so that relative path is resolved correctly
317
+ detector_model = Model .from_dir (args .detector_model )
318
+ import copy
319
+
320
+ cfg = copy .deepcopy (cfg )
321
+ cfg .data .data_dir = str (detector_model .cropped_data_dir ())
322
+ cfg .data .video_dir = str (detector_model .cropped_videos_dir ())
323
+ if isinstance (cfg .data .csv_file , str ):
324
+ cfg .data .csv_file = str (
325
+ detector_model .cropped_csv_file_path (cfg .data .csv_file )
326
+ )
327
+ else :
328
+ cfg .data .csv_file = [
329
+ str (detector_model .cropped_csv_file_path (f ))
330
+ for f in cfg .data .csv_file
331
+ ]
332
+ cfg .eval .test_videos_directory = cfg .data .video_dir
333
+
150
334
os .chdir (output_dir )
151
335
train (cfg )
152
336
@@ -155,7 +339,7 @@ def _predict(args: argparse.Namespace):
155
339
# Delay this import because it's slow.
156
340
from lightning_pose .model import Model
157
341
158
- model = Model .from_dir (args .model_dir )
342
+ model = Model .from_dir2 (args .model_dir , hydra_overrides = args . overrides )
159
343
input_paths = [Path (p ) for p in args .input_path ]
160
344
161
345
for p in input_paths :
0 commit comments