@@ -880,7 +880,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
880
880
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
881
881
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
882
882
// found a way to do that with filtegraph.
883
- // TODO: Figure out whether that's possilbe !
883
+ // TODO: Figure out whether that's possible !
884
884
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
885
885
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
886
886
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU (
@@ -890,41 +890,68 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
890
890
int streamIndex = rawOutput.streamIndex ;
891
891
AVFrame* frame = rawOutput.frame .get ();
892
892
auto & streamInfo = streams_[streamIndex];
893
- torch::Tensor tensor;
893
+
894
+ auto frameDims =
895
+ getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , *frame);
896
+ int expectedOutputHeight = frameDims.height ;
897
+ int expectedOutputWidth = frameDims.width ;
898
+
899
+ if (preAllocatedOutputTensor.has_value ()) {
900
+ auto shape = preAllocatedOutputTensor.value ().sizes ();
901
+ TORCH_CHECK (
902
+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
903
+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
904
+ " Expected pre-allocated tensor of shape " ,
905
+ expectedOutputHeight,
906
+ " x" ,
907
+ expectedOutputWidth,
908
+ " x3, got " ,
909
+ shape);
910
+ }
911
+
912
+ torch::Tensor outputTensor;
894
913
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
895
914
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
896
- auto frameDims =
897
- getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , *frame);
898
- int height = frameDims.height ;
899
- int width = frameDims.width ;
900
- if (preAllocatedOutputTensor.has_value ()) {
901
- tensor = preAllocatedOutputTensor.value ();
902
- auto shape = tensor.sizes ();
903
- TORCH_CHECK (
904
- (shape.size () == 3 ) && (shape[0 ] == height) &&
905
- (shape[1 ] == width) && (shape[2 ] == 3 ),
906
- " Expected tensor of shape " ,
907
- height,
908
- " x" ,
909
- width,
910
- " x3, got " ,
911
- shape);
912
- } else {
913
- tensor = allocateEmptyHWCTensor (height, width, torch::kCPU );
914
- }
915
- rawOutput.data = tensor.data_ptr <uint8_t >();
916
- convertFrameToBufferUsingSwsScale (rawOutput);
917
-
918
- output.frame = tensor;
915
+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
916
+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
917
+
918
+ int resultHeight =
919
+ convertFrameToBufferUsingSwsScale (streamIndex, frame, outputTensor);
920
+ // If this check failed, it would mean that the frame wasn't reshaped to
921
+ // the expected height.
922
+ // TODO: Can we do the same check for width?
923
+ TORCH_CHECK (
924
+ resultHeight == expectedOutputHeight,
925
+ " resultHeight != expectedOutputHeight: " ,
926
+ resultHeight,
927
+ " != " ,
928
+ expectedOutputHeight);
929
+
930
+ output.frame = outputTensor;
919
931
} else if (
920
932
streamInfo.colorConversionLibrary ==
921
933
ColorConversionLibrary::FILTERGRAPH) {
922
- tensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
934
+ outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
935
+
936
+ // Similarly to above, if this check fails it means the frame wasn't
937
+ // reshaped to its expected dimensions by filtergraph.
938
+ auto shape = outputTensor.sizes ();
939
+ TORCH_CHECK (
940
+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
941
+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
942
+ " Expected output tensor of shape " ,
943
+ expectedOutputHeight,
944
+ " x" ,
945
+ expectedOutputWidth,
946
+ " x3, got " ,
947
+ shape);
923
948
if (preAllocatedOutputTensor.has_value ()) {
924
- preAllocatedOutputTensor.value ().copy_ (tensor);
949
+ // We have already validated that preAllocatedOutputTensor and
950
+ // outputTensor have the same shape.
951
+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
925
952
output.frame = preAllocatedOutputTensor.value ();
926
953
} else {
927
- output.frame = tensor ;
954
+ output.frame = outputTensor ;
928
955
}
929
956
} else {
930
957
throw std::runtime_error (
@@ -1303,24 +1330,23 @@ double VideoDecoder::getPtsSecondsForFrame(
1303
1330
return ptsToSeconds (stream.allFrames [frameIndex].pts , stream.timeBase );
1304
1331
}
1305
1332
1306
- void VideoDecoder::convertFrameToBufferUsingSwsScale (
1307
- RawDecodedOutput& rawOutput) {
1308
- AVFrame* frame = rawOutput. frame . get ();
1309
- int streamIndex = rawOutput. streamIndex ;
1333
+ int VideoDecoder::convertFrameToBufferUsingSwsScale (
1334
+ int streamIndex,
1335
+ const AVFrame* frame,
1336
+ torch::Tensor& outputTensor) {
1310
1337
enum AVPixelFormat frameFormat =
1311
1338
static_cast <enum AVPixelFormat>(frame->format );
1312
1339
StreamInfo& activeStream = streams_[streamIndex];
1313
- auto frameDims =
1314
- getHeightAndWidthFromOptionsOrAVFrame (activeStream.options , *frame);
1315
- int outputHeight = frameDims.height ;
1316
- int outputWidth = frameDims.width ;
1340
+
1341
+ int expectedOutputHeight = outputTensor.sizes ()[0 ];
1342
+ int expectedOutputWidth = outputTensor.sizes ()[1 ];
1317
1343
if (activeStream.swsContext .get () == nullptr ) {
1318
1344
SwsContext* swsContext = sws_getContext (
1319
1345
frame->width ,
1320
1346
frame->height ,
1321
1347
frameFormat,
1322
- outputWidth ,
1323
- outputHeight ,
1348
+ expectedOutputWidth ,
1349
+ expectedOutputHeight ,
1324
1350
AV_PIX_FMT_RGB24,
1325
1351
SWS_BILINEAR,
1326
1352
nullptr ,
@@ -1352,8 +1378,8 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
1352
1378
}
1353
1379
SwsContext* swsContext = activeStream.swsContext .get ();
1354
1380
uint8_t * pointers[4 ] = {
1355
- static_cast <uint8_t *>(rawOutput. data ), nullptr , nullptr , nullptr };
1356
- int linesizes[4 ] = {outputWidth * 3 , 0 , 0 , 0 };
1381
+ outputTensor. data_ptr <uint8_t >( ), nullptr , nullptr , nullptr };
1382
+ int linesizes[4 ] = {expectedOutputWidth * 3 , 0 , 0 , 0 };
1357
1383
int resultHeight = sws_scale (
1358
1384
swsContext,
1359
1385
frame->data ,
@@ -1362,9 +1388,7 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
1362
1388
frame->height ,
1363
1389
pointers,
1364
1390
linesizes);
1365
- TORCH_CHECK (
1366
- outputHeight == resultHeight,
1367
- " outputHeight(" + std::to_string (resultHeight) + " ) != resultHeight" );
1391
+ return resultHeight;
1368
1392
}
1369
1393
1370
1394
torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph (
@@ -1379,8 +1403,7 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
1379
1403
ffmpegStatus =
1380
1404
av_buffersink_get_frame (filterState.sinkContext , filteredFrame.get ());
1381
1405
TORCH_CHECK_EQ (filteredFrame->format , AV_PIX_FMT_RGB24);
1382
- auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1383
- streams_[streamIndex].options , *filteredFrame.get ());
1406
+ auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredFrame.get ());
1384
1407
int height = frameDims.height ;
1385
1408
int width = frameDims.width ;
1386
1409
std::vector<int64_t > shape = {height, width, 3 };
@@ -1406,6 +1429,10 @@ VideoDecoder::~VideoDecoder() {
1406
1429
}
1407
1430
}
1408
1431
1432
+ FrameDims getHeightAndWidthFromResizedAVFrame (const AVFrame& resizedAVFrame) {
1433
+ return FrameDims (resizedAVFrame.height , resizedAVFrame.width );
1434
+ }
1435
+
1409
1436
FrameDims getHeightAndWidthFromOptionsOrMetadata (
1410
1437
const VideoDecoder::VideoStreamDecoderOptions& options,
1411
1438
const VideoDecoder::StreamMetadata& metadata) {
0 commit comments