|
4 | 4 |
|
5 | 5 | #include "source_base/module_device/device.h" |
6 | 6 | #include "source_base/timer.h" |
| 7 | +#include "source_base/tool_quit.h" |
7 | 8 | #include "source_base/kernels/math_kernel_op.h" |
8 | 9 | #include "source_base/kernels/dsp/dsp_connector.h" |
9 | 10 | // #include "source_base/module_container/ATen/kernels/lapack.h" |
@@ -425,17 +426,54 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func, |
425 | 426 | nbase, |
426 | 427 | notconv, |
427 | 428 | psi_norm); |
| 429 | + |
| 430 | + // Check for zero norms (GPU path: copy norms from device to host) |
| 431 | + Real* psi_norm_host = nullptr; |
| 432 | + resmem_real_h_op()(psi_norm_host, notconv); |
| 433 | + syncmem_var_d2h_op()(psi_norm_host, psi_norm, notconv); |
| 434 | + for (int i = 0; i < notconv; i++) |
| 435 | + { |
| 436 | + if (psi_norm_host[i] <= 1.0e-12) |
| 437 | + { |
| 438 | + std::cout << "Diago_DavSubspace::cal_grad: psi_norm <= 0 for band " << i << std::endl; |
| 439 | + std::cout << "This may be due to npwx < nbands: the number of plane waves is less than" << std::endl; |
| 440 | + std::cout << "the number of bands, leading to a rank-deficient problem." << std::endl; |
| 441 | + std::cout << "Please increase ecutwfc or reduce nbands." << std::endl; |
| 442 | + delmem_real_h_op()(psi_norm_host); |
| 443 | + delmem_real_op()(psi_norm); |
| 444 | + ModuleBase::WARNING_QUIT("cal_grad", "psi_norm <= 0"); |
| 445 | + } |
| 446 | + } |
| 447 | + delmem_real_h_op()(psi_norm_host); |
428 | 448 | delmem_real_op()(psi_norm); |
429 | 449 | } |
430 | 450 | else |
431 | 451 | #endif |
432 | 452 | { |
433 | 453 | Real* psi_norm = nullptr; |
| 454 | + resmem_real_h_op()(psi_norm, notconv); |
| 455 | + setmem_real_h_op()(psi_norm, 0.0, notconv); |
| 456 | + |
434 | 457 | normalize_op<T, Device>()(this->dim, |
435 | 458 | psi_iter, |
436 | 459 | nbase, |
437 | 460 | notconv, |
438 | 461 | psi_norm); |
| 462 | + |
| 463 | + // Check for zero norms (CPU path) |
| 464 | + for (int i = 0; i < notconv; i++) |
| 465 | + { |
| 466 | + if (psi_norm[i] <= 1.0e-12) |
| 467 | + { |
| 468 | + std::cout << "Diago_DavSubspace::cal_grad: psi_norm <= 0 for band " << i << std::endl; |
| 469 | + std::cout << "This may be due to npwx < nbands: the number of plane waves is less than" << std::endl; |
| 470 | + std::cout << "the number of bands, leading to a rank-deficient problem." << std::endl; |
| 471 | + std::cout << "Please increase ecutwfc or reduce nbands." << std::endl; |
| 472 | + delmem_real_h_op()(psi_norm); |
| 473 | + ModuleBase::WARNING_QUIT("cal_grad", "psi_norm <= 0"); |
| 474 | + } |
| 475 | + } |
| 476 | + delmem_real_h_op()(psi_norm); |
439 | 477 | } |
440 | 478 |
|
441 | 479 | // update hpsi[:, nbase:nbase+notconv] |
|
0 commit comments