forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDeviceAccelerator.cpp
120 lines (103 loc) · 3.92 KB
/
DeviceAccelerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <ATen/Context.h>
#include <ATen/DeviceAccelerator.h>
#include <c10/core/impl/VirtualGuardImpl.h>
namespace at::accelerator {
std::optional<c10::DeviceType> getAccelerator(bool checked) {
// 1. Check PrivateUse1 backends
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
// Note that this check is only for hook registration and thus is NOT initializing
// the device or poisoning fork.
if (is_privateuse1_backend_registered()) {
return kPrivateUse1;
}
// 2. Check runtime backends
// This state is temporary, these runtime checks should be moved to compile-time
// once they provide the new isBuilt API and we are sure they're never in the
// same binary as another accelerator.
#define DETECT_RUNTIME_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
return k##device_name; \
}
DETECT_RUNTIME_ACCELERATOR(MTIA)
#undef DETECT_RUNTIME_ACCELERATOR
// 2. Check compile-time backends
std::optional<c10::DeviceType> device_type = std::nullopt;
#define DETECT_AND_ASSIGN_ACCELERATOR_COMP(device_name) \
if (at::detail::get##device_name##Hooks().isBuilt()) { \
TORCH_CHECK( \
!device_type.has_value(), \
"Cannot have both " #device_name " and ", \
device_type.value(), "."); \
device_type = k##device_name; \
}
DETECT_AND_ASSIGN_ACCELERATOR_COMP(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(XPU)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HIP)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(MPS)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HPU)
if (checked) {
TORCH_CHECK(
device_type, "Cannot access accelerator device when none is available.")
}
return device_type;
#undef DETECT_AND_ASSIGN_ACCELERATOR_COMP
}
bool isAccelerator(c10::DeviceType device_type) {
switch (device_type) {
case at::kCUDA:
case at::kMTIA:
case at::kXPU:
case at::kHIP:
case at::kMPS:
case at::kHPU:
case at::kPrivateUse1:
return true;
default:
return false;
}
}
// NOLINTBEGIN(bugprone-unchecked-optional-access)
c10::DeviceIndex deviceCount() {
const auto device_type = getAccelerator(false);
if (!device_type.has_value()) {
return static_cast<c10::DeviceIndex>(0);
}
c10::impl::VirtualGuardImpl impl(device_type.value());
return static_cast<c10::DeviceIndex>(impl.deviceCount());
}
void setDeviceIndex(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
impl.setDevice({device_type, device_index});
}
c10::DeviceIndex getDeviceIndex() {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
}
void setCurrentStream(c10::Stream stream) {
const auto device_type = getAccelerator(true).value();
TORCH_CHECK(
device_type == stream.device_type(),
"stream's device type ",
c10::DeviceTypeName(stream.device_type()),
" doesn't match the current accelerator ",
c10::DeviceTypeName(device_type));
c10::impl::VirtualGuardImpl impl(device_type);
impl.exchangeStream(stream);
}
c10::Stream getCurrentStream(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return impl.getStream({device_type, device_index});
}
void synchronizeDevice(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
// impl.synchronizeDevice should can be safely called from any device
impl.synchronizeDevice(device_index);
}
// NOLINTEND(bugprone-unchecked-optional-access)
} // namespace at::accelerator