diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index a8fc7f5b..5c24358e 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -41,8 +41,8 @@ class Frame(Iterable): def __post_init__(self): # This is called after __init__() when a Frame is created. We can run # input validation checks here. - if not self.data.ndim == 3: - raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") + # if not self.data.ndim == 3: + # raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") self.pts_seconds = float(self.pts_seconds) self.duration_seconds = float(self.duration_seconds) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 7c9a7487..6a3298d9 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -4,7 +4,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Torch REQUIRED) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCH_CXX_FLAGS}") find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) function(make_torchcodec_library library_name ffmpeg_target) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index f0379c6a..dfe9e2a6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -412,6 +412,63 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions( } } +void print_codecContext(AVCodecContext* cc) { + printf("AVCodecContext details:\n"); + printf("Codec ID: %d\n", cc->codec_id); + printf("Codec Type: %d\n", cc->codec_type); + printf("Codec Name: %s\n", cc->codec ? cc->codec->name : "unknown"); + printf("Bit Rate: %ld\n", cc->bit_rate); + printf("Time Base: %d/%d\n", cc->time_base.num, cc->time_base.den); + printf("GOP Size: %d\n", cc->gop_size); + printf("Max B-Frames: %d\n", cc->max_b_frames); + printf("bit_rate: %d\n", cc->bit_rate); + printf("bit_rate_tolerance: %d\n", cc->bit_rate_tolerance); + printf("global_quality: %d\n", cc->global_quality); + printf("compression_level: %d\n", cc->compression_level); + if (cc->codec_type == AVMEDIA_TYPE_VIDEO) { + printf("Width: %d\n", cc->width); + printf("Height: %d\n", cc->height); + printf("Pixel Format: %s\n", av_get_pix_fmt_name(cc->pix_fmt)); + printf("Frame Rate: %d/%d\n", cc->framerate.num, cc->framerate.den); + } else if (cc->codec_type == AVMEDIA_TYPE_AUDIO) { + printf("Sample Rate: %d\n", cc->sample_rate); + printf("Channels: %d\n", cc->channels); + printf("Channel Layout: %ld\n", cc->channel_layout); + printf("Sample Format: %s\n", av_get_sample_fmt_name(cc->sample_fmt)); + } + printf("Profile: %d\n", cc->profile); + printf("Level: %d\n", cc->level); + printf("Flags: %d\n", cc->flags); + printf("Thread Count: %d\n", cc->thread_count); + // Additional attributes + printf("Skip Frame: %d\n", cc->skip_frame); + printf("Skip IDCT: %d\n", cc->skip_idct); + printf("Skip Loop Filter: %d\n", cc->skip_loop_filter); + printf("Error Recognition: %d\n", cc->err_recognition); + printf("Error Concealment: %d\n", cc->error_concealment); + printf("HW Device Context: %p\n", cc->hw_device_ctx); + printf("HW Accel: %p\n", cc->hwaccel); + printf("Pkt Timebase: %d/%d\n", cc->pkt_timebase.num, cc->pkt_timebase.den); + printf("Delay: %d\n", cc->delay); + printf("Extradata Size: %d\n", cc->extradata_size); + if (cc->extradata && cc->extradata_size > 0) { + printf("Extradata: "); + for (int i = 0; i < cc->extradata_size; i++) { + printf("%02X ", cc->extradata[i]); + } + printf("\n"); + } + printf("RC Buffer Size: %d\n", cc->rc_buffer_size); + printf("RC Max Rate: %d\n", cc->rc_max_rate); + printf("RC Min Rate: %d\n", cc->rc_min_rate); + printf("Thread Type: %d\n", cc->thread_type); + printf("Ticks Per Frame: %d\n", cc->ticks_per_frame); + printf( + "Subtitle Char Encoding: %s\n", + cc->sub_charenc ? cc->sub_charenc : "N/A"); + printf("\n"); +} + void VideoDecoder::addVideoStreamDecoder( int preferredStreamIndex, const VideoStreamOptions& videoStreamOptions) { @@ -421,13 +478,16 @@ void VideoDecoder::addVideoStreamDecoder( TORCH_CHECK(formatContext_.get() != nullptr); AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; - int streamIndex = av_find_best_stream( - formatContext_.get(), - AVMEDIA_TYPE_VIDEO, - preferredStreamIndex, - -1, - &avCodec, - 0); +// int streamIndex = av_find_best_stream( +// formatContext_.get(), +// AVMEDIA_TYPE_AUDIO, +// preferredStreamIndex, +// -1, +// &avCodec, +// 0); + int streamIndex = preferredStreamIndex; + avCodec = avcodec_find_decoder(formatContext_->streams[streamIndex]->codecpar->codec_id); + if (streamIndex < 0) { throw std::invalid_argument("No valid stream found in input file."); } @@ -438,7 +498,7 @@ void VideoDecoder::addVideoStreamDecoder( streamInfo.timeBase = formatContext_->streams[streamIndex]->time_base; streamInfo.stream = formatContext_->streams[streamIndex]; - if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) { + if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_AUDIO) { throw std::invalid_argument( "Stream with index " + std::to_string(streamIndex) + " is not a video stream."); @@ -462,7 +522,7 @@ void VideoDecoder::addVideoStreamDecoder( AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(codecContext != nullptr); - codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0); +// codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0); streamInfo.codecContext.reset(codecContext); int retVal = avcodec_parameters_to_context( @@ -478,12 +538,21 @@ void VideoDecoder::addVideoStreamDecoder( false, "Invalid device type: " + videoStreamOptions.device.str()); } + if (!streamInfo.codecContext->channel_layout) { + streamInfo.codecContext->channel_layout = + av_get_default_channel_layout(streamInfo.codecContext->channels); + } + + AVDictionary* opt = nullptr; + av_dict_set(&opt, "threads", "1", 0); retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); if (retVal < AVSUCCESS) { throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal)); } - codecContext->time_base = streamInfo.stream->time_base; +// codecContext->time_base = streamInfo.stream->time_base; +// AVRational tb{0, 1}; +// codecContext->time_base = tb; activeStreamIndex_ = streamIndex; updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.videoStreamOptions = videoStreamOptions; @@ -495,6 +564,8 @@ void VideoDecoder::addVideoStreamDecoder( for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) { if (i != static_cast(activeStreamIndex_)) { formatContext_->streams[i]->discard = AVDISCARD_ALL; + } else { + formatContext_->streams[i]->discard = AVDISCARD_DEFAULT; } } @@ -516,6 +587,8 @@ void VideoDecoder::addVideoStreamDecoder( streamInfo.colorConversionLibrary = videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); + + print_codecContext(streamInfo.codecContext.get()); } void VideoDecoder::updateMetadataWithCodecContext( @@ -835,7 +908,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( // -------------------------------------------------------------------------- // SEEKING APIs // -------------------------------------------------------------------------- - void VideoDecoder::setCursorPtsInSeconds(double seconds) { desiredPtsSeconds_ = seconds; } @@ -923,6 +995,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts; } + printf("Seeking to PTS = %ld\n", desiredPts); + int ffmepgStatus = avformat_seek_file( formatContext_.get(), streamInfo.streamIndex, @@ -936,6 +1010,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { getFFMPEGErrorStringFromErrorCode(ffmepgStatus)); } decodeStats_.numFlushes++; + + printf("Flushing\n"); avcodec_flush_buffers(streamInfo.codecContext.get()); } @@ -943,6 +1019,64 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { // LOW-LEVEL DECODING // -------------------------------------------------------------------------- +void print_packet(AVPacket* packet) { + printf( + "Packet PTS: %ld, DTS: %ld, Duration: %d, Size: %d, Stream Index: %d\n", + packet->pts, + packet->dts, + packet->duration, + packet->size, + packet->stream_index); + // Optional: Calculate a simple checksum or hash of the packet data + unsigned long checksum = 0; + for (int i = 0; i < packet->size; i++) { + checksum += packet->data[i]; + } + printf("Packet Checksum: %lu\n\n", checksum); + fflush(stdout); +} + +void print_avFrame(AVFrame* avFrame) { + printf("Format: %d\n", avFrame->format); + printf("Width: %d\n", avFrame->width); + printf("Height: %d\n", avFrame->height); + printf( + "Channels: %d\n", + av_get_channel_layout_nb_channels(avFrame->channel_layout)); + printf("Channel Layout: %ld\n", avFrame->channel_layout); + printf("Number of Samples: %d\n", avFrame->nb_samples); + printf("PTS: %ld\n", avFrame->pts); + printf("Packet DTS: %ld\n", avFrame->pkt_dts); + printf("Packet Duration: %d\n", avFrame->pkt_duration); + printf("Packet Pos: %d\n", avFrame->pkt_pos); + for (int i = 0; i < AV_NUM_DATA_POINTERS; i++) { + if (avFrame->data[i]) { + printf("Data[%d] Line Size: %d\n", i, avFrame->linesize[i]); + } + } + printf("Color Range: %d\n", avFrame->color_range); + printf("Color Primaries: %d\n", avFrame->color_primaries); + printf("Color Transfer Characteristic: %d\n", avFrame->color_trc); + printf("Color Space: %d\n", avFrame->colorspace); + printf("Chroma Location: %d\n", avFrame->chroma_location); + printf( + "Sample Aspect Ratio: %d/%d\n", + avFrame->sample_aspect_ratio.num, + avFrame->sample_aspect_ratio.den); + printf("Key Frame: %d\n", avFrame->key_frame); + printf("Picture Type: %d\n", avFrame->pict_type); + printf("Coded Picture Number: %d\n", avFrame->coded_picture_number); + printf("Display Picture Number: %d\n", avFrame->display_picture_number); + + unsigned long checksum = 0; + for (int i = 0; i < 100; i++) { + checksum += avFrame->extended_data[0][i]; + } + printf("Frame Checksum: %lu\n", checksum); + printf("\n"); + fflush(stdout); +} + VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( std::function filterFunction) { if (activeStreamIndex_ == NO_ACTIVE_STREAM) { @@ -959,17 +1093,22 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - // Need to get the next frame or error from PopFrame. UniqueAVFrame avFrame(av_frame_alloc()); AutoAVPacket autoAVPacket; int ffmpegStatus = AVSUCCESS; bool reachedEOF = false; while (true) { + if (veryFirstCall_) { + veryFirstCall_ = false; + goto av_read_frame_call; + } ffmpegStatus = avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); + printf("output of avcodec_receive_frame: %d\n", ffmpegStatus); if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) { // Non-retriable error + // printf("Non-retriable error\n"); break; } @@ -977,6 +1116,9 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // Is this the kind of frame we're looking for? if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) { // Yes, this is the frame we'll return; break out of the decoding loop. + // printf("%ld %ld\n", avFrame->pts, avFrame->duration); + + // printf("Found frame\n"); break; } else if (ffmpegStatus == AVSUCCESS) { // No, but we received a valid frame - just not the kind we're looking @@ -984,6 +1126,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // But since we did just receive a frame, we should skip reading more // packets and sending them to the decoder and just try to receive more // frames from the decoder. + // printf("Got AVSUCCESS, continue\n"); continue; } @@ -995,6 +1138,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // We still haven't found the frame we're looking for. So let's read more // packets and send them to the decoder. + av_read_frame_call: ReferenceAVPacket packet(autoAVPacket); do { ffmpegStatus = av_read_frame(formatContext_.get(), packet.get()); @@ -1031,8 +1175,11 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. + printf("Sending packet:\n"); + print_packet(packet.get()); ffmpegStatus = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); + if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error( "Could not push packet to decoder: " + @@ -1061,6 +1208,8 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // the file and that will flush the decoder. streamInfo.currentPts = avFrame->pts; streamInfo.currentDuration = getDuration(avFrame); + printf("Received avFrame:\n"); + print_avFrame(avFrame.get()); return AVFrameStream(std::move(avFrame), activeStreamIndex_); } @@ -1071,18 +1220,48 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( VideoDecoder::AVFrameStream& avFrameStream, - std::optional preAllocatedOutputTensor) { + [[maybe_unused]] std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. FrameOutput frameOutput; int streamIndex = avFrameStream.streamIndex; AVFrame* avFrame = avFrameStream.avFrame.get(); frameOutput.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; - TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); + TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO); frameOutput.ptsSeconds = ptsToSeconds( avFrame->pts, formatContext_->streams[streamIndex]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); + + auto numSamples = avFrame->nb_samples; + auto sampleRate = avFrame->sample_rate; + auto numChannels = avFrame->ch_layout.nb_channels; + + // printf("numSamples: %d\n", numSamples); + // printf("sample rate: %d\n", sampleRate); + + // printf("numChannels: %d\n", numChannels); + int bytesPerSample = + av_get_bytes_per_sample(streamInfo.codecContext->sample_fmt); + // printf("bytes per sample: %d\n", bytesPerSample); + + // float32 Planar +// torch::Tensor data = torch::empty({numChannels, numSamples}, torch::kFloat32); +// for (auto channel = 0; channel < numChannels; ++channel) { +// float* dataFloatPtr = (float*)(avFrame->data[channel]); +// for (auto sampleIndex = 0; sampleIndex < numSamples; ++sampleIndex) { +// data[channel][sampleIndex] = dataFloatPtr[sampleIndex]; +// } +// } + // float32 non-Planar + torch::Tensor data = torch::empty({numSamples, numChannels}, torch::kFloat32); + uint8_t* pData = static_cast(data.data_ptr()); + memcpy(pData, avFrame->extended_data[0], numSamples * numChannels * bytesPerSample); + data = data.permute({1, 0}); + + frameOutput.data = data; + return frameOutput; + // TODO: we should fold preAllocatedOutputTensor into AVFrameStream. if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToFrameOutputOnCPU( @@ -1310,6 +1489,7 @@ torch::Tensor allocateEmptyHWCTensor( torch::Tensor VideoDecoder::maybePermuteHWC2CHW( int streamIndex, torch::Tensor& hwcTensor) { + return hwcTensor; if (streamInfos_[streamIndex].videoStreamOptions.dimensionOrder == "NHWC") { return hwcTensor; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 7db5a4a6..a0902ff9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -477,6 +477,7 @@ class VideoDecoder { // ATTRIBUTES // -------------------------------------------------------------------------- + bool veryFirstCall_ = true; SeekMode seekMode_; ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 78ecc425..2a682e32 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -237,11 +237,11 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { } catch (const VideoDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } - if (result.data.sizes().size() != 3) { - throw std::runtime_error( - "image_size is unexpected. Expected 3, got: " + - std::to_string(result.data.sizes().size())); - } +// if (result.data.sizes().size() != 3) { +// throw std::runtime_error( +// "image_size is unexpected. Expected 3, got: " + +// std::to_string(result.data.sizes().size())); +// } return makeOpsFrameOutput(result); }