Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Super basic audio decoding POC #488

Closed
wants to merge 14 commits into from
Closed
4 changes: 2 additions & 2 deletions src/torchcodec/_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class Frame(Iterable):
def __post_init__(self):
# This is called after __init__() when a Frame is created. We can run
# input validation checks here.
if not self.data.ndim == 3:
raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
# if not self.data.ndim == 3:
# raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
self.pts_seconds = float(self.pts_seconds)
self.duration_seconds = float(self.duration_seconds)

Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCH_CXX_FLAGS}")
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)

function(make_torchcodec_library library_name ffmpeg_target)
Expand Down
208 changes: 194 additions & 14 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,63 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions(
}
}

void print_codecContext(AVCodecContext* cc) {
printf("AVCodecContext details:\n");
printf("Codec ID: %d\n", cc->codec_id);
printf("Codec Type: %d\n", cc->codec_type);
printf("Codec Name: %s\n", cc->codec ? cc->codec->name : "unknown");
printf("Bit Rate: %ld\n", cc->bit_rate);
printf("Time Base: %d/%d\n", cc->time_base.num, cc->time_base.den);
printf("GOP Size: %d\n", cc->gop_size);
printf("Max B-Frames: %d\n", cc->max_b_frames);
printf("bit_rate: %d\n", cc->bit_rate);
printf("bit_rate_tolerance: %d\n", cc->bit_rate_tolerance);
printf("global_quality: %d\n", cc->global_quality);
printf("compression_level: %d\n", cc->compression_level);
if (cc->codec_type == AVMEDIA_TYPE_VIDEO) {
printf("Width: %d\n", cc->width);
printf("Height: %d\n", cc->height);
printf("Pixel Format: %s\n", av_get_pix_fmt_name(cc->pix_fmt));
printf("Frame Rate: %d/%d\n", cc->framerate.num, cc->framerate.den);
} else if (cc->codec_type == AVMEDIA_TYPE_AUDIO) {
printf("Sample Rate: %d\n", cc->sample_rate);
printf("Channels: %d\n", cc->channels);
printf("Channel Layout: %ld\n", cc->channel_layout);
printf("Sample Format: %s\n", av_get_sample_fmt_name(cc->sample_fmt));
}
printf("Profile: %d\n", cc->profile);
printf("Level: %d\n", cc->level);
printf("Flags: %d\n", cc->flags);
printf("Thread Count: %d\n", cc->thread_count);
// Additional attributes
printf("Skip Frame: %d\n", cc->skip_frame);
printf("Skip IDCT: %d\n", cc->skip_idct);
printf("Skip Loop Filter: %d\n", cc->skip_loop_filter);
printf("Error Recognition: %d\n", cc->err_recognition);
printf("Error Concealment: %d\n", cc->error_concealment);
printf("HW Device Context: %p\n", cc->hw_device_ctx);
printf("HW Accel: %p\n", cc->hwaccel);
printf("Pkt Timebase: %d/%d\n", cc->pkt_timebase.num, cc->pkt_timebase.den);
printf("Delay: %d\n", cc->delay);
printf("Extradata Size: %d\n", cc->extradata_size);
if (cc->extradata && cc->extradata_size > 0) {
printf("Extradata: ");
for (int i = 0; i < cc->extradata_size; i++) {
printf("%02X ", cc->extradata[i]);
}
printf("\n");
}
printf("RC Buffer Size: %d\n", cc->rc_buffer_size);
printf("RC Max Rate: %d\n", cc->rc_max_rate);
printf("RC Min Rate: %d\n", cc->rc_min_rate);
printf("Thread Type: %d\n", cc->thread_type);
printf("Ticks Per Frame: %d\n", cc->ticks_per_frame);
printf(
"Subtitle Char Encoding: %s\n",
cc->sub_charenc ? cc->sub_charenc : "N/A");
printf("\n");
}

void VideoDecoder::addVideoStreamDecoder(
int preferredStreamIndex,
const VideoStreamOptions& videoStreamOptions) {
Expand All @@ -421,13 +478,16 @@ void VideoDecoder::addVideoStreamDecoder(
TORCH_CHECK(formatContext_.get() != nullptr);

AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
int streamIndex = av_find_best_stream(
formatContext_.get(),
AVMEDIA_TYPE_VIDEO,
preferredStreamIndex,
-1,
&avCodec,
0);
// int streamIndex = av_find_best_stream(
// formatContext_.get(),
// AVMEDIA_TYPE_AUDIO,
// preferredStreamIndex,
// -1,
// &avCodec,
// 0);
int streamIndex = preferredStreamIndex;
avCodec = avcodec_find_decoder(formatContext_->streams[streamIndex]->codecpar->codec_id);

if (streamIndex < 0) {
throw std::invalid_argument("No valid stream found in input file.");
}
Expand All @@ -438,7 +498,7 @@ void VideoDecoder::addVideoStreamDecoder(
streamInfo.timeBase = formatContext_->streams[streamIndex]->time_base;
streamInfo.stream = formatContext_->streams[streamIndex];

if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) {
if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_AUDIO) {
throw std::invalid_argument(
"Stream with index " + std::to_string(streamIndex) +
" is not a video stream.");
Expand All @@ -462,7 +522,7 @@ void VideoDecoder::addVideoStreamDecoder(

AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
TORCH_CHECK(codecContext != nullptr);
codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0);
// codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0);
streamInfo.codecContext.reset(codecContext);

int retVal = avcodec_parameters_to_context(
Expand All @@ -478,12 +538,21 @@ void VideoDecoder::addVideoStreamDecoder(
false, "Invalid device type: " + videoStreamOptions.device.str());
}

if (!streamInfo.codecContext->channel_layout) {
streamInfo.codecContext->channel_layout =
av_get_default_channel_layout(streamInfo.codecContext->channels);
}

AVDictionary* opt = nullptr;
av_dict_set(&opt, "threads", "1", 0);
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
if (retVal < AVSUCCESS) {
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
}

codecContext->time_base = streamInfo.stream->time_base;
// codecContext->time_base = streamInfo.stream->time_base;
// AVRational tb{0, 1};
// codecContext->time_base = tb;
activeStreamIndex_ = streamIndex;
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
streamInfo.videoStreamOptions = videoStreamOptions;
Expand All @@ -495,6 +564,8 @@ void VideoDecoder::addVideoStreamDecoder(
for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) {
if (i != static_cast<unsigned int>(activeStreamIndex_)) {
formatContext_->streams[i]->discard = AVDISCARD_ALL;
} else {
formatContext_->streams[i]->discard = AVDISCARD_DEFAULT;
}
}

Expand All @@ -516,6 +587,8 @@ void VideoDecoder::addVideoStreamDecoder(

streamInfo.colorConversionLibrary =
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);

print_codecContext(streamInfo.codecContext.get());
}

void VideoDecoder::updateMetadataWithCodecContext(
Expand Down Expand Up @@ -835,7 +908,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
// --------------------------------------------------------------------------
// SEEKING APIs
// --------------------------------------------------------------------------

void VideoDecoder::setCursorPtsInSeconds(double seconds) {
desiredPtsSeconds_ = seconds;
}
Expand Down Expand Up @@ -923,6 +995,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
}

printf("Seeking to PTS = %ld\n", desiredPts);

int ffmepgStatus = avformat_seek_file(
formatContext_.get(),
streamInfo.streamIndex,
Expand All @@ -936,13 +1010,73 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
}
decodeStats_.numFlushes++;

printf("Flushing\n");
avcodec_flush_buffers(streamInfo.codecContext.get());
}

// --------------------------------------------------------------------------
// LOW-LEVEL DECODING
// --------------------------------------------------------------------------

void print_packet(AVPacket* packet) {
printf(
"Packet PTS: %ld, DTS: %ld, Duration: %d, Size: %d, Stream Index: %d\n",
packet->pts,
packet->dts,
packet->duration,
packet->size,
packet->stream_index);
// Optional: Calculate a simple checksum or hash of the packet data
unsigned long checksum = 0;
for (int i = 0; i < packet->size; i++) {
checksum += packet->data[i];
}
printf("Packet Checksum: %lu\n\n", checksum);
fflush(stdout);
}

void print_avFrame(AVFrame* avFrame) {
printf("Format: %d\n", avFrame->format);
printf("Width: %d\n", avFrame->width);
printf("Height: %d\n", avFrame->height);
printf(
"Channels: %d\n",
av_get_channel_layout_nb_channels(avFrame->channel_layout));
printf("Channel Layout: %ld\n", avFrame->channel_layout);
printf("Number of Samples: %d\n", avFrame->nb_samples);
printf("PTS: %ld\n", avFrame->pts);
printf("Packet DTS: %ld\n", avFrame->pkt_dts);
printf("Packet Duration: %d\n", avFrame->pkt_duration);
printf("Packet Pos: %d\n", avFrame->pkt_pos);
for (int i = 0; i < AV_NUM_DATA_POINTERS; i++) {
if (avFrame->data[i]) {
printf("Data[%d] Line Size: %d\n", i, avFrame->linesize[i]);
}
}
printf("Color Range: %d\n", avFrame->color_range);
printf("Color Primaries: %d\n", avFrame->color_primaries);
printf("Color Transfer Characteristic: %d\n", avFrame->color_trc);
printf("Color Space: %d\n", avFrame->colorspace);
printf("Chroma Location: %d\n", avFrame->chroma_location);
printf(
"Sample Aspect Ratio: %d/%d\n",
avFrame->sample_aspect_ratio.num,
avFrame->sample_aspect_ratio.den);
printf("Key Frame: %d\n", avFrame->key_frame);
printf("Picture Type: %d\n", avFrame->pict_type);
printf("Coded Picture Number: %d\n", avFrame->coded_picture_number);
printf("Display Picture Number: %d\n", avFrame->display_picture_number);

unsigned long checksum = 0;
for (int i = 0; i < 100; i++) {
checksum += avFrame->extended_data[0][i];
}
printf("Frame Checksum: %lu\n", checksum);
printf("\n");
fflush(stdout);
}

VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
std::function<bool(AVFrame*)> filterFunction) {
if (activeStreamIndex_ == NO_ACTIVE_STREAM) {
Expand All @@ -959,31 +1093,40 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

// Need to get the next frame or error from PopFrame.
UniqueAVFrame avFrame(av_frame_alloc());
AutoAVPacket autoAVPacket;
int ffmpegStatus = AVSUCCESS;
bool reachedEOF = false;
while (true) {
if (veryFirstCall_) {
veryFirstCall_ = false;
goto av_read_frame_call;
}
ffmpegStatus =
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
printf("output of avcodec_receive_frame: %d\n", ffmpegStatus);

if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
// Non-retriable error
// printf("Non-retriable error\n");
break;
}

decodeStats_.numFramesReceivedByDecoder++;
// Is this the kind of frame we're looking for?
if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) {
// Yes, this is the frame we'll return; break out of the decoding loop.
// printf("%ld %ld\n", avFrame->pts, avFrame->duration);

// printf("Found frame\n");
break;
} else if (ffmpegStatus == AVSUCCESS) {
// No, but we received a valid frame - just not the kind we're looking
// for. The logic below will read packets and send them to the decoder.
// But since we did just receive a frame, we should skip reading more
// packets and sending them to the decoder and just try to receive more
// frames from the decoder.
// printf("Got AVSUCCESS, continue\n");
continue;
}

Expand All @@ -995,6 +1138,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(

// We still haven't found the frame we're looking for. So let's read more
// packets and send them to the decoder.
av_read_frame_call:
ReferenceAVPacket packet(autoAVPacket);
do {
ffmpegStatus = av_read_frame(formatContext_.get(), packet.get());
Expand Down Expand Up @@ -1031,8 +1175,11 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(

// We got a valid packet. Send it to the decoder, and we'll receive it in
// the next iteration.
printf("Sending packet:\n");
print_packet(packet.get());
ffmpegStatus =
avcodec_send_packet(streamInfo.codecContext.get(), packet.get());

if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error(
"Could not push packet to decoder: " +
Expand Down Expand Up @@ -1061,6 +1208,8 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
// the file and that will flush the decoder.
streamInfo.currentPts = avFrame->pts;
streamInfo.currentDuration = getDuration(avFrame);
printf("Received avFrame:\n");
print_avFrame(avFrame.get());

return AVFrameStream(std::move(avFrame), activeStreamIndex_);
}
Expand All @@ -1071,18 +1220,48 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(

VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
VideoDecoder::AVFrameStream& avFrameStream,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {
// Convert the frame to tensor.
FrameOutput frameOutput;
int streamIndex = avFrameStream.streamIndex;
AVFrame* avFrame = avFrameStream.avFrame.get();
frameOutput.streamIndex = streamIndex;
auto& streamInfo = streamInfos_[streamIndex];
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO);
frameOutput.ptsSeconds = ptsToSeconds(
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
frameOutput.durationSeconds = ptsToSeconds(
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);

auto numSamples = avFrame->nb_samples;
auto sampleRate = avFrame->sample_rate;
auto numChannels = avFrame->ch_layout.nb_channels;

// printf("numSamples: %d\n", numSamples);
// printf("sample rate: %d\n", sampleRate);

// printf("numChannels: %d\n", numChannels);
int bytesPerSample =
av_get_bytes_per_sample(streamInfo.codecContext->sample_fmt);
// printf("bytes per sample: %d\n", bytesPerSample);

// float32 Planar
// torch::Tensor data = torch::empty({numChannels, numSamples}, torch::kFloat32);
// for (auto channel = 0; channel < numChannels; ++channel) {
// float* dataFloatPtr = (float*)(avFrame->data[channel]);
// for (auto sampleIndex = 0; sampleIndex < numSamples; ++sampleIndex) {
// data[channel][sampleIndex] = dataFloatPtr[sampleIndex];
// }
// }
// float32 non-Planar
torch::Tensor data = torch::empty({numSamples, numChannels}, torch::kFloat32);
uint8_t* pData = static_cast<uint8_t*>(data.data_ptr());
memcpy(pData, avFrame->extended_data[0], numSamples * numChannels * bytesPerSample);
data = data.permute({1, 0});

frameOutput.data = data;
return frameOutput;

// TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
convertAVFrameToFrameOutputOnCPU(
Expand Down Expand Up @@ -1310,6 +1489,7 @@ torch::Tensor allocateEmptyHWCTensor(
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(
int streamIndex,
torch::Tensor& hwcTensor) {
return hwcTensor;
if (streamInfos_[streamIndex].videoStreamOptions.dimensionOrder == "NHWC") {
return hwcTensor;
}
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ class VideoDecoder {
// ATTRIBUTES
// --------------------------------------------------------------------------

bool veryFirstCall_ = true;
SeekMode seekMode_;
ContainerMetadata containerMetadata_;
UniqueAVFormatContext formatContext_;
Expand Down
Loading
Loading