Skip to content

Commit 73462db

Browse files
authored
Add guard for hsolver rank-deficient cases (deepmodeling#7284)
* Add guard for hsolver rand-deficient cases * Fix dimension of pw
1 parent 5271ce2 commit 73462db

5 files changed

Lines changed: 83 additions & 0 deletions

File tree

source/source_hsolver/diago_cg.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,9 @@ void DiagoCG<T, Device>::schmit_orth(const int& m, const ct::Tensor& psi, const
542542
<< std::endl;
543543
}
544544
std::cout << " in DiagoCG, psi norm = " << psi_norm << std::endl;
545+
std::cout << " This may be due to npwx < nbands: the number of plane waves is less than" << std::endl;
546+
std::cout << " the number of bands, leading to a rank-deficient problem." << std::endl;
547+
std::cout << " Please increase ecutwfc or reduce nbands." << std::endl;
545548
std::cout << " If you use GNU compiler, it may due to the zdotc is unavailable." << std::endl;
546549
ModuleBase::WARNING_QUIT("schmit_orth", "psi_norm <= 0.0");
547550
}

source/source_hsolver/diago_dav_subspace.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "source_base/module_device/device.h"
66
#include "source_base/timer.h"
7+
#include "source_base/tool_quit.h"
78
#include "source_base/kernels/math_kernel_op.h"
89
#include "source_base/kernels/dsp/dsp_connector.h"
910
// #include "source_base/module_container/ATen/kernels/lapack.h"
@@ -425,17 +426,54 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
425426
nbase,
426427
notconv,
427428
psi_norm);
429+
430+
// Check for zero norms (GPU path: copy norms from device to host)
431+
Real* psi_norm_host = nullptr;
432+
resmem_real_h_op()(psi_norm_host, notconv);
433+
syncmem_var_d2h_op()(psi_norm_host, psi_norm, notconv);
434+
for (int i = 0; i < notconv; i++)
435+
{
436+
if (psi_norm_host[i] <= 1.0e-12)
437+
{
438+
std::cout << "Diago_DavSubspace::cal_grad: psi_norm <= 0 for band " << i << std::endl;
439+
std::cout << "This may be due to npwx < nbands: the number of plane waves is less than" << std::endl;
440+
std::cout << "the number of bands, leading to a rank-deficient problem." << std::endl;
441+
std::cout << "Please increase ecutwfc or reduce nbands." << std::endl;
442+
delmem_real_h_op()(psi_norm_host);
443+
delmem_real_op()(psi_norm);
444+
ModuleBase::WARNING_QUIT("cal_grad", "psi_norm <= 0");
445+
}
446+
}
447+
delmem_real_h_op()(psi_norm_host);
428448
delmem_real_op()(psi_norm);
429449
}
430450
else
431451
#endif
432452
{
433453
Real* psi_norm = nullptr;
454+
resmem_real_h_op()(psi_norm, notconv);
455+
setmem_real_h_op()(psi_norm, 0.0, notconv);
456+
434457
normalize_op<T, Device>()(this->dim,
435458
psi_iter,
436459
nbase,
437460
notconv,
438461
psi_norm);
462+
463+
// Check for zero norms (CPU path)
464+
for (int i = 0; i < notconv; i++)
465+
{
466+
if (psi_norm[i] <= 1.0e-12)
467+
{
468+
std::cout << "Diago_DavSubspace::cal_grad: psi_norm <= 0 for band " << i << std::endl;
469+
std::cout << "This may be due to npwx < nbands: the number of plane waves is less than" << std::endl;
470+
std::cout << "the number of bands, leading to a rank-deficient problem." << std::endl;
471+
std::cout << "Please increase ecutwfc or reduce nbands." << std::endl;
472+
delmem_real_h_op()(psi_norm);
473+
ModuleBase::WARNING_QUIT("cal_grad", "psi_norm <= 0");
474+
}
475+
}
476+
delmem_real_h_op()(psi_norm);
439477
}
440478

441479
// update hpsi[:, nbase:nbase+notconv]

source/source_hsolver/diago_david.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,9 @@ void DiagoDavid<T, Device>::SchmidtOrth(const int& dim,
932932
if (psi_norm < 1.0e-12)
933933
{
934934
std::cout << "DiagoDavid::SchmidtOrth:aborted for psi_norm <1.0e-12" << std::endl;
935+
std::cout << "This may be due to npwx < nbands: the number of plane waves is less than" << std::endl;
936+
std::cout << "the number of bands, leading to a rank-deficient problem." << std::endl;
937+
std::cout << "Please increase ecutwfc or reduce nbands." << std::endl;
935938
std::cout << "nband = " << nband << std::endl;
936939
std::cout << "m = " << m << std::endl;
937940
exit(0);

source/source_hsolver/hsolver_pw.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,29 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
241241

242242
const int cur_nbasis = psi.get_current_nbas();
243243

244+
// Check for rank deficiency: the total number of plane waves (summed across
245+
// all MPI processes) must be >= nbands. When npw_total < nbands, the basis is
246+
// rank-deficient, leading to psi_norm <= 0 during Schmidt orthogonalization.
247+
// Note: we sum cur_nbasis (local npw for this k-point) across the pool because
248+
// psi.get_nbasis() gives the local storage dimension, not the total.
249+
const int nbands = psi.get_nbands();
250+
int npw_total = cur_nbasis;
251+
#ifdef __MPI
252+
if (this->nproc_in_pool > 1)
253+
{
254+
MPI_Allreduce(&cur_nbasis, &npw_total, 1, MPI_INT, MPI_SUM, POOL_WORLD);
255+
}
256+
#endif
257+
if (npw_total < nbands)
258+
{
259+
const std::string msg = "npw_total < nbands (" + std::to_string(npw_total) + " < " + std::to_string(nbands)
260+
+ "): the total number of plane waves across all MPI processes "
261+
+ "is less than the number of bands, "
262+
+ "which leads to a rank-deficient problem. "
263+
+ "Please increase ecutwfc or reduce nbands.";
264+
ModuleBase::WARNING_QUIT("HSolverPW::hamiltSolvePsiK", msg);
265+
}
266+
244267
// Shared matrix-blockvector operators used by all iterative solvers.
245268
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
246269
auto psi_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);

source/source_hsolver/test/test_hsolver_pw.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,19 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) {
390390
EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 0.0);
391391
EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 0.0);
392392
}
393+
394+
// Test that the program exits with an error when npwx < nbands,
395+
// which would cause rank deficiency and psi_norm <= 0 during diagonalization.
396+
TEST_F(TestHSolverPW, NpwxLessThanNbandsDeath)
397+
{
398+
// Create psi with 5 bands but only 3 basis functions -> npwx=3 < nbands=5
399+
psi_test_cd.resize(1, 5, 3);
400+
std::vector<double> precond(3, 0.0);
401+
std::vector<double> eigenvalues(5, 0.0);
402+
// Expect death from WARNING_QUIT due to npwx < nbands
403+
EXPECT_EXIT(
404+
hs_d.hamiltSolvePsiK(&hamilt_test_d, psi_test_cd, precond, eigenvalues.data(), 1),
405+
::testing::ExitedWithCode(1),
406+
".*"
407+
);
408+
}

0 commit comments

Comments
 (0)