@@ -579,29 +579,39 @@ inline bool hasMKLDNN() {
579579}
580580
581581inline void manual_seed (uint64_t seed) {
582+ auto gen = globalContext ().defaultGenerator (c10::DeviceType::CPU);
582583 {
583- auto gen = globalContext ().defaultGenerator (c10::DeviceType::CPU);
584584 // See Note [Acquire lock when using random generators]
585585 std::lock_guard<std::mutex> lock (gen.mutex ());
586586 gen.set_current_seed (seed);
587587 }
588588
589- const auto opt_device_type = at::getAccelerator ();
590- if (!opt_device_type.has_value ()) {
591- return ;
589+ for (const auto i : c10::irange (detail::getCUDAHooks ().deviceCount ())) {
590+ auto cuda_gen = globalContext ().defaultGenerator (
591+ Device (at::kCUDA , static_cast <c10::DeviceIndex>(i)));
592+ {
593+ // See Note [Acquire lock when using random generators]
594+ std::lock_guard<std::mutex> lock (cuda_gen.mutex ());
595+ cuda_gen.set_current_seed (seed);
596+ }
592597 }
593- const auto num_gpus = globalContext ()
594- .getAcceleratorHooksInterface (opt_device_type)
595- .deviceCount ();
596- for (const auto i : c10::irange (num_gpus)) {
597- auto gen = globalContext ().defaultGenerator (
598- Device (opt_device_type.value (), static_cast <c10::DeviceIndex>(i)));
598+
599+ for (const auto i : c10::irange (detail::getXPUHooks ().deviceCount ())) {
600+ auto xpu_gen = globalContext ().defaultGenerator (
601+ Device (at::kXPU , static_cast <c10::DeviceIndex>(i)));
599602 {
600603 // See Note [Acquire lock when using random generators]
601- std::lock_guard<std::mutex> lock (gen .mutex ());
602- gen .set_current_seed (seed);
604+ std::lock_guard<std::mutex> lock (xpu_gen .mutex ());
605+ xpu_gen .set_current_seed (seed);
603606 }
604607 }
608+
609+ if (hasMPS ()) {
610+ auto mps_gen = globalContext ().defaultGenerator (c10::DeviceType::MPS);
611+ // See Note [Acquire lock when using random generators]
612+ std::lock_guard<std::mutex> lock (mps_gen.mutex ());
613+ mps_gen.set_current_seed (seed);
614+ }
605615}
606616
607617// When the global flag `allow_tf32` is set to true, cuBLAS handles are
0 commit comments