Skip to content

Commit 42fa254

Browse files
authored
Add docstrings for samplers (#338)
1 parent 39463b8 commit 42fa254

File tree

7 files changed

+191
-8
lines changed

7 files changed

+191
-8
lines changed

docs/source/api_ref_samplers.rst

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
.. _samplers:
2+
3+
===================
4+
torchcodec.samplers
5+
===================
6+
7+
.. currentmodule:: torchcodec.samplers
8+
9+
10+
.. autosummary::
11+
:toctree: generated/
12+
:nosignatures:
13+
:template: function.rst
14+
15+
clips_at_regular_indices
16+
clips_at_random_indices
17+
clips_at_regular_timestamps
18+
clips_at_random_timestamps

docs/source/glossary.rst

+7
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,10 @@ Glossary
1717
A scan corresponds to an entire pass over a video file, with the purpose
1818
of retrieving metadata about the different streams and frames. **It does
1919
not involve decoding**, so it is a lot cheaper than decoding the file.
20+
21+
clips
22+
A clip is a sequence of frames, usually in :term:`pts` order. The frames
23+
may not necessarily be consecutive. A clip is represented as a 4D
24+
:class:`~torchcodec.FrameBatch`. A group of clips, which is what the
25+
:ref:`samplers <samplers>` return, is represented as 5D
26+
:class:`~torchcodec.FrameBatch`.

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ We achieve these capabilities through:
7575

7676
api_ref_torchcodec
7777
api_ref_decoders
78+
api_ref_samplers

src/torchcodec/_frame.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,20 @@ def __repr__(self):
5656

5757
@dataclass
5858
class FrameBatch(Iterable):
59-
"""Multiple video frames with associated metadata."""
59+
"""Multiple video frames with associated metadata.
60+
61+
The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW),
62+
or 5D for sequences of clips, as returned by the :ref:`samplers <samplers>`.
63+
When ``data`` is 4D (resp. 5D) the ``pts_seconds`` and ``duration_seconds``
64+
tensors are 1D (resp. 2D).
65+
"""
6066

6167
data: Tensor
62-
"""The frames data as (4-D ``torch.Tensor``)."""
68+
"""The frames data (``torch.Tensor`` of uint8)."""
6369
pts_seconds: Tensor
64-
"""The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
70+
"""The :term:`pts` of the frame, in seconds (``torch.Tensor`` of floats)."""
6571
duration_seconds: Tensor
66-
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
72+
"""The duration of the frame, in seconds (``torch.Tensor`` of floats)."""
6773

6874
def __post_init__(self):
6975
# This is called after __init__() when a FrameBatch is created. We can

src/torchcodec/samplers/_common.py

+13
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,16 @@ def _reshape_4d_framebatch_into_5d(
6969
pts_seconds=frames.pts_seconds.view(num_clips, num_frames_per_clip),
7070
duration_seconds=frames.duration_seconds.view(num_clips, num_frames_per_clip),
7171
)
72+
73+
74+
_FRAMEBATCH_RETURN_DOCS = """
75+
Returns:
76+
FrameBatch:
77+
The sampled :term:`clips`, as a 5D :class:`~torchcodec.FrameBatch`.
78+
The shape of the ``data`` field is (``num_clips``,
79+
``num_frames_per_clips``, ...) where ... is (H, W, C) or (C, H, W)
80+
depending on the ``dimension_order`` parameter of
81+
:class:`~torchcodec.decoders.VideoDecoder`. The shape of the
82+
``pts_seconds`` and ``duration_seconds`` fields is (``num_clips``,
83+
``num_frames_per_clips``).
84+
"""

src/torchcodec/samplers/_index_based.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchcodec import FrameBatch
66
from torchcodec.decoders import VideoDecoder
77
from torchcodec.samplers._common import (
8+
_FRAMEBATCH_RETURN_DOCS,
89
_POLICY_FUNCTION_TYPE,
910
_POLICY_FUNCTIONS,
1011
_reshape_4d_framebatch_into_5d,
@@ -194,6 +195,7 @@ def clips_at_random_indices(
194195
sampling_range_end: Optional[int] = None, # interval is [start, end).
195196
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
196197
) -> FrameBatch:
198+
# See docstring below
197199
return _generic_index_based_sampler(
198200
kind="random",
199201
decoder=decoder,
@@ -216,7 +218,7 @@ def clips_at_regular_indices(
216218
sampling_range_end: Optional[int] = None, # interval is [start, end).
217219
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
218220
) -> FrameBatch:
219-
221+
# See docstring below
220222
return _generic_index_based_sampler(
221223
kind="regular",
222224
decoder=decoder,
@@ -227,3 +229,57 @@ def clips_at_regular_indices(
227229
sampling_range_end=sampling_range_end,
228230
policy=policy,
229231
)
232+
233+
234+
_COMMON_DOCS = f"""
235+
Args:
236+
decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder`
237+
instance to sample clips from.
238+
num_clips (int, optional): The number of clips to return. Default: 1.
239+
num_frames_per_clip (int, optional): The number of frames per clips. Default: 1.
240+
num_indices_between_frames(int, optional): The number of indices between
241+
the frames *within* a clip. Default: 1, which means frames are
242+
consecutive. This is sometimes refered-to as "dilation".
243+
sampling_range_start (int, optional): The start of the sampling range,
244+
which defines the first index that a clip may *start* at. Default:
245+
0, i.e. the start of the video.
246+
sampling_range_end (int or None, optional): The end of the sampling
247+
range, which defines the last index that a clip may *start* at. This
248+
value is exclusive, i.e. a clip may only start within
249+
[``sampling_range_start``, ``sampling_range_end``). If None
250+
(default), the value is set automatically such that the clips never
251+
span beyond the end of the video. For example if the last valid
252+
index in a video is 99 and the clips span 10 frames, this value is
253+
set to 99 - 10 + 1 = 90. Negative values are accepted and are
254+
equivalent to ``len(video) - val``. When a clip spans beyond the end
255+
of the video, the ``policy`` parameter defines how to construct such
256+
clip.
257+
policy (str, optional): Defines how to construct clips that span beyond
258+
the end of the video. This is best described with an example:
259+
assuming the last valid index in a video is 99, and a clip was
260+
sampled to start at index 95, with ``num_frames_per_clip=5`` and
261+
``num_indices_between_frames=2``, the indices of the frames in the
262+
clip are supposed to be [95, 97, 99, 101, 103]. But 101 and 103 are
263+
invalid indices, so the ``policy`` parameter defines how to replace
264+
those frames, with valid indices:
265+
266+
- "repeat_last": repeats the last valid frame of the clip. We would
267+
get [95, 97, 99, 99, 99].
268+
- "wrap": wraps around to the beginning of the clip. We would get
269+
[95, 97, 99, 95, 97].
270+
- "error": raises an error.
271+
272+
Default is "repeat_last". Note that when ``sampling_range_end=None``
273+
(default), this policy parameter is unlikely to be relevant.
274+
275+
{_FRAMEBATCH_RETURN_DOCS}
276+
"""
277+
278+
clips_at_random_indices.__doc__ = f"""Sample :term:`clips` at random indices.
279+
{_COMMON_DOCS}
280+
"""
281+
282+
283+
clips_at_regular_indices.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) indices.
284+
{_COMMON_DOCS}
285+
"""

src/torchcodec/samplers/_time_based.py

+85-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from torchcodec import FrameBatch
66
from torchcodec.samplers._common import (
7+
_FRAMEBATCH_RETURN_DOCS,
78
_POLICY_FUNCTION_TYPE,
89
_POLICY_FUNCTIONS,
910
_reshape_4d_framebatch_into_5d,
@@ -156,7 +157,7 @@ def _generic_time_based_sampler(
156157
# None means "begining", which may not always be 0
157158
sampling_range_start: Optional[float],
158159
sampling_range_end: Optional[float], # interval is [start, end).
159-
policy: str = "repeat_last",
160+
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
160161
) -> FrameBatch:
161162
# Note: *everywhere*, sampling_range_end denotes the upper bound of where a
162163
# clip can start. This is an *open* upper bound, i.e. we will make sure no
@@ -226,8 +227,9 @@ def clips_at_random_timestamps(
226227
# None means "begining", which may not always be 0
227228
sampling_range_start: Optional[float] = None,
228229
sampling_range_end: Optional[float] = None, # interval is [start, end).
229-
policy: str = "repeat_last",
230+
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
230231
) -> FrameBatch:
232+
# See docstring below
231233
return _generic_time_based_sampler(
232234
kind="random",
233235
decoder=decoder,
@@ -250,8 +252,9 @@ def clips_at_regular_timestamps(
250252
# None means "begining", which may not always be 0
251253
sampling_range_start: Optional[float] = None,
252254
sampling_range_end: Optional[float] = None, # interval is [start, end).
253-
policy: str = "repeat_last",
255+
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
254256
) -> FrameBatch:
257+
# See docstring below
255258
return _generic_time_based_sampler(
256259
kind="regular",
257260
decoder=decoder,
@@ -263,3 +266,82 @@ def clips_at_regular_timestamps(
263266
sampling_range_end=sampling_range_end,
264267
policy=policy,
265268
)
269+
270+
271+
_COMMON_DOCS = """
272+
{maybe_note}
273+
274+
Args:
275+
decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder`
276+
instance to sample clips from.
277+
{num_clips_or_seconds_between_clip_starts}
278+
num_frames_per_clip (int, optional): The number of frames per clips. Default: 1.
279+
seconds_between_frames (float or None, optional): The time (in seconds)
280+
between each frame within a clip. More accurately, this defines the
281+
time between the *frame sampling point*, i.e. the timestamps at
282+
which we sample the frames. Because frames span intervals in time ,
283+
the resulting start of frames within a clip may not be exactly
284+
spaced by ``seconds_between_frames`` - but on average, they will be.
285+
Default is None, which is set to the average frame duration
286+
(``1/average_fps``).
287+
sampling_range_start (float or None, optional): The start of the
288+
sampling range, which defines the first timestamp (in seconds) that
289+
a clip may *start* at. Default: None, which corresponds to the start
290+
of the video. (Note: some videos start at negative values, which is
291+
why the default is not 0).
292+
sampling_range_end (float or None, optional): The end of the sampling
293+
range, which defines the last timestamp (in seconds) that a clip may
294+
*start* at. This value is exclusive, i.e. a clip may only start within
295+
[``sampling_range_start``, ``sampling_range_end``). If None
296+
(default), the value is set automatically such that the clips never
297+
span beyond the end of the video, i.e. it is set to
298+
``end_video_seconds - (num_frames_per_clip - 1) *
299+
seconds_between_frames``. When a clip spans beyond the end of the
300+
video, the ``policy`` parameter defines how to construct such clip.
301+
policy (str, optional): Defines how to construct clips that span beyond
302+
the end of the video. This is best described with an example:
303+
assuming the last valid (seekable) timestamp in a video is 10.9, and
304+
a clip was sampled to start at timestamp 10.5, with
305+
``num_frames_per_clip=5`` and ``seconds_between_frames=0.2``, the
306+
sampling timestamps of the frames in the clip are supposed to be
307+
[10.5, 10.7, 10.9, 11.1, 11.2]. But 11.1 and 11.2 are invalid
308+
timestamps, so the ``policy`` parameter defines how to replace those
309+
frames, with valid sampling timestamps:
310+
311+
- "repeat_last": repeats the last valid frame of the clip. We would
312+
get frames sampled at timestamps [10.5, 10.7, 10.9, 10.9, 10.9].
313+
- "wrap": wraps around to the beginning of the clip. We would get
314+
frames sampled at timestamps [10.5, 10.7, 10.9, 10.5, 10.7].
315+
- "error": raises an error.
316+
317+
Default is "repeat_last". Note that when ``sampling_range_end=None``
318+
(default), this policy parameter is unlikely to be relevant.
319+
320+
{return_docs}
321+
"""
322+
323+
324+
_NUM_CLIPS_DOCS = """
325+
num_clips (int, optional): The number of clips to return. Default: 1.
326+
"""
327+
clips_at_random_timestamps.__doc__ = f"""Sample :term:`clips` at random timestamps.
328+
{_COMMON_DOCS.format(maybe_note="", num_clips_or_seconds_between_clip_starts=_NUM_CLIPS_DOCS, return_docs=_FRAMEBATCH_RETURN_DOCS)}
329+
"""
330+
331+
332+
_SECONDS_BETWEEN_CLIP_STARTS = """
333+
seconds_between_clip_starts (float): The space (in seconds) between each
334+
clip start.
335+
"""
336+
337+
_NOTE_DOCS = """
338+
.. note::
339+
For consistency with existing sampling APIs (such as torchvision), this
340+
sampler takes a ``seconds_between_clip_starts`` parameter instead of
341+
``num_clips``. If you find that supporting ``num_clips`` would be
342+
useful, please let us know by `opening a feature request
343+
<https://github.com/pytorch/torchcodec/issues?q=is:open+is:issue>`_.
344+
"""
345+
clips_at_regular_timestamps.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) timestamps.
346+
{_COMMON_DOCS.format(maybe_note=_NOTE_DOCS, num_clips_or_seconds_between_clip_starts=_SECONDS_BETWEEN_CLIP_STARTS, return_docs=_FRAMEBATCH_RETURN_DOCS)}
347+
"""

0 commit comments

Comments
 (0)