Skip to content

Commit cec7e9e

Browse files
committed
Define CpuDeviceInterface
Fixes: #619 * Commit defines `CpuDeviceInterface` and moves video `*OnCPU` methods from `SingleStreamDecoder` to it. * Audio `*OnCPU` methods left in `SingleStreamDecoder` * Constructor API of `DeviceInterface` was changed to allow passing `AVRational timeBase` required to initialize ffmpeg filter graph Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 619f1ce commit cec7e9e

9 files changed

+193
-140
lines changed

src/torchcodec/_core/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ function(make_torchcodec_libraries
6060
set(decoder_sources
6161
AVIOContextHolder.cpp
6262
FFMPEGCommon.cpp
63-
DeviceInterface.cpp
63+
DeviceInterface.cpp
64+
CpuDeviceInterface.cpp
6465
SingleStreamDecoder.cpp
6566
# TODO: lib name should probably not be "*_decoder*" now that it also
6667
# contains an encoder
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "src/torchcodec/_core/CpuDeviceInterface.h"
2+
3+
namespace facebook::torchcodec {
4+
namespace {
5+
6+
bool g_cpu = registerDeviceInterface(
7+
torch::kCPU,
8+
[](const torch::Device& device, const AVRational& timeBase) {
9+
return new CpuDeviceInterface(device, timeBase);
10+
});
11+
12+
} // namespace
13+
14+
CpuDeviceInterface::CpuDeviceInterface(
15+
const torch::Device& device,
16+
const AVRational& timeBase)
17+
: DeviceInterface(device, timeBase) {
18+
if (device_.type() != torch::kCPU) {
19+
throw std::runtime_error("Unsupported device: " + device_.str());
20+
}
21+
}
22+
23+
} // namespace facebook::torchcodec
+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include "src/torchcodec/_core/DeviceInterface.h"
10+
11+
namespace facebook::torchcodec {
12+
13+
class CpuDeviceInterface : public DeviceInterface {
14+
public:
15+
CpuDeviceInterface(const torch::Device& device, const AVRational& timeBase);
16+
17+
virtual ~CpuDeviceInterface() {}
18+
19+
std::optional<const AVCodec*> findCodec(
20+
[[maybe_unused]] const AVCodecID& codecId) override {
21+
return std::nullopt;
22+
}
23+
24+
void initializeContext(
25+
[[maybe_unused]] AVCodecContext* codecContext) override {}
26+
27+
void convertAVFrameToFrameOutput(
28+
const VideoStreamOptions& videoStreamOptions,
29+
UniqueAVFrame& avFrame,
30+
FrameOutput& frameOutput,
31+
std::optional<torch::Tensor> preAllocatedOutputTensor =
32+
std::nullopt) override;
33+
34+
private:
35+
int convertAVFrameToTensorUsingSwsScale(
36+
const UniqueAVFrame& avFrame,
37+
torch::Tensor& outputTensor);
38+
39+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
40+
const UniqueAVFrame& avFrame);
41+
42+
struct FilterGraphContext {
43+
UniqueAVFilterGraph filterGraph;
44+
AVFilterContext* sourceContext = nullptr;
45+
AVFilterContext* sinkContext = nullptr;
46+
};
47+
48+
struct DecodedFrameContext {
49+
int decodedWidth;
50+
int decodedHeight;
51+
AVPixelFormat decodedFormat;
52+
AVRational decodedAspectRatio;
53+
int expectedWidth;
54+
int expectedHeight;
55+
bool operator==(const DecodedFrameContext&);
56+
bool operator!=(const DecodedFrameContext&);
57+
};
58+
59+
void createSwsContext(
60+
const DecodedFrameContext& frameContext,
61+
const enum AVColorSpace colorspace);
62+
63+
void createFilterGraph(
64+
const DecodedFrameContext& frameContext,
65+
const VideoStreamOptions& videoStreamOptions);
66+
67+
// color-conversion fields. Only one of FilterGraphContext and
68+
// UniqueSwsContext should be non-null.
69+
FilterGraphContext filterGraphContext_;
70+
UniqueSwsContext swsContext_;
71+
72+
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
73+
// be created before decoding a new frame.
74+
DecodedFrameContext prevFrameContext_;
75+
};
76+
77+
} // namespace facebook::torchcodec

src/torchcodec/_core/CudaDeviceInterface.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19-
bool g_cuda =
20-
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
21-
return new CudaDeviceInterface(device);
19+
bool g_cuda = registerDeviceInterface(
20+
torch::kCUDA,
21+
[](const torch::Device& device, const AVRational& timeBase) {
22+
return new CudaDeviceInterface(device, timeBase);
2223
});
2324

2425
// We reuse cuda contexts across VideoDeoder instances. This is because
@@ -164,8 +165,10 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
164165
}
165166
} // namespace
166167

167-
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
168-
: DeviceInterface(device) {
168+
CudaDeviceInterface::CudaDeviceInterface(
169+
const torch::Device& device,
170+
const AVRational& timeBase)
171+
: DeviceInterface(device, timeBase) {
169172
if (device_.type() != torch::kCUDA) {
170173
throw std::runtime_error("Unsupported device: " + device_.str());
171174
}

src/torchcodec/_core/CudaDeviceInterface.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
class CudaDeviceInterface : public DeviceInterface {
1414
public:
15-
CudaDeviceInterface(const torch::Device& device);
15+
CudaDeviceInterface(const torch::Device& device, const AVRational& timeBase);
1616

1717
virtual ~CudaDeviceInterface();
1818

src/torchcodec/_core/DeviceInterface.cpp

+3-12
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,6 @@ bool registerDeviceInterface(
4545
}
4646

4747
torch::Device createTorchDevice(const std::string device) {
48-
// TODO: remove once DeviceInterface for CPU is implemented
49-
if (device == "cpu") {
50-
return torch::kCPU;
51-
}
52-
5348
std::scoped_lock lock(g_interface_mutex);
5449
std::string deviceType = getDeviceType(device);
5550
auto deviceInterface = std::find_if(
@@ -68,21 +63,17 @@ torch::Device createTorchDevice(const std::string device) {
6863
}
6964

7065
std::unique_ptr<DeviceInterface> createDeviceInterface(
71-
const torch::Device& device) {
66+
const torch::Device& device,
67+
const AVRational& timeBase) {
7268
auto deviceType = device.type();
73-
// TODO: remove once DeviceInterface for CPU is implemented
74-
if (deviceType == torch::kCPU) {
75-
return nullptr;
76-
}
77-
7869
std::scoped_lock lock(g_interface_mutex);
7970
TORCH_CHECK(
8071
g_interface_map->find(deviceType) != g_interface_map->end(),
8172
"Unsupported device: ",
8273
device);
8374

8475
return std::unique_ptr<DeviceInterface>(
85-
(*g_interface_map)[deviceType](device));
76+
(*g_interface_map)[deviceType](device, timeBase));
8677
}
8778

8879
} // namespace facebook::torchcodec

src/torchcodec/_core/DeviceInterface.h

+7-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ namespace facebook::torchcodec {
2727

2828
class DeviceInterface {
2929
public:
30-
DeviceInterface(const torch::Device& device) : device_(device) {}
30+
DeviceInterface(const torch::Device& device, const AVRational& timeBase)
31+
: device_(device), timeBase_(timeBase) {}
3132

3233
virtual ~DeviceInterface(){};
3334

@@ -49,10 +50,11 @@ class DeviceInterface {
4950

5051
protected:
5152
torch::Device device_;
53+
AVRational timeBase_;
5254
};
5355

54-
using CreateDeviceInterfaceFn =
55-
std::function<DeviceInterface*(const torch::Device& device)>;
56+
using CreateDeviceInterfaceFn = std::function<
57+
DeviceInterface*(const torch::Device& device, const AVRational& timeBase)>;
5658

5759
bool registerDeviceInterface(
5860
torch::DeviceType deviceType,
@@ -61,6 +63,7 @@ bool registerDeviceInterface(
6163
torch::Device createTorchDevice(const std::string device);
6264

6365
std::unique_ptr<DeviceInterface> createDeviceInterface(
64-
const torch::Device& device);
66+
const torch::Device& device,
67+
const AVRational& timeBase);
6568

6669
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)