Skip to content

Commit 284a2f0

Browse files
authored
Samplers tutorial (#341)
1 parent d016842 commit 284a2f0

File tree

6 files changed

+262
-7
lines changed

6 files changed

+262
-7
lines changed

docs/source/api_ref_decoders.rst

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ torchcodec.decoders
77
.. currentmodule:: torchcodec.decoders
88

99

10+
For a tutorial, see: :ref:`sphx_glr_generated_examples_basic_example.py`.
11+
12+
1013
.. autosummary::
1114
:toctree: generated/
1215
:nosignatures:

docs/source/api_ref_samplers.rst

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ torchcodec.samplers
66

77
.. currentmodule:: torchcodec.samplers
88

9+
For a tutorial, see: :ref:`sphx_glr_generated_examples_sampling.py`.
910

1011
.. autosummary::
1112
:toctree: generated/

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
sphinx_gallery_conf = {
6161
"examples_dirs": "../../examples/", # path to your example scripts
6262
"gallery_dirs": "generated_examples", # path to where to save gallery generated output
63-
"filename_pattern": "basic*",
63+
"filename_pattern": ".py",
6464
"backreferences_dir": "gen_modules/backreferences",
6565
"doc_module": ("torchcodec",),
6666
"remove_config_comments": True,

docs/source/index.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ We achieve these capabilities through:
4343
A simple video decoding example
4444

4545
.. grid-item-card:: :octicon:`file-code;1em`
46-
API Reference
46+
Clip sampling
4747
:img-top: _static/img/card-background.svg
48-
:link: api_ref_torchcodec.html
48+
:link: generated_examples/sampling.html
4949
:link-type: url
5050

51-
The API reference for TorchCodec
51+
How to sample video clips
5252

5353
.. toctree::
5454
:maxdepth: 1

examples/sampling.py

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
=========================
9+
How to sample video clips
10+
=========================
11+
12+
In this example, we'll learn how to sample video :term:`clips` from a video. A
13+
clip generally denotes a sequence or batch of frames, and is typically passed as
14+
input to video models.
15+
"""
16+
17+
# %%
18+
# First, a bit of boilerplate: we'll download a video from the web, and define a
19+
# plotting utility. You can ignore that part and jump right below to
20+
# :ref:`sampling_tuto_start`.
21+
22+
from typing import Optional
23+
import torch
24+
import requests
25+
26+
27+
# Video source: https://www.pexels.com/video/dog-eating-854132/
28+
# License: CC0. Author: Coverr.
29+
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
30+
response = requests.get(url, headers={"User-Agent": ""})
31+
if response.status_code != 200:
32+
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
33+
34+
raw_video_bytes = response.content
35+
36+
37+
def plot(frames: torch.Tensor, title : Optional[str] = None):
38+
try:
39+
from torchvision.utils import make_grid
40+
from torchvision.transforms.v2.functional import to_pil_image
41+
import matplotlib.pyplot as plt
42+
except ImportError:
43+
print("Cannot plot, please run `pip install torchvision matplotlib`")
44+
return
45+
46+
plt.rcParams["savefig.bbox"] = 'tight'
47+
fig, ax = plt.subplots()
48+
ax.imshow(to_pil_image(make_grid(frames)))
49+
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
50+
if title is not None:
51+
ax.set_title(title)
52+
plt.tight_layout()
53+
54+
55+
# %%
56+
# .. _sampling_tuto_start:
57+
#
58+
# Creating a decoder
59+
# ------------------
60+
#
61+
# Sampling clips from a video always starts by creating a
62+
# :class:`~torchcodec.decoders.VideoDecoder` object. If you're not already
63+
# familiar with :class:`~torchcodec.decoders.VideoDecoder`, take a quick look
64+
# at: :ref:`sphx_glr_generated_examples_basic_example.py`.
65+
from torchcodec.decoders import VideoDecoder
66+
67+
# You can also pass a path to a local file!
68+
decoder = VideoDecoder(raw_video_bytes)
69+
70+
# %%
71+
# Sampling basics
72+
# ---------------
73+
#
74+
# We can now use our decoder to sample clips. Let's first look at a simple
75+
# example: all other samplers follow similar APIs and principles. We'll use
76+
# :func:`~torchcodec.samplers.clips_at_random_indices` to sample clips that
77+
# start at random indices.
78+
79+
from torchcodec.samplers import clips_at_random_indices
80+
81+
# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
82+
# tutorial to be reproducible across runs, but note that hard-coding a seed for
83+
# a training run is generally not recommended.
84+
torch.manual_seed(0)
85+
86+
clips = clips_at_random_indices(
87+
decoder,
88+
num_clips=5,
89+
num_frames_per_clip=4,
90+
num_indices_between_frames=3,
91+
)
92+
clips
93+
94+
# %%
95+
# The output of the sampler is a sequence of clips, represented as
96+
# :class:`~torchcodec.FrameBatch` object. In this object, we have different
97+
# fields:
98+
#
99+
# - ``data``: a 5D uint8 tensor representing the frame data. Its shape is
100+
# (num_clips, num_frames_per_clip, ...) where ... is either (C, H, W) or (H,
101+
# W, C), depending on the ``dimension_order`` parameter of the
102+
# :class:`~torchcodec.decoders.VideoDecoder`. This is typically what would get
103+
# passed to the model.
104+
# - ``pts_seconds``: a 2D float tensor of shape (num_clips, num_frames_per_clip)
105+
# giving the starting timestamps of each frame within each clip, in seconds.
106+
# - ``duration_seconds``: a 2D float tensor of shape (num_clips,
107+
# num_frames_per_clip) giving the duration of each frame within each clip, in
108+
# seconds.
109+
110+
plot(clips[0].data)
111+
112+
# %%
113+
# Indexing and manipulating clips
114+
# -------------------------------
115+
#
116+
# Clips are :class:`~torchcodec.FrameBatch` objects, and they support native
117+
# pytorch indexing semantics (including fancy indexing). This makes it easy to
118+
# filter clips based on a given criteria. For example, from the clips above we
119+
# can easily filter out those who start *after* a specific timestamp:
120+
clip_starts = clips.pts_seconds[:, 0]
121+
clip_starts
122+
123+
# %%
124+
clips_starting_after_five_seconds = clips[clip_starts > 5]
125+
clips_starting_after_five_seconds
126+
127+
# %%
128+
every_other_clip = clips[::2]
129+
every_other_clip
130+
131+
# %%
132+
#
133+
# .. note::
134+
# A more natural and efficient way to get clips after a given timestamp is to
135+
# rely on the sampling range parameters, which we'll cover later in :ref:`sampling_range`.
136+
#
137+
# Index-based and Time-based samplers
138+
# -----------------------------------
139+
#
140+
# So far we've used :func:`~torchcodec.samplers.clips_at_random_indices`.
141+
# Torchcodec support additional samplers, which fall under two main categories:
142+
#
143+
# Index-based samplers:
144+
#
145+
# - :func:`~torchcodec.samplers.clips_at_random_indices`
146+
# - :func:`~torchcodec.samplers.clips_at_regular_indices`
147+
#
148+
# Time-based samplers:
149+
#
150+
# - :func:`~torchcodec.samplers.clips_at_random_timestamps`
151+
# - :func:`~torchcodec.samplers.clips_at_regular_timestamps`
152+
#
153+
# All these samplers follow similar APIs and the time-based samplers have
154+
# analogous parameters to the index-based ones. Both samplers types generally
155+
# offer comparable performance in terms speed.
156+
#
157+
# .. note::
158+
# Is it better to use a time-based sampler or an index-based sampler? The
159+
# index-based samplers have arguably slightly simpler APIs and their behavior
160+
# is possibly simpler to understand and control, because of the discrete
161+
# nature of indices. For videos with constant fps, an index-based sampler
162+
# behaves exactly the same as a time-based samplers. For videos with variable
163+
# fps however (as is often the case), relying on indices may under/over sample
164+
# some regions in the video, which may lead to undersirable side effects when
165+
# training a model. Using a time-based sampler ensures uniform sampling
166+
# caracteristics along the time-dimension.
167+
#
168+
169+
# %%
170+
# .. _sampling_range:
171+
#
172+
# Advanced parameters: sampling range
173+
# -----------------------------------
174+
#
175+
# Sometimes, we may not want to sample clips from an entire video. We may only
176+
# be interested in clips that start within a smaller interval. In samplers, the
177+
# ``sampling_range_start`` and ``sampling_range_end`` parmeters control the
178+
# sampling range: they define where we allow clips to *start*. There are two
179+
# important things to keep in mind:
180+
#
181+
# - ``sampling_range_end`` is an *open* upper-bound: clips may only start within
182+
# [sampling_range_start, sampling_range_end).
183+
# - Because these parameter define where a clip can start, clips may contain
184+
# frames *after* ``sampling_range_end``!
185+
186+
from torchcodec.samplers import clips_at_regular_timestamps
187+
188+
clips = clips_at_regular_timestamps(
189+
decoder,
190+
seconds_between_clip_starts=1,
191+
num_frames_per_clip=4,
192+
seconds_between_frames=0.5,
193+
sampling_range_start=2,
194+
sampling_range_end=5
195+
)
196+
clips
197+
198+
# %%
199+
# Advanced parameters: policy
200+
# ---------------------------
201+
#
202+
# Depending on the length or duration of the video and on the sampling
203+
# parameters, the sampler may try to sample frames *beyond* the end of the
204+
# video. The ``policy`` parameter defines how such invalid frames should be
205+
# replaced with valid
206+
# frames.
207+
from torchcodec.samplers import clips_at_random_timestamps
208+
209+
end_of_video = decoder.metadata.end_stream_seconds
210+
print(f"{end_of_video = }")
211+
212+
# %%
213+
torch.manual_seed(0)
214+
clips = clips_at_random_timestamps(
215+
decoder,
216+
num_clips=1,
217+
num_frames_per_clip=5,
218+
seconds_between_frames=0.4,
219+
sampling_range_start=end_of_video - 1,
220+
sampling_range_end=end_of_video,
221+
policy="repeat_last",
222+
)
223+
clips.pts_seconds
224+
225+
# %%
226+
# We see above that the end of the video is at 13.8s. The sampler tries to
227+
# sample frames at timestamps [13.28, 13.68, 14.08, ...] but 14.08 is an invalid
228+
# timestamp, beyond the end video. With the "repeat_last" policy, which is the
229+
# default, the sampler simply repeats the last frame at 13.68 seconds to
230+
# construct the clip.
231+
#
232+
# An alternative policy is "wrap": the sampler then wraps-around the clip and repeats the first few valid frames as necessary:
233+
234+
torch.manual_seed(0)
235+
clips = clips_at_random_timestamps(
236+
decoder,
237+
num_clips=1,
238+
num_frames_per_clip=5,
239+
seconds_between_frames=0.4,
240+
sampling_range_start=end_of_video - 1,
241+
sampling_range_end=end_of_video,
242+
policy="wrap",
243+
)
244+
clips.pts_seconds
245+
246+
# %%
247+
# By default, the value of ``sampling_range_end`` is automatically set such that
248+
# the sampler *doesn't* try to sample frames beyond the end of the video: the
249+
# default value ensures that clips start early enough before the end. This means
250+
# that by default, the policy parameter rarely comes into action, and most users
251+
# probably don't need to worry too much about it.

src/torchcodec/_frame.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ class FrameBatch(Iterable):
5959
"""Multiple video frames with associated metadata.
6060
6161
The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW),
62-
or 5D for sequences of clips, as returned by the :ref:`samplers <samplers>`.
63-
When ``data`` is 4D (resp. 5D) the ``pts_seconds`` and ``duration_seconds``
64-
tensors are 1D (resp. 2D).
62+
or 5D for sequences of clips, as returned by the :ref:`samplers
63+
<sphx_glr_generated_examples_sampling.py>`. When ``data`` is 4D (resp. 5D)
64+
the ``pts_seconds`` and ``duration_seconds`` tensors are 1D (resp. 2D).
6565
"""
6666

6767
data: Tensor

0 commit comments

Comments
 (0)