@@ -579,29 +579,39 @@ inline bool hasMKLDNN() {
579
579
}
580
580
581
581
inline void manual_seed (uint64_t seed) {
582
+ auto gen = globalContext ().defaultGenerator (c10::DeviceType::CPU);
582
583
{
583
- auto gen = globalContext ().defaultGenerator (c10::DeviceType::CPU);
584
584
// See Note [Acquire lock when using random generators]
585
585
std::lock_guard<std::mutex> lock (gen.mutex ());
586
586
gen.set_current_seed (seed);
587
587
}
588
588
589
- const auto opt_device_type = at::getAccelerator ();
590
- if (!opt_device_type.has_value ()) {
591
- return ;
589
+ for (const auto i : c10::irange (detail::getCUDAHooks ().deviceCount ())) {
590
+ auto cuda_gen = globalContext ().defaultGenerator (
591
+ Device (at::kCUDA , static_cast <c10::DeviceIndex>(i)));
592
+ {
593
+ // See Note [Acquire lock when using random generators]
594
+ std::lock_guard<std::mutex> lock (cuda_gen.mutex ());
595
+ cuda_gen.set_current_seed (seed);
596
+ }
592
597
}
593
- const auto num_gpus = globalContext ()
594
- .getAcceleratorHooksInterface (opt_device_type)
595
- .deviceCount ();
596
- for (const auto i : c10::irange (num_gpus)) {
597
- auto gen = globalContext ().defaultGenerator (
598
- Device (opt_device_type.value (), static_cast <c10::DeviceIndex>(i)));
598
+
599
+ for (const auto i : c10::irange (detail::getXPUHooks ().deviceCount ())) {
600
+ auto xpu_gen = globalContext ().defaultGenerator (
601
+ Device (at::kXPU , static_cast <c10::DeviceIndex>(i)));
599
602
{
600
603
// See Note [Acquire lock when using random generators]
601
- std::lock_guard<std::mutex> lock (gen .mutex ());
602
- gen .set_current_seed (seed);
604
+ std::lock_guard<std::mutex> lock (xpu_gen .mutex ());
605
+ xpu_gen .set_current_seed (seed);
603
606
}
604
607
}
608
+
609
+ if (hasMPS ()) {
610
+ auto mps_gen = globalContext ().defaultGenerator (c10::DeviceType::MPS);
611
+ // See Note [Acquire lock when using random generators]
612
+ std::lock_guard<std::mutex> lock (mps_gen.mutex ());
613
+ mps_gen.set_current_seed (seed);
614
+ }
605
615
}
606
616
607
617
// When the global flag `allow_tf32` is set to true, cuBLAS handles are
0 commit comments