Skip to content

Commit d75fc58

Browse files
authored
Return pts of first frame in audio API (#552)
1 parent c6de04a commit d75fc58

File tree

6 files changed

+51
-17
lines changed

6 files changed

+51
-17
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
838838
return frameBatchOutput;
839839
}
840840

841-
torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
841+
VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
842842
double startSeconds,
843843
std::optional<double> stopSecondsOptional) {
844844
validateActiveStream(AVMEDIA_TYPE_AUDIO);
@@ -854,7 +854,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
854854

855855
if (startSeconds == stopSeconds) {
856856
// For consistency with video
857-
return torch::empty({0});
857+
return AudioFramesOutput{torch::empty({0}), 0.0};
858858
}
859859

860860
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
@@ -871,17 +871,24 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
871871
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
872872
// cat(). This would save a copy. We know the duration of the output and the
873873
// sample rate, so in theory we know the number of output samples.
874-
std::vector<torch::Tensor> tensors;
874+
std::vector<torch::Tensor> frames;
875875

876+
double firstFramePtsSeconds = std::numeric_limits<double>::max();
876877
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
877878
auto finished = false;
878879
while (!finished) {
879880
try {
880881
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
881882
return startPts < avFrame->pts + getDuration(avFrame);
882883
});
884+
// TODO: it's not great that we are getting a FrameOutput, which is
885+
// intended for videos. We should consider bypassing
886+
// convertAVFrameToFrameOutput and directly call
887+
// convertAudioAVFrameToFrameOutputOnCPU.
883888
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
884-
tensors.push_back(frameOutput.data);
889+
firstFramePtsSeconds =
890+
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
891+
frames.push_back(frameOutput.data);
885892
} catch (const EndOfFileException& e) {
886893
finished = true;
887894
}
@@ -895,7 +902,8 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
895902
finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
896903
(stopPts <= lastDecodedAvFrameEnd);
897904
}
898-
return torch::cat(tensors, 1);
905+
906+
return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
899907
}
900908

901909
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ class VideoDecoder {
170170
const StreamMetadata& streamMetadata);
171171
};
172172

173+
struct AudioFramesOutput {
174+
torch::Tensor data; // shape is (numChannels, numSamples)
175+
double ptsSeconds;
176+
};
177+
173178
// Places the cursor at the first frame on or after the position in seconds.
174179
// Calling getNextFrame() will return the first frame at
175180
// or after this position.
@@ -222,7 +227,7 @@ class VideoDecoder {
222227
double stopSeconds);
223228

224229
// TODO-AUDIO: Should accept sampleRate
225-
torch::Tensor getFramesPlayedInRangeAudio(
230+
AudioFramesOutput getFramesPlayedInRangeAudio(
226231
double startSeconds,
227232
std::optional<double> stopSecondsOptional = std::nullopt);
228233

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4848
m.def(
4949
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
5050
m.def(
51-
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> Tensor");
51+
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
5252
m.def(
5353
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
5454
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
@@ -94,6 +94,13 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput(
9494
return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds);
9595
}
9696

97+
OpsAudioFramesOutput makeOpsAudioFramesOutput(
98+
VideoDecoder::AudioFramesOutput& audioFrames) {
99+
return std::make_tuple(
100+
audioFrames.data,
101+
torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64)));
102+
}
103+
97104
VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
98105
if (seekMode == "exact") {
99106
return VideoDecoder::SeekMode::exact;
@@ -290,12 +297,14 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
290297
return makeOpsFrameBatchOutput(result);
291298
}
292299

293-
torch::Tensor get_frames_by_pts_in_range_audio(
300+
OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
294301
at::Tensor& decoder,
295302
double start_seconds,
296303
std::optional<double> stop_seconds) {
297304
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
298-
return videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
305+
auto result =
306+
videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
307+
return makeOpsAudioFramesOutput(result);
299308
}
300309

301310
std::string quoteValue(const std::string& value) {

src/torchcodec/decoders/_core/VideoDecoderOps.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ using OpsFrameOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7474
// single float.
7575
using OpsFrameBatchOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7676

77+
// The elements of this tuple are all tensors that represent the concatenation
78+
// of multiple audio frames:
79+
// 1. The frames data (concatenated)
80+
// 2. A single float value for the pts of the first frame, in seconds.
81+
using OpsAudioFramesOutput = std::tuple<at::Tensor, at::Tensor>;
82+
7783
// Return the frame that is visible at a given timestamp in seconds. Each frame
7884
// in FFMPEG has a presentation timestamp and a duration. The frame visible at a
7985
// given timestamp T has T >= PTS and T < PTS + Duration.
@@ -112,7 +118,7 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
112118
double start_seconds,
113119
double stop_seconds);
114120

115-
torch::Tensor get_frames_by_pts_in_range_audio(
121+
OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
116122
at::Tensor& decoder,
117123
double start_seconds,
118124
std::optional<double> stop_seconds = std::nullopt);

src/torchcodec/decoders/_core/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,9 @@ def get_frames_by_pts_in_range_audio_abstract(
271271
*,
272272
start_seconds: float,
273273
stop_seconds: Optional[float] = None,
274-
) -> torch.Tensor:
274+
) -> Tuple[torch.Tensor, torch.Tensor]:
275275
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
276-
return torch.empty(image_size)
276+
return (torch.empty(image_size), torch.empty([], dtype=torch.float))
277277

278278

279279
@register_fake("torchcodec_ns::_get_key_frame_indices")

test/decoders/test_ops.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -691,19 +691,23 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset):
691691
decoder = create_from_file(str(asset.path), seek_mode="approximate")
692692
add_audio_stream(decoder)
693693

694-
frames = get_frames_by_pts_in_range_audio(
694+
frames, pts_seconds = get_frames_by_pts_in_range_audio(
695695
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
696696
)
697-
698697
torch.testing.assert_close(frames, reference_frames)
699698

699+
if range == "at_frames_boundaries":
700+
assert pts_seconds == start_seconds
701+
elif range == "not_at_frames_boundaries":
702+
assert pts_seconds == start_frame_info.pts_seconds
703+
700704
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
701705
def test_decode_epsilon_range(self, asset):
702706
decoder = create_from_file(str(asset.path), seek_mode="approximate")
703707
add_audio_stream(decoder)
704708

705709
start_seconds = 5
706-
frames = get_frames_by_pts_in_range_audio(
710+
frames, *_ = get_frames_by_pts_in_range_audio(
707711
decoder, start_seconds=start_seconds, stop_seconds=start_seconds + 1e-5
708712
)
709713
torch.testing.assert_close(
@@ -720,7 +724,7 @@ def test_decode_just_one_frame_at_boundaries(self, asset):
720724

721725
start_seconds = asset.get_frame_info(idx=10).pts_seconds
722726
stop_seconds = asset.get_frame_info(idx=11).pts_seconds
723-
frames = get_frames_by_pts_in_range_audio(
727+
frames, pts_seconds = get_frames_by_pts_in_range_audio(
724728
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
725729
)
726730
torch.testing.assert_close(
@@ -729,15 +733,17 @@ def test_decode_just_one_frame_at_boundaries(self, asset):
729733
asset.get_frame_index(pts_seconds=start_seconds)
730734
),
731735
)
736+
assert pts_seconds == start_seconds
732737

733738
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
734739
def test_decode_start_equal_stop(self, asset):
735740
decoder = create_from_file(str(asset.path), seek_mode="approximate")
736741
add_audio_stream(decoder)
737-
frames = get_frames_by_pts_in_range_audio(
742+
frames, pts_seconds = get_frames_by_pts_in_range_audio(
738743
decoder, start_seconds=1, stop_seconds=1
739744
)
740745
assert frames.shape == (0,)
746+
assert pts_seconds == 0
741747

742748
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
743749
def test_multiple_calls(self, asset):

0 commit comments

Comments
 (0)