Skip to content

Commit db2a309

Browse files
Revert "Generalize at::manual_seed for all accelerators (pytorch#144370)"
This reverts commit eeb5739. Reverted pytorch#144370 on behalf of https://github.com/clee2000 due to broke internal tests D68023262, probably the same problem as noted in the issue this PR is mentioned above ([comment](pytorch#144368 (comment)))
1 parent 9ec8ece commit db2a309

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

Diff for: aten/src/ATen/Context.h

+22-12
Original file line numberDiff line numberDiff line change
@@ -579,29 +579,39 @@ inline bool hasMKLDNN() {
579579
}
580580

581581
inline void manual_seed(uint64_t seed) {
582+
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
582583
{
583-
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
584584
// See Note [Acquire lock when using random generators]
585585
std::lock_guard<std::mutex> lock(gen.mutex());
586586
gen.set_current_seed(seed);
587587
}
588588

589-
const auto opt_device_type = at::getAccelerator();
590-
if (!opt_device_type.has_value()) {
591-
return;
589+
for (const auto i : c10::irange(detail::getCUDAHooks().deviceCount())) {
590+
auto cuda_gen = globalContext().defaultGenerator(
591+
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
592+
{
593+
// See Note [Acquire lock when using random generators]
594+
std::lock_guard<std::mutex> lock(cuda_gen.mutex());
595+
cuda_gen.set_current_seed(seed);
596+
}
592597
}
593-
const auto num_gpus = globalContext()
594-
.getAcceleratorHooksInterface(opt_device_type)
595-
.deviceCount();
596-
for (const auto i : c10::irange(num_gpus)) {
597-
auto gen = globalContext().defaultGenerator(
598-
Device(opt_device_type.value(), static_cast<c10::DeviceIndex>(i)));
598+
599+
for (const auto i : c10::irange(detail::getXPUHooks().deviceCount())) {
600+
auto xpu_gen = globalContext().defaultGenerator(
601+
Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
599602
{
600603
// See Note [Acquire lock when using random generators]
601-
std::lock_guard<std::mutex> lock(gen.mutex());
602-
gen.set_current_seed(seed);
604+
std::lock_guard<std::mutex> lock(xpu_gen.mutex());
605+
xpu_gen.set_current_seed(seed);
603606
}
604607
}
608+
609+
if (hasMPS()) {
610+
auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
611+
// See Note [Acquire lock when using random generators]
612+
std::lock_guard<std::mutex> lock(mps_gen.mutex());
613+
mps_gen.set_current_seed(seed);
614+
}
605615
}
606616

607617
// When the global flag `allow_tf32` is set to true, cuBLAS handles are

0 commit comments

Comments
 (0)