Skip to content

Commit 1fd20b2

Browse files
authored
Add AudioDecoder.get_samples_played_in_range() public method (#555)
1 parent 0a20541 commit 1fd20b2

File tree

8 files changed

+239
-7
lines changed

8 files changed

+239
-7
lines changed

src/torchcodec/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
88
# but that results in circular import.
9-
from ._frame import Frame, FrameBatch # usort:skip # noqa
9+
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
1010
from . import decoders, samplers # noqa
1111

1212
try:

src/torchcodec/_frame.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def _frame_repr(self):
15-
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
15+
# Utility to replace __repr__ method of dataclasses below. This prints the
1616
# shape of the .data tensor rather than printing the (potentially very long)
1717
# data tensor itself.
1818
s = self.__class__.__name__ + ":\n"
@@ -114,3 +114,28 @@ def __len__(self):
114114

115115
def __repr__(self):
116116
return _frame_repr(self)
117+
118+
119+
@dataclass
120+
class AudioSamples(Iterable):
121+
"""Audio samples with associated metadata."""
122+
123+
# TODO-AUDIO: docs
124+
data: Tensor
125+
pts_seconds: float
126+
sample_rate: int
127+
128+
def __post_init__(self):
129+
# This is called after __init__() when a Frame is created. We can run
130+
# input validation checks here.
131+
if not self.data.ndim == 2:
132+
raise ValueError(f"data must be 2-dimensional, got {self.data.shape = }")
133+
self.pts_seconds = float(self.pts_seconds)
134+
self.sample_rate = int(self.sample_rate)
135+
136+
def __iter__(self) -> Iterator[Union[Tensor, float]]:
137+
for field in dataclasses.fields(self):
138+
yield getattr(self, field.name)
139+
140+
def __repr__(self):
141+
return _frame_repr(self)

src/torchcodec/decoders/_audio_decoder.py

+68
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torch import Tensor
1111

12+
from torchcodec import AudioSamples
1213
from torchcodec.decoders import _core as core
1314
from torchcodec.decoders._decoder_utils import (
1415
create_decoder,
@@ -37,3 +38,70 @@ def __init__(
3738
) = get_and_validate_stream_metadata(
3839
decoder=self._decoder, stream_index=stream_index, media_type="audio"
3940
)
41+
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
42+
43+
def get_samples_played_in_range(
44+
self, start_seconds: float, stop_seconds: Optional[float] = None
45+
) -> AudioSamples:
46+
"""TODO-AUDIO docs"""
47+
if stop_seconds is not None and not start_seconds <= stop_seconds:
48+
raise ValueError(
49+
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
50+
)
51+
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
52+
raise ValueError(
53+
f"Invalid start seconds: {start_seconds}. "
54+
f"It must be greater than or equal to {self._begin_stream_seconds} "
55+
f"and less than or equal to {self._end_stream_seconds}."
56+
)
57+
frames, first_pts = core.get_frames_by_pts_in_range_audio(
58+
self._decoder,
59+
start_seconds=start_seconds,
60+
stop_seconds=stop_seconds,
61+
)
62+
first_pts = first_pts.item()
63+
64+
# x = frame boundaries
65+
#
66+
# first_pts last_pts
67+
# v v
68+
# ....x..........x..........x...........x..........x..........x.....
69+
# ^ ^
70+
# start_seconds stop_seconds
71+
#
72+
# We want to return the samples in [start_seconds, stop_seconds). But
73+
# because the core API is based on frames, the `frames` tensor contains
74+
# the samples in [first_pts, last_pts)
75+
# So we do some basic math to figure out the position of the view that
76+
# we'll return.
77+
78+
# TODO: sample_rate is either the original one from metadata, or the
79+
# user-specified one (NIY)
80+
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
81+
sample_rate = self.metadata.sample_rate
82+
83+
# TODO: metadata's sample_rate should probably not be Optional
84+
assert sample_rate is not None # mypy.
85+
86+
if first_pts < start_seconds:
87+
offset_beginning = round((start_seconds - first_pts) * sample_rate)
88+
output_pts_seconds = start_seconds
89+
else:
90+
# In normal cases we'll have first_pts <= start_pts, but in some
91+
# edge cases it's possible to have first_pts > start_seconds,
92+
# typically if the stream's first frame's pts isn't exactly 0.
93+
offset_beginning = 0
94+
output_pts_seconds = first_pts
95+
96+
num_samples = frames.shape[1]
97+
last_pts = first_pts + num_samples / self.metadata.sample_rate
98+
if stop_seconds is not None and stop_seconds < last_pts:
99+
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
100+
else:
101+
offset_end = num_samples
102+
103+
return AudioSamples(
104+
data=frames[:, offset_beginning:offset_end],
105+
pts_seconds=output_pts_seconds,
106+
sample_rate=sample_rate,
107+
)

src/torchcodec/decoders/_core/VideoDecoder.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
854854

855855
if (startSeconds == stopSeconds) {
856856
// For consistency with video
857-
return AudioFramesOutput{torch::empty({0}), 0.0};
857+
return AudioFramesOutput{torch::empty({0, 0}), 0.0};
858858
}
859859

860860
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

src/torchcodec/decoders/_core/VideoDecoder.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class VideoDecoder {
147147
// DECODING AND SEEKING APIs
148148
// --------------------------------------------------------------------------
149149

150-
// All public decoding entry points return either a FrameOutput or a
150+
// All public video decoding entry points return either a FrameOutput or a
151151
// FrameBatchOutput.
152152
// They are the equivalent of the user-facing Frame and FrameBatch classes in
153153
// Python. They contain RGB decoded frames along with some associated data

test/decoders/test_decoders.py

+123
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,126 @@ def test_metadata(self, asset):
955955
)
956956
assert decoder.metadata.sample_rate == asset.sample_rate
957957
assert decoder.metadata.num_channels == asset.num_channels
958+
959+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
960+
def test_error(self, asset):
961+
decoder = AudioDecoder(asset.path)
962+
963+
with pytest.raises(ValueError, match="Invalid start seconds"):
964+
decoder.get_samples_played_in_range(start_seconds=-1300)
965+
966+
with pytest.raises(ValueError, match="Invalid start seconds"):
967+
decoder.get_samples_played_in_range(start_seconds=9999)
968+
969+
with pytest.raises(ValueError, match="Invalid start seconds"):
970+
decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=2)
971+
972+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
973+
@pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999))
974+
def test_get_all_samples(self, asset, stop_seconds):
975+
decoder = AudioDecoder(asset.path)
976+
977+
if stop_seconds == "duration":
978+
stop_seconds = asset.duration_seconds
979+
980+
samples = decoder.get_samples_played_in_range(
981+
start_seconds=0, stop_seconds=stop_seconds
982+
)
983+
984+
reference_frames = asset.get_frame_data_by_range(
985+
start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
986+
)
987+
988+
torch.testing.assert_close(samples.data, reference_frames)
989+
assert samples.sample_rate == asset.sample_rate
990+
991+
# TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553
992+
expected_pts = (
993+
0.072
994+
if asset is NASA_AUDIO_MP3
995+
else asset.get_frame_info(idx=0).pts_seconds
996+
)
997+
assert samples.pts_seconds == expected_pts
998+
999+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1000+
def test_at_frame_boundaries(self, asset):
1001+
decoder = AudioDecoder(asset.path)
1002+
1003+
start_frame_index, stop_frame_index = 10, 40
1004+
start_seconds = asset.get_frame_info(start_frame_index).pts_seconds
1005+
stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds
1006+
1007+
samples = decoder.get_samples_played_in_range(
1008+
start_seconds=start_seconds, stop_seconds=stop_seconds
1009+
)
1010+
1011+
reference_frames = asset.get_frame_data_by_range(
1012+
start=start_frame_index, stop=stop_frame_index
1013+
)
1014+
1015+
assert samples.pts_seconds == start_seconds
1016+
num_samples = samples.data.shape[1]
1017+
assert (
1018+
num_samples
1019+
== reference_frames.shape[1]
1020+
== (stop_seconds - start_seconds) * decoder.metadata.sample_rate
1021+
)
1022+
torch.testing.assert_close(samples.data, reference_frames)
1023+
assert samples.sample_rate == asset.sample_rate
1024+
1025+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1026+
def test_not_at_frame_boundaries(self, asset):
1027+
decoder = AudioDecoder(asset.path)
1028+
1029+
start_frame_index, stop_frame_index = 10, 40
1030+
start_frame_info = asset.get_frame_info(start_frame_index)
1031+
stop_frame_info = asset.get_frame_info(stop_frame_index)
1032+
start_seconds = start_frame_info.pts_seconds + (
1033+
start_frame_info.duration_seconds / 2
1034+
)
1035+
stop_seconds = stop_frame_info.pts_seconds + (
1036+
stop_frame_info.duration_seconds / 2
1037+
)
1038+
samples = decoder.get_samples_played_in_range(
1039+
start_seconds=start_seconds, stop_seconds=stop_seconds
1040+
)
1041+
1042+
reference_frames = asset.get_frame_data_by_range(
1043+
start=start_frame_index, stop=stop_frame_index + 1
1044+
)
1045+
1046+
assert samples.pts_seconds == start_seconds
1047+
num_samples = samples.data.shape[1]
1048+
assert num_samples < reference_frames.shape[1]
1049+
assert (
1050+
num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate
1051+
)
1052+
assert samples.sample_rate == asset.sample_rate
1053+
1054+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1055+
def test_start_equals_stop(self, asset):
1056+
decoder = AudioDecoder(asset.path)
1057+
samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3)
1058+
assert samples.data.shape == (0, 0)
1059+
1060+
def test_frame_start_is_not_zero(self):
1061+
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1].
1062+
# So if we request start = 0.05, we shouldn't be truncating anything.
1063+
#
1064+
# [1] well, really it's at 0.138125, not 0.072 (see
1065+
# https://github.com/pytorch/torchcodec/issues/553), but for the purpose
1066+
# of this test it doesn't matter.
1067+
1068+
asset = NASA_AUDIO_MP3
1069+
start_seconds = 0.05 # this is less than the first frame's pts
1070+
stop_frame_index = 10
1071+
stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds
1072+
1073+
decoder = AudioDecoder(asset.path)
1074+
1075+
samples = decoder.get_samples_played_in_range(
1076+
start_seconds=start_seconds, stop_seconds=stop_seconds
1077+
)
1078+
1079+
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
1080+
torch.testing.assert_close(samples.data, reference_frames)

test/decoders/test_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_decode_start_equal_stop(self, asset):
742742
frames, pts_seconds = get_frames_by_pts_in_range_audio(
743743
decoder, start_seconds=1, stop_seconds=1
744744
)
745-
assert frames.shape == (0,)
745+
assert frames.shape == (0, 0)
746746
assert pts_seconds == 0
747747

748748
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))

test/test_frame_dataclasses.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import pytest
22
import torch
3-
from torchcodec import Frame, FrameBatch
3+
from torchcodec import AudioSamples, Frame, FrameBatch
44

55

6-
def test_frame_unpacking():
6+
def test_unpacking():
77
data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa
8+
data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000)
89

910

1011
def test_frame_error():
@@ -139,3 +140,18 @@ def test_framebatch_indexing():
139140
fb_fancy = fb[[[0], [1]]] # select T=0 and N=1.
140141
assert isinstance(fb_fancy, FrameBatch)
141142
assert fb_fancy.data.shape == (1, C, H, W)
143+
144+
145+
def test_audio_samples_error():
146+
with pytest.raises(ValueError, match="data must be 2-dimensional"):
147+
AudioSamples(
148+
data=torch.rand(1),
149+
pts_seconds=1,
150+
sample_rate=16_000,
151+
)
152+
with pytest.raises(ValueError, match="data must be 2-dimensional"):
153+
AudioSamples(
154+
data=torch.rand(1, 2, 3),
155+
pts_seconds=1,
156+
sample_rate=16_000,
157+
)

0 commit comments

Comments
 (0)