Skip to content

Commit b80ecc4

Browse files
Revert "Fix poision child process issue when call getAccelerator() (pytorch#144368)"
This reverts commit 2583d83. Reverted pytorch#144368 on behalf of https://github.com/clee2000 due to broke internal tests D68023262, probably the same problem as noted in the issue this PR is mentioned above ([comment](pytorch#144368 (comment)))
1 parent db2a309 commit b80ecc4

File tree

7 files changed

+29
-47
lines changed

7 files changed

+29
-47
lines changed

aten/src/ATen/Context.h

+24-31
Original file line numberDiff line numberDiff line change
@@ -484,22 +484,8 @@ inline DeprecatedTypeProperties& MPS(ScalarType s) {
484484
Backend::MPS, s);
485485
}
486486

487-
// Note [at::hasXXX() vs. at::globalContext().hasXXX()]
488-
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
489-
//
490-
// The purpose of `at::hasXXX()` is to check if device XXX is available at
491-
// runtime. In contrast, `at::globalContext().hasXXX()` determines whether
492-
// support for device XXX was included in the PyTorch build (enabled at compile
493-
// time) or if a device XXX extension has already been registered with PyTorch.
494-
//
495-
// `at::globalContext().hasXXX()` is often used in functions like
496-
// `getAccelerator()` instead of `at::hasXXX()` to avoid initializing the
497-
// runtime for device XXX (which can poison child processes while detecting the
498-
// current accelerator).
499-
500487
inline bool hasCUDA() {
501-
return globalContext().hasCUDA() &&
502-
(detail::getCUDAHooks().deviceCount() > 0);
488+
return globalContext().hasCUDA();
503489
}
504490

505491
inline bool hasMTIA() {
@@ -527,7 +513,7 @@ inline bool hasMAIA() {
527513
}
528514

529515
inline bool hasXPU() {
530-
return globalContext().hasXPU() && (detail::getXPUHooks().deviceCount() > 0);
516+
return globalContext().hasXPU();
531517
}
532518

533519
inline bool hasHPU() {
@@ -585,24 +571,31 @@ inline void manual_seed(uint64_t seed) {
585571
std::lock_guard<std::mutex> lock(gen.mutex());
586572
gen.set_current_seed(seed);
587573
}
588-
589-
for (const auto i : c10::irange(detail::getCUDAHooks().deviceCount())) {
590-
auto cuda_gen = globalContext().defaultGenerator(
591-
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
592-
{
593-
// See Note [Acquire lock when using random generators]
594-
std::lock_guard<std::mutex> lock(cuda_gen.mutex());
595-
cuda_gen.set_current_seed(seed);
574+
// NB: Sometimes we build with CUDA, but we don't have any GPUs
575+
// available. In that case, we must not seed CUDA; it will fail!
576+
const auto cuda_num_gpus = detail::getCUDAHooks().deviceCount();
577+
if (hasCUDA() && cuda_num_gpus > 0) {
578+
for (const auto i : c10::irange(cuda_num_gpus)) {
579+
auto cuda_gen = globalContext().defaultGenerator(
580+
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
581+
{
582+
// See Note [Acquire lock when using random generators]
583+
std::lock_guard<std::mutex> lock(cuda_gen.mutex());
584+
cuda_gen.set_current_seed(seed);
585+
}
596586
}
597587
}
598588

599-
for (const auto i : c10::irange(detail::getXPUHooks().deviceCount())) {
600-
auto xpu_gen = globalContext().defaultGenerator(
601-
Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
602-
{
603-
// See Note [Acquire lock when using random generators]
604-
std::lock_guard<std::mutex> lock(xpu_gen.mutex());
605-
xpu_gen.set_current_seed(seed);
589+
const auto xpu_num_gpus = detail::getXPUHooks().deviceCount();
590+
if (hasXPU() && xpu_num_gpus) {
591+
for (const auto i : c10::irange(xpu_num_gpus)) {
592+
auto xpu_gen = globalContext().defaultGenerator(
593+
Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
594+
{
595+
// See Note [Acquire lock when using random generators]
596+
std::lock_guard<std::mutex> lock(xpu_gen.mutex());
597+
xpu_gen.set_current_seed(seed);
598+
}
606599
}
607600
}
608601

aten/src/ATen/DeviceAccelerator.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace at::accelerator {
66

77
std::optional<c10::DeviceType> getAccelerator(bool checked) {
88
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
9-
if (at::globalContext().has##device_name()) { \
9+
if (at::has##device_name()) { \
1010
device_type = k##device_name; \
1111
TORCH_CHECK( \
1212
!is_accelerator_detected, \

aten/src/ATen/DeviceAccelerator.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace at::accelerator {
1818
// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
1919

2020
// Ensures that only one accelerator is available (at
21-
// *compile time* if possible) and return it.
21+
// compile time if possible) and return it.
2222
// When checked is true, the returned optional always has a value.
2323
TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
2424

aten/src/ATen/cuda/detail/CUDAHooks.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,7 @@ bool CUDAHooks::isPinnedPtr(const void* data) const {
148148
}
149149

150150
bool CUDAHooks::hasCUDA() const {
151-
// This function determines if CUDA is built into PyTorch. It helps avoid
152-
// initializing the CUDA runtime (which can poison child processes) while
153-
// detecting the current accelerator.
154-
return true;
151+
return at::cuda::is_available();
155152
}
156153

157154
bool CUDAHooks::hasMAGMA() const {

test/cpp/jit/test_misc.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2392,7 +2392,7 @@ TEST(FuturesTest, Basic) {
23922392
// Sparse CUDA tensor test
23932393
TEST(FutureTest, SparseTensor) {
23942394
// Skip test if CUDA is not available.
2395-
bool has_cuda = at::hasCUDA();
2395+
bool has_cuda = at::globalContext().hasCUDA();
23962396
if (!has_cuda) {
23972397
LOG(INFO) << "CUDA not available, skipping test";
23982398
}

test/test_cuda.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -3345,14 +3345,7 @@ def check_output(script: str) -> str:
33453345
VISIBLE_DEVICES = (
33463346
"HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES"
33473347
)
3348-
test_script = f"""\
3349-
import os
3350-
import torch
3351-
os.environ['{VISIBLE_DEVICES}']='32'
3352-
3353-
torch.device(0) # see https://github.com/pytorch/pytorch/issues/144152
3354-
print(torch.cuda.device_count())
3355-
"""
3348+
test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())"
33563349
rc = check_output(test_script)
33573350
self.assertEqual(rc, "0")
33583351
if not TEST_WITH_ROCM:

test/test_xpu.py

-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def test_multi_process(model, input):
192192
torch.nn.ReLU(),
193193
torch.nn.MaxPool2d(2, 2),
194194
)
195-
torch.device(0) # see https://github.com/pytorch/pytorch/issues/144152
196195
test_multi_process(model, input)
197196
test_multi_process(model, input)
198197
print(torch.xpu.device_count())

0 commit comments

Comments
 (0)