@@ -484,22 +484,8 @@ inline DeprecatedTypeProperties& MPS(ScalarType s) {
484
484
Backend::MPS, s);
485
485
}
486
486
487
- // Note [at::hasXXX() vs. at::globalContext().hasXXX()]
488
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
489
- //
490
- // The purpose of `at::hasXXX()` is to check if device XXX is available at
491
- // runtime. In contrast, `at::globalContext().hasXXX()` determines whether
492
- // support for device XXX was included in the PyTorch build (enabled at compile
493
- // time) or if a device XXX extension has already been registered with PyTorch.
494
- //
495
- // `at::globalContext().hasXXX()` is often used in functions like
496
- // `getAccelerator()` instead of `at::hasXXX()` to avoid initializing the
497
- // runtime for device XXX (which can poison child processes while detecting the
498
- // current accelerator).
499
-
500
487
inline bool hasCUDA () {
501
- return globalContext ().hasCUDA () &&
502
- (detail::getCUDAHooks ().deviceCount () > 0 );
488
+ return globalContext ().hasCUDA ();
503
489
}
504
490
505
491
inline bool hasMTIA () {
@@ -527,7 +513,7 @@ inline bool hasMAIA() {
527
513
}
528
514
529
515
inline bool hasXPU () {
530
- return globalContext ().hasXPU () && ( detail::getXPUHooks (). deviceCount () > 0 ) ;
516
+ return globalContext ().hasXPU ();
531
517
}
532
518
533
519
inline bool hasHPU () {
@@ -585,24 +571,31 @@ inline void manual_seed(uint64_t seed) {
585
571
std::lock_guard<std::mutex> lock (gen.mutex ());
586
572
gen.set_current_seed (seed);
587
573
}
588
-
589
- for (const auto i : c10::irange (detail::getCUDAHooks ().deviceCount ())) {
590
- auto cuda_gen = globalContext ().defaultGenerator (
591
- Device (at::kCUDA , static_cast <c10::DeviceIndex>(i)));
592
- {
593
- // See Note [Acquire lock when using random generators]
594
- std::lock_guard<std::mutex> lock (cuda_gen.mutex ());
595
- cuda_gen.set_current_seed (seed);
574
+ // NB: Sometimes we build with CUDA, but we don't have any GPUs
575
+ // available. In that case, we must not seed CUDA; it will fail!
576
+ const auto cuda_num_gpus = detail::getCUDAHooks ().deviceCount ();
577
+ if (hasCUDA () && cuda_num_gpus > 0 ) {
578
+ for (const auto i : c10::irange (cuda_num_gpus)) {
579
+ auto cuda_gen = globalContext ().defaultGenerator (
580
+ Device (at::kCUDA , static_cast <c10::DeviceIndex>(i)));
581
+ {
582
+ // See Note [Acquire lock when using random generators]
583
+ std::lock_guard<std::mutex> lock (cuda_gen.mutex ());
584
+ cuda_gen.set_current_seed (seed);
585
+ }
596
586
}
597
587
}
598
588
599
- for (const auto i : c10::irange (detail::getXPUHooks ().deviceCount ())) {
600
- auto xpu_gen = globalContext ().defaultGenerator (
601
- Device (at::kXPU , static_cast <c10::DeviceIndex>(i)));
602
- {
603
- // See Note [Acquire lock when using random generators]
604
- std::lock_guard<std::mutex> lock (xpu_gen.mutex ());
605
- xpu_gen.set_current_seed (seed);
589
+ const auto xpu_num_gpus = detail::getXPUHooks ().deviceCount ();
590
+ if (hasXPU () && xpu_num_gpus) {
591
+ for (const auto i : c10::irange (xpu_num_gpus)) {
592
+ auto xpu_gen = globalContext ().defaultGenerator (
593
+ Device (at::kXPU , static_cast <c10::DeviceIndex>(i)));
594
+ {
595
+ // See Note [Acquire lock when using random generators]
596
+ std::lock_guard<std::mutex> lock (xpu_gen.mutex ());
597
+ xpu_gen.set_current_seed (seed);
598
+ }
606
599
}
607
600
}
608
601
0 commit comments