From 16d698fd355b10a17464322cfbd836200f013b52 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 16 Dec 2024 07:41:23 -0800 Subject: [PATCH 01/56] Start implementation of approximate mode --- .../decoders/_core/VideoDecoder.cpp | 58 ++++++++++++++----- src/torchcodec/decoders/_core/VideoDecoder.h | 20 ++++++- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 4c98980c..f000bf45 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -216,14 +216,16 @@ bool VideoDecoder::SwsContextKey::operator!=( return !(*this == other); } -VideoDecoder::VideoDecoder(const std::string& videoFilePath) { +VideoDecoder::VideoDecoder(const std::string& videoFilePath, SeekMode seek) + : seekMode_(seek) { AVInput input = createAVFormatContextFromFilePath(videoFilePath); formatContext_ = std::move(input.formatContext); initializeDecoder(); } -VideoDecoder::VideoDecoder(const void* buffer, size_t length) { +VideoDecoder::VideoDecoder(const void* buffer, size_t length, SeekMode seek) + : seekMode_(seek) { TORCH_CHECK(buffer != nullptr, "Video buffer cannot be nullptr!"); AVInput input = createAVFormatContextFromBuffer(buffer, length); @@ -1033,13 +1035,15 @@ void VideoDecoder::validateScannedAllStreams(const std::string& msg) { } void VideoDecoder::validateFrameIndex( - const StreamInfo& stream, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata, int64_t frameIndex) { + int64_t framesSize = getFramesSize(streamInfo, streamMetadata); TORCH_CHECK( - frameIndex >= 0 && frameIndex < stream.allFrames.size(), + frameIndex >= 0 && frameIndex < framesSize, "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(stream.streamIndex) + - " numFrames=" + std::to_string(stream.allFrames.size())); + " for streamIndex=" + std::to_string(streamInfo.streamIndex) + + " numFrames=" + std::to_string(framesSize)); } VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( @@ -1050,6 +1054,32 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( return output; } +int64_t VideoDecoder::getPts( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata, + int64_t frameIndex) { + switch (seekMode_) { + case SeekMode::EXACT: + return streamInfo.allFrames[frameIndex].pts; + case SeekMode::APPROXIMATE: + return secondsToClosestPts( + frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +int64_t VideoDecoder::getFramesSize(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata) { + switch(seekMode_) { + case SeekMode::EXACT: + return streamInfo.allFrames.size(); + case SeekMode::APPROXIMATE: + return streamMetadata.numFrames.value(); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, @@ -1057,11 +1087,12 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFrameAtIndex"); - const auto& stream = streams_[streamIndex]; - validateFrameIndex(stream, frameIndex); + const auto& streamInfo = streams_[streamIndex]; + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + validateFrameIndex(streamInfo, streamMetadata, frameIndex); - int64_t pts = stream.allFrames[frameIndex].pts; - setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); + int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); + setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor); } @@ -1332,10 +1363,11 @@ double VideoDecoder::getPtsSecondsForFrame( validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getPtsSecondsForFrame"); - const auto& stream = streams_[streamIndex]; - validateFrameIndex(stream, frameIndex); + const auto& streamInfo = streams_[streamIndex]; + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + validateFrameIndex(streamInfo, streamMetadata, frameIndex); - return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase); + return ptsToSeconds(streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); } int VideoDecoder::convertFrameToBufferUsingSwsScale( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index fcd1b17c..5cf211ab 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -50,12 +50,19 @@ class VideoDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- + enum class SeekMode { EXACT, APPROXIMATE }; + // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder(const std::string& videoFilePath); + explicit VideoDecoder( + const std::string& videoFilePath, + SeekMode seek = SeekMode::EXACT); // Creates a VideoDecoder from a given buffer. Note that the buffer is not // owned by the VideoDecoder. - explicit VideoDecoder(const void* buffer, size_t length); + explicit VideoDecoder( + const void* buffer, + size_t length, + SeekMode seek = SeekMode::EXACT); static std::unique_ptr createFromFilePath( const std::string& videoFilePath); @@ -368,13 +375,19 @@ class VideoDecoder { void initializeDecoder(); void validateUserProvidedStreamIndex(uint64_t streamIndex); void validateScannedAllStreams(const std::string& msg); - void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex); + void validateFrameIndex(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata, int64_t frameIndex); // Creates and initializes a filter graph for a stream. The filter graph can // do rescaling and color conversion. void initializeFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth); + + int64_t getFramesSize(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata); + int64_t getPts( + const StreamInfo& stream, + const StreamMetadata& streamMetadata, + int64_t frameIndex); void maybeSeekToBeforeDesiredPts(); RawDecodedOutput getDecodedOutputWithFilter( std::function); @@ -404,6 +417,7 @@ class VideoDecoder { DecodedOutput getNextFrameOutputNoDemuxInternal( std::optional preAllocatedOutputTensor = std::nullopt); + SeekMode seekMode_; ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; std::map streams_; From d95b128488fd34e9787e5ed598bf2fa1c418f70c Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 19 Dec 2024 12:01:14 -0800 Subject: [PATCH 02/56] Initial seek mode implementation in VideoDecoder. --- .../decoders/_core/VideoDecoder.cpp | 245 ++++++++++++------ src/torchcodec/decoders/_core/VideoDecoder.h | 24 +- 2 files changed, 177 insertions(+), 92 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7ca87342..94eca202 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -204,6 +204,22 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); } +VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( + const std::vector& inFrames, + std::vector& inPtsSeconds, + std::vector& inDurationSeconds) + : frames(torch::stack(inFrames)), + ptsSeconds(torch::from_blob( + inPtsSeconds.data(), + inPtsSeconds.size(), + {torch::kFloat64}) + .clone()), + durationSeconds(torch::from_blob( + inDurationSeconds.data(), + inDurationSeconds.size(), + {torch::kFloat64}) + .clone()) {} + bool VideoDecoder::DecodedFrameContext::operator==( const VideoDecoder::DecodedFrameContext& other) { return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight && @@ -306,14 +322,15 @@ void VideoDecoder::initializeDecoder() { } std::unique_ptr VideoDecoder::createFromFilePath( - const std::string& videoFilePath) { - return std::unique_ptr(new VideoDecoder(videoFilePath)); + const std::string& videoFilePath, SeekMode seekMode) { + return std::unique_ptr(new VideoDecoder(videoFilePath, seekMode)); } std::unique_ptr VideoDecoder::createFromBuffer( const void* buffer, - size_t length) { - return std::unique_ptr(new VideoDecoder(buffer, length)); + size_t length, + SeekMode seekMode) { + return std::unique_ptr(new VideoDecoder(buffer, length, seekMode)); } void VideoDecoder::createFilterGraph( @@ -1085,9 +1102,9 @@ int64_t VideoDecoder::getPts( const StreamMetadata& streamMetadata, int64_t frameIndex) { switch (seekMode_) { - case SeekMode::EXACT: + case SeekMode::exact: return streamInfo.allFrames[frameIndex].pts; - case SeekMode::APPROXIMATE: + case SeekMode::approximate: return secondsToClosestPts( frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase); default: @@ -1095,11 +1112,13 @@ int64_t VideoDecoder::getPts( } } -int64_t VideoDecoder::getFramesSize(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata) { - switch(seekMode_) { - case SeekMode::EXACT: +int64_t VideoDecoder::getFramesSize( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: return streamInfo.allFrames.size(); - case SeekMode::APPROXIMATE: + case SeekMode::approximate: return streamMetadata.numFrames.value(); default: throw std::runtime_error("Unknown SeekMode"); @@ -1111,7 +1130,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFrameAtIndex"); + validateScannedAllStreams("getFrameAtIndex"); // converted const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; @@ -1126,7 +1145,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesAtIndices"); + validateScannedAllStreams("getFramesAtIndices"); // converted auto indicesAreSorted = std::is_sorted(frameIndices.begin(), frameIndices.end()); @@ -1156,7 +1175,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( for (auto f = 0; f < frameIndices.size(); ++f) { auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; - if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { + if (indexInVideo < 0 || + indexInVideo >= getFramesSize(stream, streamMetadata)) { throw std::runtime_error( "Invalid frame index=" + std::to_string(indexInVideo)); } @@ -1184,41 +1204,61 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( int streamIndex, const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesPlayedByTimestamps"); - - // The frame played at timestamp t and the one played at timestamp `t + - // eps` are probably the same frame, with the same index. The easiest way to - // avoid decoding that unique frame twice is to convert the input timestamps - // to indices, and leverage the de-duplication logic of getFramesAtIndices. - // This means this function requires a scan. - // TODO: longer term, we should implement this without requiring a scan + validateScannedAllStreams("getFramesPlayedByTimestamps"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); - std::vector frameIndices(timestamps.size()); - for (auto i = 0; i < timestamps.size(); ++i) { - auto framePts = timestamps[i]; - TORCH_CHECK( - framePts >= minSeconds && framePts < maxSeconds, - "frame pts is " + std::to_string(framePts) + "; must be in range [" + - std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + - ")."); + if (seekMode_ == SeekMode::exact) { + // The frame played at timestamp t and the one played at timestamp `t + + // eps` are probably the same frame, with the same index. The easiest way to + // avoid decoding that unique frame twice is to convert the input timestamps + // to indices, and leverage the de-duplication logic of getFramesAtIndices. - auto it = std::lower_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - framePts, - [&stream](const FrameInfo& info, double framePts) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; - }); - int64_t frameIndex = it - stream.allFrames.begin(); - frameIndices[i] = frameIndex; - } + std::vector frameIndices(timestamps.size()); + for (auto i = 0; i < timestamps.size(); ++i) { + auto framePts = timestamps[i]; + TORCH_CHECK( + framePts >= minSeconds && framePts < maxSeconds, + "frame pts is " + std::to_string(framePts) + "; must be in range [" + + std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + + ")."); + + auto it = std::lower_bound( + stream.allFrames.begin(), + stream.allFrames.end(), + framePts, + [&stream](const FrameInfo& info, double framePts) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; + }); + int64_t frameIndex = it - stream.allFrames.begin(); + frameIndices[i] = frameIndex; + } + + return getFramesAtIndices(streamIndex, frameIndices); - return getFramesAtIndices(streamIndex, frameIndices); + } else if (seekMode_ == SeekMode::approximate) { + // TODO: Figure out if we can be smarter than just iterating over the + // timestamps one-by-one. + + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& options = stream.options; + BatchDecodedOutput output(timestamps.size(), options, streamMetadata); + + for (auto i = 0; i < timestamps.size(); ++i) { + DecodedOutput singleOut = getFramePlayedAtTimestampNoDemux(timestamps[i]); + output.ptsSeconds[i] = singleOut.ptsSeconds; + output.durationSeconds[i] = singleOut.durationSeconds; + } + + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + return output; + + } else { + throw std::runtime_error("Unknown SeekMode"); + } } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( @@ -1227,17 +1267,17 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t stop, int64_t step) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesInRange"); + validateScannedAllStreams("getFramesInRange"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; + int64_t framesSize = getFramesSize(stream, streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); TORCH_CHECK( - stop <= stream.allFrames.size(), + stop <= framesSize, "Range stop, " + std::to_string(stop) + - ", is more than the number of frames, " + - std::to_string(stream.allFrames.size())); + ", is more than the number of frames, " + std::to_string(framesSize)); TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); @@ -1261,7 +1301,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( double startSeconds, double stopSeconds) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesPlayedByTimestampInRange"); + validateScannedAllStreams("getFramesPlayedByTimestampInRange"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); @@ -1308,46 +1348,82 @@ VideoDecoder::getFramesPlayedByTimestampInRange( return output; } - // Note that we look at nextPts for a frame, and not its pts or duration. Our - // abstract player displays frames starting at the pts for that frame until - // the pts for the next frame. There are two consequences: - // - // 1. We ignore the duration for a frame. A frame is played until the - // next frame replaces it. This model is robust to durations being 0 or - // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + duration. - // 2. In order to establish if the start of an interval maps to a particular - // frame, we need to figure out if it is ordered after the frame's pts, but - // before the next frames's pts. - auto startFrame = std::lower_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - startSeconds, - [&stream](const FrameInfo& info, double start) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= start; - }); + if (seekMode_ == SeekMode::exact) { + // Note that we look at nextPts for a frame, and not its pts or duration. + // Our abstract player displays frames starting at the pts for that frame + // until the pts for the next frame. There are two consequences: + // + // 1. We ignore the duration for a frame. A frame is played until the + // next frame replaces it. This model is robust to durations being 0 or + // incorrect; our source of truth is the pts for frames. If duration is + // accurate, the nextPts for a frame would be equivalent to pts + + // duration. + // 2. In order to establish if the start of an interval maps to a + // particular frame, we need to figure out if it is ordered after the + // frame's pts, but before the next frames's pts. + + auto startFrame = std::lower_bound( + stream.allFrames.begin(), + stream.allFrames.end(), + startSeconds, + [&stream](const FrameInfo& info, double start) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= start; + }); - auto stopFrame = std::upper_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - stopSeconds, - [&stream](double stop, const FrameInfo& info) { - return stop <= ptsToSeconds(info.pts, stream.timeBase); - }); + auto stopFrame = std::upper_bound( + stream.allFrames.begin(), + stream.allFrames.end(), + stopSeconds, + [&stream](double stop, const FrameInfo& info) { + return stop <= ptsToSeconds(info.pts, stream.timeBase); + }); - int64_t startFrameIndex = startFrame - stream.allFrames.begin(); - int64_t stopFrameIndex = stopFrame - stream.allFrames.begin(); - int64_t numFrames = stopFrameIndex - startFrameIndex; - BatchDecodedOutput output(numFrames, options, streamMetadata); - for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = - getFrameAtIndexInternal(streamIndex, i, output.frames[f]); - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; - } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + int64_t startFrameIndex = startFrame - stream.allFrames.begin(); + int64_t stopFrameIndex = stopFrame - stream.allFrames.begin(); + int64_t numFrames = stopFrameIndex - startFrameIndex; + BatchDecodedOutput output(numFrames, options, streamMetadata); + for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { + DecodedOutput singleOut = + getFrameAtIndexInternal(streamIndex, i, output.frames[f]); + output.ptsSeconds[f] = singleOut.ptsSeconds; + output.durationSeconds[f] = singleOut.durationSeconds; + } + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); - return output; + return output; + + } else if (seekMode_ == SeekMode::approximate) { + // Because we can only discover when to stop by doing the actual decoding, + // we can't pre-allocate the correct dimensions for our BatchDecodedOutput; + // we don't yet know N, the number of frames. So we have to store all of the + // decoded frames in a vector, and construct the final data tensor after. + + // TODO: Figure out if there is a better of doing this. That is, we store + // everything in vectors and then call torch::stack and torch::tensor.clone + // after the fact. We can't preallocate the final tensor because we don't + // know how many frames we're going to decode up front. + + setCursorPtsInSeconds(startSeconds); + DecodedOutput singleOut = getNextFrameNoDemux(); + + std::vector frames = {singleOut.frame}; + std::vector ptsSeconds = {singleOut.ptsSeconds}; + std::vector durationSeconds = {singleOut.durationSeconds}; + + while (singleOut.ptsSeconds <= stopSeconds) { + singleOut = getNextFrameNoDemux(); + frames.push_back(singleOut.frame); + ptsSeconds.push_back(singleOut.ptsSeconds); + durationSeconds.push_back(singleOut.durationSeconds); + } + + BatchDecodedOutput output(frames, ptsSeconds, durationSeconds); + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + return output; + + } else { + throw std::runtime_error("Unknown SeekMode"); + } } VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { @@ -1387,13 +1463,14 @@ double VideoDecoder::getPtsSecondsForFrame( int streamIndex, int64_t frameIndex) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getPtsSecondsForFrame"); + validateScannedAllStreams("getPtsSecondsForFrame"); // keeping? const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; validateFrameIndex(streamInfo, streamMetadata, frameIndex); - return ptsToSeconds(streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); + return ptsToSeconds( + streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); } void VideoDecoder::createSwsContext( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 91080341..4b35c961 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -50,26 +50,27 @@ class VideoDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- - enum class SeekMode { EXACT, APPROXIMATE }; + enum class SeekMode { exact, approximate }; // Creates a VideoDecoder from the video at videoFilePath. explicit VideoDecoder( const std::string& videoFilePath, - SeekMode seek = SeekMode::EXACT); + SeekMode); // Creates a VideoDecoder from a given buffer. Note that the buffer is not // owned by the VideoDecoder. explicit VideoDecoder( const void* buffer, size_t length, - SeekMode seek = SeekMode::EXACT); + SeekMode); static std::unique_ptr createFromFilePath( - const std::string& videoFilePath); + const std::string& videoFilePath, SeekMode seekMode = SeekMode::exact); static std::unique_ptr createFromBuffer( const void* buffer, - size_t length); + size_t length, + SeekMode seekMode = SeekMode::exact); // -------------------------------------------------------------------------- // VIDEO METADATA QUERY API @@ -242,9 +243,14 @@ class VideoDecoder { torch::Tensor durationSeconds; explicit BatchDecodedOutput( - int64_t numFrames, - const VideoStreamDecoderOptions& options, - const StreamMetadata& metadata); + int64_t numFrames, + const VideoStreamDecoderOptions& options, + const StreamMetadata& metadata); + + explicit BatchDecodedOutput( + const std::vector& inFrames, + std::vector& inPtsSeconds, + std::vector& inDurationSeconds); }; // Returns frames at the given indices for a given stream as a single stacked @@ -376,6 +382,7 @@ class VideoDecoder { void validateUserProvidedStreamIndex(uint64_t streamIndex); void validateScannedAllStreams(const std::string& msg); void validateFrameIndex(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata, int64_t frameIndex); + // Creates and initializes a filter graph for a stream. The filter graph can // do rescaling and color conversion. void createFilterGraph( @@ -384,6 +391,7 @@ class VideoDecoder { int expectedOutputWidth); int64_t getFramesSize(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata); + int64_t getPts( const StreamInfo& stream, const StreamMetadata& streamMetadata, From 35f2e596c1afb7a076a723eb686c29fb20b4a142 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 19 Dec 2024 19:19:26 -0800 Subject: [PATCH 03/56] Added Python side support, extended tests. --- .../decoders/benchmark_decoders_library.py | 36 +++-- .../decoders/_core/VideoDecoder.cpp | 125 +++++++++++++----- src/torchcodec/decoders/_core/VideoDecoder.h | 39 +++--- .../decoders/_core/VideoDecoderOps.cpp | 53 ++++++-- .../decoders/_core/VideoDecoderOps.h | 13 +- src/torchcodec/decoders/_core/_metadata.py | 2 +- .../decoders/_core/video_decoder_ops.py | 12 +- src/torchcodec/decoders/_video_decoder.py | 76 +++++++---- test/decoders/test_metadata.py | 5 +- test/decoders/test_video_decoder.py | 97 +++++++++----- test/decoders/test_video_decoder_ops.py | 63 ++++----- 11 files changed, 335 insertions(+), 186 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index be60d8e8..5b4e8a02 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -22,7 +22,6 @@ get_frames_by_pts, get_json_metadata, get_next_frame, - scan_all_streams_to_update_metadata, seek_to_pts, ) @@ -154,8 +153,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu" self._device = device def decode_frames(self, video_file, pts_list): - decoder = create_from_file(video_file) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(video_file, seek_mode="exact") _add_video_stream( decoder, num_threads=self._num_threads, @@ -170,7 +168,7 @@ def decode_frames(self, video_file, pts_list): return frames def decode_first_n_frames(self, video_file, n): - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") _add_video_stream( decoder, num_threads=self._num_threads, @@ -197,7 +195,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu" self.transforms_v2 = transforms_v2 def decode_frames(self, video_file, pts_list): - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") num_threads = int(self._num_threads) if self._num_threads else 0 _add_video_stream( decoder, @@ -216,7 +214,7 @@ def decode_frames(self, video_file, pts_list): def decode_first_n_frames(self, video_file, n): num_threads = int(self._num_threads) if self._num_threads else 0 - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") _add_video_stream( decoder, num_threads=num_threads, @@ -233,7 +231,7 @@ def decode_first_n_frames(self, video_file, n): def decode_and_resize(self, video_file, pts_list, height, width, device): num_threads = int(self._num_threads) if self._num_threads else 1 - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") _add_video_stream( decoder, num_threads=num_threads, @@ -263,8 +261,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu" self._device = device def decode_frames(self, video_file, pts_list): - decoder = create_from_file(video_file) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(video_file, seek_mode="exact") _add_video_stream( decoder, num_threads=self._num_threads, @@ -279,8 +276,7 @@ def decode_frames(self, video_file, pts_list): return frames def decode_first_n_frames(self, video_file, n): - decoder = create_from_file(video_file) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(video_file, seek_mode="exact") _add_video_stream( decoder, num_threads=self._num_threads, @@ -297,9 +293,10 @@ def decode_first_n_frames(self, video_file, n): class TorchCodecPublic(AbstractDecoder): - def __init__(self, num_ffmpeg_threads=None, device="cpu"): + def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"): self._num_ffmpeg_threads = num_ffmpeg_threads self._device = device + self._seek_mode = seek_mode from torchvision.transforms import v2 as transforms_v2 @@ -310,7 +307,7 @@ def decode_frames(self, video_file, pts_list): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode ) return decoder.get_frames_played_at(pts_list) @@ -319,7 +316,7 @@ def decode_first_n_frames(self, video_file, n): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode ) frames = [] count = 0 @@ -335,7 +332,7 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode ) frames = decoder.get_frames_played_at(pts_list) frames = self.transforms_v2.functional.resize(frames.data, (height, width)) @@ -343,9 +340,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): class TorchCodecPublicNonBatch(AbstractDecoder): - def __init__(self, num_ffmpeg_threads=None, device="cpu"): + def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"): self._num_ffmpeg_threads = num_ffmpeg_threads self._device = device + self._seek_mode = seek_mode from torchvision.transforms import v2 as transforms_v2 @@ -356,7 +354,7 @@ def decode_frames(self, video_file, pts_list): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode ) frames = [] @@ -370,7 +368,7 @@ def decode_first_n_frames(self, video_file, n): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode ) frames = [] count = 0 @@ -386,7 +384,7 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode ) frames = [] diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 94eca202..aa05b65a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -318,19 +318,26 @@ void VideoDecoder::initializeDecoder() { containerMetadata_.bestAudioStreamIndex = bestAudioStream; } + if (seekMode_ == SeekMode::exact) { + scanFileAndUpdateMetadataAndIndex(); + } + initialized_ = true; } std::unique_ptr VideoDecoder::createFromFilePath( - const std::string& videoFilePath, SeekMode seekMode) { - return std::unique_ptr(new VideoDecoder(videoFilePath, seekMode)); + const std::string& videoFilePath, + SeekMode seekMode) { + return std::unique_ptr( + new VideoDecoder(videoFilePath, seekMode)); } std::unique_ptr VideoDecoder::createFromBuffer( const void* buffer, size_t length, SeekMode seekMode) { - return std::unique_ptr(new VideoDecoder(buffer, length, seekMode)); + return std::unique_ptr( + new VideoDecoder(buffer, length, seekMode)); } void VideoDecoder::createFilterGraph( @@ -454,6 +461,7 @@ void VideoDecoder::addVideoStreamDecoder( " is already active."); } TORCH_CHECK(formatContext_.get() != nullptr); + AVCodecPtr codec = nullptr; int streamNumber = av_find_best_stream( formatContext_.get(), @@ -466,21 +474,35 @@ void VideoDecoder::addVideoStreamDecoder( throw std::invalid_argument("No valid stream found in input file."); } TORCH_CHECK(codec != nullptr); + + StreamMetadata& streamMetadata = containerMetadata_.streams[streamNumber]; + if (seekMode_ == SeekMode::approximate && + !streamMetadata.averageFps.has_value()) { + throw std::runtime_error( + "Seek mode is approximate, but stream " + std::to_string(streamNumber) + + " does not have an average fps in its metadata."); + } + StreamInfo& streamInfo = streams_[streamNumber]; streamInfo.streamIndex = streamNumber; streamInfo.timeBase = formatContext_->streams[streamNumber]->time_base; streamInfo.stream = formatContext_->streams[streamNumber]; + if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) { throw std::invalid_argument( "Stream with index " + std::to_string(streamNumber) + " is not a video stream."); } + AVCodecContext* codecContext = avcodec_alloc_context3(codec); - codecContext->thread_count = options.ffmpegThreadCount.value_or(0); TORCH_CHECK(codecContext != nullptr); + codecContext->thread_count = options.ffmpegThreadCount.value_or(0); streamInfo.codecContext.reset(codecContext); + int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); + TORCH_CHECK_EQ(retVal, AVSUCCESS); + if (options.device.type() == torch::kCPU) { // No more initialization needed for CPU. } else if (options.device.type() == torch::kCUDA) { @@ -488,16 +510,16 @@ void VideoDecoder::addVideoStreamDecoder( } else { TORCH_CHECK(false, "Invalid device type: " + options.device.str()); } - TORCH_CHECK_EQ(retVal, AVSUCCESS); + retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); if (retVal < AVSUCCESS) { throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal)); } + codecContext->time_base = streamInfo.stream->time_base; activeStreamIndices_.insert(streamNumber); updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.options = options; - int width = options.width.value_or(codecContext->width); // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -506,6 +528,7 @@ void VideoDecoder::addVideoStreamDecoder( // swscale's width requirements to be violated. We don't expose the ability to // choose color conversion library publicly; we only use this ability // internally. + int width = options.width.value_or(codecContext->width); auto defaultLibrary = getDefaultColorConversionLibrary(width); streamInfo.colorConversionLibrary = options.colorConversionLibrary.value_or(defaultLibrary); @@ -550,7 +573,7 @@ int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex( } void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { - if (scanned_all_streams_) { + if (scannedAllStreams_) { return; } while (true) { @@ -624,7 +647,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { } } } - scanned_all_streams_ = true; + scannedAllStreams_ = true; } int VideoDecoder::getKeyFrameIndexForPts( @@ -1021,6 +1044,15 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( double seconds) { + auto output = getFramePlayedAtTimestampNoDemuxInternal(seconds); + output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); + return output; +} + +VideoDecoder::DecodedOutput +VideoDecoder::getFramePlayedAtTimestampNoDemuxInternal( + double seconds, + std::optional preAllocatedOutputTensor) { for (auto& [streamIndex, stream] : streams_) { double frameStartTime = ptsToSeconds(stream.currentPts, stream.timeBase); double frameEndTime = ptsToSeconds( @@ -1032,6 +1064,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( break; } } + setCursorPtsInSeconds(seconds); RawDecodedOutput rawOutput = getDecodedOutputWithFilter( [seconds, this](int frameStreamIndex, AVFrame* frame) { @@ -1051,10 +1084,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } return seconds >= frameStartTime && seconds < frameEndTime; }); + // Convert the frame to tensor. - auto output = convertAVFrameToDecodedOutput(rawOutput); - output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); - return output; + return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) { @@ -1071,7 +1103,7 @@ void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) { } void VideoDecoder::validateScannedAllStreams(const std::string& msg) { - if (!scanned_all_streams_) { + if (!scannedAllStreams_) { throw std::runtime_error( "Must scan all streams to update metadata before calling " + msg); } @@ -1130,7 +1162,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFrameAtIndex"); // converted + //validateScannedAllStreams("getFrameAtIndex"); // converted const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; @@ -1145,7 +1177,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesAtIndices"); // converted + //validateScannedAllStreams("getFramesAtIndices"); // converted auto indicesAreSorted = std::is_sorted(frameIndices.begin(), frameIndices.end()); @@ -1204,14 +1236,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( int streamIndex, const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesPlayedByTimestamps"); // converted + //validateScannedAllStreams("getFramesPlayedByTimestamps"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); if (seekMode_ == SeekMode::exact) { + double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); + double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); + // The frame played at timestamp t and the one played at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to // avoid decoding that unique frame twice is to convert the input timestamps @@ -1240,6 +1273,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( return getFramesAtIndices(streamIndex, frameIndices); } else if (seekMode_ == SeekMode::approximate) { + double minSeconds = 0; + double maxSeconds = streamMetadata.durationSeconds.value(); + // TODO: Figure out if we can be smarter than just iterating over the // timestamps one-by-one. @@ -1248,7 +1284,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( BatchDecodedOutput output(timestamps.size(), options, streamMetadata); for (auto i = 0; i < timestamps.size(); ++i) { - DecodedOutput singleOut = getFramePlayedAtTimestampNoDemux(timestamps[i]); + auto framePts = timestamps[i]; + TORCH_CHECK( + framePts >= minSeconds && framePts < maxSeconds, + "frame pts is " + std::to_string(framePts) + "; must be in range [" + + std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + + ")."); + + DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal( + framePts, output.frames[i]); output.ptsSeconds[i] = singleOut.ptsSeconds; output.durationSeconds[i] = singleOut.durationSeconds; } @@ -1267,7 +1311,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t stop, int64_t step) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesInRange"); // converted + //validateScannedAllStreams("getFramesInRange"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; @@ -1301,26 +1345,14 @@ VideoDecoder::getFramesPlayedByTimestampInRange( double startSeconds, double stopSeconds) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesPlayedByTimestampInRange"); // converted + //validateScannedAllStreams("getFramesPlayedByTimestampInRange"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); TORCH_CHECK( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + std::to_string(stopSeconds) + "."); - TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, - "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + - ")."); const auto& stream = streams_[streamIndex]; const auto& options = stream.options; @@ -1349,6 +1381,19 @@ VideoDecoder::getFramesPlayedByTimestampInRange( } if (seekMode_ == SeekMode::exact) { + double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); + double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); + TORCH_CHECK( + startSeconds >= minSeconds && startSeconds < maxSeconds, + "Start seconds is " + std::to_string(startSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + TORCH_CHECK( + stopSeconds <= maxSeconds, + "Stop seconds (" + std::to_string(stopSeconds) + + "; must be less than or equal to " + std::to_string(maxSeconds) + + ")."); + // Note that we look at nextPts for a frame, and not its pts or duration. // Our abstract player displays frames starting at the pts for that frame // until the pts for the next frame. There are two consequences: @@ -1393,6 +1438,19 @@ VideoDecoder::getFramesPlayedByTimestampInRange( return output; } else if (seekMode_ == SeekMode::approximate) { + double minSeconds = 0; + double maxSeconds = streamMetadata.durationSeconds.value(); + TORCH_CHECK( + startSeconds >= minSeconds && startSeconds < maxSeconds, + "Start seconds is " + std::to_string(startSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + TORCH_CHECK( + stopSeconds <= maxSeconds, + "Stop seconds (" + std::to_string(stopSeconds) + + "; must be less than or equal to " + std::to_string(maxSeconds) + + ")."); + // Because we can only discover when to stop by doing the actual decoding, // we can't pre-allocate the correct dimensions for our BatchDecodedOutput; // we don't yet know N, the number of frames. So we have to store all of the @@ -1410,7 +1468,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( std::vector ptsSeconds = {singleOut.ptsSeconds}; std::vector durationSeconds = {singleOut.durationSeconds}; - while (singleOut.ptsSeconds <= stopSeconds) { + while (singleOut.ptsSeconds < stopSeconds) { singleOut = getNextFrameNoDemux(); frames.push_back(singleOut.frame); ptsSeconds.push_back(singleOut.ptsSeconds); @@ -1418,7 +1476,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange( } BatchDecodedOutput output(frames, ptsSeconds, durationSeconds); - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } else { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 4b35c961..2ecc28ac 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -53,19 +53,15 @@ class VideoDecoder { enum class SeekMode { exact, approximate }; // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder( - const std::string& videoFilePath, - SeekMode); + explicit VideoDecoder(const std::string& videoFilePath, SeekMode); // Creates a VideoDecoder from a given buffer. Note that the buffer is not // owned by the VideoDecoder. - explicit VideoDecoder( - const void* buffer, - size_t length, - SeekMode); + explicit VideoDecoder(const void* buffer, size_t length, SeekMode); static std::unique_ptr createFromFilePath( - const std::string& videoFilePath, SeekMode seekMode = SeekMode::exact); + const std::string& videoFilePath, + SeekMode seekMode = SeekMode::exact); static std::unique_ptr createFromBuffer( const void* buffer, @@ -243,14 +239,14 @@ class VideoDecoder { torch::Tensor durationSeconds; explicit BatchDecodedOutput( - int64_t numFrames, - const VideoStreamDecoderOptions& options, - const StreamMetadata& metadata); + int64_t numFrames, + const VideoStreamDecoderOptions& options, + const StreamMetadata& metadata); explicit BatchDecodedOutput( - const std::vector& inFrames, - std::vector& inPtsSeconds, - std::vector& inDurationSeconds); + const std::vector& inFrames, + std::vector& inPtsSeconds, + std::vector& inDurationSeconds); }; // Returns frames at the given indices for a given stream as a single stacked @@ -381,7 +377,10 @@ class VideoDecoder { void initializeDecoder(); void validateUserProvidedStreamIndex(uint64_t streamIndex); void validateScannedAllStreams(const std::string& msg); - void validateFrameIndex(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata, int64_t frameIndex); + void validateFrameIndex( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata, + int64_t frameIndex); // Creates and initializes a filter graph for a stream. The filter graph can // do rescaling and color conversion. @@ -390,7 +389,9 @@ class VideoDecoder { int expectedOutputHeight, int expectedOutputWidth); - int64_t getFramesSize(const StreamInfo& streamInfo, const StreamMetadata& streamMetadata); + int64_t getFramesSize( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata); int64_t getPts( const StreamInfo& stream, @@ -402,6 +403,10 @@ class VideoDecoder { const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); + DecodedOutput getFramePlayedAtTimestampNoDemuxInternal( + double seconds, + std::optional preAllocatedOutputTensor = std::nullopt); + void maybeSeekToBeforeDesiredPts(); RawDecodedOutput getDecodedOutputWithFilter( std::function); @@ -447,7 +452,7 @@ class VideoDecoder { // Stores the AVIOContext for the input buffer. std::unique_ptr ioBytesContext_; // Whether or not we have already scanned all streams to update the metadata. - bool scanned_all_streams_ = false; + bool scannedAllStreams_ = false; // Tracks that we've already been initialized. bool initialized_ = false; }; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 7f6bd3b3..3a035d7c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -27,8 +27,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.impl_abstract_pystub( "torchcodec.decoders._core.video_decoder_ops", "//pytorch/torchcodec:torchcodec"); - m.def("create_from_file(str filename) -> Tensor"); - m.def("create_from_tensor(Tensor video_tensor) -> Tensor"); + m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); + m.def( + "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()"); m.def( @@ -88,31 +89,67 @@ OpsBatchDecodedOutput makeOpsBatchDecodedOutput( VideoDecoder::BatchDecodedOutput& batch) { return std::make_tuple(batch.frames, batch.ptsSeconds, batch.durationSeconds); } + +VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { + if (seekMode == "exact") { + return VideoDecoder::SeekMode::exact; + } else if (seekMode == "approximate") { + return VideoDecoder::SeekMode::approximate; + } else { + throw std::runtime_error("Invalid seek mode: " + std::string(seekMode)); + } +} + } // namespace // ============================== // Implementations for the operators // ============================== -at::Tensor create_from_file(std::string_view filename) { +at::Tensor create_from_file( + std::string_view filename, + std::optional seek_mode) { std::string filenameStr(filename); + + VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + if (seek_mode.has_value()) { + realSeek = seekModeFromString(seek_mode.value()); + } + std::unique_ptr uniqueDecoder = - VideoDecoder::createFromFilePath(filenameStr); + VideoDecoder::createFromFilePath(filenameStr, realSeek); + return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -at::Tensor create_from_tensor(at::Tensor video_tensor) { +at::Tensor create_from_tensor( + at::Tensor video_tensor, + std::optional seek_mode) { TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); void* buffer = video_tensor.mutable_data_ptr(); size_t length = video_tensor.numel(); + + VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + if (seek_mode.has_value()) { + realSeek = seekModeFromString(seek_mode.value()); + } + std::unique_ptr videoDecoder = - VideoDecoder::createFromBuffer(buffer, length); + VideoDecoder::createFromBuffer(buffer, length, realSeek); return wrapDecoderPointerToTensor(std::move(videoDecoder)); } -at::Tensor create_from_buffer(const void* buffer, size_t length) { +at::Tensor create_from_buffer( + const void* buffer, + size_t length, + std::optional seek_mode) { + VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + if (seek_mode.has_value()) { + realSeek = seekModeFromString(seek_mode.value()); + } + std::unique_ptr uniqueDecoder = - VideoDecoder::createFromBuffer(buffer, length); + VideoDecoder::createFromBuffer(buffer, length, realSeek); return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 7717a48b..60635094 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -20,13 +20,20 @@ namespace facebook::torchcodec { // auto decoderTensor = createDecoderOp.call(videoPath); // Create a VideoDecoder from file and wrap the pointer in a tensor. -at::Tensor create_from_file(std::string_view filename); +at::Tensor create_from_file( + std::string_view filename, + std::optional seek_mode = std::nullopt); -at::Tensor create_from_tensor(at::Tensor video_tensor); +at::Tensor create_from_tensor( + at::Tensor video_tensor, + std::optional seek_mode = std::nullopt); // This API is C++ only and will not be exposed via custom ops, use // videodecoder_create_from_bytes in Python -at::Tensor create_from_buffer(const void* buffer, size_t length); +at::Tensor create_from_buffer( + const void* buffer, + size_t length, + std::optional seek_mode = std::nullopt); // Add a new video stream at `stream_index` using the provided options. void add_video_stream( diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index 4d5e9a31..e0400fdf 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -172,4 +172,4 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata: def get_video_metadata_from_header(filename: Union[str, pathlib.Path]) -> VideoMetadata: - return get_video_metadata(create_from_file(str(filename))) + return get_video_metadata(create_from_file(str(filename), seek_mode="approximate")) diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index d3f8e9a6..16e63392 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -94,25 +94,29 @@ def load_torchcodec_extension(): # ============================= # Functions not related to custom ops, but similar implementation to c++ ops # ============================= -def create_from_bytes(video_bytes: bytes) -> torch.Tensor: +def create_from_bytes( + video_bytes: bytes, seek_mode: Optional[str] = None +) -> torch.Tensor: with warnings.catch_warnings(): # Ignore warning stating that the underlying video_bytes buffer is # non-writable. warnings.filterwarnings("ignore", category=UserWarning) buffer = torch.frombuffer(video_bytes, dtype=torch.uint8) - return create_from_tensor(buffer) + return create_from_tensor(buffer, seek_mode) # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== @register_fake("torchcodec_ns::create_from_file") -def create_from_file_abstract(filename: str) -> torch.Tensor: +def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.Tensor: return torch.empty([], dtype=torch.long) @register_fake("torchcodec_ns::create_from_tensor") -def create_from_tensor_abstract(video_tensor: torch.Tensor) -> torch.Tensor: +def create_from_tensor_abstract( + video_tensor: torch.Tensor, seek_mode: Optional[str] +) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 7930756a..5d73c651 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -50,6 +50,10 @@ class VideoDecoder: Passing 0 lets FFmpeg decide on the number of threads. Default: 1. device (str or torch.device, optional): The device to use for decoding. Default: "cpu". + seek_mode (str, optional): Determines if index-based frame access will be "exact" or + "approximate". Exact guarantees that requesting frame i will always returns frame i, + but doing so requires an initial scan of the file. Approximate avoids scanning the + file, but uses the file's metadata to calculate where i probably is. Default: "exact". Attributes: @@ -67,15 +71,23 @@ def __init__( dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, device: Optional[Union[str, device]] = "cpu", + seek_mode: Literal["exact", "approximate"] = "exact", ): + allowed_seek_modes = ("exact", "approximate") + if seek_mode not in allowed_seek_modes: + raise ValueError( + f"Invalid seek mode ({seek_mode}). " + f"Supported values are {', '.join(allowed_seek_modes)}." + ) + if isinstance(source, str): - self._decoder = core.create_from_file(source) + self._decoder = core.create_from_file(source, seek_mode) elif isinstance(source, Path): - self._decoder = core.create_from_file(str(source)) + self._decoder = core.create_from_file(str(source), seek_mode) elif isinstance(source, bytes): - self._decoder = core.create_from_bytes(source) + self._decoder = core.create_from_bytes(source, seek_mode) elif isinstance(source, Tensor): - self._decoder = core.create_from_tensor(source) + self._decoder = core.create_from_tensor(source, seek_mode) else: raise TypeError( f"Unknown source type: {type(source)}. " @@ -92,7 +104,6 @@ def __init__( if num_ffmpeg_threads is None: raise ValueError(f"{num_ffmpeg_threads = } should be an int.") - core.scan_all_streams_to_update_metadata(self._decoder) core.add_video_stream( self._decoder, stream_index=stream_index, @@ -105,25 +116,44 @@ def __init__( self._decoder, stream_index ) - if self.metadata.num_frames_from_content is None: - raise ValueError( - "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS - ) - self._num_frames = self.metadata.num_frames_from_content + if seek_mode is "exact": + if self.metadata.num_frames_from_content is None: + raise ValueError( + "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS + ) + self._num_frames = self.metadata.num_frames_from_content + + if self.metadata.begin_stream_seconds is None: + raise ValueError( + "The minimum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + self._begin_stream_seconds = self.metadata.begin_stream_seconds + + if self.metadata.end_stream_seconds is None: + raise ValueError( + "The maximum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + self._end_stream_seconds = self.metadata.end_stream_seconds + elif seek_mode is "approximate": + if self.metadata.num_frames_from_header is None: + raise ValueError( + "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS + ) + self._num_frames = self.metadata.num_frames_from_header + + self._begin_stream_seconds = 0 + + if self.metadata.duration_seconds_from_header is None: + raise ValueError( + "The maximum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + self._end_stream_seconds = self.metadata.duration_seconds_from_header - if self.metadata.begin_stream_seconds is None: - raise ValueError( - "The minimum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._begin_stream_seconds = self.metadata.begin_stream_seconds - - if self.metadata.end_stream_seconds is None: - raise ValueError( - "The maximum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._end_stream_seconds = self.metadata.end_stream_seconds + else: + raise ValueError(f"Invalid seek mode: {seek_mode}.") def __len__(self) -> int: return self._num_frames diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 1ccceb62..be2af376 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -21,9 +21,10 @@ def _get_video_metadata(path, with_scan: bool): - decoder = create_from_file(str(path)) if with_scan: - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(path), seek_mode="exact") + else: + decoder = create_from_file(str(path), seek_mode="approximate") return get_video_metadata(decoder) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 1ff8266c..d4bbfb10 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -16,7 +16,8 @@ class TestVideoDecoder: @pytest.mark.parametrize("source_kind", ("str", "path", "tensor", "bytes")) - def test_create(self, source_kind): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_create(self, source_kind, seek_mode): if source_kind == "str": source = str(NASA_VIDEO.path) elif source_kind == "path": @@ -30,12 +31,11 @@ def test_create(self, source_kind): else: raise ValueError("Oops, double check the parametrization of this test!") - decoder = VideoDecoder(source) + decoder = VideoDecoder(source, seek_mode=seek_mode) assert isinstance(decoder.metadata, _core.VideoStreamMetadata) assert ( len(decoder) == decoder._num_frames - == decoder.metadata.num_frames_from_content == 390 ) assert decoder.stream_index == decoder.metadata.stream_index == 3 @@ -55,11 +55,18 @@ def test_create_fails(self): with pytest.raises(ValueError, match="No valid stream found"): decoder = VideoDecoder(NASA_VIDEO.path, stream_index=1) # noqa + with pytest.raises(ValueError, match="Invalid seek mode"): + decoder = VideoDecoder(NASA_VIDEO.path, seek_mode="blah") # noqa + @pytest.mark.parametrize("num_ffmpeg_threads", (1, 4)) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_getitem_int(self, num_ffmpeg_threads, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads, device=device + NASA_VIDEO.path, + num_ffmpeg_threads=num_ffmpeg_threads, + device=device, + seek_mode=seek_mode, ) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) @@ -103,8 +110,9 @@ def test_getitem_numpy_int(self): assert_frames_equal(ref_frame180, decoder[numpy.uint32(180)]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_getitem_slice(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_getitem_slice(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) # ensure that the degenerate case of a range of size 1 works @@ -244,8 +252,9 @@ def test_getitem_slice(self, device): assert_frames_equal(sliced, ref) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_getitem_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_getitem_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(IndexError, match="out of bounds"): frame = decoder[1000] # noqa @@ -260,8 +269,9 @@ def test_getitem_fails(self, device): frame = decoder[2.3] # noqa @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_iteration(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_iteration(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) @@ -305,8 +315,9 @@ def test_iteration_slow(self): assert iterations == len(decoder) == 390 @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_at(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_at(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device) frame9 = decoder.get_frame_at(9) @@ -348,8 +359,9 @@ def test_get_frame_at_tuple_unpacking(self, device): assert frame.duration_seconds == duration @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(-1) # noqa @@ -358,8 +370,9 @@ def test_get_frame_at_fails(self, device): frame = decoder.get_frame_at(10000) # noqa @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_at(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_at(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) frames = decoder.get_frames_at([35, 25]) @@ -397,8 +410,9 @@ def test_get_frames_at(self, device): ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(RuntimeError, match="Invalid frame index=-1"): decoder.get_frames_at([-1]) @@ -410,8 +424,9 @@ def test_get_frames_at_fails(self, device): decoder.get_frames_at([0.3]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_played_at(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_played_at(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device) assert_frames_equal( @@ -431,13 +446,16 @@ def test_get_frame_played_at_h265(self): # We don't parametrize with CUDA because the current GPUs on CI do not # support x265: # https://github.com/pytorch/torchcodec/pull/350#issuecomment-2465011730 - decoder = VideoDecoder(H265_VIDEO.path) + # Note that because our internal fix-up depends on the key frame index, it + # only works in exact seeking mode. + decoder = VideoDecoder(H265_VIDEO.path, seek_mode="exact") ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_played_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_played_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_played_at(-1.0) # noqa @@ -446,9 +464,10 @@ def test_get_frame_played_at_fails(self, device): frame = decoder.get_frame_played_at(100.0) # noqa @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_played_at(self, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_played_at(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has # index 35. We use those indices as reference to test against. @@ -462,6 +481,7 @@ def test_get_frames_played_at(self, device): assert_frames_equal( frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i]).to(device), + msg=f"index {i}", ) assert frames.pts_seconds.device.type == "cpu" @@ -483,8 +503,9 @@ def test_get_frames_played_at(self, device): ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_played_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_played_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(RuntimeError, match="must be in range"): decoder.get_frames_played_at([-1]) @@ -497,9 +518,10 @@ def test_get_frames_played_at_fails(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("stream_index", [0, 3, None]) - def test_get_frames_in_range(self, stream_index, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, stream_index=stream_index, device=device + NASA_VIDEO.path, stream_index=stream_index, device=device, seek_mode=seek_mode ) # test degenerate case where we only actually get 1 frame @@ -610,9 +632,10 @@ def test_get_frames_in_range(self, stream_index, device): ), ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_dimension_order(self, dimension_order, frame_getter, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, dimension_order=dimension_order, device=device + NASA_VIDEO.path, dimension_order=dimension_order, device=device, seek_mode=seek_mode ) frame = frame_getter(decoder) @@ -634,9 +657,10 @@ def test_dimension_order_fails(self): @pytest.mark.parametrize("stream_index", [0, 3, None]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_by_pts_in_range(self, stream_index, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, stream_index=stream_index, device=device + NASA_VIDEO.path, stream_index=stream_index, device=device, seek_mode=seek_mode ) # Note that we are comparing the results of VideoDecoder's method: @@ -769,8 +793,9 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): assert_frames_equal(all_frames.data, decoder[:]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_by_pts_in_range_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(ValueError, match="Invalid start seconds"): frame = decoder.get_frames_played_in_range(100.0, 1.0) # noqa diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 9baf6a39..41e893fd 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -32,7 +32,6 @@ get_frames_in_range, get_json_metadata, get_next_frame, - scan_all_streams_to_update_metadata, seek_to_pts, ) @@ -66,7 +65,7 @@ def seek(self, pts: float): class TestOps: @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_and_next(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") add_video_stream(decoder, device=device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -83,8 +82,7 @@ def test_seek_and_next(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_to_negative_pts(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder, device=device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -96,7 +94,7 @@ def test_seek_to_negative_pts(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_pts(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") add_video_stream(decoder, device=device) # This frame has pts=6.006 and duration=0.033367, so it should be visible # at timestamps in the range [6.006, 6.039367) (not including the last timestamp). @@ -120,8 +118,7 @@ def test_get_frame_at_pts(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_index(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder, device=device) frame0, _, _ = get_frame_at_index(decoder, stream_index=3, frame_index=0) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -135,8 +132,7 @@ def test_get_frame_at_index(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_with_info_at_index(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder, device=device) frame6, pts, duration = get_frame_at_index( decoder, stream_index=3, frame_index=180 @@ -150,8 +146,7 @@ def test_get_frame_with_info_at_index(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at_indices(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder, device=device) frames0and180, *_ = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] @@ -165,9 +160,8 @@ def test_get_frames_at_indices(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at_indices_unsorted_indices(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") _add_video_stream(decoder, device=device) - scan_all_streams_to_update_metadata(decoder) stream_index = 3 frame_indices = [2, 0, 1, 0, 2] @@ -197,9 +191,8 @@ def test_get_frames_at_indices_unsorted_indices(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_by_pts(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") _add_video_stream(decoder, device=device) - scan_all_streams_to_update_metadata(decoder) stream_index = 3 # Note: 13.01 should give the last video frame for the NASA video @@ -232,8 +225,7 @@ def test_pts_apis_against_index_ref(self, device): # Get all frames in the video, then query all frames with all time-based # APIs exactly where those frames are supposed to start. We assert that # we get the expected frame. - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder, device=device) metadata = get_json_metadata(decoder) @@ -289,8 +281,7 @@ def test_pts_apis_against_index_ref(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_in_range(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder, device=device) # ensure that the degenerate case of a range of size 1 works @@ -340,7 +331,7 @@ def test_get_frames_in_range(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_throws_exception_at_eof(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") add_video_stream(decoder, device=device) seek_to_pts(decoder, 12.979633) last_frame, _, _ = get_next_frame(decoder) @@ -351,7 +342,7 @@ def test_throws_exception_at_eof(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_throws_exception_if_seek_too_far(self, device): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") add_video_stream(decoder, device=device) # pts=12.979633 is the last frame in the video. seek_to_pts(decoder, 12.979633 + 1.0e-4) @@ -373,7 +364,7 @@ def get_frame1_and_frame_time6(decoder): # NB: create needs to happen outside the torch.compile region, # for now. Otherwise torch.compile constant-props it. - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") frame0, frame_time6 = get_frame1_and_frame_time6(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index( @@ -451,8 +442,7 @@ def test_video_get_json_metadata(self): assert metadata_dict["bitRate"] == 324915.0 def test_video_get_json_metadata_with_stream(self): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -479,8 +469,7 @@ def test_get_ffmpeg_version(self): assert "ffmpeg_version" in ffmpeg_dict def test_frame_pts_equality(self): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder) # Note that for all of these tests, we store the return value of @@ -500,7 +489,7 @@ def test_frame_pts_equality(self): @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) def test_color_conversion_library(self, color_conversion_library): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") _add_video_stream(decoder, color_conversion_library=color_conversion_library) frame0, *_ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -525,8 +514,7 @@ def test_color_conversion_library(self, color_conversion_library): def test_color_conversion_library_with_scaling( self, input_video, width_scaling_factor, height_scaling_factor ): - decoder = create_from_file(str(input_video.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(input_video.path), seek_mode="exact") add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -540,7 +528,7 @@ def test_color_conversion_library_with_scaling( if height_scaling_factor != 1.0: assert target_height != input_video.height - filtergraph_decoder = create_from_file(str(input_video.path)) + filtergraph_decoder = create_from_file(str(input_video.path), seek_mode="approximate") _add_video_stream( filtergraph_decoder, width=target_width, @@ -549,7 +537,7 @@ def test_color_conversion_library_with_scaling( ) filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - swscale_decoder = create_from_file(str(input_video.path)) + swscale_decoder = create_from_file(str(input_video.path), seek_mode="approximate") _add_video_stream( swscale_decoder, width=target_width, @@ -564,13 +552,12 @@ def test_color_conversion_library_with_scaling( def test_color_conversion_library_with_dimension_order( self, dimension_order, color_conversion_library ): - decoder = create_from_file(str(NASA_VIDEO.path)) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") _add_video_stream( decoder, color_conversion_library=color_conversion_library, dimension_order=dimension_order, ) - scan_all_streams_to_update_metadata(decoder) frame0_ref = NASA_VIDEO.get_frame_data_by_index(0) if dimension_order == "NHWC": @@ -650,8 +637,7 @@ def test_color_conversion_library_with_generated_videos( ] subprocess.check_call(command) - decoder = create_from_file(str(video_path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(video_path), seek_mode="exact") add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -665,7 +651,7 @@ def test_color_conversion_library_with_generated_videos( if height_scaling_factor != 1.0: assert target_height != height - filtergraph_decoder = create_from_file(str(video_path)) + filtergraph_decoder = create_from_file(str(video_path), seek_mode="approximate") _add_video_stream( filtergraph_decoder, width=target_width, @@ -674,7 +660,7 @@ def test_color_conversion_library_with_generated_videos( ) filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - auto_decoder = create_from_file(str(video_path)) + auto_decoder = create_from_file(str(video_path), seek_mode="approximate") add_video_stream( auto_decoder, width=target_width, @@ -685,8 +671,7 @@ def test_color_conversion_library_with_generated_videos( @needs_cuda def test_cuda_decoder(self): - decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") add_video_stream(decoder, device="cuda") frame0, pts, duration = get_next_frame(decoder) assert frame0.device.type == "cuda" From 8c9aeac716302117d15eaaf8fa5a49914d3d6258 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 19 Dec 2024 19:34:05 -0800 Subject: [PATCH 04/56] Apply lints --- .../decoders/benchmark_decoders_library.py | 30 +++++++++++++++---- .../decoders/_core/VideoDecoder.cpp | 15 +++++----- src/torchcodec/decoders/_video_decoder.py | 4 +-- test/decoders/test_metadata.py | 1 - test/decoders/test_video_decoder.py | 21 ++++++++----- test/decoders/test_video_decoder_ops.py | 8 +++-- 6 files changed, 53 insertions(+), 26 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index 5b4e8a02..e8fcd7d3 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -307,7 +307,10 @@ def decode_frames(self, video_file, pts_list): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) return decoder.get_frames_played_at(pts_list) @@ -316,7 +319,10 @@ def decode_first_n_frames(self, video_file, n): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] count = 0 @@ -332,7 +338,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = decoder.get_frames_played_at(pts_list) frames = self.transforms_v2.functional.resize(frames.data, (height, width)) @@ -354,7 +363,10 @@ def decode_frames(self, video_file, pts_list): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] @@ -368,7 +380,10 @@ def decode_first_n_frames(self, video_file, n): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] count = 0 @@ -384,7 +399,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index aa05b65a..86ee7286 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1162,7 +1162,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); - //validateScannedAllStreams("getFrameAtIndex"); // converted + // validateScannedAllStreams("getFrameAtIndex"); // converted const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; @@ -1177,7 +1177,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); - //validateScannedAllStreams("getFramesAtIndices"); // converted + // validateScannedAllStreams("getFramesAtIndices"); // converted auto indicesAreSorted = std::is_sorted(frameIndices.begin(), frameIndices.end()); @@ -1236,7 +1236,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( int streamIndex, const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); - //validateScannedAllStreams("getFramesPlayedByTimestamps"); // converted + // validateScannedAllStreams("getFramesPlayedByTimestamps"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; @@ -1291,8 +1291,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + ")."); - DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal( - framePts, output.frames[i]); + DecodedOutput singleOut = + getFramePlayedAtTimestampNoDemuxInternal(framePts, output.frames[i]); output.ptsSeconds[i] = singleOut.ptsSeconds; output.durationSeconds[i] = singleOut.durationSeconds; } @@ -1311,7 +1311,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t stop, int64_t step) { validateUserProvidedStreamIndex(streamIndex); - //validateScannedAllStreams("getFramesInRange"); // converted + // validateScannedAllStreams("getFramesInRange"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; @@ -1345,7 +1345,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange( double startSeconds, double stopSeconds) { validateUserProvidedStreamIndex(streamIndex); - //validateScannedAllStreams("getFramesPlayedByTimestampInRange"); // converted + // validateScannedAllStreams("getFramesPlayedByTimestampInRange"); // + // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; TORCH_CHECK( diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 5d73c651..2f36d06d 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -116,7 +116,7 @@ def __init__( self._decoder, stream_index ) - if seek_mode is "exact": + if seek_mode == "exact": if self.metadata.num_frames_from_content is None: raise ValueError( "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS @@ -136,7 +136,7 @@ def __init__( + _ERROR_REPORTING_INSTRUCTIONS ) self._end_stream_seconds = self.metadata.end_stream_seconds - elif seek_mode is "approximate": + elif seek_mode == "approximate": if self.metadata.num_frames_from_header is None: raise ValueError( "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index be2af376..83505b66 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -13,7 +13,6 @@ get_ffmpeg_library_versions, get_video_metadata, get_video_metadata_from_header, - scan_all_streams_to_update_metadata, VideoStreamMetadata, ) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index d4bbfb10..395699fb 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -33,11 +33,7 @@ def test_create(self, source_kind, seek_mode): decoder = VideoDecoder(source, seek_mode=seek_mode) assert isinstance(decoder.metadata, _core.VideoStreamMetadata) - assert ( - len(decoder) - == decoder._num_frames - == 390 - ) + assert len(decoder) == decoder._num_frames == 390 assert decoder.stream_index == decoder.metadata.stream_index == 3 assert decoder.metadata.duration_seconds == pytest.approx(13.013) assert decoder.metadata.average_fps == pytest.approx(29.970029) @@ -521,7 +517,10 @@ def test_get_frames_played_at_fails(self, device, seek_mode): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, stream_index=stream_index, device=device, seek_mode=seek_mode + NASA_VIDEO.path, + stream_index=stream_index, + device=device, + seek_mode=seek_mode, ) # test degenerate case where we only actually get 1 frame @@ -635,7 +634,10 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, dimension_order=dimension_order, device=device, seek_mode=seek_mode + NASA_VIDEO.path, + dimension_order=dimension_order, + device=device, + seek_mode=seek_mode, ) frame = frame_getter(decoder) @@ -660,7 +662,10 @@ def test_dimension_order_fails(self): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, stream_index=stream_index, device=device, seek_mode=seek_mode + NASA_VIDEO.path, + stream_index=stream_index, + device=device, + seek_mode=seek_mode, ) # Note that we are comparing the results of VideoDecoder's method: diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 41e893fd..f12f08ac 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -528,7 +528,9 @@ def test_color_conversion_library_with_scaling( if height_scaling_factor != 1.0: assert target_height != input_video.height - filtergraph_decoder = create_from_file(str(input_video.path), seek_mode="approximate") + filtergraph_decoder = create_from_file( + str(input_video.path), seek_mode="approximate" + ) _add_video_stream( filtergraph_decoder, width=target_width, @@ -537,7 +539,9 @@ def test_color_conversion_library_with_scaling( ) filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - swscale_decoder = create_from_file(str(input_video.path), seek_mode="approximate") + swscale_decoder = create_from_file( + str(input_video.path), seek_mode="approximate" + ) _add_video_stream( swscale_decoder, width=target_width, From 921b82285b6c7f21f8a6dfd23eb24263e9919f2a Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 19 Dec 2024 19:44:12 -0800 Subject: [PATCH 05/56] Default C++ tests to approximate mode --- test/decoders/VideoDecoderTest.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 1fbbd9a0..f996cd54 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -50,9 +50,9 @@ class VideoDecoderTest : public testing::TestWithParam { content_ = outputStringStream.str(); void* buffer = content_.data(); size_t length = outputStringStream.str().length(); - return VideoDecoder::createFromBuffer(buffer, length); + return VideoDecoder::createFromBuffer(buffer, length, VideoDecoder::SeekMode::approximate); } else { - return VideoDecoder::createFromFilePath(filepath); + return VideoDecoder::createFromFilePath(filepath, VideoDecoder::SeekMode::approximate); } } std::string content_; From b34928292d8bddc13503bce122743d9f0577d2f1 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 19 Dec 2024 19:44:53 -0800 Subject: [PATCH 06/56] Apply lints --- test/decoders/VideoDecoderTest.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index f996cd54..bed95a84 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -50,9 +50,11 @@ class VideoDecoderTest : public testing::TestWithParam { content_ = outputStringStream.str(); void* buffer = content_.data(); size_t length = outputStringStream.str().length(); - return VideoDecoder::createFromBuffer(buffer, length, VideoDecoder::SeekMode::approximate); + return VideoDecoder::createFromBuffer( + buffer, length, VideoDecoder::SeekMode::approximate); } else { - return VideoDecoder::createFromFilePath(filepath, VideoDecoder::SeekMode::approximate); + return VideoDecoder::createFromFilePath( + filepath, VideoDecoder::SeekMode::approximate); } } std::string content_; From 081a5bb629e6624cd772ceea3354bbca253d0ce4 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 20 Dec 2024 10:16:06 -0800 Subject: [PATCH 07/56] Updated metadata; all tests pass. --- .../decoders/_core/VideoDecoder.cpp | 72 ++++++++++++------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 +- src/torchcodec/decoders/_core/_metadata.py | 45 +++++++++--- src/torchcodec/decoders/_video_decoder.py | 55 +++++--------- test/decoders/test_metadata.py | 4 +- test/samplers/test_samplers.py | 19 +---- 6 files changed, 108 insertions(+), 91 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 86ee7286..2a3b3284 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1170,7 +1170,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); - return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor); + return getNextFrameNoDemuxInternal(preAllocatedOutputTensor); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( @@ -1252,19 +1252,19 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( std::vector frameIndices(timestamps.size()); for (auto i = 0; i < timestamps.size(); ++i) { - auto framePts = timestamps[i]; + auto frameSeconds = timestamps[i]; TORCH_CHECK( - framePts >= minSeconds && framePts < maxSeconds, - "frame pts is " + std::to_string(framePts) + "; must be in range [" + - std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + - ")."); + frameSeconds >= minSeconds && frameSeconds < maxSeconds, + "frame pts is " + std::to_string(frameSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); auto it = std::lower_bound( stream.allFrames.begin(), stream.allFrames.end(), - framePts, - [&stream](const FrameInfo& info, double framePts) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; + frameSeconds, + [&stream](const FrameInfo& info, double frameSeconds) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= frameSeconds; }); int64_t frameIndex = it - stream.allFrames.begin(); frameIndices[i] = frameIndex; @@ -1284,15 +1284,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( BatchDecodedOutput output(timestamps.size(), options, streamMetadata); for (auto i = 0; i < timestamps.size(); ++i) { - auto framePts = timestamps[i]; + auto frameSeconds = timestamps[i]; TORCH_CHECK( - framePts >= minSeconds && framePts < maxSeconds, - "frame pts is " + std::to_string(framePts) + "; must be in range [" + - std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + - ")."); + frameSeconds >= minSeconds && frameSeconds < maxSeconds, + "frame pts is " + std::to_string(frameSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); - DecodedOutput singleOut = - getFramePlayedAtTimestampNoDemuxInternal(framePts, output.frames[i]); + DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal( + frameSeconds, output.frames[i]); output.ptsSeconds[i] = singleOut.ptsSeconds; output.durationSeconds[i] = singleOut.durationSeconds; } @@ -1462,21 +1462,43 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // after the fact. We can't preallocate the final tensor because we don't // know how many frames we're going to decode up front. - setCursorPtsInSeconds(startSeconds); - DecodedOutput singleOut = getNextFrameNoDemux(); + DecodedOutput singleOut = + getFramePlayedAtTimestampNoDemuxInternal(startSeconds); - std::vector frames = {singleOut.frame}; - std::vector ptsSeconds = {singleOut.ptsSeconds}; - std::vector durationSeconds = {singleOut.durationSeconds}; + std::vector frames; + std::vector ptsSeconds; + std::vector durationSeconds; - while (singleOut.ptsSeconds < stopSeconds) { - singleOut = getNextFrameNoDemux(); + // Note that we only know we've decoded all frames in the range when we have + // decoded the first frame outside of the range. That is, we have to decode + // one frame past where we want to stop, and conclude from its pts that all + // of the prior frames comprises our range. That means we decode one extra + // frame; we don't return it, but we decode it. + // + // This algorithm works fine except when stopSeconds is the duration of the + // video. In that case, we're going to hit the end-of-file exception. + // + // We could avoid decoding an extra frame, and the end-of-file exception, by + // using the currently decoded frame's duration to know that the next frame + // is outside of our range. This would be more efficient. However, up until + // now we have avoided relying on a frame's duration to determine if a frame + // is played during a time window. So there is a potential TODO here where + // we relax that principle and just do the math to avoid the extra decode. + bool eof = false; + while (singleOut.ptsSeconds < stopSeconds && !eof) { frames.push_back(singleOut.frame); ptsSeconds.push_back(singleOut.ptsSeconds); durationSeconds.push_back(singleOut.durationSeconds); + + try { + singleOut = getNextFrameNoDemuxInternal(); + } catch (EndOfFileException e) { + eof = true; + } } BatchDecodedOutput output(frames, ptsSeconds, durationSeconds); + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } else { @@ -1494,12 +1516,12 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { } VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { - auto output = getNextFrameOutputNoDemuxInternal(); + auto output = getNextFrameNoDemuxInternal(); output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } -VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal( +VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 2ecc28ac..5726b8c5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -161,7 +161,7 @@ class VideoDecoder { // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextFrameOutputNoDemuxInternal() will return the first frame at + // Calling getNextFrameNoDemuxInternal() will return the first frame at // or after this position. void setCursorPtsInSeconds(double seconds); // This is an internal structure that is used to store the decoded output @@ -433,7 +433,7 @@ class VideoDecoder { DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); - DecodedOutput getNextFrameOutputNoDemuxInternal( + DecodedOutput getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor = std::nullopt); SeekMode seekMode_; diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index e0400fdf..424d3dbc 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -37,12 +37,12 @@ class VideoStreamMetadata: content (the scan doesn't involve decoding). This is more accurate than ``num_frames_from_header``. We recommend using the ``num_frames`` attribute instead. (int or None).""" - begin_stream_seconds: Optional[float] + begin_stream_seconds_from_content: Optional[float] """Beginning of the stream, in seconds (float or None). Conceptually, this corresponds to the first frame's :term:`pts`. It is computed as min(frame.pts) across all frames in the stream. Usually, this is equal to 0.""" - end_stream_seconds: Optional[float] + end_stream_seconds_from_content: Optional[float] """End of the stream, in seconds (float or None). Conceptually, this corresponds to last_frame.pts + last_frame.duration. It is computed as max(frame.pts + frame.duration) across all frames in the @@ -81,9 +81,15 @@ def duration_seconds(self) -> Optional[float]: from the actual frames if a :term:`scan` was performed. Otherwise we fall back to ``duration_seconds_from_header``. """ - if self.end_stream_seconds is None or self.begin_stream_seconds is None: + if ( + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None + ): return self.duration_seconds_from_header - return self.end_stream_seconds - self.begin_stream_seconds + return ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) @property def average_fps(self) -> Optional[float]: @@ -92,12 +98,29 @@ def average_fps(self) -> Optional[float]: Otherwise we fall back to ``average_fps_from_header``. """ if ( - self.end_stream_seconds is None - or self.begin_stream_seconds is None + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None or self.num_frames is None ): return self.average_fps_from_header - return self.num_frames / (self.end_stream_seconds - self.begin_stream_seconds) + return self.num_frames / ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) + + @property + def begin_stream_seconds(self) -> float: + """TODO.""" + if self.begin_stream_seconds_from_content is None: + return 0 + return self.begin_stream_seconds_from_content + + @property + def end_stream_seconds(self) -> Optional[float]: + """TODO.""" + if self.end_stream_seconds_from_content is None: + return self.duration_seconds + return self.end_stream_seconds_from_content def __repr__(self): # Overridden because properites are not printed by default. @@ -152,8 +175,12 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata: bit_rate=stream_dict.get("bitRate"), num_frames_from_header=stream_dict.get("numFrames"), num_frames_from_content=stream_dict.get("numFramesFromScan"), - begin_stream_seconds=stream_dict.get("minPtsSecondsFromScan"), - end_stream_seconds=stream_dict.get("maxPtsSecondsFromScan"), + begin_stream_seconds_from_content=stream_dict.get( + "minPtsSecondsFromScan" + ), + end_stream_seconds_from_content=stream_dict.get( + "maxPtsSecondsFromScan" + ), codec=stream_dict.get("codec"), width=stream_dict.get("width"), height=stream_dict.get("height"), diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 2f36d06d..a0561423 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -116,44 +116,25 @@ def __init__( self._decoder, stream_index ) - if seek_mode == "exact": - if self.metadata.num_frames_from_content is None: - raise ValueError( - "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS - ) - self._num_frames = self.metadata.num_frames_from_content - - if self.metadata.begin_stream_seconds is None: - raise ValueError( - "The minimum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._begin_stream_seconds = self.metadata.begin_stream_seconds - - if self.metadata.end_stream_seconds is None: - raise ValueError( - "The maximum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._end_stream_seconds = self.metadata.end_stream_seconds - elif seek_mode == "approximate": - if self.metadata.num_frames_from_header is None: - raise ValueError( - "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS - ) - self._num_frames = self.metadata.num_frames_from_header - - self._begin_stream_seconds = 0 - - if self.metadata.duration_seconds_from_header is None: - raise ValueError( - "The maximum pts value in seconds is unknown. " - + _ERROR_REPORTING_INSTRUCTIONS - ) - self._end_stream_seconds = self.metadata.duration_seconds_from_header + if self.metadata.num_frames is None: + raise ValueError( + "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS + ) + self._num_frames = self.metadata.num_frames - else: - raise ValueError(f"Invalid seek mode: {seek_mode}.") + if self.metadata.begin_stream_seconds is None: + raise ValueError( + "The minimum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + self._begin_stream_seconds = self.metadata.begin_stream_seconds + + if self.metadata.end_stream_seconds is None: + raise ValueError( + "The maximum pts value in seconds is unknown. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + self._end_stream_seconds = self.metadata.end_stream_seconds def __len__(self) -> int: return self._num_frames diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 83505b66..8a6830ad 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -92,8 +92,8 @@ def test_num_frames_fallback( bit_rate=123, num_frames_from_header=num_frames_from_header, num_frames_from_content=num_frames_from_content, - begin_stream_seconds=0, - end_stream_seconds=4, + begin_stream_seconds_from_content=0, + end_stream_seconds_from_content=4, codec="whatever", width=123, height=321, diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index d5c7eb44..94225574 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -590,23 +590,10 @@ def restore_metadata(): decoder.metadata = original_metadata with restore_metadata(): - decoder.metadata.begin_stream_seconds = None - with pytest.raises( - ValueError, match="Could not infer stream end and start from video metadata" - ): - sampler(decoder) - - with restore_metadata(): - decoder.metadata.end_stream_seconds = None - with pytest.raises( - ValueError, match="Could not infer stream end and start from video metadata" - ): - sampler(decoder) - - with restore_metadata(): - decoder.metadata.begin_stream_seconds = None - decoder.metadata.end_stream_seconds = None + decoder.metadata.begin_stream_seconds_from_content = None + decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.average_fps_from_header = None + decoder.metadata.duration_seconds_from_header = None with pytest.raises(ValueError, match="Could not infer average fps"): sampler(decoder) From 802b881b92cd5c859cf966f7b23da9d39da670c6 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 20 Dec 2024 11:27:14 -0800 Subject: [PATCH 08/56] Removed commened out code. --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 2a3b3284..9cc1dd46 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1162,7 +1162,6 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); - // validateScannedAllStreams("getFrameAtIndex"); // converted const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; @@ -1177,7 +1176,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); - // validateScannedAllStreams("getFramesAtIndices"); // converted auto indicesAreSorted = std::is_sorted(frameIndices.begin(), frameIndices.end()); @@ -1236,7 +1234,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( int streamIndex, const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); - // validateScannedAllStreams("getFramesPlayedByTimestamps"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; @@ -1311,7 +1308,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t stop, int64_t step) { validateUserProvidedStreamIndex(streamIndex); - // validateScannedAllStreams("getFramesInRange"); // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; @@ -1345,8 +1341,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange( double startSeconds, double stopSeconds) { validateUserProvidedStreamIndex(streamIndex); - // validateScannedAllStreams("getFramesPlayedByTimestampInRange"); // - // converted const auto& streamMetadata = containerMetadata_.streams[streamIndex]; TORCH_CHECK( From 911a3bcff650dadcbbee69fcfe7c8c14e4814bf3 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 20 Dec 2024 13:43:51 -0800 Subject: [PATCH 09/56] Consolidated logic for timestamp batch. Big perf win. --- .../decoders/benchmark_decoders_library.py | 2 +- .../decoders/_core/VideoDecoder.cpp | 95 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 3 + 3 files changed, 50 insertions(+), 50 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index e8fcd7d3..b0b69dfd 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -349,7 +349,7 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): class TorchCodecPublicNonBatch(AbstractDecoder): - def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"): + def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="approximate"): self._num_ffmpeg_threads = num_ffmpeg_threads self._device = device self._seek_mode = seek_mode diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 9cc1dd46..07e25739 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1157,6 +1157,28 @@ int64_t VideoDecoder::getFramesSize( } } +double VideoDecoder::getMinSeconds(const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: + return streamMetadata.minPtsSecondsFromScan.value(); + case SeekMode::approximate: + return 0; + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: + return streamMetadata.maxPtsSecondsFromScan.value(); + case SeekMode::approximate: + return streamMetadata.durationSeconds.value(); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, @@ -1238,24 +1260,25 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; - if (seekMode_ == SeekMode::exact) { - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); + double minSeconds = getMinSeconds(streamMetadata); + double maxSeconds = getMaxSeconds(streamMetadata); - // The frame played at timestamp t and the one played at timestamp `t + - // eps` are probably the same frame, with the same index. The easiest way to - // avoid decoding that unique frame twice is to convert the input timestamps - // to indices, and leverage the de-duplication logic of getFramesAtIndices. + // The frame played at timestamp t and the one played at timestamp `t + + // eps` are probably the same frame, with the same index. The easiest way to + // avoid decoding that unique frame twice is to convert the input timestamps + // to indices, and leverage the de-duplication logic of getFramesAtIndices. - std::vector frameIndices(timestamps.size()); - for (auto i = 0; i < timestamps.size(); ++i) { - auto frameSeconds = timestamps[i]; - TORCH_CHECK( - frameSeconds >= minSeconds && frameSeconds < maxSeconds, - "frame pts is " + std::to_string(frameSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); + std::vector frameIndices(timestamps.size()); + for (auto i = 0; i < timestamps.size(); ++i) { + auto frameSeconds = timestamps[i]; + TORCH_CHECK( + frameSeconds >= minSeconds && frameSeconds < maxSeconds, + "frame pts is " + std::to_string(frameSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + int64_t frameIndex = -1; + if (seekMode_ == SeekMode::exact) { auto it = std::lower_bound( stream.allFrames.begin(), stream.allFrames.end(), @@ -1263,43 +1286,17 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( [&stream](const FrameInfo& info, double frameSeconds) { return ptsToSeconds(info.nextPts, stream.timeBase) <= frameSeconds; }); - int64_t frameIndex = it - stream.allFrames.begin(); - frameIndices[i] = frameIndex; - } - - return getFramesAtIndices(streamIndex, frameIndices); - - } else if (seekMode_ == SeekMode::approximate) { - double minSeconds = 0; - double maxSeconds = streamMetadata.durationSeconds.value(); - - // TODO: Figure out if we can be smarter than just iterating over the - // timestamps one-by-one. - - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - const auto& options = stream.options; - BatchDecodedOutput output(timestamps.size(), options, streamMetadata); - - for (auto i = 0; i < timestamps.size(); ++i) { - auto frameSeconds = timestamps[i]; - TORCH_CHECK( - frameSeconds >= minSeconds && frameSeconds < maxSeconds, - "frame pts is " + std::to_string(frameSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - - DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal( - frameSeconds, output.frames[i]); - output.ptsSeconds[i] = singleOut.ptsSeconds; - output.durationSeconds[i] = singleOut.durationSeconds; + frameIndex = it - stream.allFrames.begin(); + } else if (seekMode_ == SeekMode::approximate) { + frameIndex = std::floor(frameSeconds * streamMetadata.averageFps.value()); + } else { + throw std::runtime_error("Unknown SeekMode"); } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); - return output; - - } else { - throw std::runtime_error("Unknown SeekMode"); + frameIndices[i] = frameIndex; } + + return getFramesAtIndices(streamIndex, frameIndices); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 5726b8c5..68bb4305 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -398,6 +398,9 @@ class VideoDecoder { const StreamMetadata& streamMetadata, int64_t frameIndex); + double getMinSeconds(const StreamMetadata& streamMetadata); + double getMaxSeconds(const StreamMetadata& streamMetadata); + void createSwsContext( StreamInfo& streamInfo, const DecodedFrameContext& frameContext, From 7267b5abb59d440654bad9ea08543e9390d4cfd8 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 20 Dec 2024 17:37:18 -0800 Subject: [PATCH 10/56] Consolidated logic for timestamp range. --- .../decoders/_core/VideoDecoder.cpp | 207 ++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 10 + 2 files changed, 96 insertions(+), 121 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 07e25739..c77bdfd8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1179,6 +1179,52 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { } } +int64_t VideoDecoder::secondsToIndexLowerBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: { + auto frame = std::lower_bound( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + seconds, + [&streamInfo](const FrameInfo& info, double start) { + return ptsToSeconds(info.nextPts, streamInfo.timeBase) <= start; + }); + + return frame - streamInfo.allFrames.begin(); + } + case SeekMode::approximate: + return std::floor(seconds * streamMetadata.averageFps.value()); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +int64_t VideoDecoder::secondsToIndexUpperBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: { + auto frame = std::upper_bound( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + seconds, + [&streamInfo](double stop, const FrameInfo& info) { + return stop <= ptsToSeconds(info.pts, streamInfo.timeBase); + }); + + return frame - streamInfo.allFrames.begin(); + } + case SeekMode::approximate: + return std::ceil(seconds * streamMetadata.averageFps.value()); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, @@ -1372,129 +1418,48 @@ VideoDecoder::getFramesPlayedByTimestampInRange( return output; } - if (seekMode_ == SeekMode::exact) { - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); - TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, - "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + - ")."); - - // Note that we look at nextPts for a frame, and not its pts or duration. - // Our abstract player displays frames starting at the pts for that frame - // until the pts for the next frame. There are two consequences: - // - // 1. We ignore the duration for a frame. A frame is played until the - // next frame replaces it. This model is robust to durations being 0 or - // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + - // duration. - // 2. In order to establish if the start of an interval maps to a - // particular frame, we need to figure out if it is ordered after the - // frame's pts, but before the next frames's pts. - - auto startFrame = std::lower_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - startSeconds, - [&stream](const FrameInfo& info, double start) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= start; - }); - - auto stopFrame = std::upper_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - stopSeconds, - [&stream](double stop, const FrameInfo& info) { - return stop <= ptsToSeconds(info.pts, stream.timeBase); - }); - - int64_t startFrameIndex = startFrame - stream.allFrames.begin(); - int64_t stopFrameIndex = stopFrame - stream.allFrames.begin(); - int64_t numFrames = stopFrameIndex - startFrameIndex; - BatchDecodedOutput output(numFrames, options, streamMetadata); - for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = - getFrameAtIndexInternal(streamIndex, i, output.frames[f]); - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; - } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); - - return output; - - } else if (seekMode_ == SeekMode::approximate) { - double minSeconds = 0; - double maxSeconds = streamMetadata.durationSeconds.value(); - TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, - "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + - ")."); - - // Because we can only discover when to stop by doing the actual decoding, - // we can't pre-allocate the correct dimensions for our BatchDecodedOutput; - // we don't yet know N, the number of frames. So we have to store all of the - // decoded frames in a vector, and construct the final data tensor after. - - // TODO: Figure out if there is a better of doing this. That is, we store - // everything in vectors and then call torch::stack and torch::tensor.clone - // after the fact. We can't preallocate the final tensor because we don't - // know how many frames we're going to decode up front. - + double minSeconds = getMinSeconds(streamMetadata); + double maxSeconds = getMaxSeconds(streamMetadata); + TORCH_CHECK( + startSeconds >= minSeconds && startSeconds < maxSeconds, + "Start seconds is " + std::to_string(startSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + TORCH_CHECK( + stopSeconds <= maxSeconds, + "Stop seconds (" + std::to_string(stopSeconds) + + "; must be less than or equal to " + std::to_string(maxSeconds) + + ")."); + + // Note that we look at nextPts for a frame, and not its pts or duration. + // Our abstract player displays frames starting at the pts for that frame + // until the pts for the next frame. There are two consequences: + // + // 1. We ignore the duration for a frame. A frame is played until the + // next frame replaces it. This model is robust to durations being 0 or + // incorrect; our source of truth is the pts for frames. If duration is + // accurate, the nextPts for a frame would be equivalent to pts + + // duration. + // 2. In order to establish if the start of an interval maps to a + // particular frame, we need to figure out if it is ordered after the + // frame's pts, but before the next frames's pts. + + int64_t startFrameIndex = + secondsToIndexLowerBound(startSeconds, stream, streamMetadata); + int64_t stopFrameIndex = + secondsToIndexUpperBound(stopSeconds, stream, streamMetadata); + int64_t numFrames = stopFrameIndex - startFrameIndex; + + BatchDecodedOutput output(numFrames, options, streamMetadata); + for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { DecodedOutput singleOut = - getFramePlayedAtTimestampNoDemuxInternal(startSeconds); - - std::vector frames; - std::vector ptsSeconds; - std::vector durationSeconds; - - // Note that we only know we've decoded all frames in the range when we have - // decoded the first frame outside of the range. That is, we have to decode - // one frame past where we want to stop, and conclude from its pts that all - // of the prior frames comprises our range. That means we decode one extra - // frame; we don't return it, but we decode it. - // - // This algorithm works fine except when stopSeconds is the duration of the - // video. In that case, we're going to hit the end-of-file exception. - // - // We could avoid decoding an extra frame, and the end-of-file exception, by - // using the currently decoded frame's duration to know that the next frame - // is outside of our range. This would be more efficient. However, up until - // now we have avoided relying on a frame's duration to determine if a frame - // is played during a time window. So there is a potential TODO here where - // we relax that principle and just do the math to avoid the extra decode. - bool eof = false; - while (singleOut.ptsSeconds < stopSeconds && !eof) { - frames.push_back(singleOut.frame); - ptsSeconds.push_back(singleOut.ptsSeconds); - durationSeconds.push_back(singleOut.durationSeconds); - - try { - singleOut = getNextFrameNoDemuxInternal(); - } catch (EndOfFileException e) { - eof = true; - } - } - - BatchDecodedOutput output(frames, ptsSeconds, durationSeconds); - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); - return output; - - } else { - throw std::runtime_error("Unknown SeekMode"); + getFrameAtIndexInternal(streamIndex, i, output.frames[f]); + output.ptsSeconds[f] = singleOut.ptsSeconds; + output.durationSeconds[f] = singleOut.durationSeconds; } + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + + return output; } VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 68bb4305..187e199c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -401,6 +401,16 @@ class VideoDecoder { double getMinSeconds(const StreamMetadata& streamMetadata); double getMaxSeconds(const StreamMetadata& streamMetadata); + int64_t secondsToIndexLowerBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata); + + int64_t secondsToIndexUpperBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata); + void createSwsContext( StreamInfo& streamInfo, const DecodedFrameContext& frameContext, From ae44f784ae64dce79b298ed89899f8bd860e648f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 20 Dec 2024 18:06:03 -0800 Subject: [PATCH 11/56] More mode consolidation. --- .../decoders/_core/VideoDecoder.cpp | 35 ++----------------- src/torchcodec/decoders/_core/VideoDecoder.h | 5 --- 2 files changed, 2 insertions(+), 38 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index c77bdfd8..09b5b115 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -204,22 +204,6 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); } -VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( - const std::vector& inFrames, - std::vector& inPtsSeconds, - std::vector& inDurationSeconds) - : frames(torch::stack(inFrames)), - ptsSeconds(torch::from_blob( - inPtsSeconds.data(), - inPtsSeconds.size(), - {torch::kFloat64}) - .clone()), - durationSeconds(torch::from_blob( - inDurationSeconds.data(), - inDurationSeconds.size(), - {torch::kFloat64}) - .clone()) {} - bool VideoDecoder::DecodedFrameContext::operator==( const VideoDecoder::DecodedFrameContext& other) { return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight && @@ -1323,23 +1307,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( "; must be in range [" + std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + ")."); - int64_t frameIndex = -1; - if (seekMode_ == SeekMode::exact) { - auto it = std::lower_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - frameSeconds, - [&stream](const FrameInfo& info, double frameSeconds) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= frameSeconds; - }); - frameIndex = it - stream.allFrames.begin(); - } else if (seekMode_ == SeekMode::approximate) { - frameIndex = std::floor(frameSeconds * streamMetadata.averageFps.value()); - } else { - throw std::runtime_error("Unknown SeekMode"); - } - - frameIndices[i] = frameIndex; + frameIndices[i] = + secondsToIndexLowerBound(frameSeconds, stream, streamMetadata); } return getFramesAtIndices(streamIndex, frameIndices); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 187e199c..8ee09ff2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -242,11 +242,6 @@ class VideoDecoder { int64_t numFrames, const VideoStreamDecoderOptions& options, const StreamMetadata& metadata); - - explicit BatchDecodedOutput( - const std::vector& inFrames, - std::vector& inPtsSeconds, - std::vector& inDurationSeconds); }; // Returns frames at the given indices for a given stream as a single stacked From f62af46dd2b1e539ba5cb5d7ea3801db50c51be7 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 16:29:01 -0800 Subject: [PATCH 12/56] Provide constructor param names --- src/torchcodec/decoders/_core/VideoDecoder.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index ede19e66..35c7c48a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -53,11 +53,11 @@ class VideoDecoder { enum class SeekMode { exact, approximate }; // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder(const std::string& videoFilePath, SeekMode); + explicit VideoDecoder(const std::string& videoFilePath, SeekMode seek); // Creates a VideoDecoder from a given buffer. Note that the buffer is not // owned by the VideoDecoder. - explicit VideoDecoder(const void* buffer, size_t length, SeekMode); + explicit VideoDecoder(const void* buffer, size_t length, SeekMode seek); static std::unique_ptr createFromFilePath( const std::string& videoFilePath, From d15dfa0f510950d6e6ecedc8136e633122c27879 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 16:32:46 -0800 Subject: [PATCH 13/56] getFramesSize -> getNumFrames --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 16 ++++++++-------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6008c69b..03815ac2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1103,12 +1103,12 @@ void VideoDecoder::validateFrameIndex( const StreamInfo& streamInfo, const StreamMetadata& streamMetadata, int64_t frameIndex) { - int64_t framesSize = getFramesSize(streamInfo, streamMetadata); + int64_t numFrames = getNumFrames(streamInfo, streamMetadata); TORCH_CHECK( - frameIndex >= 0 && frameIndex < framesSize, + frameIndex >= 0 && frameIndex < numFrames, "Invalid frame index=" + std::to_string(frameIndex) + " for streamIndex=" + std::to_string(streamInfo.streamIndex) + - " numFrames=" + std::to_string(framesSize)); + " numFrames=" + std::to_string(numFrames)); } VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( @@ -1134,7 +1134,7 @@ int64_t VideoDecoder::getPts( } } -int64_t VideoDecoder::getFramesSize( +int64_t VideoDecoder::getNumFrames( const StreamInfo& streamInfo, const StreamMetadata& streamMetadata) { switch (seekMode_) { @@ -1264,7 +1264,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; if (indexInVideo < 0 || - indexInVideo >= getFramesSize(stream, streamMetadata)) { + indexInVideo >= getNumFrames(stream, streamMetadata)) { throw std::runtime_error( "Invalid frame index=" + std::to_string(indexInVideo)); } @@ -1329,13 +1329,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; - int64_t framesSize = getFramesSize(stream, streamMetadata); + int64_t numFrames = getNumFrames(stream, streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); TORCH_CHECK( - stop <= framesSize, + stop <= numFrames, "Range stop, " + std::to_string(stop) + - ", is more than the number of frames, " + std::to_string(framesSize)); + ", is more than the number of frames, " + std::to_string(numFrames)); TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 35c7c48a..9f8ea129 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -384,7 +384,7 @@ class VideoDecoder { int expectedOutputHeight, int expectedOutputWidth); - int64_t getFramesSize( + int64_t getNumFrames( const StreamInfo& streamInfo, const StreamMetadata& streamMetadata); From 8879c324e1d9f76e9cf1a8b2fa10334168d1fe30 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 17:20:10 -0800 Subject: [PATCH 14/56] Use seek_mode to paramterize metadata tests --- test/decoders/test_metadata.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 8a6830ad..e4caad37 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -19,11 +19,8 @@ from ..utils import NASA_VIDEO -def _get_video_metadata(path, with_scan: bool): - if with_scan: - decoder = create_from_file(str(path), seek_mode="exact") - else: - decoder = create_from_file(str(path), seek_mode="approximate") +def _get_video_metadata(path, seek_mode): + decoder = create_from_file(str(path), seek_mode=seek_mode) return get_video_metadata(decoder) @@ -31,13 +28,13 @@ def _get_video_metadata(path, with_scan: bool): "metadata_getter", ( get_video_metadata_from_header, - functools.partial(_get_video_metadata, with_scan=False), - functools.partial(_get_video_metadata, with_scan=True), + functools.partial(_get_video_metadata, seek_mode="approximate"), + functools.partial(_get_video_metadata, seek_mode="exact"), ), ) def test_get_metadata(metadata_getter): with_scan = ( - metadata_getter.keywords["with_scan"] + metadata_getter.keywords["seek_mode"] == "exact" if isinstance(metadata_getter, functools.partial) else False ) From 8443de5bca3fb0ea3a9b2c61a60e5c749eabba67 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 17:30:40 -0800 Subject: [PATCH 15/56] stream -> streamInfo --- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 9f8ea129..092565e4 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -389,7 +389,7 @@ class VideoDecoder { const StreamMetadata& streamMetadata); int64_t getPts( - const StreamInfo& stream, + const StreamInfo& streamInfo, const StreamMetadata& streamMetadata, int64_t frameIndex); From e34ca31e07242dbab4cd8aae2e34077663186ce9 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 18:01:37 -0800 Subject: [PATCH 16/56] seek -> seekMode --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 8 ++++---- src/torchcodec/decoders/_core/VideoDecoder.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 03815ac2..723a8938 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -217,16 +217,16 @@ bool VideoDecoder::DecodedFrameContext::operator!=( return !(*this == other); } -VideoDecoder::VideoDecoder(const std::string& videoFilePath, SeekMode seek) - : seekMode_(seek) { +VideoDecoder::VideoDecoder(const std::string& videoFilePath, SeekMode seekMode) + : seekMode_(seekMode) { AVInput input = createAVFormatContextFromFilePath(videoFilePath); formatContext_ = std::move(input.formatContext); initializeDecoder(); } -VideoDecoder::VideoDecoder(const void* buffer, size_t length, SeekMode seek) - : seekMode_(seek) { +VideoDecoder::VideoDecoder(const void* buffer, size_t length, SeekMode seekMode) + : seekMode_(seekMode) { TORCH_CHECK(buffer != nullptr, "Video buffer cannot be nullptr!"); AVInput input = createAVFormatContextFromBuffer(buffer, length); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 092565e4..d6736fe3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -53,11 +53,11 @@ class VideoDecoder { enum class SeekMode { exact, approximate }; // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder(const std::string& videoFilePath, SeekMode seek); + explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); // Creates a VideoDecoder from a given buffer. Note that the buffer is not // owned by the VideoDecoder. - explicit VideoDecoder(const void* buffer, size_t length, SeekMode seek); + explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); static std::unique_ptr createFromFilePath( const std::string& videoFilePath, From 67c1225ce4ae742b7c8e504ce391d20efbd9c661 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 10 Jan 2025 18:09:45 -0800 Subject: [PATCH 17/56] remove getFramePlayedAtTimestampNoDemuxInternal --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 13 +++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 ---- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 723a8938..c8b60943 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1034,15 +1034,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( double seconds) { - auto output = getFramePlayedAtTimestampNoDemuxInternal(seconds); - output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); - return output; -} - -VideoDecoder::DecodedOutput -VideoDecoder::getFramePlayedAtTimestampNoDemuxInternal( - double seconds, - std::optional preAllocatedOutputTensor) { for (auto& [streamIndex, stream] : streams_) { double frameStartTime = ptsToSeconds(stream.currentPts, stream.timeBase); double frameEndTime = ptsToSeconds( @@ -1076,7 +1067,9 @@ VideoDecoder::getFramePlayedAtTimestampNoDemuxInternal( }); // Convert the frame to tensor. - return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); + DecodedOutput output = convertAVFrameToDecodedOutput(rawOutput); + output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); + return output; } void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index d6736fe3..f1be3c8b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -411,10 +411,6 @@ class VideoDecoder { const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); - DecodedOutput getFramePlayedAtTimestampNoDemuxInternal( - double seconds, - std::optional preAllocatedOutputTensor = std::nullopt); - void maybeSeekToBeforeDesiredPts(); RawDecodedOutput getDecodedOutputWithFilter( std::function); From 06bb2c3a2e2b5e5b4eb5b3a82b158eeae8767d66 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 14 Jan 2025 11:22:22 -0800 Subject: [PATCH 18/56] Rationalize time based samplers and valid metadata --- src/torchcodec/samplers/_time_based.py | 11 ++++++----- test/samplers/test_samplers.py | 9 +++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index 2b531e53..03d11575 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -38,12 +38,13 @@ def _validate_params_time_based( "Could not infer average fps from video metadata. " "Try using an index-based sampler instead." ) - if ( - decoder.metadata.end_stream_seconds is None - or decoder.metadata.begin_stream_seconds is None - ): + + # Note that metadata.begin_stream_seconds is a property that will always yield a valid + # value; if it is not present in the actual metadata, the metadata object will return 0. + # Hence, we do not test for it here and only test metadata.end_stream_seconds. + if decoder.metadata.end_stream_seconds is None: raise ValueError( - "Could not infer stream end and start from video metadata. " + "Could not infer stream end from video metadata. " "Try using an index-based sampler instead." ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 94225574..ee3f7276 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -589,6 +589,15 @@ def restore_metadata(): finally: decoder.metadata = original_metadata + with restore_metadata(): + decoder.metadata.end_stream_seconds_from_content = None + decoder.metadata.duration_seconds_from_header = None + decoder.metadata.duration_seconds_from_content = None + with pytest.raises( + ValueError, match="Could not infer stream end from video metadata" + ): + sampler(decoder) + with restore_metadata(): decoder.metadata.begin_stream_seconds_from_content = None decoder.metadata.end_stream_seconds_from_content = None From 8ed3c5e5c4df2b041ad7c39a6c6e6b46ac46dd53 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 14 Jan 2025 12:04:32 -0800 Subject: [PATCH 19/56] Refactor setting and using scanned number of frames --- .../decoders/_core/VideoDecoder.cpp | 55 +++++++++++-------- src/torchcodec/decoders/_core/VideoDecoder.h | 5 +- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6659fa08..c831e571 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -570,41 +570,51 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { if (scannedAllStreams_) { return; } + while (true) { + // Get the next packet. UniqueAVPacket packet(av_packet_alloc()); int ffmpegStatus = av_read_frame(formatContext_.get(), packet.get()); + if (ffmpegStatus == AVERROR_EOF) { break; } + if (ffmpegStatus != AVSUCCESS) { throw std::runtime_error( "Failed to read frame from input file: " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - int streamIndex = packet->stream_index; if (packet->flags & AV_PKT_FLAG_DISCARD) { continue; } - auto& stream = containerMetadata_.streams[streamIndex]; - stream.minPtsFromScan = - std::min(stream.minPtsFromScan.value_or(INT64_MAX), packet->pts); - stream.maxPtsFromScan = std::max( - stream.maxPtsFromScan.value_or(INT64_MIN), - packet->pts + packet->duration); - stream.numFramesFromScan = stream.numFramesFromScan.value_or(0) + 1; - FrameInfo frameInfo; - frameInfo.pts = packet->pts; + // We got a valid packet. Let's figure out what stream it belongs to and + // record its relevant metadata. + int streamIndex = packet->stream_index; + auto& streamMetadata = containerMetadata_.streams[streamIndex]; + streamMetadata.minPtsFromScan = std::min( + streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts); + streamMetadata.maxPtsFromScan = std::max( + streamMetadata.maxPtsFromScan.value_or(INT64_MIN), + packet->pts + packet->duration); + FrameInfo frameInfo{.pts = packet->pts}; if (packet->flags & AV_PKT_FLAG_KEY) { streams_[streamIndex].keyFrames.push_back(frameInfo); } streams_[streamIndex].allFrames.push_back(frameInfo); } + + // Set all per-stream metadata that requires knowing the content of all + // packets. for (int i = 0; i < containerMetadata_.streams.size(); ++i) { auto& streamMetadata = containerMetadata_.streams[i]; auto stream = formatContext_->streams[i]; + + streamMetadata.numFramesFromScan = streams_[i].allFrames.size(); + if (streamMetadata.minPtsFromScan.has_value()) { streamMetadata.minPtsSecondsFromScan = *streamMetadata.minPtsFromScan * av_q2d(stream->time_base); @@ -614,6 +624,8 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { *streamMetadata.maxPtsFromScan * av_q2d(stream->time_base); } } + + // Reset the seek-cursor back to the beginning. int ffmepgStatus = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0); if (ffmepgStatus < 0) { @@ -621,6 +633,8 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { "Could not seek file to pts=0: " + getFFMPEGErrorStringFromErrorCode(ffmepgStatus)); } + + // Sort all frames by their pts. for (auto& [streamIndex, stream] : streams_) { std::sort( stream.keyFrames.begin(), @@ -641,6 +655,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { } } } + scannedAllStreams_ = true; } @@ -1098,14 +1113,13 @@ void VideoDecoder::validateScannedAllStreams(const std::string& msg) { } void VideoDecoder::validateFrameIndex( - const StreamInfo& streamInfo, const StreamMetadata& streamMetadata, int64_t frameIndex) { - int64_t numFrames = getNumFrames(streamInfo, streamMetadata); + int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( frameIndex >= 0 && frameIndex < numFrames, "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(streamInfo.streamIndex) + + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + " numFrames=" + std::to_string(numFrames)); } @@ -1132,12 +1146,10 @@ int64_t VideoDecoder::getPts( } } -int64_t VideoDecoder::getNumFrames( - const StreamInfo& streamInfo, - const StreamMetadata& streamMetadata) { +int64_t VideoDecoder::getNumFrames(const StreamMetadata& streamMetadata) { switch (seekMode_) { case SeekMode::exact: - return streamInfo.allFrames.size(); + return streamMetadata.numFramesFromScan.value(); case SeekMode::approximate: return streamMetadata.numFrames.value(); default: @@ -1221,7 +1233,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - validateFrameIndex(streamInfo, streamMetadata, frameIndex); + validateFrameIndex(streamMetadata, frameIndex); int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); @@ -1261,8 +1273,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( for (auto f = 0; f < frameIndices.size(); ++f) { auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; - if (indexInVideo < 0 || - indexInVideo >= getNumFrames(stream, streamMetadata)) { + if (indexInVideo < 0 || indexInVideo >= getNumFrames(streamMetadata)) { throw std::runtime_error( "Invalid frame index=" + std::to_string(indexInVideo)); } @@ -1327,7 +1338,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; - int64_t numFrames = getNumFrames(stream, streamMetadata); + int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); TORCH_CHECK( @@ -1476,7 +1487,7 @@ double VideoDecoder::getPtsSecondsForFrame( const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - validateFrameIndex(streamInfo, streamMetadata, frameIndex); + validateFrameIndex(streamMetadata, frameIndex); return ptsToSeconds( streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index f1be3c8b..5b5e79b2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -373,7 +373,6 @@ class VideoDecoder { void validateUserProvidedStreamIndex(uint64_t streamIndex); void validateScannedAllStreams(const std::string& msg); void validateFrameIndex( - const StreamInfo& streamInfo, const StreamMetadata& streamMetadata, int64_t frameIndex); @@ -384,9 +383,7 @@ class VideoDecoder { int expectedOutputHeight, int expectedOutputWidth); - int64_t getNumFrames( - const StreamInfo& streamInfo, - const StreamMetadata& streamMetadata); + int64_t getNumFrames(const StreamMetadata& streamMetadata); int64_t getPts( const StreamInfo& streamInfo, From bc10db838b109de238eb8730c6a420ae2a13ab9f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 14 Jan 2025 12:13:49 -0800 Subject: [PATCH 20/56] Tweak FrameInfo struct initialization --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index c831e571..3436a95b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -600,7 +600,9 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { streamMetadata.maxPtsFromScan.value_or(INT64_MIN), packet->pts + packet->duration); - FrameInfo frameInfo{.pts = packet->pts}; + // Note that we set the other value in this struct, nextPts, only after + // we have scanned all packets and sorted by pts. + FrameInfo frameInfo = {.pts = packet->pts}; if (packet->flags & AV_PKT_FLAG_KEY) { streams_[streamIndex].keyFrames.push_back(frameInfo); } From a6a2b6a4afd104c640d3ee78c89b8738c2c2c295 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 14 Jan 2025 12:28:40 -0800 Subject: [PATCH 21/56] Use validateFrameIndex in getFramesInIndices --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 3436a95b..7e2c1001 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1275,10 +1275,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( for (auto f = 0; f < frameIndices.size(); ++f) { auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; - if (indexInVideo < 0 || indexInVideo >= getNumFrames(streamMetadata)) { - throw std::runtime_error( - "Invalid frame index=" + std::to_string(indexInVideo)); - } + validateFrameIndex(streamMetadata, indexInVideo); + if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; From 3cd6842abb3bf834b2aea5cf2c0794899cbd4242 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 14 Jan 2025 15:32:01 -0500 Subject: [PATCH 22/56] Update src/torchcodec/decoders/_video_decoder.py Co-authored-by: Nicolas Hug --- src/torchcodec/decoders/_video_decoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index a0561423..ff1e98e6 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -50,10 +50,10 @@ class VideoDecoder: Passing 0 lets FFmpeg decide on the number of threads. Default: 1. device (str or torch.device, optional): The device to use for decoding. Default: "cpu". - seek_mode (str, optional): Determines if index-based frame access will be "exact" or + seek_mode (str, optional): Determines if frame access will be "exact" or "approximate". Exact guarantees that requesting frame i will always returns frame i, - but doing so requires an initial scan of the file. Approximate avoids scanning the - file, but uses the file's metadata to calculate where i probably is. Default: "exact". + but doing so requires an initial :term:`scan` of the file. Approximate is faster as it avoids scanning the + file, but less accurate uses the file's metadata to calculate where i probably is. Default: "exact". Attributes: From 9a4640406da4a18c1a72e8f6c39d66ad48d53719 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 14 Jan 2025 12:33:17 -0800 Subject: [PATCH 23/56] Tweak VideoDecoder doc string --- src/torchcodec/decoders/_video_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index ff1e98e6..46279135 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -51,9 +51,9 @@ class VideoDecoder: Default: 1. device (str or torch.device, optional): The device to use for decoding. Default: "cpu". seek_mode (str, optional): Determines if frame access will be "exact" or - "approximate". Exact guarantees that requesting frame i will always returns frame i, + "approximate". Exact guarantees that requesting frame i will always return frame i, but doing so requires an initial :term:`scan` of the file. Approximate is faster as it avoids scanning the - file, but less accurate uses the file's metadata to calculate where i probably is. Default: "exact". + file, but less accurate as it uses the file's metadata to calculate where i probably is. Default: "exact". Attributes: From f4c001bf1e82d657de5a2bd9002dc45a9d5dfd63 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 15 Jan 2025 12:39:30 -0800 Subject: [PATCH 24/56] Remove comment --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7e2c1001..7247c952 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1483,7 +1483,7 @@ double VideoDecoder::getPtsSecondsForFrame( int streamIndex, int64_t frameIndex) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getPtsSecondsForFrame"); // keeping? + validateScannedAllStreams("getPtsSecondsForFrame"); const auto& streamInfo = streams_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; From abad57b830bedf7be8c532e1d3818d993215aee1 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 17 Jan 2025 08:17:04 -0800 Subject: [PATCH 25/56] FrameInfo struct initialization --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 55c95f30..92577fad 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -603,7 +603,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // Note that we set the other value in this struct, nextPts, only after // we have scanned all packets and sorted by pts. - FrameInfo frameInfo = {.pts = packet->pts}; + FrameInfo frameInfo = {packet->pts}; if (packet->flags & AV_PKT_FLAG_KEY) { streams_[streamIndex].keyFrames.push_back(frameInfo); } From 737e1b6a8568609fde01d2222cec64eb96519636 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 21 Jan 2025 07:48:19 -0800 Subject: [PATCH 26/56] Remove explicit setting of seek_mode in unrelated tests --- test/decoders/test_video_decoder_ops.py | 48 ++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index f12f08ac..f1c68938 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -65,7 +65,7 @@ def seek(self, pts: float): class TestOps: @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_and_next(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -82,7 +82,7 @@ def test_seek_and_next(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_to_negative_pts(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -94,7 +94,7 @@ def test_seek_to_negative_pts(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_pts(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) # This frame has pts=6.006 and duration=0.033367, so it should be visible # at timestamps in the range [6.006, 6.039367) (not including the last timestamp). @@ -118,7 +118,7 @@ def test_get_frame_at_pts(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_index(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) frame0, _, _ = get_frame_at_index(decoder, stream_index=3, frame_index=0) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -132,7 +132,7 @@ def test_get_frame_at_index(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_with_info_at_index(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) frame6, pts, duration = get_frame_at_index( decoder, stream_index=3, frame_index=180 @@ -146,7 +146,7 @@ def test_get_frame_with_info_at_index(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at_indices(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) frames0and180, *_ = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] @@ -160,7 +160,7 @@ def test_get_frames_at_indices(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at_indices_unsorted_indices(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) stream_index = 3 @@ -191,7 +191,7 @@ def test_get_frames_at_indices_unsorted_indices(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_by_pts(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) stream_index = 3 @@ -225,7 +225,7 @@ def test_pts_apis_against_index_ref(self, device): # Get all frames in the video, then query all frames with all time-based # APIs exactly where those frames are supposed to start. We assert that # we get the expected frame. - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) metadata = get_json_metadata(decoder) @@ -281,7 +281,7 @@ def test_pts_apis_against_index_ref(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_in_range(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) # ensure that the degenerate case of a range of size 1 works @@ -331,7 +331,7 @@ def test_get_frames_in_range(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_throws_exception_at_eof(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) seek_to_pts(decoder, 12.979633) last_frame, _, _ = get_next_frame(decoder) @@ -342,7 +342,7 @@ def test_throws_exception_at_eof(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_throws_exception_if_seek_too_far(self, device): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) # pts=12.979633 is the last frame in the video. seek_to_pts(decoder, 12.979633 + 1.0e-4) @@ -364,7 +364,7 @@ def get_frame1_and_frame_time6(decoder): # NB: create needs to happen outside the torch.compile region, # for now. Otherwise torch.compile constant-props it. - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") + decoder = create_from_file(str(NASA_VIDEO.path)) frame0, frame_time6 = get_frame1_and_frame_time6(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index( @@ -442,7 +442,7 @@ def test_video_get_json_metadata(self): assert metadata_dict["bitRate"] == 324915.0 def test_video_get_json_metadata_with_stream(self): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -469,7 +469,7 @@ def test_get_ffmpeg_version(self): assert "ffmpeg_version" in ffmpeg_dict def test_frame_pts_equality(self): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) # Note that for all of these tests, we store the return value of @@ -489,7 +489,7 @@ def test_frame_pts_equality(self): @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) def test_color_conversion_library(self, color_conversion_library): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="approximate") + decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, color_conversion_library=color_conversion_library) frame0, *_ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -514,7 +514,7 @@ def test_color_conversion_library(self, color_conversion_library): def test_color_conversion_library_with_scaling( self, input_video, width_scaling_factor, height_scaling_factor ): - decoder = create_from_file(str(input_video.path), seek_mode="exact") + decoder = create_from_file(str(input_video.path)) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -529,7 +529,7 @@ def test_color_conversion_library_with_scaling( assert target_height != input_video.height filtergraph_decoder = create_from_file( - str(input_video.path), seek_mode="approximate" + str(input_video.path) ) _add_video_stream( filtergraph_decoder, @@ -540,7 +540,7 @@ def test_color_conversion_library_with_scaling( filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) swscale_decoder = create_from_file( - str(input_video.path), seek_mode="approximate" + str(input_video.path) ) _add_video_stream( swscale_decoder, @@ -556,7 +556,7 @@ def test_color_conversion_library_with_scaling( def test_color_conversion_library_with_dimension_order( self, dimension_order, color_conversion_library ): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream( decoder, color_conversion_library=color_conversion_library, @@ -641,7 +641,7 @@ def test_color_conversion_library_with_generated_videos( ] subprocess.check_call(command) - decoder = create_from_file(str(video_path), seek_mode="exact") + decoder = create_from_file(str(video_path)) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -655,7 +655,7 @@ def test_color_conversion_library_with_generated_videos( if height_scaling_factor != 1.0: assert target_height != height - filtergraph_decoder = create_from_file(str(video_path), seek_mode="approximate") + filtergraph_decoder = create_from_file(str(video_path)) _add_video_stream( filtergraph_decoder, width=target_width, @@ -664,7 +664,7 @@ def test_color_conversion_library_with_generated_videos( ) filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - auto_decoder = create_from_file(str(video_path), seek_mode="approximate") + auto_decoder = create_from_file(str(video_path)) add_video_stream( auto_decoder, width=target_width, @@ -675,7 +675,7 @@ def test_color_conversion_library_with_generated_videos( @needs_cuda def test_cuda_decoder(self): - decoder = create_from_file(str(NASA_VIDEO.path), seek_mode="exact") + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device="cuda") frame0, pts, duration = get_next_frame(decoder) assert frame0.device.type == "cuda" From 32a0f8fa0bc7606c2f2920d045f19639b124b28b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 11:03:41 +0000 Subject: [PATCH 27/56] Handle stream names --- .../decoders/_core/VideoDecoder.cpp | 158 +++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 80 insertions(+), 80 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 18089a64..24098393 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -124,7 +124,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary( torch::Tensor VideoDecoder::maybePermuteHWC2CHW( int streamIndex, torch::Tensor& hwcTensor) { - if (streams_[streamIndex].options.dimensionOrder == "NHWC") { + if (streamInfos_[streamIndex].options.dimensionOrder == "NHWC") { return hwcTensor; } auto numDimensions = hwcTensor.dim(); @@ -195,10 +195,10 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( int64_t numFrames, const VideoStreamDecoderOptions& options, - const StreamMetadata& metadata) + const StreamMetadata& streamMetadata) : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { - auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata); + auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, streamMetadata); int height = frameDims.height; int width = frameDims.width; frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); @@ -253,40 +253,40 @@ void VideoDecoder::initializeDecoder() { } for (unsigned int i = 0; i < formatContext_->nb_streams; i++) { - AVStream* stream = formatContext_->streams[i]; - StreamMetadata meta; + AVStream* avStream = formatContext_->streams[i]; + StreamMetadata streamMetadata; TORCH_CHECK( - static_cast(i) == stream->index, + static_cast(i) == avStream->index, "Our stream index, " + std::to_string(i) + ", does not match AVStream's index, " + - std::to_string(stream->index) + "."); - meta.streamIndex = i; - meta.mediaType = stream->codecpar->codec_type; - meta.codecName = avcodec_get_name(stream->codecpar->codec_id); - meta.bitRate = stream->codecpar->bit_rate; + std::to_string(avStream->index) + "."); + streamMetadata.streamIndex = i; + streamMetadata.mediaType = avStream->codecpar->codec_type; + streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id); + streamMetadata.bitRate = avStream->codecpar->bit_rate; - int64_t frameCount = stream->nb_frames; + int64_t frameCount = avStream->nb_frames; if (frameCount > 0) { - meta.numFrames = frameCount; + streamMetadata.numFrames = frameCount; } - if (stream->duration > 0 && stream->time_base.den > 0) { - meta.durationSeconds = av_q2d(stream->time_base) * stream->duration; + if (avStream->duration > 0 && avStream->time_base.den > 0) { + streamMetadata.durationSeconds = av_q2d(avStream->time_base) * avStream->duration; } - double fps = av_q2d(stream->r_frame_rate); + double fps = av_q2d(avStream->r_frame_rate); if (fps > 0) { - meta.averageFps = fps; + streamMetadata.averageFps = fps; } - if (stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) { + if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) { containerMetadata_.numVideoStreams++; - } else if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { + } else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { containerMetadata_.numAudioStreams++; } - containerMetadata_.streams.push_back(meta); + containerMetadata_.streams.push_back(streamMetadata); } if (formatContext_->duration > 0) { @@ -437,50 +437,50 @@ void VideoDecoder::createFilterGraph( int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) { AVCodecPtr codec = nullptr; - int streamNumber = + int streamIndex = av_find_best_stream(formatContext_.get(), mediaType, -1, -1, &codec, 0); - return streamNumber; + return streamIndex; } void VideoDecoder::addVideoStreamDecoder( - int preferredStreamNumber, + int preferredStreamIndex, const VideoStreamDecoderOptions& options) { - if (activeStreamIndices_.count(preferredStreamNumber) > 0) { + if (activeStreamIndices_.count(preferredStreamIndex) > 0) { throw std::invalid_argument( - "Stream with index " + std::to_string(preferredStreamNumber) + + "Stream with index " + std::to_string(preferredStreamIndex) + " is already active."); } TORCH_CHECK(formatContext_.get() != nullptr); AVCodecPtr codec = nullptr; - int streamNumber = av_find_best_stream( + int streamIndex = av_find_best_stream( formatContext_.get(), AVMEDIA_TYPE_VIDEO, - preferredStreamNumber, + preferredStreamIndex, -1, &codec, 0); - if (streamNumber < 0) { + if (streamIndex < 0) { throw std::invalid_argument("No valid stream found in input file."); } TORCH_CHECK(codec != nullptr); - StreamMetadata& streamMetadata = containerMetadata_.streams[streamNumber]; + StreamMetadata& streamMetadata = containerMetadata_.streams[streamIndex]; if (seekMode_ == SeekMode::approximate && !streamMetadata.averageFps.has_value()) { throw std::runtime_error( - "Seek mode is approximate, but stream " + std::to_string(streamNumber) + + "Seek mode is approximate, but stream " + std::to_string(streamIndex) + " does not have an average fps in its metadata."); } - StreamInfo& streamInfo = streams_[streamNumber]; - streamInfo.streamIndex = streamNumber; - streamInfo.timeBase = formatContext_->streams[streamNumber]->time_base; - streamInfo.stream = formatContext_->streams[streamNumber]; + StreamInfo& streamInfo = streamInfos_[streamIndex]; + streamInfo.streamIndex = streamIndex; + streamInfo.timeBase = formatContext_->streams[streamIndex]->time_base; + streamInfo.stream = formatContext_->streams[streamIndex]; if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) { throw std::invalid_argument( - "Stream with index " + std::to_string(streamNumber) + + "Stream with index " + std::to_string(streamIndex) + " is not a video stream."); } @@ -512,7 +512,7 @@ void VideoDecoder::addVideoStreamDecoder( } codecContext->time_base = streamInfo.stream->time_base; - activeStreamIndices_.insert(streamNumber); + activeStreamIndices_.insert(streamIndex); updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.options = options; @@ -607,9 +607,9 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // we have scanned all packets and sorted by pts. FrameInfo frameInfo = {packet->pts}; if (packet->flags & AV_PKT_FLAG_KEY) { - streams_[streamIndex].keyFrames.push_back(frameInfo); + streamInfos_[streamIndex].keyFrames.push_back(frameInfo); } - streams_[streamIndex].allFrames.push_back(frameInfo); + streamInfos_[streamIndex].allFrames.push_back(frameInfo); } // Set all per-stream metadata that requires knowing the content of all @@ -619,7 +619,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { auto& streamMetadata = containerMetadata_.streams[streamIndex]; auto avStream = formatContext_->streams[streamIndex]; - streamMetadata.numFramesFromScan = streams_[streamIndex].allFrames.size(); + streamMetadata.numFramesFromScan = streamInfos_[streamIndex].allFrames.size(); if (streamMetadata.minPtsFromScan.has_value()) { streamMetadata.minPtsSecondsFromScan = @@ -641,7 +641,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { } // Sort all frames by their pts. - for (auto& [streamIndex, streamInfo] : streams_) { + for (auto& [streamIndex, streamInfo] : streamInfos_) { std::sort( streamInfo.keyFrames.begin(), streamInfo.keyFrames.end(), @@ -732,7 +732,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { return; } for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streams_[streamIndex]; + StreamInfo& streamInfo = streamInfos_[streamIndex]; // clang-format off: clang format clashes streamInfo.discardFramesBeforePts = secondsToClosestPts(*maybeDesiredPts_, streamInfo.timeBase); // clang-format on @@ -743,7 +743,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { // works. bool mustSeek = false; for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streams_[streamIndex]; + StreamInfo& streamInfo = streamInfos_[streamIndex]; int64_t desiredPtsForStream = *maybeDesiredPts_ * streamInfo.timeBase.den; if (!canWeAvoidSeekingForStream( streamInfo, streamInfo.currentPts, desiredPtsForStream)) { @@ -756,7 +756,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { return; } int firstActiveStreamIndex = *activeStreamIndices_.begin(); - const auto& firstStreamInfo = streams_[firstActiveStreamIndex]; + const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex]; int64_t desiredPts = secondsToClosestPts(*maybeDesiredPts_, firstStreamInfo.timeBase); @@ -786,7 +786,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { } decodeStats_.numFlushes++; for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streams_[streamIndex]; + StreamInfo& streamInfo = streamInfos_[streamIndex]; avcodec_flush_buffers(streamInfo.codecContext.get()); } } @@ -810,7 +810,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( frameStreamIndex = -1; bool gotPermanentErrorOnAnyActiveStream = false; for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streams_[streamIndex]; + StreamInfo& streamInfo = streamInfos_[streamIndex]; ffmpegStatus = avcodec_receive_frame(streamInfo.codecContext.get(), frame.get()); bool gotNonRetriableError = @@ -849,7 +849,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( // End of file reached. We must drain all codecs by sending a nullptr // packet. for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streams_[streamIndex]; + StreamInfo& streamInfo = streamInfos_[streamIndex]; ffmpegStatus = avcodec_send_packet( streamInfo.codecContext.get(), /*avpkt=*/nullptr); @@ -872,7 +872,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( continue; } ffmpegStatus = avcodec_send_packet( - streams_[packet->stream_index].codecContext.get(), packet.get()); + streamInfos_[packet->stream_index].codecContext.get(), packet.get()); decodeStats_.numPacketsSentToDecoder++; if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error( @@ -896,9 +896,9 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( // haven't received as frames. Eventually we will either hit AVERROR_EOF from // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. - StreamInfo& activeStream = streams_[frameStreamIndex]; - activeStream.currentPts = frame->pts; - activeStream.currentDuration = getDuration(frame); + StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; + activeStreamInfo.currentPts = frame->pts; + activeStreamInfo.currentDuration = getDuration(frame); RawDecodedOutput rawOutput; rawOutput.streamIndex = frameStreamIndex; rawOutput.frame = std::move(frame); @@ -913,7 +913,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); output.streamIndex = streamIndex; - auto& streamInfo = streams_[streamIndex]; + auto& streamInfo = streamInfos_[streamIndex]; TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); output.ptsSeconds = ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base); @@ -952,7 +952,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( std::optional preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); - auto& streamInfo = streams_[streamIndex]; + auto& streamInfo = streamInfos_[streamIndex]; auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame); @@ -1050,10 +1050,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( double seconds) { - for (auto& [streamIndex, stream] : streams_) { - double frameStartTime = ptsToSeconds(stream.currentPts, stream.timeBase); + for (auto& [streamIndex, streamInfo] : streamInfos_) { + double frameStartTime = ptsToSeconds(streamInfo.currentPts, streamInfo.timeBase); double frameEndTime = ptsToSeconds( - stream.currentPts + stream.currentDuration, stream.timeBase); + streamInfo.currentPts + streamInfo.currentDuration, streamInfo.timeBase); if (seconds >= frameStartTime && seconds < frameEndTime) { // We are in the same frame as the one we just returned. However, since we // don't cache it locally, we have to rewind back. @@ -1065,10 +1065,10 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( setCursorPtsInSeconds(seconds); RawDecodedOutput rawOutput = getDecodedOutputWithFilter( [seconds, this](int frameStreamIndex, AVFrame* frame) { - StreamInfo& stream = streams_[frameStreamIndex]; - double frameStartTime = ptsToSeconds(frame->pts, stream.timeBase); + StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; + double frameStartTime = ptsToSeconds(frame->pts, streamInfo.timeBase); double frameEndTime = - ptsToSeconds(frame->pts + getDuration(frame), stream.timeBase); + ptsToSeconds(frame->pts + getDuration(frame), streamInfo.timeBase); if (frameStartTime > seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() @@ -1096,7 +1096,7 @@ void VideoDecoder::validateUserProvidedStreamIndex(int streamIndex) { "; valid indices are in the range [0, " + std::to_string(streamsSize) + ")."); TORCH_CHECK( - streams_.count(streamIndex) > 0, + streamInfos_.count(streamIndex) > 0, "Provided stream index=" + std::to_string(streamIndex) + " was not previously added."); } @@ -1227,7 +1227,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); - const auto& streamInfo = streams_[streamIndex]; + const auto& streamInfo = streamInfos_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; validateFrameIndex(streamMetadata, frameIndex); @@ -1261,8 +1261,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( } const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - const auto& stream = streams_[streamIndex]; - const auto& options = stream.options; + const auto& streamInfo = streamInfos_[streamIndex]; + const auto& options = streamInfo.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); auto previousIndexInVideo = -1; @@ -1298,7 +1298,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( validateUserProvidedStreamIndex(streamIndex); const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - const auto& stream = streams_[streamIndex]; + const auto& streamInfo = streamInfos_[streamIndex]; double minSeconds = getMinSeconds(streamMetadata); double maxSeconds = getMaxSeconds(streamMetadata); @@ -1318,7 +1318,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( std::to_string(maxSeconds) + ")."); frameIndices[i] = - secondsToIndexLowerBound(frameSeconds, stream, streamMetadata); + secondsToIndexLowerBound(frameSeconds, streamInfo, streamMetadata); } return getFramesAtIndices(streamIndex, frameIndices); @@ -1332,7 +1332,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( validateUserProvidedStreamIndex(streamIndex); const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - const auto& stream = streams_[streamIndex]; + const auto& streamInfo = streamInfos_[streamIndex]; int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); @@ -1344,7 +1344,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( step > 0, "Step must be greater than 0; is " + std::to_string(step)); int64_t numOutputFrames = std::ceil((stop - start) / double(step)); - const auto& options = stream.options; + const auto& options = streamInfo.options; BatchDecodedOutput output(numOutputFrames, options, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { @@ -1371,8 +1371,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange( ") must be less than or equal to stop seconds (" + std::to_string(stopSeconds) + "."); - const auto& stream = streams_[streamIndex]; - const auto& options = stream.options; + const auto& streamInfo = streamInfos_[streamIndex]; + const auto& options = streamInfo.options; // Special case needed to implement a half-open range. At first glance, this // may seem unnecessary, as our search for stopFrame can return the end, and @@ -1424,9 +1424,9 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // frame's pts, but before the next frames's pts. int64_t startFrameIndex = - secondsToIndexLowerBound(startSeconds, stream, streamMetadata); + secondsToIndexLowerBound(startSeconds, streamInfo, streamMetadata); int64_t stopFrameIndex = - secondsToIndexUpperBound(stopSeconds, stream, streamMetadata); + secondsToIndexUpperBound(stopSeconds, streamInfo, streamMetadata); int64_t numFrames = stopFrameIndex - startFrameIndex; BatchDecodedOutput output(numFrames, options, streamMetadata); @@ -1444,8 +1444,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange( VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { auto rawOutput = getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* frame) { - StreamInfo& activeStream = streams_[frameStreamIndex]; - return frame->pts >= activeStream.discardFramesBeforePts; + StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; + return frame->pts >= activeStreamInfo.discardFramesBeforePts; }); return rawOutput; } @@ -1480,7 +1480,7 @@ double VideoDecoder::getPtsSecondsForFrame( validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getPtsSecondsForFrame"); - const auto& streamInfo = streams_[streamIndex]; + const auto& streamInfo = streamInfos_[streamIndex]; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; validateFrameIndex(streamMetadata, frameIndex); @@ -1538,8 +1538,8 @@ int VideoDecoder::convertFrameToTensorUsingSwsScale( int streamIndex, const AVFrame* frame, torch::Tensor& outputTensor) { - StreamInfo& activeStream = streams_[streamIndex]; - SwsContext* swsContext = activeStream.swsContext.get(); + StreamInfo& activeStreamInfo = streamInfos_[streamIndex]; + SwsContext* swsContext = activeStreamInfo.swsContext.get(); uint8_t* pointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; int expectedOutputWidth = outputTensor.sizes()[1]; @@ -1558,7 +1558,7 @@ int VideoDecoder::convertFrameToTensorUsingSwsScale( torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* frame) { - FilterState& filterState = streams_[streamIndex].filterState; + FilterState& filterState = streamInfos_[streamIndex].filterState; int ffmpegStatus = av_buffersrc_write_frame(filterState.sourceContext, frame); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); @@ -1583,11 +1583,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( } VideoDecoder::~VideoDecoder() { - for (auto& [streamIndex, stream] : streams_) { - auto& device = stream.options.device; + for (auto& [streamIndex, streamInfo] : streamInfos_) { + auto& device = streamInfo.options.device; if (device.type() == torch::kCPU) { } else if (device.type() == torch::kCUDA) { - releaseContextOnCuda(device, stream.codecContext.get()); + releaseContextOnCuda(device, streamInfo.codecContext.get()); } else { TORCH_CHECK(false, "Invalid device type: " + device.str()); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index ee723695..f3234560 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -414,7 +414,7 @@ class VideoDecoder { SeekMode seekMode_; ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; - std::map streams_; + std::map streamInfos_; // Stores the stream indices of the active streams, i.e. the streams we are // decoding and returning to the user. std::set activeStreamIndices_; From 64f459571308a9c196f1ce94e4d0ef83801db9ed Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 11:19:06 +0000 Subject: [PATCH 28/56] Handle frame names --- .../decoders/_core/VideoDecoder.cpp | 78 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 8 +- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 24098393..ba527f31 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -802,7 +802,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( maybeDesiredPts_ = std::nullopt; } // Need to get the next frame or error from PopFrame. - UniqueAVFrame frame(av_frame_alloc()); + UniqueAVFrame avFrame(av_frame_alloc()); int ffmpegStatus = AVSUCCESS; bool reachedEOF = false; int frameStreamIndex = -1; @@ -812,7 +812,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( for (int streamIndex : activeStreamIndices_) { StreamInfo& streamInfo = streamInfos_[streamIndex]; ffmpegStatus = - avcodec_receive_frame(streamInfo.codecContext.get(), frame.get()); + avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); bool gotNonRetriableError = ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN); if (gotNonRetriableError) { @@ -829,7 +829,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( } decodeStats_.numFramesReceivedByDecoder++; bool gotNeededFrame = ffmpegStatus == AVSUCCESS && - filterFunction(frameStreamIndex, frame.get()); + filterFunction(frameStreamIndex, avFrame.get()); if (gotNeededFrame) { break; } else if (ffmpegStatus == AVSUCCESS) { @@ -897,11 +897,11 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; - activeStreamInfo.currentPts = frame->pts; - activeStreamInfo.currentDuration = getDuration(frame); + activeStreamInfo.currentPts = avFrame->pts; + activeStreamInfo.currentDuration = getDuration(avFrame); RawDecodedOutput rawOutput; rawOutput.streamIndex = frameStreamIndex; - rawOutput.frame = std::move(frame); + rawOutput.frame = std::move(avFrame); return rawOutput; } @@ -911,14 +911,14 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( // Convert the frame to tensor. DecodedOutput output; int streamIndex = rawOutput.streamIndex; - AVFrame* frame = rawOutput.frame.get(); + AVFrame* avFrame = rawOutput.frame.get(); output.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); output.ptsSeconds = - ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base); + ptsToSeconds(avFrame->pts, formatContext_->streams[streamIndex]->time_base); output.durationSeconds = ptsToSeconds( - getDuration(frame), formatContext_->streams[streamIndex]->time_base); + getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput. if (streamInfo.options.device.type() == torch::kCPU) { convertAVFrameToDecodedOutputOnCPU( @@ -951,11 +951,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( DecodedOutput& output, std::optional preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; - AVFrame* frame = rawOutput.frame.get(); + AVFrame* avFrame = rawOutput.frame.get(); auto& streamInfo = streamInfos_[streamIndex]; auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame); + getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *avFrame); int expectedOutputHeight = frameDims.height; int expectedOutputWidth = frameDims.width; @@ -981,10 +981,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( // resolution to change mid-stream. Finally, we want to reuse the colorspace // conversion objects as much as possible for performance reasons. enum AVPixelFormat frameFormat = - static_cast(frame->format); + static_cast(avFrame->format); auto frameContext = DecodedFrameContext{ - frame->width, - frame->height, + avFrame->width, + avFrame->height, frameFormat, expectedOutputWidth, expectedOutputHeight}; @@ -994,11 +994,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( expectedOutputHeight, expectedOutputWidth, torch::kCPU)); if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) { - createSwsContext(streamInfo, frameContext, frame->colorspace); + createSwsContext(streamInfo, frameContext, avFrame->colorspace); streamInfo.prevFrameContext = frameContext; } int resultHeight = - convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor); + convertAVFrameToTensorUsingSwsScale(streamIndex, avFrame, outputTensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? @@ -1018,7 +1018,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); streamInfo.prevFrameContext = frameContext; } - outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); + outputTensor = convertAVFrameToTensorUsingFilterGraph(streamIndex, avFrame); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. @@ -1064,11 +1064,11 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( setCursorPtsInSeconds(seconds); RawDecodedOutput rawOutput = getDecodedOutputWithFilter( - [seconds, this](int frameStreamIndex, AVFrame* frame) { + [seconds, this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; - double frameStartTime = ptsToSeconds(frame->pts, streamInfo.timeBase); + double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); double frameEndTime = - ptsToSeconds(frame->pts + getDuration(frame), streamInfo.timeBase); + ptsToSeconds(avFrame->pts + getDuration(avFrame), streamInfo.timeBase); if (frameStartTime > seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() @@ -1443,9 +1443,9 @@ VideoDecoder::getFramesPlayedByTimestampInRange( VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { auto rawOutput = - getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* frame) { + getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; - return frame->pts >= activeStreamInfo.discardFramesBeforePts; + return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; }); return rawOutput; } @@ -1534,9 +1534,9 @@ void VideoDecoder::createSwsContext( streamInfo.swsContext.reset(swsContext); } -int VideoDecoder::convertFrameToTensorUsingSwsScale( +int VideoDecoder::convertAVFrameToTensorUsingSwsScale( int streamIndex, - const AVFrame* frame, + const AVFrame* avFrame, torch::Tensor& outputTensor) { StreamInfo& activeStreamInfo = streamInfos_[streamIndex]; SwsContext* swsContext = activeStreamInfo.swsContext.get(); @@ -1546,40 +1546,40 @@ int VideoDecoder::convertFrameToTensorUsingSwsScale( int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; int resultHeight = sws_scale( swsContext, - frame->data, - frame->linesize, + avFrame->data, + avFrame->linesize, 0, - frame->height, + avFrame->height, pointers, linesizes); return resultHeight; } -torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( +torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( int streamIndex, - const AVFrame* frame) { + const AVFrame* avFrame) { FilterState& filterState = streamInfos_[streamIndex].filterState; - int ffmpegStatus = av_buffersrc_write_frame(filterState.sourceContext, frame); + int ffmpegStatus = av_buffersrc_write_frame(filterState.sourceContext, avFrame); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); } - UniqueAVFrame filteredFrame(av_frame_alloc()); + UniqueAVFrame filteredAVFrame(av_frame_alloc()); ffmpegStatus = - av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get()); - TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24); + av_buffersink_get_frame(filterState.sinkContext, filteredAVFrame.get()); + TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); - auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get()); + auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); int height = frameDims.height; int width = frameDims.width; std::vector shape = {height, width, 3}; - std::vector strides = {filteredFrame->linesize[0], 3, 1}; - AVFrame* filteredFramePtr = filteredFrame.release(); - auto deleter = [filteredFramePtr](void*) { - UniqueAVFrame frameToDelete(filteredFramePtr); + std::vector strides = {filteredAVFrame->linesize[0], 3, 1}; + AVFrame* filteredAVFramePtr = filteredAVFrame.release(); + auto deleter = [filteredAVFramePtr](void*) { + UniqueAVFrame avFrameToDelete(filteredAVFramePtr); }; return torch::from_blob( - filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); + filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); } VideoDecoder::~VideoDecoder() { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index f3234560..bca6843e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -393,12 +393,12 @@ class VideoDecoder { int streamIndex, AVCodecContext* codecContext); void populateVideoMetadataFromStreamIndex(int streamIndex); - torch::Tensor convertFrameToTensorUsingFilterGraph( + torch::Tensor convertAVFrameToTensorUsingFilterGraph( int streamIndex, - const AVFrame* frame); - int convertFrameToTensorUsingSwsScale( + const AVFrame* avFrame); + int convertAVFrameToTensorUsingSwsScale( int streamIndex, - const AVFrame* frame, + const AVFrame* avFrame, torch::Tensor& outputTensor); DecodedOutput convertAVFrameToDecodedOutput( RawDecodedOutput& rawOutput, From c6290737698d14a6f2e92b5ca2502536fa84279e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 11:25:00 +0000 Subject: [PATCH 29/56] Streams again --- .../decoders/_core/VideoDecoder.cpp | 30 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 14 ++++----- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index ba527f31..7ecc7342 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -286,7 +286,7 @@ void VideoDecoder::initializeDecoder() { containerMetadata_.numAudioStreams++; } - containerMetadata_.streams.push_back(streamMetadata); + containerMetadata_.streamMetadatas.push_back(streamMetadata); } if (formatContext_->duration > 0) { @@ -465,7 +465,7 @@ void VideoDecoder::addVideoStreamDecoder( } TORCH_CHECK(codec != nullptr); - StreamMetadata& streamMetadata = containerMetadata_.streams[streamIndex]; + StreamMetadata& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; if (seekMode_ == SeekMode::approximate && !streamMetadata.averageFps.has_value()) { throw std::runtime_error( @@ -532,10 +532,10 @@ void VideoDecoder::addVideoStreamDecoder( void VideoDecoder::updateMetadataWithCodecContext( int streamIndex, AVCodecContext* codecContext) { - containerMetadata_.streams[streamIndex].width = codecContext->width; - containerMetadata_.streams[streamIndex].height = codecContext->height; + containerMetadata_.streamMetadatas[streamIndex].width = codecContext->width; + containerMetadata_.streamMetadatas[streamIndex].height = codecContext->height; auto codedId = codecContext->codec_id; - containerMetadata_.streams[streamIndex].codecName = + containerMetadata_.streamMetadatas[streamIndex].codecName = std::string(avcodec_get_name(codedId)); } @@ -594,7 +594,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // We got a valid packet. Let's figure out what stream it belongs to and // record its relevant metadata. int streamIndex = packet->stream_index; - auto& streamMetadata = containerMetadata_.streams[streamIndex]; + auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; streamMetadata.minPtsFromScan = std::min( streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts); streamMetadata.maxPtsFromScan = std::max( @@ -614,9 +614,9 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // Set all per-stream metadata that requires knowing the content of all // packets. - for (size_t streamIndex = 0; streamIndex < containerMetadata_.streams.size(); + for (size_t streamIndex = 0; streamIndex < containerMetadata_.streamMetadatas.size(); ++streamIndex) { - auto& streamMetadata = containerMetadata_.streams[streamIndex]; + auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; auto avStream = formatContext_->streams[streamIndex]; streamMetadata.numFramesFromScan = streamInfos_[streamIndex].allFrames.size(); @@ -1089,7 +1089,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } void VideoDecoder::validateUserProvidedStreamIndex(int streamIndex) { - int streamsSize = static_cast(containerMetadata_.streams.size()); + int streamsSize = static_cast(containerMetadata_.streamMetadatas.size()); TORCH_CHECK( streamIndex >= 0 && streamIndex < streamsSize, "Invalid stream index=" + std::to_string(streamIndex) + @@ -1228,7 +1228,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( validateUserProvidedStreamIndex(streamIndex); const auto& streamInfo = streamInfos_[streamIndex]; - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; validateFrameIndex(streamMetadata, frameIndex); int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); @@ -1260,7 +1260,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( }); } - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; const auto& options = streamInfo.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); @@ -1297,7 +1297,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; double minSeconds = getMinSeconds(streamMetadata); @@ -1331,7 +1331,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t step) { validateUserProvidedStreamIndex(streamIndex); - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( @@ -1364,7 +1364,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( double stopSeconds) { validateUserProvidedStreamIndex(streamIndex); - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; TORCH_CHECK( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + @@ -1481,7 +1481,7 @@ double VideoDecoder::getPtsSecondsForFrame( validateScannedAllStreams("getPtsSecondsForFrame"); const auto& streamInfo = streamInfos_[streamIndex]; - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; validateFrameIndex(streamMetadata, frameIndex); return ptsToSeconds( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index bca6843e..06870f4f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -104,7 +104,7 @@ class VideoDecoder { std::optional height; }; struct ContainerMetadata { - std::vector streams; + std::vector streamMetadatas; int numAudioStreams = 0; int numVideoStreams = 0; // Note that this is the container-level duration, which is usually the max diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index c63a89e7..b8b62ce0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -344,9 +344,9 @@ std::string get_json_metadata(at::Tensor& decoder) { // serialize the metadata into a string std::stringstream ss; double durationSeconds = 0; if (maybeBestVideoStreamIndex.has_value() && - videoMetadata.streams[*maybeBestVideoStreamIndex] + videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex] .durationSeconds.has_value()) { - durationSeconds = videoMetadata.streams[*maybeBestVideoStreamIndex] + durationSeconds = videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex] .durationSeconds.value_or(0); } else { // Fallback to container-level duration if stream duration is not found. @@ -359,7 +359,7 @@ std::string get_json_metadata(at::Tensor& decoder) { } if (maybeBestVideoStreamIndex.has_value()) { - auto streamMetadata = videoMetadata.streams[*maybeBestVideoStreamIndex]; + auto streamMetadata = videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex]; if (streamMetadata.numFramesFromScan.has_value()) { metadataMap["numFrames"] = std::to_string(*streamMetadata.numFramesFromScan); @@ -423,7 +423,7 @@ std::string get_container_json_metadata(at::Tensor& decoder) { std::to_string(*containerMetadata.bestAudioStreamIndex); } - map["numStreams"] = std::to_string(containerMetadata.streams.size()); + map["numStreams"] = std::to_string(containerMetadata.streamMetadatas.size()); return mapToJson(map); } @@ -432,13 +432,13 @@ std::string get_stream_json_metadata( at::Tensor& decoder, int64_t stream_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto streams = videoDecoder->getContainerMetadata().streams; + auto streamMetadatas = videoDecoder->getContainerMetadata().streamMetadatas; if (stream_index < 0 || - stream_index >= static_cast(streams.size())) { + stream_index >= static_cast(streamMetadatas.size())) { throw std::out_of_range( "stream_index out of bounds: " + std::to_string(stream_index)); } - auto streamMetadata = streams[stream_index]; + auto streamMetadata = streamMetadatas[stream_index]; std::map map; From 85569fb14607c3210ad2ae893e9fe27928fd5dab Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 11:34:07 +0000 Subject: [PATCH 30/56] metadata names --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 +++--- src/torchcodec/decoders/_core/VideoDecoder.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7ecc7342..d4d17c61 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1600,10 +1600,10 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { FrameDims getHeightAndWidthFromOptionsOrMetadata( const VideoDecoder::VideoStreamDecoderOptions& options, - const VideoDecoder::StreamMetadata& metadata) { + const VideoDecoder::StreamMetadata& streamMetadata) { return FrameDims( - options.height.value_or(*metadata.height), - options.width.value_or(*metadata.width)); + options.height.value_or(*streamMetadata.height), + options.width.value_or(*streamMetadata.width)); } FrameDims getHeightAndWidthFromOptionsOrAVFrame( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 06870f4f..d54ca0e1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -215,7 +215,7 @@ class VideoDecoder { explicit BatchDecodedOutput( int64_t numFrames, const VideoStreamDecoderOptions& options, - const StreamMetadata& metadata); + const StreamMetadata& streamMetadata); }; // Returns frames at the given indices for a given stream as a single stacked @@ -489,7 +489,7 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); FrameDims getHeightAndWidthFromOptionsOrMetadata( const VideoDecoder::VideoStreamDecoderOptions& options, - const VideoDecoder::StreamMetadata& metadata); + const VideoDecoder::StreamMetadata& streamMetadata); FrameDims getHeightAndWidthFromOptionsOrAVFrame( const VideoDecoder::VideoStreamDecoderOptions& options, From d914615914b451416c8cf04a7ba7cef34ec5657b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 12:32:29 +0000 Subject: [PATCH 31/56] Handle options --- .../decoders/_core/CPUOnlyDevice.cpp | 2 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 2 +- .../decoders/_core/DeviceInterface.h | 2 +- .../decoders/_core/VideoDecoder.cpp | 76 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 20 ++--- .../decoders/_core/VideoDecoderOps.cpp | 20 ++--- test/decoders/VideoDecoderTest.cpp | 20 ++--- 7 files changed, 71 insertions(+), 71 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 213d507e..996e36cb 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -16,7 +16,7 @@ namespace facebook::torchcodec { void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, - [[maybe_unused]] const VideoDecoder::VideoStreamDecoderOptions& options, + [[maybe_unused]] const VideoDecoder::VideoStreamOptions& options, [[maybe_unused]] VideoDecoder::RawDecodedOutput& rawOutput, [[maybe_unused]] VideoDecoder::DecodedOutput& output, [[maybe_unused]] std::optional preAllocatedOutputTensor) { diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index a786be90..a10cdb64 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -185,7 +185,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, - const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::VideoStreamOptions& options, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional preAllocatedOutputTensor) { diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 5ae201d2..dc0d30c6 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -31,7 +31,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, - const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::VideoStreamOptions& options, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index d4d17c61..6f6a40e3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -124,7 +124,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary( torch::Tensor VideoDecoder::maybePermuteHWC2CHW( int streamIndex, torch::Tensor& hwcTensor) { - if (streamInfos_[streamIndex].options.dimensionOrder == "NHWC") { + if (streamInfos_[streamIndex].videoStreamOptions.dimensionOrder == "NHWC") { return hwcTensor; } auto numDimensions = hwcTensor.dim(); @@ -141,7 +141,7 @@ torch::Tensor VideoDecoder::maybePermuteHWC2CHW( } } -VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( +VideoDecoder::VideoStreamOptions::VideoStreamOptions( const std::string& optionsString) { std::vector tokens = splitStringWithDelimiters(optionsString, ","); @@ -194,14 +194,14 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( int64_t numFrames, - const VideoStreamDecoderOptions& options, + const VideoStreamOptions& videoStreamOptions, const StreamMetadata& streamMetadata) : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { - auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, streamMetadata); + auto frameDims = getHeightAndWidthFromOptionsOrMetadata(videoStreamOptions, streamMetadata); int height = frameDims.height; int width = frameDims.width; - frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); + frames = allocateEmptyHWCTensor(height, width, videoStreamOptions.device, numFrames); } bool VideoDecoder::DecodedFrameContext::operator==( @@ -338,9 +338,9 @@ void VideoDecoder::createFilterGraph( filterState.filterGraph.reset(avfilter_graph_alloc()); TORCH_CHECK(filterState.filterGraph.get() != nullptr); - if (streamInfo.options.ffmpegThreadCount.has_value()) { + if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) { filterState.filterGraph->nb_threads = - streamInfo.options.ffmpegThreadCount.value(); + streamInfo.videoStreamOptions.ffmpegThreadCount.value(); } const AVFilter* buffersrc = avfilter_get_by_name("buffer"); @@ -444,7 +444,7 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) { void VideoDecoder::addVideoStreamDecoder( int preferredStreamIndex, - const VideoStreamDecoderOptions& options) { + const VideoStreamOptions& videoStreamOptions) { if (activeStreamIndices_.count(preferredStreamIndex) > 0) { throw std::invalid_argument( "Stream with index " + std::to_string(preferredStreamIndex) + @@ -484,26 +484,26 @@ void VideoDecoder::addVideoStreamDecoder( " is not a video stream."); } - if (options.device.type() == torch::kCUDA) { - codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id) + if (videoStreamOptions.device.type() == torch::kCUDA) { + codec = findCudaCodec(videoStreamOptions.device, streamInfo.stream->codecpar->codec_id) .value_or(codec); } AVCodecContext* codecContext = avcodec_alloc_context3(codec); TORCH_CHECK(codecContext != nullptr); - codecContext->thread_count = options.ffmpegThreadCount.value_or(0); + codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0); streamInfo.codecContext.reset(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); TORCH_CHECK_EQ(retVal, AVSUCCESS); - if (options.device.type() == torch::kCPU) { + if (videoStreamOptions.device.type() == torch::kCPU) { // No more initialization needed for CPU. - } else if (options.device.type() == torch::kCUDA) { - initializeContextOnCuda(options.device, codecContext); + } else if (videoStreamOptions.device.type() == torch::kCUDA) { + initializeContextOnCuda(videoStreamOptions.device, codecContext); } else { - TORCH_CHECK(false, "Invalid device type: " + options.device.str()); + TORCH_CHECK(false, "Invalid device type: " + videoStreamOptions.device.str()); } retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); @@ -514,7 +514,7 @@ void VideoDecoder::addVideoStreamDecoder( codecContext->time_base = streamInfo.stream->time_base; activeStreamIndices_.insert(streamIndex); updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); - streamInfo.options = options; + streamInfo.videoStreamOptions = videoStreamOptions; // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -523,10 +523,10 @@ void VideoDecoder::addVideoStreamDecoder( // swscale's width requirements to be violated. We don't expose the ability to // choose color conversion library publicly; we only use this ability // internally. - int width = options.width.value_or(codecContext->width); + int width = videoStreamOptions.width.value_or(codecContext->width); auto defaultLibrary = getDefaultColorConversionLibrary(width); streamInfo.colorConversionLibrary = - options.colorConversionLibrary.value_or(defaultLibrary); + videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); } void VideoDecoder::updateMetadataWithCodecContext( @@ -920,19 +920,19 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput. - if (streamInfo.options.device.type() == torch::kCPU) { + if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToDecodedOutputOnCPU( rawOutput, output, preAllocatedOutputTensor); - } else if (streamInfo.options.device.type() == torch::kCUDA) { + } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { convertAVFrameToDecodedOutputOnCuda( - streamInfo.options.device, - streamInfo.options, + streamInfo.videoStreamOptions.device, + streamInfo.videoStreamOptions, rawOutput, output, preAllocatedOutputTensor); } else { TORCH_CHECK( - false, "Invalid device type: " + streamInfo.options.device.str()); + false, "Invalid device type: " + streamInfo.videoStreamOptions.device.str()); } return output; } @@ -955,7 +955,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( auto& streamInfo = streamInfos_[streamIndex]; auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *avFrame); + getHeightAndWidthFromOptionsOrAVFrame(streamInfo.videoStreamOptions, *avFrame); int expectedOutputHeight = frameDims.height; int expectedOutputWidth = frameDims.width; @@ -1262,8 +1262,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; - const auto& options = streamInfo.options; - BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); + const auto& videoStreamOptions = streamInfo.videoStreamOptions; + BatchDecodedOutput output(frameIndices.size(), videoStreamOptions, streamMetadata); auto previousIndexInVideo = -1; for (size_t f = 0; f < frameIndices.size(); ++f) { @@ -1344,8 +1344,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( step > 0, "Step must be greater than 0; is " + std::to_string(step)); int64_t numOutputFrames = std::ceil((stop - start) / double(step)); - const auto& options = streamInfo.options; - BatchDecodedOutput output(numOutputFrames, options, streamMetadata); + const auto& videoStreamOptions = streamInfo.videoStreamOptions; + BatchDecodedOutput output(numOutputFrames, videoStreamOptions, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { DecodedOutput singleOut = @@ -1372,7 +1372,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( std::to_string(stopSeconds) + "."); const auto& streamInfo = streamInfos_[streamIndex]; - const auto& options = streamInfo.options; + const auto& videoStreamOptions = streamInfo.videoStreamOptions; // Special case needed to implement a half-open range. At first glance, this // may seem unnecessary, as our search for stopFrame can return the end, and @@ -1392,7 +1392,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // values of the intervals will map to the same frame indices below. Hence, we // need this special case below. if (startSeconds == stopSeconds) { - BatchDecodedOutput output(0, options, streamMetadata); + BatchDecodedOutput output(0, videoStreamOptions, streamMetadata); output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1429,7 +1429,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( secondsToIndexUpperBound(stopSeconds, streamInfo, streamMetadata); int64_t numFrames = stopFrameIndex - startFrameIndex; - BatchDecodedOutput output(numFrames, options, streamMetadata); + BatchDecodedOutput output(numFrames, videoStreamOptions, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { DecodedOutput singleOut = getFrameAtIndexInternal(streamIndex, i, output.frames[f]); @@ -1584,7 +1584,7 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( VideoDecoder::~VideoDecoder() { for (auto& [streamIndex, streamInfo] : streamInfos_) { - auto& device = streamInfo.options.device; + auto& device = streamInfo.videoStreamOptions.device; if (device.type() == torch::kCPU) { } else if (device.type() == torch::kCUDA) { releaseContextOnCuda(device, streamInfo.codecContext.get()); @@ -1599,19 +1599,19 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { } FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, const VideoDecoder::StreamMetadata& streamMetadata) { return FrameDims( - options.height.value_or(*streamMetadata.height), - options.width.value_or(*streamMetadata.width)); + videoStreamOptions.height.value_or(*streamMetadata.height), + videoStreamOptions.width.value_or(*streamMetadata.width)); } FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, const AVFrame& avFrame) { return FrameDims( - options.height.value_or(avFrame.height), - options.width.value_or(avFrame.width)); + videoStreamOptions.height.value_or(avFrame.height), + videoStreamOptions.width.value_or(avFrame.width)); } torch::Tensor allocateEmptyHWCTensor( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index d54ca0e1..6fb486a1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -130,9 +130,9 @@ class VideoDecoder { // Use the libswscale library for color conversion. SWSCALE }; - struct VideoStreamDecoderOptions { - VideoStreamDecoderOptions() {} - explicit VideoStreamDecoderOptions(const std::string& optionsString); + struct VideoStreamOptions { + VideoStreamOptions() {} + explicit VideoStreamOptions(const std::string& optionsString); // Number of threads we pass to FFMPEG for decoding. // 0 means FFMPEG will choose the number of threads automatically to fully // utilize all cores. If not set, it will be the default FFMPEG behavior for @@ -149,13 +149,13 @@ class VideoDecoder { // By default we use CPU for decoding for both C++ and python users. torch::Device device = torch::kCPU; }; - struct AudioStreamDecoderOptions {}; + struct AudioStreamOptions {}; void addVideoStreamDecoder( int streamIndex, - const VideoStreamDecoderOptions& options = VideoStreamDecoderOptions()); + const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); void addAudioStreamDecoder( int streamIndex, - const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions()); + const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); @@ -214,7 +214,7 @@ class VideoDecoder { explicit BatchDecodedOutput( int64_t numFrames, - const VideoStreamDecoderOptions& options, + const VideoStreamOptions& videoStreamOptions, const StreamMetadata& streamMetadata); }; @@ -313,7 +313,7 @@ class VideoDecoder { // this pts to the user when they request a frame. // We update this field if the user requested a seek. int64_t discardFramesBeforePts = INT64_MIN; - VideoStreamDecoderOptions options; + VideoStreamOptions videoStreamOptions; // The filter state associated with this stream (for video streams). The // actual graph will be nullptr for inactive streams. FilterState filterState; @@ -488,11 +488,11 @@ struct FrameDims { FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, const VideoDecoder::StreamMetadata& streamMetadata); FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, const AVFrame& avFrame); torch::Tensor allocateEmptyHWCTensor( diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index b8b62ce0..de4fa44a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -180,23 +180,23 @@ void _add_video_stream( std::optional stream_index, std::optional device, std::optional color_conversion_library) { - VideoDecoder::VideoStreamDecoderOptions options; - options.width = width; - options.height = height; - options.ffmpegThreadCount = num_threads; + VideoDecoder::VideoStreamOptions videoStreamOptions; + videoStreamOptions.width = width; + videoStreamOptions.height = height; + videoStreamOptions.ffmpegThreadCount = num_threads; if (dimension_order.has_value()) { std::string stdDimensionOrder{dimension_order.value()}; TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW"); - options.dimensionOrder = stdDimensionOrder; + videoStreamOptions.dimensionOrder = stdDimensionOrder; } if (color_conversion_library.has_value()) { std::string stdColorConversionLibrary{color_conversion_library.value()}; if (stdColorConversionLibrary == "filtergraph") { - options.colorConversionLibrary = + videoStreamOptions.colorConversionLibrary = VideoDecoder::ColorConversionLibrary::FILTERGRAPH; } else if (stdColorConversionLibrary == "swscale") { - options.colorConversionLibrary = + videoStreamOptions.colorConversionLibrary = VideoDecoder::ColorConversionLibrary::SWSCALE; } else { throw std::runtime_error( @@ -206,10 +206,10 @@ void _add_video_stream( } if (device.has_value()) { if (device.value() == "cpu") { - options.device = torch::Device(torch::kCPU); + videoStreamOptions.device = torch::Device(torch::kCPU); } else if (device.value().rfind("cuda", 0) == 0) { // starts with "cuda" std::string deviceStr(device.value()); - options.device = torch::Device(deviceStr); + videoStreamOptions.device = torch::Device(deviceStr); } else { throw std::runtime_error( "Invalid device=" + std::string(device.value()) + @@ -218,7 +218,7 @@ void _add_video_stream( } auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), options); + videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), videoStreamOptions); } void seek_to_pts(at::Tensor& decoder, double seconds) { diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index e3b3f1e3..04d8c45d 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -146,10 +146,10 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = VideoDecoder::createFromFilePath(path); - VideoDecoder::VideoStreamDecoderOptions streamOptions; - streamOptions.width = 100; - streamOptions.height = 120; - decoder->addVideoStreamDecoder(-1, streamOptions); + VideoDecoder::VideoStreamOptions videoStreamOptions; + videoStreamOptions.width = 100; + videoStreamOptions.height = 120; + decoder->addVideoStreamDecoder(-1, videoStreamOptions); torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } @@ -158,9 +158,9 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = VideoDecoder::createFromFilePath(path); - VideoDecoder::VideoStreamDecoderOptions streamOptions; - streamOptions.dimensionOrder = "NHWC"; - decoder->addVideoStreamDecoder(-1, streamOptions); + VideoDecoder::VideoStreamOptions videoStreamOptions; + videoStreamOptions.dimensionOrder = "NHWC"; + decoder->addVideoStreamDecoder(-1, videoStreamOptions); torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -232,7 +232,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; ourDecoder->addVideoStreamDecoder( bestVideoStreamIndex, - VideoDecoder::VideoStreamDecoderOptions("dimension_order=NHWC")); + VideoDecoder::VideoStreamOptions("dimension_order=NHWC")); // Frame with index 180 corresponds to timestamp 6.006. auto output = ourDecoder->getFramesAtIndices(bestVideoStreamIndex, {0, 180}); auto tensor = output.frames; @@ -398,7 +398,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; ourDecoder->addVideoStreamDecoder( bestVideoStreamIndex, - VideoDecoder::VideoStreamDecoderOptions( + VideoDecoder::VideoStreamOptions( "color_conversion_library=filtergraph")); auto output = ourDecoder->getFrameAtIndexInternal( bestVideoStreamIndex, 0, preAllocatedOutputTensor); @@ -416,7 +416,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; ourDecoder->addVideoStreamDecoder( bestVideoStreamIndex, - VideoDecoder::VideoStreamDecoderOptions( + VideoDecoder::VideoStreamOptions( "color_conversion_library=swscale")); auto output = ourDecoder->getFrameAtIndexInternal( bestVideoStreamIndex, 0, preAllocatedOutputTensor); From 621a64c9059eb578c91de4fea2b35d5c2d064562 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 12:36:23 +0000 Subject: [PATCH 32/56] Frame again --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 +++--- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6f6a40e3..7a095faa 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -901,7 +901,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( activeStreamInfo.currentDuration = getDuration(avFrame); RawDecodedOutput rawOutput; rawOutput.streamIndex = frameStreamIndex; - rawOutput.frame = std::move(avFrame); + rawOutput.avFrame = std::move(avFrame); return rawOutput; } @@ -911,7 +911,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( // Convert the frame to tensor. DecodedOutput output; int streamIndex = rawOutput.streamIndex; - AVFrame* avFrame = rawOutput.frame.get(); + AVFrame* avFrame = rawOutput.avFrame.get(); output.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); @@ -951,7 +951,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( DecodedOutput& output, std::optional preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; - AVFrame* avFrame = rawOutput.frame.get(); + AVFrame* avFrame = rawOutput.avFrame.get(); auto& streamInfo = streamInfos_[streamIndex]; auto frameDims = diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 6fb486a1..332f9fda 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -168,7 +168,7 @@ class VideoDecoder { // Note that AVFrame itself doesn't retain the streamIndex. struct RawDecodedOutput { // The actual decoded output as a unique pointer to an AVFrame. - UniqueAVFrame frame; + UniqueAVFrame avFrame; // The stream index of the decoded frame. int streamIndex; }; From f247884cc73a1a31899faa213b62382e903a2826 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 12:36:42 +0000 Subject: [PATCH 33/56] Lint --- .../decoders/_core/VideoDecoder.cpp | 61 ++++++++++++------- .../decoders/_core/VideoDecoderOps.cpp | 6 +- test/decoders/VideoDecoderTest.cpp | 6 +- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7a095faa..10134d9c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -198,10 +198,12 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( const StreamMetadata& streamMetadata) : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { - auto frameDims = getHeightAndWidthFromOptionsOrMetadata(videoStreamOptions, streamMetadata); + auto frameDims = getHeightAndWidthFromOptionsOrMetadata( + videoStreamOptions, streamMetadata); int height = frameDims.height; int width = frameDims.width; - frames = allocateEmptyHWCTensor(height, width, videoStreamOptions.device, numFrames); + frames = allocateEmptyHWCTensor( + height, width, videoStreamOptions.device, numFrames); } bool VideoDecoder::DecodedFrameContext::operator==( @@ -272,7 +274,8 @@ void VideoDecoder::initializeDecoder() { } if (avStream->duration > 0 && avStream->time_base.den > 0) { - streamMetadata.durationSeconds = av_q2d(avStream->time_base) * avStream->duration; + streamMetadata.durationSeconds = + av_q2d(avStream->time_base) * avStream->duration; } double fps = av_q2d(avStream->r_frame_rate); @@ -465,7 +468,8 @@ void VideoDecoder::addVideoStreamDecoder( } TORCH_CHECK(codec != nullptr); - StreamMetadata& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + StreamMetadata& streamMetadata = + containerMetadata_.streamMetadatas[streamIndex]; if (seekMode_ == SeekMode::approximate && !streamMetadata.averageFps.has_value()) { throw std::runtime_error( @@ -485,8 +489,10 @@ void VideoDecoder::addVideoStreamDecoder( } if (videoStreamOptions.device.type() == torch::kCUDA) { - codec = findCudaCodec(videoStreamOptions.device, streamInfo.stream->codecpar->codec_id) - .value_or(codec); + codec = + findCudaCodec( + videoStreamOptions.device, streamInfo.stream->codecpar->codec_id) + .value_or(codec); } AVCodecContext* codecContext = avcodec_alloc_context3(codec); @@ -503,7 +509,8 @@ void VideoDecoder::addVideoStreamDecoder( } else if (videoStreamOptions.device.type() == torch::kCUDA) { initializeContextOnCuda(videoStreamOptions.device, codecContext); } else { - TORCH_CHECK(false, "Invalid device type: " + videoStreamOptions.device.str()); + TORCH_CHECK( + false, "Invalid device type: " + videoStreamOptions.device.str()); } retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); @@ -614,12 +621,14 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // Set all per-stream metadata that requires knowing the content of all // packets. - for (size_t streamIndex = 0; streamIndex < containerMetadata_.streamMetadatas.size(); + for (size_t streamIndex = 0; + streamIndex < containerMetadata_.streamMetadatas.size(); ++streamIndex) { auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; auto avStream = formatContext_->streams[streamIndex]; - streamMetadata.numFramesFromScan = streamInfos_[streamIndex].allFrames.size(); + streamMetadata.numFramesFromScan = + streamInfos_[streamIndex].allFrames.size(); if (streamMetadata.minPtsFromScan.has_value()) { streamMetadata.minPtsSecondsFromScan = @@ -915,8 +924,8 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); - output.ptsSeconds = - ptsToSeconds(avFrame->pts, formatContext_->streams[streamIndex]->time_base); + output.ptsSeconds = ptsToSeconds( + avFrame->pts, formatContext_->streams[streamIndex]->time_base); output.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput. @@ -932,7 +941,8 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( preAllocatedOutputTensor); } else { TORCH_CHECK( - false, "Invalid device type: " + streamInfo.videoStreamOptions.device.str()); + false, + "Invalid device type: " + streamInfo.videoStreamOptions.device.str()); } return output; } @@ -954,8 +964,8 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( AVFrame* avFrame = rawOutput.avFrame.get(); auto& streamInfo = streamInfos_[streamIndex]; - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(streamInfo.videoStreamOptions, *avFrame); + auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( + streamInfo.videoStreamOptions, *avFrame); int expectedOutputHeight = frameDims.height; int expectedOutputWidth = frameDims.width; @@ -1051,9 +1061,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( double seconds) { for (auto& [streamIndex, streamInfo] : streamInfos_) { - double frameStartTime = ptsToSeconds(streamInfo.currentPts, streamInfo.timeBase); + double frameStartTime = + ptsToSeconds(streamInfo.currentPts, streamInfo.timeBase); double frameEndTime = ptsToSeconds( - streamInfo.currentPts + streamInfo.currentDuration, streamInfo.timeBase); + streamInfo.currentPts + streamInfo.currentDuration, + streamInfo.timeBase); if (seconds >= frameStartTime && seconds < frameEndTime) { // We are in the same frame as the one we just returned. However, since we // don't cache it locally, we have to rewind back. @@ -1067,8 +1079,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( [seconds, this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); - double frameEndTime = - ptsToSeconds(avFrame->pts + getDuration(avFrame), streamInfo.timeBase); + double frameEndTime = ptsToSeconds( + avFrame->pts + getDuration(avFrame), streamInfo.timeBase); if (frameStartTime > seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() @@ -1263,7 +1275,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; - BatchDecodedOutput output(frameIndices.size(), videoStreamOptions, streamMetadata); + BatchDecodedOutput output( + frameIndices.size(), videoStreamOptions, streamMetadata); auto previousIndexInVideo = -1; for (size_t f = 0; f < frameIndices.size(); ++f) { @@ -1345,7 +1358,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; - BatchDecodedOutput output(numOutputFrames, videoStreamOptions, streamMetadata); + BatchDecodedOutput output( + numOutputFrames, videoStreamOptions, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { DecodedOutput singleOut = @@ -1442,8 +1456,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange( } VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { - auto rawOutput = - getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* avFrame) { + auto rawOutput = getDecodedOutputWithFilter( + [this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; }); @@ -1559,7 +1573,8 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* avFrame) { FilterState& filterState = streamInfos_[streamIndex].filterState; - int ffmpegStatus = av_buffersrc_write_frame(filterState.sourceContext, avFrame); + int ffmpegStatus = + av_buffersrc_write_frame(filterState.sourceContext, avFrame); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index de4fa44a..f8716a79 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -218,7 +218,8 @@ void _add_video_stream( } auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), videoStreamOptions); + videoDecoder->addVideoStreamDecoder( + stream_index.value_or(-1), videoStreamOptions); } void seek_to_pts(at::Tensor& decoder, double seconds) { @@ -359,7 +360,8 @@ std::string get_json_metadata(at::Tensor& decoder) { } if (maybeBestVideoStreamIndex.has_value()) { - auto streamMetadata = videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex]; + auto streamMetadata = + videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex]; if (streamMetadata.numFramesFromScan.has_value()) { metadataMap["numFrames"] = std::to_string(*streamMetadata.numFramesFromScan); diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 04d8c45d..93163911 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -398,8 +398,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; ourDecoder->addVideoStreamDecoder( bestVideoStreamIndex, - VideoDecoder::VideoStreamOptions( - "color_conversion_library=filtergraph")); + VideoDecoder::VideoStreamOptions("color_conversion_library=filtergraph")); auto output = ourDecoder->getFrameAtIndexInternal( bestVideoStreamIndex, 0, preAllocatedOutputTensor); EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr()); @@ -416,8 +415,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; ourDecoder->addVideoStreamDecoder( bestVideoStreamIndex, - VideoDecoder::VideoStreamOptions( - "color_conversion_library=swscale")); + VideoDecoder::VideoStreamOptions("color_conversion_library=swscale")); auto output = ourDecoder->getFrameAtIndexInternal( bestVideoStreamIndex, 0, preAllocatedOutputTensor); EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr()); From 55a58402cd1406c979aad05e9e5cffaa360b2b0a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 12:44:08 +0000 Subject: [PATCH 34/56] More videoStreamOptions --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 2 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 7 ++++--- src/torchcodec/decoders/_core/DeviceInterface.h | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 996e36cb..9e9ae9e2 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -16,7 +16,7 @@ namespace facebook::torchcodec { void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, - [[maybe_unused]] const VideoDecoder::VideoStreamOptions& options, + [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, [[maybe_unused]] VideoDecoder::RawDecodedOutput& rawOutput, [[maybe_unused]] VideoDecoder::DecodedOutput& output, [[maybe_unused]] std::optional preAllocatedOutputTensor) { diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index a10cdb64..9204936a 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -185,7 +185,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, - const VideoDecoder::VideoStreamOptions& options, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional preAllocatedOutputTensor) { @@ -195,7 +195,8 @@ void convertAVFrameToDecodedOutputOnCuda( src->format == AV_PIX_FMT_CUDA, "Expected format to be AV_PIX_FMT_CUDA, got " + std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); - auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(options, *src); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *src); int height = frameDims.height; int width = frameDims.width; torch::Tensor& dst = output.frame; @@ -212,7 +213,7 @@ void convertAVFrameToDecodedOutputOnCuda( "x3, got ", shape); } else { - dst = allocateEmptyHWCTensor(height, width, options.device); + dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device); } // Use the user-requested GPU for running the NPP kernel. diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index dc0d30c6..e027ce8e 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -31,7 +31,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, - const VideoDecoder::VideoStreamOptions& options, + const VideoDecoder::VideoStreamOptions& videoStreamOptions, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); From 79ea16771b4120407f3aa32c7dda157d9f294a5b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 16:58:03 +0000 Subject: [PATCH 35/56] reduce diff --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6b40f1ad..59f13f2e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -468,15 +468,6 @@ void VideoDecoder::addVideoStreamDecoder( } TORCH_CHECK(avCodec != nullptr); - StreamMetadata& streamMetadata = - containerMetadata_.streamMetadatas[streamIndex]; - if (seekMode_ == SeekMode::approximate && - !streamMetadata.averageFps.has_value()) { - throw std::runtime_error( - "Seek mode is approximate, but stream " + std::to_string(streamIndex) + - " does not have an average fps in its metadata."); - } - StreamInfo& streamInfo = streamInfos_[streamIndex]; streamInfo.streamIndex = streamIndex; streamInfo.timeBase = formatContext_->streams[streamIndex]->time_base; @@ -495,6 +486,15 @@ void VideoDecoder::addVideoStreamDecoder( .value_or(avCodec)); } + StreamMetadata& streamMetadata = + containerMetadata_.streamMetadatas[streamIndex]; + if (seekMode_ == SeekMode::approximate && + !streamMetadata.averageFps.has_value()) { + throw std::runtime_error( + "Seek mode is approximate, but stream " + std::to_string(streamIndex) + + " does not have an average fps in its metadata."); + } + AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(codecContext != nullptr); codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0); From 404b2e45b8c581aedbd87cba79defa289a066274 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 17:08:23 +0000 Subject: [PATCH 36/56] Fix C++ tests --- test/decoders/VideoDecoderTest.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 93163911..2a4c1d93 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -72,8 +72,8 @@ TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { #else EXPECT_NEAR(metadata.bitRate.value(), 324915, 1e-1); #endif - EXPECT_EQ(metadata.streams.size(), 6); - const auto& videoStream = metadata.streams[3]; + EXPECT_EQ(metadata.streamMetadatas.size(), 6); + const auto& videoStream = metadata.streamMetadatas[3]; EXPECT_EQ(videoStream.mediaType, AVMEDIA_TYPE_VIDEO); EXPECT_EQ(videoStream.codecName, "h264"); EXPECT_NEAR(*videoStream.averageFps, 29.97f, 1e-1); @@ -85,7 +85,7 @@ TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { EXPECT_FALSE(videoStream.numFramesFromScan.has_value()); decoder->scanFileAndUpdateMetadataAndIndex(); metadata = decoder->getContainerMetadata(); - const auto& videoStream1 = metadata.streams[3]; + const auto& videoStream1 = metadata.streamMetadatas[3]; EXPECT_EQ(*videoStream1.minPtsSecondsFromScan, 0); EXPECT_EQ(*videoStream1.maxPtsSecondsFromScan, 13.013); EXPECT_EQ(*videoStream1.numFramesFromScan, 390); @@ -428,9 +428,9 @@ TEST_P(VideoDecoderTest, GetAudioMetadata) { VideoDecoder::ContainerMetadata metadata = decoder->getContainerMetadata(); EXPECT_EQ(metadata.numAudioStreams, 1); EXPECT_EQ(metadata.numVideoStreams, 0); - EXPECT_EQ(metadata.streams.size(), 1); + EXPECT_EQ(metadata.streamMetadatas.size(), 1); - const auto& audioStream = metadata.streams[0]; + const auto& audioStream = metadata.streamMetadatas[0]; EXPECT_EQ(audioStream.mediaType, AVMEDIA_TYPE_AUDIO); EXPECT_NEAR(*audioStream.durationSeconds, 13.25, 1e-1); } From 941d6a358612d009e5ba264b5f8f4b7626fd00f5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 22 Jan 2025 17:10:18 +0000 Subject: [PATCH 37/56] Fix CUDA? --- src/torchcodec/decoders/_core/CudaDevice.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 9204936a..c7c00fa1 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -189,14 +189,14 @@ void convertAVFrameToDecodedOutputOnCuda( VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional preAllocatedOutputTensor) { - AVFrame* src = rawOutput.frame.get(); + AVFrame* avFrame = rawOutput.avFrame.get(); TORCH_CHECK( - src->format == AV_PIX_FMT_CUDA, + avFrame->format == AV_PIX_FMT_CUDA, "Expected format to be AV_PIX_FMT_CUDA, got " + - std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *src); + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame); int height = frameDims.height; int width = frameDims.width; torch::Tensor& dst = output.frame; @@ -220,21 +220,21 @@ void convertAVFrameToDecodedOutputOnCuda( c10::cuda::CUDAGuard deviceGuard(device); NppiSize oSizeROI = {width, height}; - Npp8u* input[2] = {src->data[0], src->data[1]}; + Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]}; auto start = std::chrono::high_resolution_clock::now(); NppStatus status; - if (src->colorspace == AVColorSpace::AVCOL_SPC_BT709) { + if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) { status = nppiNV12ToRGB_709CSC_8u_P2C3R( input, - src->linesize[0], + avFrame->linesize[0], static_cast(dst.data_ptr()), dst.stride(0), oSizeROI); } else { status = nppiNV12ToRGB_8u_P2C3R( input, - src->linesize[0], + avFrame->linesize[0], static_cast(dst.data_ptr()), dst.stride(0), oSizeROI); From a7c5711faef3a4c4c3621af6b99fe36800e54943 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 09:35:51 +0000 Subject: [PATCH 38/56] Use allStreamMetadata --- .../decoders/_core/VideoDecoder.cpp | 38 +++++++++++-------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 19 ++++++---- test/decoders/VideoDecoderTest.cpp | 10 ++--- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index ac6a936d..4be92abc 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -289,7 +289,7 @@ void VideoDecoder::initializeDecoder() { containerMetadata_.numAudioStreams++; } - containerMetadata_.streamMetadatas.push_back(streamMetadata); + containerMetadata_.allStreamMetadata.push_back(streamMetadata); } if (formatContext_->duration > 0) { @@ -487,7 +487,7 @@ void VideoDecoder::addVideoStreamDecoder( } StreamMetadata& streamMetadata = - containerMetadata_.streamMetadatas[streamIndex]; + containerMetadata_.allStreamMetadata[streamIndex]; if (seekMode_ == SeekMode::approximate && !streamMetadata.averageFps.has_value()) { throw std::runtime_error( @@ -539,10 +539,11 @@ void VideoDecoder::addVideoStreamDecoder( void VideoDecoder::updateMetadataWithCodecContext( int streamIndex, AVCodecContext* codecContext) { - containerMetadata_.streamMetadatas[streamIndex].width = codecContext->width; - containerMetadata_.streamMetadatas[streamIndex].height = codecContext->height; + containerMetadata_.allStreamMetadata[streamIndex].width = codecContext->width; + containerMetadata_.allStreamMetadata[streamIndex].height = + codecContext->height; auto codedId = codecContext->codec_id; - containerMetadata_.streamMetadatas[streamIndex].codecName = + containerMetadata_.allStreamMetadata[streamIndex].codecName = std::string(avcodec_get_name(codedId)); } @@ -603,7 +604,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // We got a valid packet. Let's figure out what stream it belongs to and // record its relevant metadata. int streamIndex = packet->stream_index; - auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; streamMetadata.minPtsFromScan = std::min( streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts); streamMetadata.maxPtsFromScan = std::max( @@ -624,9 +625,9 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // Set all per-stream metadata that requires knowing the content of all // packets. for (size_t streamIndex = 0; - streamIndex < containerMetadata_.streamMetadatas.size(); + streamIndex < containerMetadata_.allStreamMetadata.size(); ++streamIndex) { - auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; auto avStream = formatContext_->streams[streamIndex]; streamMetadata.numFramesFromScan = @@ -1104,7 +1105,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } void VideoDecoder::validateUserProvidedStreamIndex(int streamIndex) { - int streamsSize = static_cast(containerMetadata_.streamMetadatas.size()); + int streamsSize = + static_cast(containerMetadata_.allStreamMetadata.size()); TORCH_CHECK( streamIndex >= 0 && streamIndex < streamsSize, "Invalid stream index=" + std::to_string(streamIndex) + @@ -1243,7 +1245,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( validateUserProvidedStreamIndex(streamIndex); const auto& streamInfo = streamInfos_[streamIndex]; - const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[streamIndex]; validateFrameIndex(streamMetadata, frameIndex); int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); @@ -1275,7 +1278,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( }); } - const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; BatchDecodedOutput output( @@ -1313,7 +1317,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); - const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; double minSeconds = getMinSeconds(streamMetadata); @@ -1347,7 +1352,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t step) { validateUserProvidedStreamIndex(streamIndex); - const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( @@ -1381,7 +1387,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange( double stopSeconds) { validateUserProvidedStreamIndex(streamIndex); - const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[streamIndex]; TORCH_CHECK( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + @@ -1498,7 +1505,8 @@ double VideoDecoder::getPtsSecondsForFrame( validateScannedAllStreams("getPtsSecondsForFrame"); const auto& streamInfo = streamInfos_[streamIndex]; - const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[streamIndex]; validateFrameIndex(streamMetadata, frameIndex); return ptsToSeconds( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 332f9fda..c690c83c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -104,7 +104,7 @@ class VideoDecoder { std::optional height; }; struct ContainerMetadata { - std::vector streamMetadatas; + std::vector allStreamMetadata; int numAudioStreams = 0; int numVideoStreams = 0; // Note that this is the container-level duration, which is usually the max diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index f8716a79..41b9f4c7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -345,10 +345,11 @@ std::string get_json_metadata(at::Tensor& decoder) { // serialize the metadata into a string std::stringstream ss; double durationSeconds = 0; if (maybeBestVideoStreamIndex.has_value() && - videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex] + videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex] .durationSeconds.has_value()) { - durationSeconds = videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex] - .durationSeconds.value_or(0); + durationSeconds = + videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex] + .durationSeconds.value_or(0); } else { // Fallback to container-level duration if stream duration is not found. durationSeconds = videoMetadata.durationSeconds.value_or(0); @@ -361,7 +362,7 @@ std::string get_json_metadata(at::Tensor& decoder) { if (maybeBestVideoStreamIndex.has_value()) { auto streamMetadata = - videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex]; + videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex]; if (streamMetadata.numFramesFromScan.has_value()) { metadataMap["numFrames"] = std::to_string(*streamMetadata.numFramesFromScan); @@ -425,7 +426,8 @@ std::string get_container_json_metadata(at::Tensor& decoder) { std::to_string(*containerMetadata.bestAudioStreamIndex); } - map["numStreams"] = std::to_string(containerMetadata.streamMetadatas.size()); + map["numStreams"] = + std::to_string(containerMetadata.allStreamMetadata.size()); return mapToJson(map); } @@ -434,13 +436,14 @@ std::string get_stream_json_metadata( at::Tensor& decoder, int64_t stream_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto streamMetadatas = videoDecoder->getContainerMetadata().streamMetadatas; + auto allStreamMetadata = + videoDecoder->getContainerMetadata().allStreamMetadata; if (stream_index < 0 || - stream_index >= static_cast(streamMetadatas.size())) { + stream_index >= static_cast(allStreamMetadata.size())) { throw std::out_of_range( "stream_index out of bounds: " + std::to_string(stream_index)); } - auto streamMetadata = streamMetadatas[stream_index]; + auto streamMetadata = allStreamMetadata[stream_index]; std::map map; diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 2a4c1d93..bcc1c0c8 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -72,8 +72,8 @@ TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { #else EXPECT_NEAR(metadata.bitRate.value(), 324915, 1e-1); #endif - EXPECT_EQ(metadata.streamMetadatas.size(), 6); - const auto& videoStream = metadata.streamMetadatas[3]; + EXPECT_EQ(metadata.allStreamMetadata.size(), 6); + const auto& videoStream = metadata.allStreamMetadata[3]; EXPECT_EQ(videoStream.mediaType, AVMEDIA_TYPE_VIDEO); EXPECT_EQ(videoStream.codecName, "h264"); EXPECT_NEAR(*videoStream.averageFps, 29.97f, 1e-1); @@ -85,7 +85,7 @@ TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { EXPECT_FALSE(videoStream.numFramesFromScan.has_value()); decoder->scanFileAndUpdateMetadataAndIndex(); metadata = decoder->getContainerMetadata(); - const auto& videoStream1 = metadata.streamMetadatas[3]; + const auto& videoStream1 = metadata.allStreamMetadata[3]; EXPECT_EQ(*videoStream1.minPtsSecondsFromScan, 0); EXPECT_EQ(*videoStream1.maxPtsSecondsFromScan, 13.013); EXPECT_EQ(*videoStream1.numFramesFromScan, 390); @@ -428,9 +428,9 @@ TEST_P(VideoDecoderTest, GetAudioMetadata) { VideoDecoder::ContainerMetadata metadata = decoder->getContainerMetadata(); EXPECT_EQ(metadata.numAudioStreams, 1); EXPECT_EQ(metadata.numVideoStreams, 0); - EXPECT_EQ(metadata.streamMetadatas.size(), 1); + EXPECT_EQ(metadata.allStreamMetadata.size(), 1); - const auto& audioStream = metadata.streamMetadatas[0]; + const auto& audioStream = metadata.allStreamMetadata[0]; EXPECT_EQ(audioStream.mediaType, AVMEDIA_TYPE_AUDIO); EXPECT_NEAR(*audioStream.durationSeconds, 13.25, 1e-1); } From 16d5e520bc03c9c44e1e0bf946f66838d12b337d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 09:46:04 +0000 Subject: [PATCH 39/56] Rename BatchDecodedOutput into FrameBatchOutput --- .../decoders/_core/VideoDecoder.cpp | 18 +++++------ src/torchcodec/decoders/_core/VideoDecoder.h | 31 ++++++++++--------- .../decoders/_core/VideoDecoderOps.cpp | 20 ++++++------ .../decoders/_core/VideoDecoderOps.h | 10 +++--- 4 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 4be92abc..1f7a6d64 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -192,7 +192,7 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions( } } -VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( +VideoDecoder::FrameBatchOutput::FrameBatchOutput( int64_t numFrames, const VideoStreamOptions& videoStreamOptions, const StreamMetadata& streamMetadata) @@ -1254,7 +1254,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( return getNextFrameNoDemuxInternal(preAllocatedOutputTensor); } -VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( +VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( int streamIndex, const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); @@ -1282,7 +1282,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( containerMetadata_.allStreamMetadata[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; - BatchDecodedOutput output( + FrameBatchOutput output( frameIndices.size(), videoStreamOptions, streamMetadata); auto previousIndexInVideo = -1; @@ -1312,7 +1312,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( return output; } -VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( +VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestamps( int streamIndex, const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); @@ -1345,7 +1345,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( return getFramesAtIndices(streamIndex, frameIndices); } -VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( +VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( int streamIndex, int64_t start, int64_t stop, @@ -1367,7 +1367,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; - BatchDecodedOutput output( + FrameBatchOutput output( numOutputFrames, videoStreamOptions, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { @@ -1380,7 +1380,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( return output; } -VideoDecoder::BatchDecodedOutput +VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( int streamIndex, double startSeconds, @@ -1416,7 +1416,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // values of the intervals will map to the same frame indices below. Hence, we // need this special case below. if (startSeconds == stopSeconds) { - BatchDecodedOutput output(0, videoStreamOptions, streamMetadata); + FrameBatchOutput output(0, videoStreamOptions, streamMetadata); output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1453,7 +1453,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( secondsToIndexUpperBound(stopSeconds, streamInfo, streamMetadata); int64_t numFrames = stopFrameIndex - startFrameIndex; - BatchDecodedOutput output(numFrames, videoStreamOptions, streamMetadata); + FrameBatchOutput output(numFrames, videoStreamOptions, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { DecodedOutput singleOut = getFrameAtIndexInternal(streamIndex, i, output.frames[f]); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c690c83c..707f2568 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -172,6 +172,7 @@ class VideoDecoder { // The stream index of the decoded frame. int streamIndex; }; + struct DecodedOutput { // The actual decoded output as a Tensor. torch::Tensor frame; @@ -183,6 +184,18 @@ class VideoDecoder { // The duration of the decoded frame in seconds. double durationSeconds; }; + + struct FrameBatchOutput { + torch::Tensor frames; + torch::Tensor ptsSeconds; + torch::Tensor durationSeconds; + + explicit FrameBatchOutput( + int64_t numFrames, + const VideoStreamOptions& videoStreamOptions, + const StreamMetadata& streamMetadata); + }; + class EndOfFileException : public std::runtime_error { public: explicit EndOfFileException(const std::string& msg) @@ -207,24 +220,14 @@ class VideoDecoder { int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor = std::nullopt); - struct BatchDecodedOutput { - torch::Tensor frames; - torch::Tensor ptsSeconds; - torch::Tensor durationSeconds; - - explicit BatchDecodedOutput( - int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); - }; // Returns frames at the given indices for a given stream as a single stacked // Tensor. - BatchDecodedOutput getFramesAtIndices( + FrameBatchOutput getFramesAtIndices( int streamIndex, const std::vector& frameIndices); - BatchDecodedOutput getFramesPlayedByTimestamps( + FrameBatchOutput getFramesPlayedByTimestamps( int streamIndex, const std::vector& timestamps); @@ -233,7 +236,7 @@ class VideoDecoder { // the range are: // [start, start+step, start+(2*step), start+(3*step), ..., stop) // The default for step is 1. - BatchDecodedOutput + FrameBatchOutput getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step); // Returns frames within a given pts range for a given stream as a single @@ -253,7 +256,7 @@ class VideoDecoder { // Valid values for startSeconds and stopSeconds are: // // [minPtsSecondsFromScan, maxPtsSecondsFromScan) - BatchDecodedOutput getFramesPlayedByTimestampInRange( + FrameBatchOutput getFramesPlayedByTimestampInRange( int streamIndex, double startSeconds, double stopSeconds); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 41b9f4c7..87b0a03d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -85,8 +85,8 @@ OpsDecodedOutput makeOpsDecodedOutput(VideoDecoder::DecodedOutput& frame) { torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64))); } -OpsBatchDecodedOutput makeOpsBatchDecodedOutput( - VideoDecoder::BatchDecodedOutput& batch) { +OpsFrameBatchOutput makeOpsFrameBatchOutput( + VideoDecoder::FrameBatchOutput& batch) { return std::make_tuple(batch.frames, batch.ptsSeconds, batch.durationSeconds); } @@ -258,7 +258,7 @@ OpsDecodedOutput get_frame_at_index( return makeOpsDecodedOutput(result); } -OpsBatchDecodedOutput get_frames_at_indices( +OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices) { @@ -266,10 +266,10 @@ OpsBatchDecodedOutput get_frames_at_indices( std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); - return makeOpsBatchDecodedOutput(result); + return makeOpsFrameBatchOutput(result); } -OpsBatchDecodedOutput get_frames_in_range( +OpsFrameBatchOutput get_frames_in_range( at::Tensor& decoder, int64_t stream_index, int64_t start, @@ -278,9 +278,9 @@ OpsBatchDecodedOutput get_frames_in_range( auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramesInRange( stream_index, start, stop, step.value_or(1)); - return makeOpsBatchDecodedOutput(result); + return makeOpsFrameBatchOutput(result); } -OpsBatchDecodedOutput get_frames_by_pts( +OpsFrameBatchOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, at::ArrayRef timestamps) { @@ -288,10 +288,10 @@ OpsBatchDecodedOutput get_frames_by_pts( std::vector timestampsVec(timestamps.begin(), timestamps.end()); auto result = videoDecoder->getFramesPlayedByTimestamps(stream_index, timestampsVec); - return makeOpsBatchDecodedOutput(result); + return makeOpsFrameBatchOutput(result); } -OpsBatchDecodedOutput get_frames_by_pts_in_range( +OpsFrameBatchOutput get_frames_by_pts_in_range( at::Tensor& decoder, int64_t stream_index, double start_seconds, @@ -299,7 +299,7 @@ OpsBatchDecodedOutput get_frames_by_pts_in_range( auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramesPlayedByTimestampInRange( stream_index, start_seconds, stop_seconds); - return makeOpsBatchDecodedOutput(result); + return makeOpsFrameBatchOutput(result); } std::string quoteValue(const std::string& value) { diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 60635094..fc5927b9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -75,7 +75,7 @@ using OpsDecodedOutput = std::tuple; // float. // 3. Tensor of N durationis in seconds, where each duration is a // single float. -using OpsBatchDecodedOutput = std::tuple; +using OpsFrameBatchOutput = std::tuple; // Return the frame that is visible at a given timestamp in seconds. Each frame // in FFMPEG has a presentation timestamp and a duration. The frame visible at a @@ -83,7 +83,7 @@ using OpsBatchDecodedOutput = std::tuple; OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); // Return the frames at given ptss for a given stream -OpsBatchDecodedOutput get_frames_by_pts( +OpsFrameBatchOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, at::ArrayRef timestamps); @@ -99,14 +99,14 @@ OpsDecodedOutput get_frame_at_index( OpsDecodedOutput get_next_frame(at::Tensor& decoder); // Return the frames at given indices for a given stream -OpsBatchDecodedOutput get_frames_at_indices( +OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices); // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). -OpsBatchDecodedOutput get_frames_in_range( +OpsFrameBatchOutput get_frames_in_range( at::Tensor& decoder, int64_t stream_index, int64_t start, @@ -116,7 +116,7 @@ OpsBatchDecodedOutput get_frames_in_range( // Return the frames inside the range as a single stacked Tensor. The range is // defined as [start_seconds, stop_seconds). The frames are stacked in pts // order. -OpsBatchDecodedOutput get_frames_by_pts_in_range( +OpsFrameBatchOutput get_frames_by_pts_in_range( at::Tensor& decoder, int64_t stream_index, double start_seconds, From 78b095a78ccd346b345b5f6306b77bf5a269b4a0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 09:47:02 +0000 Subject: [PATCH 40/56] Rename RawDecodedOutput into AVFrameWithStreamIndex --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 2 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 2 +- src/torchcodec/decoders/_core/DeviceInterface.h | 2 +- src/torchcodec/decoders/_core/VideoDecoder.cpp | 16 ++++++++-------- src/torchcodec/decoders/_core/VideoDecoder.h | 10 +++++----- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 0b4be2d7..eae95ea9 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -17,7 +17,7 @@ namespace facebook::torchcodec { void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, - [[maybe_unused]] VideoDecoder::RawDecodedOutput& rawOutput, + [[maybe_unused]] VideoDecoder::AVFrameWithStreamIndex& rawOutput, [[maybe_unused]] VideoDecoder::DecodedOutput& output, [[maybe_unused]] std::optional preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index c7c00fa1..2778d3f3 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -186,7 +186,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, - VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& rawOutput, VideoDecoder::DecodedOutput& output, std::optional preAllocatedOutputTensor) { AVFrame* avFrame = rawOutput.avFrame.get(); diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 061231a8..204da050 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -32,7 +32,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, - VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& rawOutput, VideoDecoder::DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 1f7a6d64..bf24696e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -803,7 +803,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { } } -VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( +VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getDecodedOutputWithFilter( std::function filterFunction) { if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); @@ -912,14 +912,14 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; activeStreamInfo.currentPts = avFrame->pts; activeStreamInfo.currentDuration = getDuration(avFrame); - RawDecodedOutput rawOutput; + AVFrameWithStreamIndex rawOutput; rawOutput.streamIndex = frameStreamIndex; rawOutput.avFrame = std::move(avFrame); return rawOutput; } VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( - VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& rawOutput, std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. DecodedOutput output; @@ -932,7 +932,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( avFrame->pts, formatContext_->streams[streamIndex]->time_base); output.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); - // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput. + // TODO: we should fold preAllocatedOutputTensor into AVFrameWithStreamIndex. if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToDecodedOutputOnCPU( rawOutput, output, preAllocatedOutputTensor); @@ -961,7 +961,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( - VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& rawOutput, DecodedOutput& output, std::optional preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; @@ -1079,7 +1079,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } setCursorPtsInSeconds(seconds); - RawDecodedOutput rawOutput = getDecodedOutputWithFilter( + AVFrameWithStreamIndex rawOutput = getDecodedOutputWithFilter( [seconds, this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); @@ -1465,7 +1465,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( return output; } -VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { +VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameWithStreamIndexNoDemux() { auto rawOutput = getDecodedOutputWithFilter( [this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; @@ -1482,7 +1482,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { - auto rawOutput = getNextRawDecodedOutputNoDemux(); + auto rawOutput = getNextAVFrameWithStreamIndexNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 707f2568..e5c1caf0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -166,7 +166,7 @@ class VideoDecoder { void setCursorPtsInSeconds(double seconds); // This structure ensures we always keep the streamIndex and AVFrame together // Note that AVFrame itself doesn't retain the streamIndex. - struct RawDecodedOutput { + struct AVFrameWithStreamIndex { // The actual decoded output as a unique pointer to an AVFrame. UniqueAVFrame avFrame; // The stream index of the decoded frame. @@ -386,9 +386,9 @@ class VideoDecoder { const enum AVColorSpace colorspace); void maybeSeekToBeforeDesiredPts(); - RawDecodedOutput getDecodedOutputWithFilter( + AVFrameWithStreamIndex getDecodedOutputWithFilter( std::function); - RawDecodedOutput getNextRawDecodedOutputNoDemux(); + AVFrameWithStreamIndex getNextAVFrameWithStreamIndexNoDemux(); // Once we create a decoder can update the metadata with the codec context. // For example, for video streams, we can add the height and width of the // decoded stream. @@ -404,10 +404,10 @@ class VideoDecoder { const AVFrame* avFrame, torch::Tensor& outputTensor); DecodedOutput convertAVFrameToDecodedOutput( - RawDecodedOutput& rawOutput, + AVFrameWithStreamIndex& rawOutput, std::optional preAllocatedOutputTensor = std::nullopt); void convertAVFrameToDecodedOutputOnCPU( - RawDecodedOutput& rawOutput, + AVFrameWithStreamIndex& rawOutput, DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); From e1e46ff1ac96a53780659122f501a287aa88c186 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 09:47:45 +0000 Subject: [PATCH 41/56] Rename DecodedOutput into FrameOutput --- .../decoders/_core/CPUOnlyDevice.cpp | 4 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 4 +- .../decoders/_core/DeviceInterface.h | 4 +- .../decoders/_core/VideoDecoder.cpp | 38 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 22 +++++------ .../decoders/_core/VideoDecoderOps.cpp | 16 ++++---- .../decoders/_core/VideoDecoderOps.h | 8 ++-- 7 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index eae95ea9..3c17012d 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -14,11 +14,11 @@ namespace facebook::torchcodec { TORCH_CHECK(false, "Unsupported device: " + device.str()); } -void convertAVFrameToDecodedOutputOnCuda( +void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, [[maybe_unused]] VideoDecoder::AVFrameWithStreamIndex& rawOutput, - [[maybe_unused]] VideoDecoder::DecodedOutput& output, + [[maybe_unused]] VideoDecoder::FrameOutput& output, [[maybe_unused]] std::optional preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 2778d3f3..3756371c 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -183,11 +183,11 @@ void initializeContextOnCuda( return; } -void convertAVFrameToDecodedOutputOnCuda( +void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, VideoDecoder::AVFrameWithStreamIndex& rawOutput, - VideoDecoder::DecodedOutput& output, + VideoDecoder::FrameOutput& output, std::optional preAllocatedOutputTensor) { AVFrame* avFrame = rawOutput.avFrame.get(); diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 204da050..21957194 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -29,11 +29,11 @@ void initializeContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); -void convertAVFrameToDecodedOutputOnCuda( +void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, VideoDecoder::AVFrameWithStreamIndex& rawOutput, - VideoDecoder::DecodedOutput& output, + VideoDecoder::FrameOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); void releaseContextOnCuda( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index bf24696e..d1016b65 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -803,7 +803,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { } } -VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getDecodedOutputWithFilter( +VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getFrameOutputWithFilter( std::function filterFunction) { if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); @@ -918,11 +918,11 @@ VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getDecodedOutputWithFilter( return rawOutput; } -VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( +VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( VideoDecoder::AVFrameWithStreamIndex& rawOutput, std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. - DecodedOutput output; + FrameOutput output; int streamIndex = rawOutput.streamIndex; AVFrame* avFrame = rawOutput.avFrame.get(); output.streamIndex = streamIndex; @@ -934,10 +934,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); // TODO: we should fold preAllocatedOutputTensor into AVFrameWithStreamIndex. if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { - convertAVFrameToDecodedOutputOnCPU( + convertAVFrameToFrameOutputOnCPU( rawOutput, output, preAllocatedOutputTensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { - convertAVFrameToDecodedOutputOnCuda( + convertAVFrameToFrameOutputOnCuda( streamInfo.videoStreamOptions.device, streamInfo.videoStreamOptions, rawOutput, @@ -960,9 +960,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( // TODO: Figure out whether that's possible! // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. -void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( +void VideoDecoder::convertAVFrameToFrameOutputOnCPU( VideoDecoder::AVFrameWithStreamIndex& rawOutput, - DecodedOutput& output, + FrameOutput& output, std::optional preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; AVFrame* avFrame = rawOutput.avFrame.get(); @@ -1062,7 +1062,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( } } -VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( +VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( double seconds) { for (auto& [streamIndex, streamInfo] : streamInfos_) { double frameStartTime = @@ -1079,7 +1079,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } setCursorPtsInSeconds(seconds); - AVFrameWithStreamIndex rawOutput = getDecodedOutputWithFilter( + AVFrameWithStreamIndex rawOutput = getFrameOutputWithFilter( [seconds, this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); @@ -1099,7 +1099,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( }); // Convert the frame to tensor. - DecodedOutput output = convertAVFrameToDecodedOutput(rawOutput); + FrameOutput output = convertAVFrameToFrameOutput(rawOutput); output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } @@ -1136,7 +1136,7 @@ void VideoDecoder::validateFrameIndex( " numFrames=" + std::to_string(numFrames)); } -VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( +VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex( int streamIndex, int64_t frameIndex) { auto output = getFrameAtIndexInternal(streamIndex, frameIndex); @@ -1238,7 +1238,7 @@ int64_t VideoDecoder::secondsToIndexUpperBound( } } -VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( +VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor) { @@ -1301,7 +1301,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( output.durationSeconds[indexInOutput] = output.durationSeconds[previousIndexInOutput]; } else { - DecodedOutput singleOut = getFrameAtIndexInternal( + FrameOutput singleOut = getFrameAtIndexInternal( streamIndex, indexInVideo, output.frames[indexInOutput]); output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; output.durationSeconds[indexInOutput] = singleOut.durationSeconds; @@ -1371,7 +1371,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( numOutputFrames, videoStreamOptions, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - DecodedOutput singleOut = + FrameOutput singleOut = getFrameAtIndexInternal(streamIndex, i, output.frames[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; @@ -1455,7 +1455,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( FrameBatchOutput output(numFrames, videoStreamOptions, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = + FrameOutput singleOut = getFrameAtIndexInternal(streamIndex, i, output.frames[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; @@ -1466,7 +1466,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( } VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameWithStreamIndexNoDemux() { - auto rawOutput = getDecodedOutputWithFilter( + auto rawOutput = getFrameOutputWithFilter( [this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; @@ -1474,16 +1474,16 @@ VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameWithStreamIndex return rawOutput; } -VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { +VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { auto output = getNextFrameNoDemuxInternal(); output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } -VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemuxInternal( +VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { auto rawOutput = getNextAVFrameWithStreamIndexNoDemux(); - return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); + return convertAVFrameToFrameOutput(rawOutput, preAllocatedOutputTensor); } void VideoDecoder::setCursorPtsInSeconds(double seconds) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index e5c1caf0..a6ab56a2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -33,7 +33,7 @@ video_decoder.addVideoStreamDecoder(-1); // API for seeking and frame extraction: // Let's extract the first frame at or after pts=5.0 seconds. video_decoder.setCursorPtsInSeconds(5.0); -auto output = video_decoder->getNextDecodedOutput(); +auto output = video_decoder->getNextFrameOutput(); torch::Tensor frame = output.frame; double presentation_timestamp = output.ptsSeconds; // Note that presentation_timestamp can be any timestamp at 5.0 or above @@ -173,7 +173,7 @@ class VideoDecoder { int streamIndex; }; - struct DecodedOutput { + struct FrameOutput { // The actual decoded output as a Tensor. torch::Tensor frame; // The stream index of the decoded frame. Used to distinguish @@ -203,20 +203,20 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - DecodedOutput getNextFrameNoDemux(); + FrameOutput getNextFrameNoDemux(); // Decodes the first frame in any added stream that is visible at a given // timestamp. Frames in the video have a presentation timestamp and a // duration. For example, if a frame has presentation timestamp of 5.0s and a // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. - DecodedOutput getFramePlayedAtTimestampNoDemux(double seconds); + FrameOutput getFramePlayedAtTimestampNoDemux(double seconds); - DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); + FrameOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); // This is morally private but needs to be exposed for C++ tests. Once // getFrameAtIndex supports the preAllocatedOutputTensor parameter, we can // move it back to private. - DecodedOutput getFrameAtIndexInternal( + FrameOutput getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor = std::nullopt); @@ -386,7 +386,7 @@ class VideoDecoder { const enum AVColorSpace colorspace); void maybeSeekToBeforeDesiredPts(); - AVFrameWithStreamIndex getDecodedOutputWithFilter( + AVFrameWithStreamIndex getFrameOutputWithFilter( std::function); AVFrameWithStreamIndex getNextAVFrameWithStreamIndexNoDemux(); // Once we create a decoder can update the metadata with the codec context. @@ -403,15 +403,15 @@ class VideoDecoder { int streamIndex, const AVFrame* avFrame, torch::Tensor& outputTensor); - DecodedOutput convertAVFrameToDecodedOutput( + FrameOutput convertAVFrameToFrameOutput( AVFrameWithStreamIndex& rawOutput, std::optional preAllocatedOutputTensor = std::nullopt); - void convertAVFrameToDecodedOutputOnCPU( + void convertAVFrameToFrameOutputOnCPU( AVFrameWithStreamIndex& rawOutput, - DecodedOutput& output, + FrameOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); - DecodedOutput getNextFrameNoDemuxInternal( + FrameOutput getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor = std::nullopt); SeekMode seekMode_; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 87b0a03d..854c8b90 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -78,7 +78,7 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { return decoder; } -OpsDecodedOutput makeOpsDecodedOutput(VideoDecoder::DecodedOutput& frame) { +OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { return std::make_tuple( frame.frame, torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)), @@ -227,9 +227,9 @@ void seek_to_pts(at::Tensor& decoder, double seconds) { videoDecoder->setCursorPtsInSeconds(seconds); } -OpsDecodedOutput get_next_frame(at::Tensor& decoder) { +OpsFrameOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - VideoDecoder::DecodedOutput result; + VideoDecoder::FrameOutput result; try { result = videoDecoder->getNextFrameNoDemux(); } catch (const VideoDecoder::EndOfFileException& e) { @@ -240,22 +240,22 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) { "image_size is unexpected. Expected 3, got: " + std::to_string(result.frame.sizes().size())); } - return makeOpsDecodedOutput(result); + return makeOpsFrameOutput(result); } -OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { +OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramePlayedAtTimestampNoDemux(seconds); - return makeOpsDecodedOutput(result); + return makeOpsFrameOutput(result); } -OpsDecodedOutput get_frame_at_index( +OpsFrameOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, int64_t frame_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index); - return makeOpsDecodedOutput(result); + return makeOpsFrameOutput(result); } OpsFrameBatchOutput get_frames_at_indices( diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index fc5927b9..5b25e7f6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -64,7 +64,7 @@ void seek_to_pts(at::Tensor& decoder, double seconds); // 3. A single float value for the duration in seconds. // The reason we use Tensors for the second and third values is so we can run // under torch.compile(). -using OpsDecodedOutput = std::tuple; +using OpsFrameOutput = std::tuple; // All elements of this tuple are tensors of the same leading dimension. The // tuple represents the frames for N total frames, where N is the dimension of @@ -80,7 +80,7 @@ using OpsFrameBatchOutput = std::tuple; // Return the frame that is visible at a given timestamp in seconds. Each frame // in FFMPEG has a presentation timestamp and a duration. The frame visible at a // given timestamp T has T >= PTS and T < PTS + Duration. -OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); +OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds); // Return the frames at given ptss for a given stream OpsFrameBatchOutput get_frames_by_pts( @@ -89,14 +89,14 @@ OpsFrameBatchOutput get_frames_by_pts( at::ArrayRef timestamps); // Return the frame that is visible at a given index in the video. -OpsDecodedOutput get_frame_at_index( +OpsFrameOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, int64_t frame_index); // Get the next frame from the video as a tuple that has the frame data, pts and // duration as tensors. -OpsDecodedOutput get_next_frame(at::Tensor& decoder); +OpsFrameOutput get_next_frame(at::Tensor& decoder); // Return the frames at given indices for a given stream OpsFrameBatchOutput get_frames_at_indices( From cd0a18199008fa61b836c4cfe03bd5990e5b3040 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 09:54:11 +0000 Subject: [PATCH 42/56] Rename .frames into .data --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 18 +++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index d1016b65..a2f0f827 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -202,7 +202,7 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput( videoStreamOptions, streamMetadata); int height = frameDims.height; int width = frameDims.width; - frames = allocateEmptyHWCTensor( + data = allocateEmptyHWCTensor( height, width, videoStreamOptions.device, numFrames); } @@ -1295,20 +1295,20 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; - output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); + output.data[indexInOutput].copy_(output.data[previousIndexInOutput]); output.ptsSeconds[indexInOutput] = output.ptsSeconds[previousIndexInOutput]; output.durationSeconds[indexInOutput] = output.durationSeconds[previousIndexInOutput]; } else { FrameOutput singleOut = getFrameAtIndexInternal( - streamIndex, indexInVideo, output.frames[indexInOutput]); + streamIndex, indexInVideo, output.data[indexInOutput]); output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; output.durationSeconds[indexInOutput] = singleOut.durationSeconds; } previousIndexInVideo = indexInVideo; } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + output.data = maybePermuteHWC2CHW(streamIndex, output.data); return output; } @@ -1372,11 +1372,11 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput singleOut = - getFrameAtIndexInternal(streamIndex, i, output.frames[f]); + getFrameAtIndexInternal(streamIndex, i, output.data[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + output.data = maybePermuteHWC2CHW(streamIndex, output.data); return output; } @@ -1417,7 +1417,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // need this special case below. if (startSeconds == stopSeconds) { FrameBatchOutput output(0, videoStreamOptions, streamMetadata); - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + output.data = maybePermuteHWC2CHW(streamIndex, output.data); return output; } @@ -1456,11 +1456,11 @@ VideoDecoder::getFramesPlayedByTimestampInRange( FrameBatchOutput output(numFrames, videoStreamOptions, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { FrameOutput singleOut = - getFrameAtIndexInternal(streamIndex, i, output.frames[f]); + getFrameAtIndexInternal(streamIndex, i, output.data[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + output.data = maybePermuteHWC2CHW(streamIndex, output.data); return output; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a6ab56a2..a334c302 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -186,7 +186,7 @@ class VideoDecoder { }; struct FrameBatchOutput { - torch::Tensor frames; + torch::Tensor data; torch::Tensor ptsSeconds; torch::Tensor durationSeconds; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 854c8b90..e64f9ab6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -87,7 +87,7 @@ OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { OpsFrameBatchOutput makeOpsFrameBatchOutput( VideoDecoder::FrameBatchOutput& batch) { - return std::make_tuple(batch.frames, batch.ptsSeconds, batch.durationSeconds); + return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); } VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { From 0111bfcc1608192b94b4d687af74f6a6fde70d61 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 10:59:16 +0000 Subject: [PATCH 43/56] rename .frame to .data --- .../decoders/_core/VideoDecoder.cpp | 27 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 6 ++--- test/decoders/VideoDecoderTest.cpp | 22 +++++++-------- 4 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index a2f0f827..71ecaf41 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -952,7 +952,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( } // Note [preAllocatedOutputTensor with swscale and filtergraph]: -// Callers may pass a pre-allocated tensor, where the output frame tensor will +// Callers may pass a pre-allocated tensor, where the output.data tensor will // be stored. This parameter is honored in any case, but it only leads to a // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the // decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet @@ -1023,7 +1023,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( " != ", expectedOutputHeight); - output.frame = outputTensor; + output.data = outputTensor; } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { @@ -1051,9 +1051,9 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( // We have already validated that preAllocatedOutputTensor and // outputTensor have the same shape. preAllocatedOutputTensor.value().copy_(outputTensor); - output.frame = preAllocatedOutputTensor.value(); + output.data = preAllocatedOutputTensor.value(); } else { - output.frame = outputTensor; + output.data = outputTensor; } } else { throw std::runtime_error( @@ -1100,7 +1100,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( // Convert the frame to tensor. FrameOutput output = convertAVFrameToFrameOutput(rawOutput); - output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); + output.data = maybePermuteHWC2CHW(output.streamIndex, output.data); return output; } @@ -1140,7 +1140,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex( int streamIndex, int64_t frameIndex) { auto output = getFrameAtIndexInternal(streamIndex, frameIndex); - output.frame = maybePermuteHWC2CHW(streamIndex, output.frame); + output.data = maybePermuteHWC2CHW(streamIndex, output.data); return output; } @@ -1367,8 +1367,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; - FrameBatchOutput output( - numOutputFrames, videoStreamOptions, streamMetadata); + FrameBatchOutput output(numOutputFrames, videoStreamOptions, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput singleOut = @@ -1380,8 +1379,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( return output; } -VideoDecoder::FrameBatchOutput -VideoDecoder::getFramesPlayedByTimestampInRange( +VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( int streamIndex, double startSeconds, double stopSeconds) { @@ -1465,9 +1463,10 @@ VideoDecoder::getFramesPlayedByTimestampInRange( return output; } -VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameWithStreamIndexNoDemux() { - auto rawOutput = getFrameOutputWithFilter( - [this](int frameStreamIndex, AVFrame* avFrame) { +VideoDecoder::AVFrameWithStreamIndex +VideoDecoder::getNextAVFrameWithStreamIndexNoDemux() { + auto rawOutput = + getFrameOutputWithFilter([this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; }); @@ -1476,7 +1475,7 @@ VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameWithStreamIndex VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { auto output = getNextFrameNoDemuxInternal(); - output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); + output.data = maybePermuteHWC2CHW(output.streamIndex, output.data); return output; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a334c302..3fb3e357 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -175,7 +175,7 @@ class VideoDecoder { struct FrameOutput { // The actual decoded output as a Tensor. - torch::Tensor frame; + torch::Tensor data; // The stream index of the decoded frame. Used to distinguish // between streams that are of the same type. int streamIndex; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index e64f9ab6..62699280 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -80,7 +80,7 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { return std::make_tuple( - frame.frame, + frame.data, torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)), torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64))); } @@ -235,10 +235,10 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { } catch (const VideoDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } - if (result.frame.sizes().size() != 3) { + if (result.data.sizes().size() != 3) { throw std::runtime_error( "image_size is unexpected. Expected 3, got: " + - std::to_string(result.frame.sizes().size())); + std::to_string(result.data.sizes().size())); } return makeOpsFrameOutput(result); } diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index bcc1c0c8..730732fa 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -171,11 +171,11 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); auto output = ourDecoder->getNextFrameNoDemux(); - torch::Tensor tensor0FromOurDecoder = output.frame; + torch::Tensor tensor0FromOurDecoder = output.data; EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 0.0); output = ourDecoder->getNextFrameNoDemux(); - torch::Tensor tensor1FromOurDecoder = output.frame; + torch::Tensor tensor1FromOurDecoder = output.data; EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); @@ -211,7 +211,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { ourDecoder->addVideoStreamDecoder(bestVideoStreamIndex); // Frame with index 180 corresponds to timestamp 6.006. auto output = ourDecoder->getFramesAtIndices(bestVideoStreamIndex, {0, 180}); - auto tensor = output.frames; + auto tensor = output.data; EXPECT_EQ(tensor.sizes(), std::vector({2, 3, 270, 480})); torch::Tensor tensor0FromFFMPEG = @@ -235,7 +235,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) { VideoDecoder::VideoStreamOptions("dimension_order=NHWC")); // Frame with index 180 corresponds to timestamp 6.006. auto output = ourDecoder->getFramesAtIndices(bestVideoStreamIndex, {0, 180}); - auto tensor = output.frames; + auto tensor = output.; EXPECT_EQ(tensor.sizes(), std::vector({2, 270, 480, 3})); torch::Tensor tensor0FromFFMPEG = @@ -299,7 +299,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(6.0); auto output = ourDecoder->getNextFrameNoDemux(); - torch::Tensor tensor6FromOurDecoder = output.frame; + torch::Tensor tensor6FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); torch::Tensor tensor6FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time6.000000.pt"); @@ -315,7 +315,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->setCursorPtsInSeconds(6.1); output = ourDecoder->getNextFrameNoDemux(); - torch::Tensor tensor61FromOurDecoder = output.frame; + torch::Tensor tensor61FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000); torch::Tensor tensor61FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time6.100000.pt"); @@ -335,7 +335,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->setCursorPtsInSeconds(10.0); output = ourDecoder->getNextFrameNoDemux(); - torch::Tensor tensor10FromOurDecoder = output.frame; + torch::Tensor tensor10FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000); torch::Tensor tensor10FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time10.000000.pt"); @@ -352,7 +352,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { ourDecoder->setCursorPtsInSeconds(6.0); output = ourDecoder->getNextFrameNoDemux(); - tensor6FromOurDecoder = output.frame; + tensor6FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG)); EXPECT_EQ(ourDecoder->getDecodeStats().numSeeksAttempted, 1); @@ -367,7 +367,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9 ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream); output = ourDecoder->getNextFrameNoDemux(); - torch::Tensor tensor7FromOurDecoder = output.frame; + torch::Tensor tensor7FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); torch::Tensor tensor7FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time12.979633.pt"); @@ -401,7 +401,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) { VideoDecoder::VideoStreamOptions("color_conversion_library=filtergraph")); auto output = ourDecoder->getFrameAtIndexInternal( bestVideoStreamIndex, 0, preAllocatedOutputTensor); - EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr()); + EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); } TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) { @@ -418,7 +418,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) { VideoDecoder::VideoStreamOptions("color_conversion_library=swscale")); auto output = ourDecoder->getFrameAtIndexInternal( bestVideoStreamIndex, 0, preAllocatedOutputTensor); - EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr()); + EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); } TEST_P(VideoDecoderTest, GetAudioMetadata) { From 42ea096cd087f4998012aac7cfbac43beba0eb1e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:08:23 +0000 Subject: [PATCH 44/56] Use frameBatchOutput variable name --- .../decoders/_core/VideoDecoder.cpp | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 71ecaf41..d90eb792 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1282,7 +1282,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( containerMetadata_.allStreamMetadata[streamIndex]; const auto& streamInfo = streamInfos_[streamIndex]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; - FrameBatchOutput output( + FrameBatchOutput frameBatchOutput( frameIndices.size(), videoStreamOptions, streamMetadata); auto previousIndexInVideo = -1; @@ -1295,21 +1295,24 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; - output.data[indexInOutput].copy_(output.data[previousIndexInOutput]); - output.ptsSeconds[indexInOutput] = - output.ptsSeconds[previousIndexInOutput]; - output.durationSeconds[indexInOutput] = - output.durationSeconds[previousIndexInOutput]; + frameBatchOutput.data[indexInOutput].copy_( + frameBatchOutput.data[previousIndexInOutput]); + frameBatchOutput.ptsSeconds[indexInOutput] = + frameBatchOutput.ptsSeconds[previousIndexInOutput]; + frameBatchOutput.durationSeconds[indexInOutput] = + frameBatchOutput.durationSeconds[previousIndexInOutput]; } else { FrameOutput singleOut = getFrameAtIndexInternal( - streamIndex, indexInVideo, output.data[indexInOutput]); - output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; - output.durationSeconds[indexInOutput] = singleOut.durationSeconds; + streamIndex, indexInVideo, frameBatchOutput.data[indexInOutput]); + frameBatchOutput.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; + frameBatchOutput.durationSeconds[indexInOutput] = + singleOut.durationSeconds; } previousIndexInVideo = indexInVideo; } - output.data = maybePermuteHWC2CHW(streamIndex, output.data); - return output; + frameBatchOutput.data = + maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + return frameBatchOutput; } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestamps( @@ -1367,16 +1370,18 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; - FrameBatchOutput output(numOutputFrames, videoStreamOptions, streamMetadata); + FrameBatchOutput frameBatchOutput( + numOutputFrames, videoStreamOptions, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput singleOut = - getFrameAtIndexInternal(streamIndex, i, output.data[f]); - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; + getFrameAtIndexInternal(streamIndex, i, frameBatchOutput.data[f]); + frameBatchOutput.ptsSeconds[f] = singleOut.ptsSeconds; + frameBatchOutput.durationSeconds[f] = singleOut.durationSeconds; } - output.data = maybePermuteHWC2CHW(streamIndex, output.data); - return output; + frameBatchOutput.data = + maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + return frameBatchOutput; } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( @@ -1414,9 +1419,9 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( // values of the intervals will map to the same frame indices below. Hence, we // need this special case below. if (startSeconds == stopSeconds) { - FrameBatchOutput output(0, videoStreamOptions, streamMetadata); - output.data = maybePermuteHWC2CHW(streamIndex, output.data); - return output; + FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata); + frameBatchOutput.data = maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + return frameBatchOutput; } double minSeconds = getMinSeconds(streamMetadata); @@ -1451,16 +1456,18 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( secondsToIndexUpperBound(stopSeconds, streamInfo, streamMetadata); int64_t numFrames = stopFrameIndex - startFrameIndex; - FrameBatchOutput output(numFrames, videoStreamOptions, streamMetadata); + FrameBatchOutput frameBatchOutput( + numFrames, videoStreamOptions, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { FrameOutput singleOut = - getFrameAtIndexInternal(streamIndex, i, output.data[f]); - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; + getFrameAtIndexInternal(streamIndex, i, frameBatchOutput.data[f]); + frameBatchOutput.ptsSeconds[f] = singleOut.ptsSeconds; + frameBatchOutput.durationSeconds[f] = singleOut.durationSeconds; } - output.data = maybePermuteHWC2CHW(streamIndex, output.data); + frameBatchOutput.data = + maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); - return output; + return frameBatchOutput; } VideoDecoder::AVFrameWithStreamIndex From 3a2ab2deb5f5d74f163b4175044f5ba5231a8122 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:10:55 +0000 Subject: [PATCH 45/56] rename getFrameOutputWithFilter --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 +++--- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index d90eb792..6ad28cac 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -803,7 +803,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { } } -VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getFrameOutputWithFilter( +VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getAVFrameUsingFilterFunction( std::function filterFunction) { if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); @@ -1079,7 +1079,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } setCursorPtsInSeconds(seconds); - AVFrameWithStreamIndex rawOutput = getFrameOutputWithFilter( + AVFrameWithStreamIndex rawOutput = getAVFrameUsingFilterFunction( [seconds, this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); @@ -1473,7 +1473,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameWithStreamIndexNoDemux() { auto rawOutput = - getFrameOutputWithFilter([this](int frameStreamIndex, AVFrame* avFrame) { + getAVFrameUsingFilterFunction([this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; }); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 3fb3e357..ce700c17 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -386,7 +386,7 @@ class VideoDecoder { const enum AVColorSpace colorspace); void maybeSeekToBeforeDesiredPts(); - AVFrameWithStreamIndex getFrameOutputWithFilter( + AVFrameWithStreamIndex getAVFrameUsingFilterFunction( std::function); AVFrameWithStreamIndex getNextAVFrameWithStreamIndexNoDemux(); // Once we create a decoder can update the metadata with the codec context. From 05da318111b772e14fbdafc920595b404a6e49f6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:14:58 +0000 Subject: [PATCH 46/56] Use frameOutput name --- .../decoders/_core/VideoDecoder.cpp | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6ad28cac..62ef001d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -922,33 +922,33 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( VideoDecoder::AVFrameWithStreamIndex& rawOutput, std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. - FrameOutput output; + FrameOutput frameOutput; int streamIndex = rawOutput.streamIndex; AVFrame* avFrame = rawOutput.avFrame.get(); - output.streamIndex = streamIndex; + frameOutput.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); - output.ptsSeconds = ptsToSeconds( + frameOutput.ptsSeconds = ptsToSeconds( avFrame->pts, formatContext_->streams[streamIndex]->time_base); - output.durationSeconds = ptsToSeconds( + frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); // TODO: we should fold preAllocatedOutputTensor into AVFrameWithStreamIndex. if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToFrameOutputOnCPU( - rawOutput, output, preAllocatedOutputTensor); + rawOutput, frameOutput, preAllocatedOutputTensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { convertAVFrameToFrameOutputOnCuda( streamInfo.videoStreamOptions.device, streamInfo.videoStreamOptions, rawOutput, - output, + frameOutput, preAllocatedOutputTensor); } else { TORCH_CHECK( false, "Invalid device type: " + streamInfo.videoStreamOptions.device.str()); } - return output; + return frameOutput; } // Note [preAllocatedOutputTensor with swscale and filtergraph]: @@ -1099,9 +1099,9 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( }); // Convert the frame to tensor. - FrameOutput output = convertAVFrameToFrameOutput(rawOutput); - output.data = maybePermuteHWC2CHW(output.streamIndex, output.data); - return output; + FrameOutput frameOutput = convertAVFrameToFrameOutput(rawOutput); + frameOutput.data = maybePermuteHWC2CHW(frameOutput.streamIndex, frameOutput.data); + return frameOutput; } void VideoDecoder::validateUserProvidedStreamIndex(int streamIndex) { @@ -1139,9 +1139,9 @@ void VideoDecoder::validateFrameIndex( VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex( int streamIndex, int64_t frameIndex) { - auto output = getFrameAtIndexInternal(streamIndex, frameIndex); - output.data = maybePermuteHWC2CHW(streamIndex, output.data); - return output; + auto frameOutput = getFrameAtIndexInternal(streamIndex, frameIndex); + frameOutput.data = maybePermuteHWC2CHW(streamIndex, frameOutput.data); + return frameOutput; } int64_t VideoDecoder::getPts( @@ -1302,11 +1302,11 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( frameBatchOutput.durationSeconds[indexInOutput] = frameBatchOutput.durationSeconds[previousIndexInOutput]; } else { - FrameOutput singleOut = getFrameAtIndexInternal( + FrameOutput frameOutput = getFrameAtIndexInternal( streamIndex, indexInVideo, frameBatchOutput.data[indexInOutput]); - frameBatchOutput.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; + frameBatchOutput.ptsSeconds[indexInOutput] = frameOutput.ptsSeconds; frameBatchOutput.durationSeconds[indexInOutput] = - singleOut.durationSeconds; + frameOutput.durationSeconds; } previousIndexInVideo = indexInVideo; } @@ -1374,10 +1374,10 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( numOutputFrames, videoStreamOptions, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - FrameOutput singleOut = + FrameOutput frameOutput = getFrameAtIndexInternal(streamIndex, i, frameBatchOutput.data[f]); - frameBatchOutput.ptsSeconds[f] = singleOut.ptsSeconds; - frameBatchOutput.durationSeconds[f] = singleOut.durationSeconds; + frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; + frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; } frameBatchOutput.data = maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); @@ -1459,10 +1459,10 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( FrameBatchOutput frameBatchOutput( numFrames, videoStreamOptions, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - FrameOutput singleOut = + FrameOutput frameOutput = getFrameAtIndexInternal(streamIndex, i, frameBatchOutput.data[f]); - frameBatchOutput.ptsSeconds[f] = singleOut.ptsSeconds; - frameBatchOutput.durationSeconds[f] = singleOut.durationSeconds; + frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; + frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; } frameBatchOutput.data = maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); From a6a47f16ffdb8fa36422cbd41fb31a369190f910 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:16:10 +0000 Subject: [PATCH 47/56] lint --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 62ef001d..754e79f5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -803,7 +803,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { } } -VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getAVFrameUsingFilterFunction( +VideoDecoder::AVFrameWithStreamIndex +VideoDecoder::getAVFrameUsingFilterFunction( std::function filterFunction) { if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); @@ -1100,7 +1101,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( // Convert the frame to tensor. FrameOutput frameOutput = convertAVFrameToFrameOutput(rawOutput); - frameOutput.data = maybePermuteHWC2CHW(frameOutput.streamIndex, frameOutput.data); + frameOutput.data = + maybePermuteHWC2CHW(frameOutput.streamIndex, frameOutput.data); return frameOutput; } @@ -1420,7 +1422,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( // need this special case below. if (startSeconds == stopSeconds) { FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata); - frameBatchOutput.data = maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + frameBatchOutput.data = + maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); return frameBatchOutput; } @@ -1472,8 +1475,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameWithStreamIndexNoDemux() { - auto rawOutput = - getAVFrameUsingFilterFunction([this](int frameStreamIndex, AVFrame* avFrame) { + auto rawOutput = getAVFrameUsingFilterFunction( + [this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; }); From 658b727fbbe64288cd8dd9992faf813994d655ef Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:24:54 +0000 Subject: [PATCH 48/56] getNextAVFrameNoDemux --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 4 ++-- src/torchcodec/decoders/_core/VideoDecoder.h | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 754e79f5..5dad144b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1474,7 +1474,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( } VideoDecoder::AVFrameWithStreamIndex -VideoDecoder::getNextAVFrameWithStreamIndexNoDemux() { +VideoDecoder::getNextAVFrameNoDemux() { auto rawOutput = getAVFrameUsingFilterFunction( [this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; @@ -1491,7 +1491,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { - auto rawOutput = getNextAVFrameWithStreamIndexNoDemux(); + auto rawOutput = getNextAVFrameNoDemux(); return convertAVFrameToFrameOutput(rawOutput, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index ce700c17..d25874db 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -164,6 +164,7 @@ class VideoDecoder { // Calling getNextFrameNoDemuxInternal() will return the first frame at // or after this position. void setCursorPtsInSeconds(double seconds); + // This structure ensures we always keep the streamIndex and AVFrame together // Note that AVFrame itself doesn't retain the streamIndex. struct AVFrameWithStreamIndex { @@ -388,7 +389,7 @@ class VideoDecoder { void maybeSeekToBeforeDesiredPts(); AVFrameWithStreamIndex getAVFrameUsingFilterFunction( std::function); - AVFrameWithStreamIndex getNextAVFrameWithStreamIndexNoDemux(); + AVFrameWithStreamIndex getNextAVFrameNoDemux(); // Once we create a decoder can update the metadata with the codec context. // For example, for video streams, we can add the height and width of the // decoded stream. From 62818a7fa758a9ee6a9a6f282d3c1c8fd6c598e2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:27:03 +0000 Subject: [PATCH 49/56] more stuff --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 2 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 4 ++-- src/torchcodec/decoders/_core/DeviceInterface.h | 2 +- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 3c17012d..24a5837d 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -18,7 +18,7 @@ void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, [[maybe_unused]] VideoDecoder::AVFrameWithStreamIndex& rawOutput, - [[maybe_unused]] VideoDecoder::FrameOutput& output, + [[maybe_unused]] VideoDecoder::FrameOutput& frameOutput, [[maybe_unused]] std::optional preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 3756371c..206555e6 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -187,7 +187,7 @@ void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, VideoDecoder::AVFrameWithStreamIndex& rawOutput, - VideoDecoder::FrameOutput& output, + VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { AVFrame* avFrame = rawOutput.avFrame.get(); @@ -199,7 +199,7 @@ void convertAVFrameToFrameOutputOnCuda( getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame); int height = frameDims.height; int width = frameDims.width; - torch::Tensor& dst = output.frame; + torch::Tensor& dst = frameOutput.data; if (preAllocatedOutputTensor.has_value()) { dst = preAllocatedOutputTensor.value(); auto shape = dst.sizes(); diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 21957194..e05ac5f5 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -33,7 +33,7 @@ void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, VideoDecoder::AVFrameWithStreamIndex& rawOutput, - VideoDecoder::FrameOutput& output, + VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); void releaseContextOnCuda( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index d25874db..7bc08a23 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -409,7 +409,7 @@ class VideoDecoder { std::optional preAllocatedOutputTensor = std::nullopt); void convertAVFrameToFrameOutputOnCPU( AVFrameWithStreamIndex& rawOutput, - FrameOutput& output, + FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); FrameOutput getNextFrameNoDemuxInternal( From 24d003541d1697ccd0f9ce0dfcf0e505ca288f12 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:30:02 +0000 Subject: [PATCH 50/56] use avFrameWithStreamIndex --- .../decoders/_core/CPUOnlyDevice.cpp | 3 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 4 +- .../decoders/_core/DeviceInterface.h | 2 +- .../decoders/_core/VideoDecoder.cpp | 40 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 +- 5 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 24a5837d..61c92674 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -17,7 +17,8 @@ namespace facebook::torchcodec { void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, - [[maybe_unused]] VideoDecoder::AVFrameWithStreamIndex& rawOutput, + [[maybe_unused]] VideoDecoder::AVFrameWithStreamIndex& + avFrameWithStreamIndex, [[maybe_unused]] VideoDecoder::FrameOutput& frameOutput, [[maybe_unused]] std::optional preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 206555e6..229c5495 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -186,10 +186,10 @@ void initializeContextOnCuda( void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, - VideoDecoder::AVFrameWithStreamIndex& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& avFrameWithStreamIndex, VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - AVFrame* avFrame = rawOutput.avFrame.get(); + AVFrame* avFrame = avFrameWithStreamIndex.avFrame.get(); TORCH_CHECK( avFrame->format == AV_PIX_FMT_CUDA, diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index e05ac5f5..bf26ed72 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -32,7 +32,7 @@ void initializeContextOnCuda( void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, - VideoDecoder::AVFrameWithStreamIndex& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& avFrameWithStreamIndex, VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 5dad144b..89dcd77b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -913,19 +913,19 @@ VideoDecoder::getAVFrameUsingFilterFunction( StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; activeStreamInfo.currentPts = avFrame->pts; activeStreamInfo.currentDuration = getDuration(avFrame); - AVFrameWithStreamIndex rawOutput; - rawOutput.streamIndex = frameStreamIndex; - rawOutput.avFrame = std::move(avFrame); - return rawOutput; + AVFrameWithStreamIndex avFrameWithStreamIndex; + avFrameWithStreamIndex.streamIndex = frameStreamIndex; + avFrameWithStreamIndex.avFrame = std::move(avFrame); + return avFrameWithStreamIndex; } VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( - VideoDecoder::AVFrameWithStreamIndex& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& avFrameWithStreamIndex, std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. FrameOutput frameOutput; - int streamIndex = rawOutput.streamIndex; - AVFrame* avFrame = rawOutput.avFrame.get(); + int streamIndex = avFrameWithStreamIndex.streamIndex; + AVFrame* avFrame = avFrameWithStreamIndex.avFrame.get(); frameOutput.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); @@ -936,12 +936,12 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( // TODO: we should fold preAllocatedOutputTensor into AVFrameWithStreamIndex. if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToFrameOutputOnCPU( - rawOutput, frameOutput, preAllocatedOutputTensor); + avFrameWithStreamIndex, frameOutput, preAllocatedOutputTensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { convertAVFrameToFrameOutputOnCuda( streamInfo.videoStreamOptions.device, streamInfo.videoStreamOptions, - rawOutput, + avFrameWithStreamIndex, frameOutput, preAllocatedOutputTensor); } else { @@ -962,11 +962,11 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. void VideoDecoder::convertAVFrameToFrameOutputOnCPU( - VideoDecoder::AVFrameWithStreamIndex& rawOutput, + VideoDecoder::AVFrameWithStreamIndex& avFrameWithStreamIndex, FrameOutput& output, std::optional preAllocatedOutputTensor) { - int streamIndex = rawOutput.streamIndex; - AVFrame* avFrame = rawOutput.avFrame.get(); + int streamIndex = avFrameWithStreamIndex.streamIndex; + AVFrame* avFrame = avFrameWithStreamIndex.avFrame.get(); auto& streamInfo = streamInfos_[streamIndex]; auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( @@ -1080,7 +1080,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } setCursorPtsInSeconds(seconds); - AVFrameWithStreamIndex rawOutput = getAVFrameUsingFilterFunction( + AVFrameWithStreamIndex avFrameWithStreamIndex = getAVFrameUsingFilterFunction( [seconds, this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); @@ -1100,7 +1100,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( }); // Convert the frame to tensor. - FrameOutput frameOutput = convertAVFrameToFrameOutput(rawOutput); + FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameWithStreamIndex); frameOutput.data = maybePermuteHWC2CHW(frameOutput.streamIndex, frameOutput.data); return frameOutput; @@ -1473,14 +1473,13 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( return frameBatchOutput; } -VideoDecoder::AVFrameWithStreamIndex -VideoDecoder::getNextAVFrameNoDemux() { - auto rawOutput = getAVFrameUsingFilterFunction( +VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameNoDemux() { + auto avFrameWithStreamIndex = getAVFrameUsingFilterFunction( [this](int frameStreamIndex, AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; }); - return rawOutput; + return avFrameWithStreamIndex; } VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { @@ -1491,8 +1490,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { - auto rawOutput = getNextAVFrameNoDemux(); - return convertAVFrameToFrameOutput(rawOutput, preAllocatedOutputTensor); + auto avFrameWithStreamIndex = getNextAVFrameNoDemux(); + return convertAVFrameToFrameOutput( + avFrameWithStreamIndex, preAllocatedOutputTensor); } void VideoDecoder::setCursorPtsInSeconds(double seconds) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 7bc08a23..1925bf24 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -405,10 +405,10 @@ class VideoDecoder { const AVFrame* avFrame, torch::Tensor& outputTensor); FrameOutput convertAVFrameToFrameOutput( - AVFrameWithStreamIndex& rawOutput, + AVFrameWithStreamIndex& avFrameWithStreamIndex, std::optional preAllocatedOutputTensor = std::nullopt); void convertAVFrameToFrameOutputOnCPU( - AVFrameWithStreamIndex& rawOutput, + AVFrameWithStreamIndex& avFrameWithStreamIndex, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); From a1805d123fc1973f01b86aed6af3824c8990e83c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:37:38 +0000 Subject: [PATCH 51/56] Cpp tests --- test/decoders/VideoDecoderTest.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 730732fa..8e1a8388 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -150,7 +150,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { videoStreamOptions.width = 100; videoStreamOptions.height = 120; decoder->addVideoStreamDecoder(-1, videoStreamOptions); - torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; + torch::Tensor tensor = decoder->getNextFrameNoDemux().data; EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } @@ -161,7 +161,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { VideoDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; decoder->addVideoStreamDecoder(-1, videoStreamOptions); - torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; + torch::Tensor tensor = decoder->getNextFrameNoDemux().data; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -235,7 +235,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) { VideoDecoder::VideoStreamOptions("dimension_order=NHWC")); // Frame with index 180 corresponds to timestamp 6.006. auto output = ourDecoder->getFramesAtIndices(bestVideoStreamIndex, {0, 180}); - auto tensor = output.; + auto tensor = output.data; EXPECT_EQ(tensor.sizes(), std::vector({2, 270, 480, 3})); torch::Tensor tensor0FromFFMPEG = From eb1773a91f798be12b5956a703ea5923e9d963e1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 11:47:03 +0000 Subject: [PATCH 52/56] more --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 89dcd77b..7bc15423 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -963,7 +963,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( // `dimension_order` parameter. It's up to callers to re-shape it if needed. void VideoDecoder::convertAVFrameToFrameOutputOnCPU( VideoDecoder::AVFrameWithStreamIndex& avFrameWithStreamIndex, - FrameOutput& output, + FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { int streamIndex = avFrameWithStreamIndex.streamIndex; AVFrame* avFrame = avFrameWithStreamIndex.avFrame.get(); @@ -1024,7 +1024,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( " != ", expectedOutputHeight); - output.data = outputTensor; + frameOutput.data = outputTensor; } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { @@ -1052,9 +1052,9 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( // We have already validated that preAllocatedOutputTensor and // outputTensor have the same shape. preAllocatedOutputTensor.value().copy_(outputTensor); - output.data = preAllocatedOutputTensor.value(); + frameOutput.data = preAllocatedOutputTensor.value(); } else { - output.data = outputTensor; + frameOutput.data = outputTensor; } } else { throw std::runtime_error( From 21b8ff24f7c94dd3724ff0490bf374308a3e3c58 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 12:58:27 +0000 Subject: [PATCH 53/56] Remove getNextAVFrameNoDemux --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 15 +++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7bc15423..43c7d19f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1473,15 +1473,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedByTimestampInRange( return frameBatchOutput; } -VideoDecoder::AVFrameWithStreamIndex VideoDecoder::getNextAVFrameNoDemux() { - auto avFrameWithStreamIndex = getAVFrameUsingFilterFunction( - [this](int frameStreamIndex, AVFrame* avFrame) { - StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; - return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; - }); - return avFrameWithStreamIndex; -} - VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { auto output = getNextFrameNoDemuxInternal(); output.data = maybePermuteHWC2CHW(output.streamIndex, output.data); @@ -1490,7 +1481,11 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { - auto avFrameWithStreamIndex = getNextAVFrameNoDemux(); + auto avFrameWithStreamIndex = getAVFrameUsingFilterFunction( + [this](int frameStreamIndex, AVFrame* avFrame) { + StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; + return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; + }); return convertAVFrameToFrameOutput( avFrameWithStreamIndex, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 1925bf24..08ecd14d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -389,7 +389,7 @@ class VideoDecoder { void maybeSeekToBeforeDesiredPts(); AVFrameWithStreamIndex getAVFrameUsingFilterFunction( std::function); - AVFrameWithStreamIndex getNextAVFrameNoDemux(); + // Once we create a decoder can update the metadata with the codec context. // For example, for video streams, we can add the height and width of the // decoded stream. From dac446310b63c5a4767fb5209f258d2083fe377f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 13:10:42 +0000 Subject: [PATCH 54/56] Rename maybeDesiredPts_ into desiredPts_ --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 12 ++++++------ src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 4be92abc..b120cd3d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -746,7 +746,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { for (int streamIndex : activeStreamIndices_) { StreamInfo& streamInfo = streamInfos_[streamIndex]; // clang-format off: clang format clashes - streamInfo.discardFramesBeforePts = secondsToClosestPts(*maybeDesiredPts_, streamInfo.timeBase); + streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPts_, streamInfo.timeBase); // clang-format on } @@ -756,7 +756,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { bool mustSeek = false; for (int streamIndex : activeStreamIndices_) { StreamInfo& streamInfo = streamInfos_[streamIndex]; - int64_t desiredPtsForStream = *maybeDesiredPts_ * streamInfo.timeBase.den; + int64_t desiredPtsForStream = *desiredPts_ * streamInfo.timeBase.den; if (!canWeAvoidSeekingForStream( streamInfo, streamInfo.currentPts, desiredPtsForStream)) { mustSeek = true; @@ -770,7 +770,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { int firstActiveStreamIndex = *activeStreamIndices_.begin(); const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex]; int64_t desiredPts = - secondsToClosestPts(*maybeDesiredPts_, firstStreamInfo.timeBase); + secondsToClosestPts(*desiredPts_, firstStreamInfo.timeBase); // For some encodings like H265, FFMPEG sometimes seeks past the point we // set as the max_ts. So we use our own index to give it the exact pts of @@ -809,9 +809,9 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( throw std::runtime_error("No active streams configured."); } resetDecodeStats(); - if (maybeDesiredPts_.has_value()) { + if (desiredPts_.has_value()) { maybeSeekToBeforeDesiredPts(); - maybeDesiredPts_ = std::nullopt; + desiredPts_ = std::nullopt; } // Need to get the next frame or error from PopFrame. UniqueAVFrame avFrame(av_frame_alloc()); @@ -1487,7 +1487,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemuxInternal( } void VideoDecoder::setCursorPtsInSeconds(double seconds) { - maybeDesiredPts_ = seconds; + desiredPts_ = seconds; } VideoDecoder::DecodeStats VideoDecoder::getDecodeStats() const { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c690c83c..d660cdd2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -420,7 +420,7 @@ class VideoDecoder { std::set activeStreamIndices_; // Set when the user wants to seek and stores the desired pts that the user // wants to seek to. - std::optional maybeDesiredPts_; + std::optional desiredPts_; // Stores various internal decoding stats. DecodeStats decodeStats_; From a17027c84e59983b2c56268f1f14d5a43a3b21a9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 13:48:28 +0000 Subject: [PATCH 55/56] this should work --- .../decoders/_core/VideoDecoder.cpp | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b1dffbb2..946dcbc1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1080,27 +1080,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } setCursorPtsInSeconds(seconds); - AVFrameWithStreamIndex avFrameWithStreamIndex = getAVFrameUsingFilterFunction( - [seconds, this](int frameStreamIndex, AVFrame* avFrame) { - StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; - double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); - double frameEndTime = ptsToSeconds( - avFrame->pts + getDuration(avFrame), streamInfo.timeBase); - if (frameStartTime > seconds) { - // FFMPEG seeked past the frame we are looking for even though we - // set max_ts to be our needed timestamp in avformat_seek_file() - // in maybeSeekToBeforeDesiredPts(). - // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137 - // In this case we return the very next frame instead of throwing an - // exception. - // TODO: Maybe log to stderr for Debug builds? - return true; - } - return seconds >= frameStartTime && seconds < frameEndTime; - }); - - // Convert the frame to tensor. - FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameWithStreamIndex); + FrameOutput frameOutput = getNextFrameNoDemuxInternal(); frameOutput.data = maybePermuteHWC2CHW(frameOutput.streamIndex, frameOutput.data); return frameOutput; From 86c6ffd489043987dcc883ba714cda54b6ffc286 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 23 Jan 2025 21:08:42 +0000 Subject: [PATCH 56/56] Try --- .../decoders/_core/VideoDecoder.cpp | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 59936f9e..983bb329 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1079,27 +1079,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux( } setCursorPtsInSeconds(seconds); - AVFrameStream avFrameStream = getAVFrameUsingFilterFunction( - [seconds, this](int frameStreamIndex, AVFrame* avFrame) { - StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; - double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); - double frameEndTime = ptsToSeconds( - avFrame->pts + getDuration(avFrame), streamInfo.timeBase); - if (frameStartTime > seconds) { - // FFMPEG seeked past the frame we are looking for even though we - // set max_ts to be our needed timestamp in avformat_seek_file() - // in maybeSeekToBeforeDesiredPts(). - // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137 - // In this case we return the very next frame instead of throwing an - // exception. - // TODO: Maybe log to stderr for Debug builds? - return true; - } - return seconds >= frameStartTime && seconds < frameEndTime; - }); - - // Convert the frame to tensor. - FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream); + FrameOutput frameOutput = getNextFrameNoDemuxInternal(); frameOutput.data = maybePermuteHWC2CHW(frameOutput.streamIndex, frameOutput.data); return frameOutput;