Skip to content

Commit c2bea4b

Browse files
authored
Validation of allocated output tensor shapes (#339)
1 parent 284a2f0 commit c2bea4b

File tree

3 files changed

+108
-69
lines changed

3 files changed

+108
-69
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ void convertAVFrameToDecodedOutputOnCuda(
224224

225225
auto start = std::chrono::high_resolution_clock::now();
226226

227+
// TODO height and width info of output tensor comes from the metadata, which
228+
// may not be accurate. How do we make sure we won't corrupt memory if the
229+
// allocated tensor is too short/large?
227230
NppStatus status = nppiNV12ToRGB_8u_P2C3R(
228231
input,
229232
src->linesize[0],

src/torchcodec/decoders/_core/VideoDecoder.cpp

+72-45
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
880880
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
881881
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
882882
// found a way to do that with filtegraph.
883-
// TODO: Figure out whether that's possilbe!
883+
// TODO: Figure out whether that's possible!
884884
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
885885
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
886886
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
@@ -890,41 +890,68 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
890890
int streamIndex = rawOutput.streamIndex;
891891
AVFrame* frame = rawOutput.frame.get();
892892
auto& streamInfo = streams_[streamIndex];
893-
torch::Tensor tensor;
893+
894+
auto frameDims =
895+
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame);
896+
int expectedOutputHeight = frameDims.height;
897+
int expectedOutputWidth = frameDims.width;
898+
899+
if (preAllocatedOutputTensor.has_value()) {
900+
auto shape = preAllocatedOutputTensor.value().sizes();
901+
TORCH_CHECK(
902+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
903+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
904+
"Expected pre-allocated tensor of shape ",
905+
expectedOutputHeight,
906+
"x",
907+
expectedOutputWidth,
908+
"x3, got ",
909+
shape);
910+
}
911+
912+
torch::Tensor outputTensor;
894913
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
895914
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
896-
auto frameDims =
897-
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame);
898-
int height = frameDims.height;
899-
int width = frameDims.width;
900-
if (preAllocatedOutputTensor.has_value()) {
901-
tensor = preAllocatedOutputTensor.value();
902-
auto shape = tensor.sizes();
903-
TORCH_CHECK(
904-
(shape.size() == 3) && (shape[0] == height) &&
905-
(shape[1] == width) && (shape[2] == 3),
906-
"Expected tensor of shape ",
907-
height,
908-
"x",
909-
width,
910-
"x3, got ",
911-
shape);
912-
} else {
913-
tensor = allocateEmptyHWCTensor(height, width, torch::kCPU);
914-
}
915-
rawOutput.data = tensor.data_ptr<uint8_t>();
916-
convertFrameToBufferUsingSwsScale(rawOutput);
917-
918-
output.frame = tensor;
915+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
916+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
917+
918+
int resultHeight =
919+
convertFrameToBufferUsingSwsScale(streamIndex, frame, outputTensor);
920+
// If this check failed, it would mean that the frame wasn't reshaped to
921+
// the expected height.
922+
// TODO: Can we do the same check for width?
923+
TORCH_CHECK(
924+
resultHeight == expectedOutputHeight,
925+
"resultHeight != expectedOutputHeight: ",
926+
resultHeight,
927+
" != ",
928+
expectedOutputHeight);
929+
930+
output.frame = outputTensor;
919931
} else if (
920932
streamInfo.colorConversionLibrary ==
921933
ColorConversionLibrary::FILTERGRAPH) {
922-
tensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
934+
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
935+
936+
// Similarly to above, if this check fails it means the frame wasn't
937+
// reshaped to its expected dimensions by filtergraph.
938+
auto shape = outputTensor.sizes();
939+
TORCH_CHECK(
940+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
941+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
942+
"Expected output tensor of shape ",
943+
expectedOutputHeight,
944+
"x",
945+
expectedOutputWidth,
946+
"x3, got ",
947+
shape);
923948
if (preAllocatedOutputTensor.has_value()) {
924-
preAllocatedOutputTensor.value().copy_(tensor);
949+
// We have already validated that preAllocatedOutputTensor and
950+
// outputTensor have the same shape.
951+
preAllocatedOutputTensor.value().copy_(outputTensor);
925952
output.frame = preAllocatedOutputTensor.value();
926953
} else {
927-
output.frame = tensor;
954+
output.frame = outputTensor;
928955
}
929956
} else {
930957
throw std::runtime_error(
@@ -1303,24 +1330,23 @@ double VideoDecoder::getPtsSecondsForFrame(
13031330
return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase);
13041331
}
13051332

1306-
void VideoDecoder::convertFrameToBufferUsingSwsScale(
1307-
RawDecodedOutput& rawOutput) {
1308-
AVFrame* frame = rawOutput.frame.get();
1309-
int streamIndex = rawOutput.streamIndex;
1333+
int VideoDecoder::convertFrameToBufferUsingSwsScale(
1334+
int streamIndex,
1335+
const AVFrame* frame,
1336+
torch::Tensor& outputTensor) {
13101337
enum AVPixelFormat frameFormat =
13111338
static_cast<enum AVPixelFormat>(frame->format);
13121339
StreamInfo& activeStream = streams_[streamIndex];
1313-
auto frameDims =
1314-
getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, *frame);
1315-
int outputHeight = frameDims.height;
1316-
int outputWidth = frameDims.width;
1340+
1341+
int expectedOutputHeight = outputTensor.sizes()[0];
1342+
int expectedOutputWidth = outputTensor.sizes()[1];
13171343
if (activeStream.swsContext.get() == nullptr) {
13181344
SwsContext* swsContext = sws_getContext(
13191345
frame->width,
13201346
frame->height,
13211347
frameFormat,
1322-
outputWidth,
1323-
outputHeight,
1348+
expectedOutputWidth,
1349+
expectedOutputHeight,
13241350
AV_PIX_FMT_RGB24,
13251351
SWS_BILINEAR,
13261352
nullptr,
@@ -1352,8 +1378,8 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13521378
}
13531379
SwsContext* swsContext = activeStream.swsContext.get();
13541380
uint8_t* pointers[4] = {
1355-
static_cast<uint8_t*>(rawOutput.data), nullptr, nullptr, nullptr};
1356-
int linesizes[4] = {outputWidth * 3, 0, 0, 0};
1381+
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
1382+
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
13571383
int resultHeight = sws_scale(
13581384
swsContext,
13591385
frame->data,
@@ -1362,9 +1388,7 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13621388
frame->height,
13631389
pointers,
13641390
linesizes);
1365-
TORCH_CHECK(
1366-
outputHeight == resultHeight,
1367-
"outputHeight(" + std::to_string(resultHeight) + ") != resultHeight");
1391+
return resultHeight;
13681392
}
13691393

13701394
torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
@@ -1379,8 +1403,7 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13791403
ffmpegStatus =
13801404
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
13811405
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);
1382-
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
1383-
streams_[streamIndex].options, *filteredFrame.get());
1406+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get());
13841407
int height = frameDims.height;
13851408
int width = frameDims.width;
13861409
std::vector<int64_t> shape = {height, width, 3};
@@ -1406,6 +1429,10 @@ VideoDecoder::~VideoDecoder() {
14061429
}
14071430
}
14081431

1432+
FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) {
1433+
return FrameDims(resizedAVFrame.height, resizedAVFrame.width);
1434+
}
1435+
14091436
FrameDims getHeightAndWidthFromOptionsOrMetadata(
14101437
const VideoDecoder::VideoStreamDecoderOptions& options,
14111438
const VideoDecoder::StreamMetadata& metadata) {

src/torchcodec/decoders/_core/VideoDecoder.h

+33-24
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,10 @@ class VideoDecoder {
383383
torch::Tensor convertFrameToTensorUsingFilterGraph(
384384
int streamIndex,
385385
const AVFrame* frame);
386-
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput);
386+
int convertFrameToBufferUsingSwsScale(
387+
int streamIndex,
388+
const AVFrame* frame,
389+
torch::Tensor& outputTensor);
387390
DecodedOutput convertAVFrameToDecodedOutput(
388391
RawDecodedOutput& rawOutput,
389392
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
@@ -426,30 +429,32 @@ class VideoDecoder {
426429
// MaybePermuteHWC2CHW().
427430
//
428431
// Also, importantly, the way we figure out the the height and width of the
429-
// output frame varies and depends on the decoding entry-point:
430-
// - In all cases, if the user requested specific height and width from the
431-
// options, we honor that. Otherwise we fall into one of the categories below.
432-
// - In Batch decoding APIs (e.g. getFramesAtIndices), we get height and width
433-
// from the stream metadata, which itself got its value from the CodecContext,
434-
// when the stream was added.
435-
// - In single frames APIs:
436-
// - On CPU we get height and width from the AVFrame.
437-
// - On GPU, we get height and width from the metadata (same as batch APIs)
438-
//
439-
// These 2 strategies are encapsulated within
440-
// getHeightAndWidthFromOptionsOrMetadata() and
441-
// getHeightAndWidthFromOptionsOrAVFrame(). The reason they exist is to make it
442-
// very obvious which logic is used in which place, and they allow for `git
443-
// grep`ing.
432+
// output frame tensor varies, and depends on the decoding entry-point. In
433+
// *decreasing order of accuracy*, we use the following sources for determining
434+
// height and width:
435+
// - getHeightAndWidthFromResizedAVFrame(). This is the height and width of the
436+
// AVframe, *post*-resizing. This is only used for single-frame decoding APIs,
437+
// on CPU, with filtergraph.
438+
// - getHeightAndWidthFromOptionsOrAVFrame(). This is the height and width from
439+
// the user-specified options if they exist, or the height and width of the
440+
// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within
441+
// our code or within FFmpeg code, this should be exactly the same as
442+
// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame
443+
// decoding APIs, on CPU, with swscale.
444+
// - getHeightAndWidthFromOptionsOrMetadata(). This is the height and width from
445+
// the user-specified options if they exist, or the height and width form the
446+
// stream metadata, which itself got its value from the CodecContext, when the
447+
// stream was added. This is used by batch decoding APIs, or by GPU-APIs (both
448+
// batch and single-frames).
444449
//
445-
// The source of truth for height and width really is the AVFrame: it's the
446-
// decoded ouptut from FFmpeg. The info from the metadata (i.e. from the
447-
// CodecContext) may not be as accurate. However, the AVFrame is only available
448-
// late in the call stack, when the frame is decoded, while the CodecContext is
449-
// available early when a stream is added. This is why we use the CodecContext
450-
// for pre-allocating batched output tensors (we could pre-allocate those only
451-
// once we decode the first frame to get the info frame the AVFrame, but that's
452-
// a more complex logic).
450+
// The source of truth for height and width really is the (resized) AVFrame:
451+
// it's the decoded ouptut from FFmpeg. The info from the metadata (i.e. from
452+
// the CodecContext) may not be as accurate. However, the AVFrame is only
453+
// available late in the call stack, when the frame is decoded, while the
454+
// CodecContext is available early when a stream is added. This is why we use
455+
// the CodecContext for pre-allocating batched output tensors (we could
456+
// pre-allocate those only once we decode the first frame to get the info frame
457+
// the AVFrame, but that's a more complex logic).
453458
//
454459
// Because the sources for height and width may disagree, we may end up with
455460
// conflicts: e.g. if we pre-allocate a batch output tensor based on the
@@ -463,6 +468,10 @@ struct FrameDims {
463468
FrameDims(int h, int w) : height(h), width(w) {}
464469
};
465470

471+
// There's nothing preventing you from calling this on a non-resized frame, but
472+
// please don't.
473+
FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame);
474+
466475
FrameDims getHeightAndWidthFromOptionsOrMetadata(
467476
const VideoDecoder::VideoStreamDecoderOptions& options,
468477
const VideoDecoder::StreamMetadata& metadata);

0 commit comments

Comments
 (0)