Skip to content

Commit 633bd07

Browse files
leizhenyuanpytorchmergebot
authored andcommitted
Integrate xpu into torch.Generator and torch.seed (pytorch#109866)
Integrate torch.xpu.Generator into torch.Generator Integrate torch.xpu.seed into torch.seed Pull Request resolved: pytorch#109866 Approved by: https://github.com/ezyang
1 parent 0511df0 commit 633bd07

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

aten/src/ATen/Context.h

+18-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class TORCH_API Context {
4444
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
4545
} else if (device_type == at::kMPS) {
4646
return at::detail::getMPSHooks().getDefaultMPSGenerator();
47+
} else if (device_type == at::kXPU) {
48+
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
4749
} else if (device_type == at::kPrivateUse1) {
4850
return at::GetPrivateUse1HooksInterface()->getDefaultGenerator(
4951
device.index());
@@ -445,9 +447,9 @@ static inline void manual_seed(uint64_t seed) {
445447
}
446448
// NB: Sometimes we build with CUDA, but we don't have any GPUs
447449
// available. In that case, we must not seed CUDA; it will fail!
448-
const auto num_gpus = detail::getCUDAHooks().getNumGPUs();
449-
if (hasCUDA() && num_gpus > 0) {
450-
for (const auto i : c10::irange(num_gpus)) {
450+
const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
451+
if (hasCUDA() && cuda_num_gpus > 0) {
452+
for (const auto i : c10::irange(cuda_num_gpus)) {
451453
auto cuda_gen = globalContext().defaultGenerator(
452454
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
453455
{
@@ -458,6 +460,19 @@ static inline void manual_seed(uint64_t seed) {
458460
}
459461
}
460462

463+
const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
464+
if (hasXPU() && xpu_num_gpus > 0) {
465+
for (const auto i : c10::irange(xpu_num_gpus)) {
466+
auto xpu_gen = globalContext().defaultGenerator(
467+
Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
468+
{
469+
// See Note [Acquire lock when using random generators]
470+
std::lock_guard<std::mutex> lock(xpu_gen.mutex());
471+
xpu_gen.set_current_seed(seed);
472+
}
473+
}
474+
}
475+
461476
if (hasMPS()) {
462477
auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
463478
// See Note [Acquire lock when using random generators]

aten/src/ATen/detail/XPUHooksInterface.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <c10/core/Device.h>
44
#include <c10/util/Exception.h>
5-
5+
#include <ATen/core/Generator.h>
66
#include <c10/util/Registry.h>
77

88
#include <cstddef>
@@ -55,7 +55,7 @@ struct TORCH_API XPUHooksInterface {
5555
false,
5656
"Cannot get XPU device without Intel Extension for Pytorch. ",
5757
XPU_HELP);
58-
}
58+
};
5959

6060
virtual DLDevice_& getDLPackDeviceFromATenDevice(
6161
DLDevice_& dl_device,
@@ -65,6 +65,20 @@ struct TORCH_API XPUHooksInterface {
6565
false,
6666
"Cannot get XPU DL device without Intel Extension for Pytorch. ",
6767
XPU_HELP);
68+
};
69+
70+
virtual Generator getXPUGenerator(DeviceIndex device_index = -1) const {
71+
(void)device_index; // Suppress unused variable warning
72+
TORCH_CHECK(false, "Cannot get XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
73+
}
74+
75+
const Generator& getDefaultXPUGenerator(DeviceIndex device_index = -1) const {
76+
(void)device_index; // Suppress unused variable warning
77+
TORCH_CHECK(false, "Cannot get default XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
78+
}
79+
80+
virtual int getNumGPUs() const {
81+
return 0;
6882
}
6983
};
7084

torch/csrc/Generator.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <structmember.h>
66

77
#include <ATen/core/GeneratorForPrivateuseone.h>
8+
#include <ATen/detail/XPUHooksInterface.h>
89
#include <torch/csrc/Device.h>
910
#include <torch/csrc/Exceptions.h>
1011
#include <torch/csrc/THP.h>
@@ -71,7 +72,9 @@ static PyObject* THPGenerator_pynew(
7172
self->cdata = make_generator<MPSGeneratorImpl>();
7273
}
7374
#endif
74-
else if (device.type() == at::kPrivateUse1) {
75+
else if (device.type() == at::kXPU) {
76+
self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index());
77+
} else if (device.type() == at::kPrivateUse1) {
7578
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
7679
} else {
7780
AT_ERROR(

torch/random.py

+6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def manual_seed(seed) -> torch._C.Generator:
4343
if not torch.mps._is_in_bad_fork():
4444
torch.mps.manual_seed(seed)
4545

46+
if hasattr(torch, 'xpu') and not torch.xpu._is_in_bad_fork():
47+
torch.xpu.manual_seed(seed)
48+
4649
_seed_custom_device(seed)
4750

4851
return default_generator.manual_seed(seed)
@@ -62,6 +65,9 @@ def seed() -> int:
6265
if not torch.mps._is_in_bad_fork():
6366
torch.mps.manual_seed(seed)
6467

68+
if hasattr(torch, 'xpu') and not torch.xpu._is_in_bad_fork():
69+
torch.xpu.manual_seed(seed)
70+
6571
_seed_custom_device(seed)
6672

6773
return seed

0 commit comments

Comments
 (0)