Skip to content

Commit 619f1ce

Browse files
authored
[torchcodec] delay device map init to runtime and fix targets
Differential Revision: D72722867 Pull Request resolved: #631
1 parent 82750c8 commit 619f1ce

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

src/torchcodec/_core/DeviceInterface.cpp

+19-8
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
namespace facebook::torchcodec {
1212

1313
namespace {
14+
using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
1415
std::mutex g_interface_mutex;
15-
std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map;
16+
std::unique_ptr<DeviceInterfaceMap> g_interface_map;
1617

1718
std::string getDeviceType(const std::string& device) {
1819
size_t pos = device.find(':');
@@ -28,11 +29,18 @@ bool registerDeviceInterface(
2829
torch::DeviceType deviceType,
2930
CreateDeviceInterfaceFn createInterface) {
3031
std::scoped_lock lock(g_interface_mutex);
32+
if (!g_interface_map) {
33+
// We delay this initialization until runtime to avoid the Static
34+
// Initialization Order Fiasco:
35+
//
36+
// https://en.cppreference.com/w/cpp/language/siof
37+
g_interface_map = std::make_unique<DeviceInterfaceMap>();
38+
}
3139
TORCH_CHECK(
32-
g_interface_map.find(deviceType) == g_interface_map.end(),
40+
g_interface_map->find(deviceType) == g_interface_map->end(),
3341
"Device interface already registered for ",
3442
deviceType);
35-
g_interface_map.insert({deviceType, createInterface});
43+
g_interface_map->insert({deviceType, createInterface});
3644
return true;
3745
}
3846

@@ -45,14 +53,16 @@ torch::Device createTorchDevice(const std::string device) {
4553
std::scoped_lock lock(g_interface_mutex);
4654
std::string deviceType = getDeviceType(device);
4755
auto deviceInterface = std::find_if(
48-
g_interface_map.begin(),
49-
g_interface_map.end(),
56+
g_interface_map->begin(),
57+
g_interface_map->end(),
5058
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
5159
return device.rfind(
5260
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
5361
});
5462
TORCH_CHECK(
55-
deviceInterface != g_interface_map.end(), "Unsupported device: ", device);
63+
deviceInterface != g_interface_map->end(),
64+
"Unsupported device: ",
65+
device);
5666

5767
return torch::Device(device);
5868
}
@@ -67,11 +77,12 @@ std::unique_ptr<DeviceInterface> createDeviceInterface(
6777

6878
std::scoped_lock lock(g_interface_mutex);
6979
TORCH_CHECK(
70-
g_interface_map.find(deviceType) != g_interface_map.end(),
80+
g_interface_map->find(deviceType) != g_interface_map->end(),
7181
"Unsupported device: ",
7282
device);
7383

74-
return std::unique_ptr<DeviceInterface>(g_interface_map[deviceType](device));
84+
return std::unique_ptr<DeviceInterface>(
85+
(*g_interface_map)[deviceType](device));
7586
}
7687

7788
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)