Skip to content

Commit 69535a9

Browse files
authored
Add support for cropzoom pipeline (#250)
* Cropzoom PR * try re-enabling cz test * Revert "try re-enabling cz test" This reverts commit 8bb3bbf. * fixups * further docs * bump version
1 parent d1cd064 commit 69535a9

20 files changed

+646
-233
lines changed

cli/main.py

+185-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,15 @@
55
import os
66
import sys
77
from pathlib import Path
8+
from textwrap import dedent
89
from typing import TYPE_CHECKING
910

11+
from omegaconf import OmegaConf
12+
13+
# Don't import anything from torch or lightning_pose until needed.
14+
# These imports are slow and delay CLI help text outputs.
15+
# if TYPE_CHECKING allows use of imports for type annotations, without
16+
# actually invoking the import at runtime.
1017
if TYPE_CHECKING:
1118
from lightning_pose.model import Model
1219

@@ -45,6 +52,11 @@ def _build_parser():
4552
"If not specified, defaults to "
4653
"./outputs/{YYYY-MM-DD}/{HH:MM:SS}/",
4754
)
55+
train_parser.add_argument(
56+
"--detector_model",
57+
type=types.existing_model_dir,
58+
help="If specified, uses cropped training data in the detector model's directory.",
59+
)
4860
train_parser.add_argument(
4961
"--overrides",
5062
nargs="*",
@@ -94,13 +106,86 @@ def _build_parser():
94106
" uses the labels to compute pixel error.\n"
95107
" saves outputs to `image_preds/<csv_file_name>`\n",
96108
)
109+
predict_parser.add_argument(
110+
"--overrides",
111+
nargs="*",
112+
metavar="KEY=VALUE",
113+
help="overrides attributes of the config file. Uses hydra syntax:\n"
114+
"https://hydra.cc/docs/advanced/override_grammar/basic/",
115+
)
97116

98117
post_prediction_args = predict_parser.add_argument_group("post-prediction")
99118
post_prediction_args.add_argument(
100119
"--skip_viz",
101120
action="store_true",
102121
help="skip generating prediction-annotated images/videos",
103122
)
123+
124+
# Crop command
125+
crop_parser = subparsers.add_parser(
126+
"crop",
127+
description=dedent(
128+
"""\
129+
Crops a video or labeled frames based on model predictions.
130+
Requires model predictions to already have been generated using `litpose predict`.
131+
132+
Cropped videos are saved to:
133+
<model_dir>/
134+
└── video_preds/
135+
├── <video_filename>.csv (predictions)
136+
├── <video_filename>_bbox.csv (bbox)
137+
└── remapped_<video_filename>.csv (TODO move to remap command)
138+
└── cropped_videos/
139+
└── cropped_<video_filename>.mp4 (cropped video)
140+
141+
Cropped images are saved to:
142+
<model_dir>/
143+
└── image_preds/
144+
└── <csv_file_name>/
145+
├── predictions.csv
146+
├── bbox.csv (bbox)
147+
└── cropped_<csv_file_name>.csv (cropped labels)
148+
└── cropped_images/
149+
└── a/b/c/<image_name>.png (cropped images)\
150+
"""
151+
),
152+
usage="litpose crop <model_dir> <input_path:video|csv>... --crop_ratio=CROP_RATIO --anchor_keypoints=x,y,z",
153+
)
154+
crop_parser.add_argument(
155+
"model_dir", type=types.existing_model_dir, help="path to a model directory"
156+
)
157+
158+
crop_parser.add_argument(
159+
"input_path", type=Path, nargs="+", help="one or more files"
160+
)
161+
crop_parser.add_argument(
162+
"--crop_ratio",
163+
type=float,
164+
default=2.0,
165+
help="Crop a bounding box this much larger than the animal. Default is 2.",
166+
)
167+
crop_parser.add_argument(
168+
"--anchor_keypoints",
169+
type=str,
170+
default="", # Or a reasonable default like "0,0,0" if appropriate
171+
help="Comma-separated list of anchor keypoint names, defaults to all keypoints",
172+
)
173+
174+
remap_parser = subparsers.add_parser(
175+
"remap",
176+
description=dedent(
177+
"""\
178+
Remaps predictions from cropped to original coordinate space.
179+
Requires model predictions to already have been generated using `litpose predict`.
180+
181+
Remapped predictions are saved as "remapped_{preds_file}" in the same folder as preds_file.
182+
"""
183+
),
184+
usage="litpose remap <preds_file> <bbox_file>",
185+
)
186+
remap_parser.add_argument("preds_file", type=Path, help="path to a prediction file")
187+
remap_parser.add_argument("bbox_file", type=Path, help="path to a bbox file")
188+
104189
return parser
105190

106191

@@ -120,6 +205,84 @@ def main():
120205
elif args.command == "predict":
121206
_predict(args)
122207

208+
elif args.command == "crop":
209+
_crop(args)
210+
211+
elif args.command == "remap":
212+
_remap_preds(args)
213+
214+
215+
def _crop(args: argparse.Namespace):
216+
import lightning_pose.utils.cropzoom as cz
217+
from lightning_pose.model import Model
218+
219+
model_dir = args.model_dir
220+
model = Model.from_dir(model_dir)
221+
222+
# Make both cropped_images and cropped_videos dirs. Reason: After this, the user
223+
# will train a pose model, and current code in io utils checks that both
224+
# data_dir and videos_dir are present. if we just create one or the other,
225+
# the check will fail.
226+
model.cropped_data_dir().mkdir(parents=True, exist_ok=True)
227+
model.cropped_videos_dir().mkdir(parents=True, exist_ok=True)
228+
229+
input_paths = [Path(p) for p in args.input_path]
230+
231+
detector_cfg = OmegaConf.create(
232+
{
233+
"crop_ratio": args.crop_ratio,
234+
"anchor_keypoints": args.anchor_keypoints.split(",") if args.anchor_keypoints else [],
235+
}
236+
)
237+
assert detector_cfg.crop_ratio > 1
238+
239+
for input_path in input_paths:
240+
if input_path.suffix == ".mp4":
241+
input_preds_file = model.video_preds_dir() / (input_path.stem + ".csv")
242+
output_bbox_file = model.video_preds_dir() / (
243+
input_path.stem + "_bbox.csv"
244+
)
245+
output_file = model.cropped_videos_dir() / ("cropped_" + input_path.name)
246+
247+
cz.generate_cropped_video(
248+
input_video_file=input_path,
249+
input_preds_file=input_preds_file,
250+
detector_cfg=detector_cfg,
251+
output_bbox_file=output_bbox_file,
252+
output_file=output_file,
253+
)
254+
elif input_path.suffix == ".csv":
255+
preds_dir = model.image_preds_dir() / input_path.name
256+
input_data_dir = Path(model.config.cfg.data.data_dir)
257+
cropped_data_dir = model.cropped_data_dir()
258+
259+
output_bbox_file = preds_dir / "bbox.csv"
260+
output_csv_file_path = preds_dir / ("cropped_" + input_path.name)
261+
input_preds_file = preds_dir / "predictions.csv"
262+
cz.generate_cropped_labeled_frames(
263+
input_data_dir=input_data_dir,
264+
input_csv_file=input_path,
265+
input_preds_file=input_preds_file,
266+
detector_cfg=detector_cfg,
267+
output_data_dir=cropped_data_dir,
268+
output_bbox_file=output_bbox_file,
269+
output_csv_file=output_csv_file_path,
270+
)
271+
else:
272+
raise NotImplementedError("Only mp4 and csv files are supported.")
273+
274+
275+
def _remap_preds(args: argparse.Namespace):
276+
import lightning_pose.utils.cropzoom as cz
277+
278+
output_file = args.preds_file.with_name("remapped_" + args.preds_file.name)
279+
280+
cz.generate_cropped_csv_file(
281+
input_csv_file=args.preds_file,
282+
input_bbox_file=args.bbox_file,
283+
output_csv_file=output_file,
284+
)
285+
123286

124287
def _train(args: argparse.Namespace):
125288
import hydra
@@ -142,11 +305,32 @@ def _train(args: argparse.Namespace):
142305
cfg = hydra.compose(config_name=args.config_file.stem, overrides=args.overrides)
143306

144307
# Delay this import because it's slow.
308+
from lightning_pose.model import Model
145309
from lightning_pose.train import train
146310

147311
# TODO: Move some aspects of directory mgmt to the train function.
148312
output_dir.mkdir(parents=True, exist_ok=True)
149313
# Maintain legacy hydra chdir until downstream no longer depends on it.
314+
315+
if args.detector_model:
316+
# create detector model object before chdir so that relative path is resolved correctly
317+
detector_model = Model.from_dir(args.detector_model)
318+
import copy
319+
320+
cfg = copy.deepcopy(cfg)
321+
cfg.data.data_dir = str(detector_model.cropped_data_dir())
322+
cfg.data.video_dir = str(detector_model.cropped_videos_dir())
323+
if isinstance(cfg.data.csv_file, str):
324+
cfg.data.csv_file = str(
325+
detector_model.cropped_csv_file_path(cfg.data.csv_file)
326+
)
327+
else:
328+
cfg.data.csv_file = [
329+
str(detector_model.cropped_csv_file_path(f))
330+
for f in cfg.data.csv_file
331+
]
332+
cfg.eval.test_videos_directory = cfg.data.video_dir
333+
150334
os.chdir(output_dir)
151335
train(cfg)
152336

@@ -155,7 +339,7 @@ def _predict(args: argparse.Namespace):
155339
# Delay this import because it's slow.
156340
from lightning_pose.model import Model
157341

158-
model = Model.from_dir(args.model_dir)
342+
model = Model.from_dir2(args.model_dir, hydra_overrides=args.overrides)
159343
input_paths = [Path(p) for p in args.input_path]
160344

161345
for p in input_paths:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
generate_cropped_csv_file
2+
=========================
3+
4+
.. currentmodule:: lightning_pose.utils.cropzoom
5+
6+
.. autofunction:: generate_cropped_csv_file
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
calculate_steps_per_epoch
2+
=========================
3+
4+
.. currentmodule:: lightning_pose.utils.scripts
5+
6+
.. autofunction:: calculate_steps_per_epoch

docs/api/lightning_pose.utils.scripts.calculate_train_batches.rst

-6
This file was deleted.

docs/source/api.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ API Reference:
6767

6868
.. autoclass:: lightning_pose.model.Model
6969
:members:
70-
:exclude-members: __init__, PredictionResult, predict_on_label_csv_internal
70+
:exclude-members: __init__, PredictionResult, from_dir2
7171

7272

7373
Lightning Pose Internal API
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
########################
2+
Cropzoom pipeline
3+
########################
4+
5+
For setups where an animal is freely moving in a large arena,
6+
it's advantageous to crop around the animal before running pose estimation.
7+
Lightning-pose calls this technique "cropzoom". This document describes how
8+
to set up a such a pipeline.
9+
10+
Conceptual overview
11+
===================
12+
13+
A cropzoom pipeline consists of two lightning-pose models:
14+
a "detector model" and a "pose model".
15+
16+
* The detector model operates on the full image of the arena.
17+
* The pose model operates on the cropped animal.
18+
19+
These two models are trained and predicted like any other
20+
lightning pose model. We provide additional tools that help you compose these models:
21+
22+
* ``litpose crop``: Given the detector model's predictions, crops around the animal.
23+
* ``litpose remap``: Given the pose model's predictions and the crop bounding boxes,
24+
remaps the predictions to the original coordinate space.
25+
26+
Training
27+
--------
28+
29+
Training involves:
30+
31+
1. Train a "detector model"
32+
2. Crop training data for the pose model.
33+
3. Train a "pose model".
34+
35+
Inference
36+
---------
37+
38+
Inference involves:
39+
40+
1. Predict using the "detector model"
41+
2. Crop the data using the above predictions
42+
3. Predict on the cropped data using the "pose model".
43+
4. Remap the pose model's predictions to the original coordinate space.
44+
45+
46+
Example
47+
=======
48+
49+
This is a basic example of how you can setup a cropzoom pipeline.
50+
Paths to CSV and MP4 files below should be replaced with your files.
51+
The example is illustrative only. In reality you might be interested in
52+
making modifications to this such as:
53+
54+
1. Using different model type, backbone, image_resize_dims for
55+
your detector model and pose model. This can be accomplished using
56+
different config files for the detector and pose model.
57+
2. Limiting ``train_frames`` and ``max_epochs`` for testing purposes.
58+
59+
We'll use some bash variables to avoid repeating paths below:
60+
61+
.. code-block:: bash
62+
63+
MODEL_DIR=outputs/chickadee/cropzoom
64+
DETECTOR_MODEL=detector_0
65+
POSE_MODEL=pose_supervised_0
66+
67+
Training script
68+
---------------
69+
70+
.. code-block:: bash
71+
72+
#!/bin/bash
73+
74+
# Train the detector model.
75+
litpose train config.yaml --output_dir $MODEL_DIR/$DETECTOR_MODEL
76+
77+
# Crop data for pose model training.
78+
litpose crop $MODEL_DIR/$DETECTOR_MODEL data/CollectedData.csv
79+
80+
# Train the pose model.
81+
litpose train config.yaml --output_dir $MODEL_DIR/$POSE_MODEL \
82+
--detector_model=$MODEL_DIR/$DETECTOR_MODEL
83+
84+
Prediction on videos script
85+
---------------------------
86+
87+
.. code-block:: bash
88+
89+
#!/bin/bash
90+
91+
litpose predict $MODEL_DIR/$DETECTOR_MODEL data/videos/test_vid.short.mp4
92+
93+
litpose crop $MODEL_DIR/$DETECTOR_MODEL data/videos/test_vid.short.mp4
94+
95+
litpose predict $MODEL_DIR/$POSE_MODEL $MODEL_DIR/$DETECTOR_MODEL/cropped_videos/cropped_test_vid.short.mp4
96+
97+
litpose remap $MODEL_DIR/$POSE_MODEL/video_preds/cropped_TRQ177_200624_112234_lBack.short.csv \
98+
$MODEL_DIR/$DETECTOR_MODEL/video_preds/test_vid.short_bbox.csv
99+
100+
Prediction on OOD Labeled Data
101+
------------------------------
102+
103+
Say you have new labeled data for OoD animals, at `data/CollectedData_new.csv`,
104+
and you want to predict on these frames as well as compute pixel error.
105+
106+
.. code-block:: bash
107+
108+
#!/bin/bash
109+
110+
litpose predict $MODEL_DIR/$DETECTOR_MODEL data/CollectedData_new.csv
111+
112+
litpose crop $MODEL_DIR/$DETECTOR_MODEL data/CollectedData_new.csv
113+
114+
litpose predict $MODEL_DIR/$POSE_MODEL \
115+
$MODEL_DIR/$DETECTOR_MODEL/image_preds/CollectedData_new.csv/cropped_CollectedData_new.csv
116+
117+
litpose remap $MODEL_DIR/$POSE_MODEL/image_preds/cropped_CollectedData_new.csv/predictions.csv \
118+
$MODEL_DIR/$DETECTOR_MODEL/image_preds/CollectedData_new.csv/bbox.csv
119+
120+
Limitations
121+
===========
122+
123+
* Pose models do not yet support PCA Multiview loss.

0 commit comments

Comments
 (0)