Skip to content

Commit 4b847b6

Browse files
[Fix] Eliminate unnecessary vkb CPU allocation on GPU path (#7296)
* [Fix] Skip unnecessary vkb CPU allocation on GPU path in pseudopot_cell_vnl On GPU path, vkb.create(nkb, npwx) allocates CPU ComplexMatrix memory that is never used — getvnl() writes directly to GPU buffers (c_vkb/z_vkb). The only consumer of vkb.nc metadata is the leading dimension in gemm/gemv. This wastes nkb*npwx*16 bytes of CPU memory (~3.2 GB for large systems). Changes: - Add vkbnc member to store column dimension independently - Guard vkb.create() behind !use_gpu_ in init() - Replace all ppcell->vkb.nc with ppcell->vkbnc (op_pw_nl.cpp, hamilt_pw.cpp) - Add lazy-allocation guard in getgradq_vnl() for GPU Velocity path Tested: GPU build + 28/28 kernel UTs + 38/40 GPU integration tests (2 pre-existing failures: scf_bpcg, scf_out_wf) * Update source/source_pw/module_pwdft/vnl_pw_grad.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update source/source_pw/module_pwdft/vnl_pw_grad.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Revert "Update source/source_pw/module_pwdft/vnl_pw_grad.cpp" This reverts commit b71e3fe. * Revert "Update source/source_pw/module_pwdft/vnl_pw_grad.cpp" This reverts commit a0f43dc. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e882c33 commit 4b847b6

5 files changed

Lines changed: 26 additions & 10 deletions

File tree

source/source_pw/module_pwdft/hamilt_pw.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
272272
this->ppcell->nkb,
273273
&one,
274274
this->vkb,
275-
this->ppcell->vkb.nc,
275+
this->ppcell->vkbnc,
276276
psi_in,
277277
inc,
278278
&zero,
@@ -288,7 +288,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
288288
npw,
289289
&one,
290290
this->vkb,
291-
this->ppcell->vkb.nc,
291+
this->ppcell->vkbnc,
292292
psi_in,
293293
nrow,
294294
&zero,
@@ -360,7 +360,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
360360
this->ppcell->nkb,
361361
&one,
362362
this->vkb,
363-
this->ppcell->vkb.nc,
363+
this->ppcell->vkbnc,
364364
ps,
365365
inc,
366366
&one,
@@ -376,7 +376,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
376376
this->ppcell->nkb,
377377
&one,
378378
this->vkb,
379-
this->ppcell->vkb.nc,
379+
this->ppcell->vkbnc,
380380
ps,
381381
this->ppcell->nkb,
382382
&one,

source/source_pw/module_pwdft/op_pw_nl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ void Nonlocal<OperatorPW<T, Device>>::add_nonlocal_pp(T *hpsi_in, const T *becp,
172172
this->ppcell->nkb,
173173
&this->one,
174174
this->vkb,
175-
this->ppcell->vkb.nc,
175+
this->ppcell->vkbnc,
176176
this->ps,
177177
inc,
178178
&this->one,
@@ -197,7 +197,7 @@ void Nonlocal<OperatorPW<T, Device>>::add_nonlocal_pp(T *hpsi_in, const T *becp,
197197
this->ppcell->nkb,
198198
&this->one,
199199
this->vkb,
200-
this->ppcell->vkb.nc,
200+
this->ppcell->vkbnc,
201201
this->ps,
202202
npm,
203203
&this->one,
@@ -251,7 +251,7 @@ void Nonlocal<OperatorPW<T, Device>>::act(
251251
nkb,
252252
&this->one,
253253
this->vkb,
254-
this->ppcell->vkb.nc,
254+
this->ppcell->vkbnc,
255255
tmpsi_in,
256256
inc,
257257
&this->zero,
@@ -276,7 +276,7 @@ void Nonlocal<OperatorPW<T, Device>>::act(
276276
this->npw,
277277
&this->one,
278278
this->vkb,
279-
this->ppcell->vkb.nc,
279+
this->ppcell->vkbnc,
280280
tmpsi_in,
281281
max_npw,
282282
&this->zero,

source/source_pw/module_pwdft/vnl_pw.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,17 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell,
214214
// dq+4)*cell_factor;
215215
this->lmaxq = 2 * this->lmaxkb + 1;
216216
int npwx = this->wfcpw->npwk_max;
217+
this->vkbnc = npwx;
217218
if (nkb > 0 && allocate_vkb)
218219
{
219-
vkb.create(nkb, npwx);
220-
ModuleBase::Memory::record("VNL::vkb", nkb * npwx * sizeof(std::complex<double>));
220+
if (!this->use_gpu_)
221+
{
222+
vkb.create(nkb, npwx);
223+
ModuleBase::Memory::record("VNL::vkb", nkb * npwx * sizeof(std::complex<double>));
224+
}
225+
// GPU path: vkb ComplexMatrix is not allocated.
226+
// Column dimension is stored in vkbnc for gemm/gemv leading dimension.
227+
// Actual GPU buffers (c_vkb/z_vkb) are allocated below.
221228
}
222229

223230
// this->nqx = 10000; // calculted in allocate_nlpot.f90

source/source_pw/module_pwdft/vnl_pw.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ class pseudopot_cell_vnl
108108
std::complex<double>*** vkb_alpha;
109109
Structure_Factor* psf = nullptr;
110110

111+
// Column dimension of vkb matrix (= npwx), used as leading dimension in gemm/gemv.
112+
// On GPU path vkb ComplexMatrix is not allocated to save CPU memory; this stores the dimension.
113+
int vkbnc = 0;
114+
111115
// other variables
112116
std::complex<double> Cal_C(int alpha, int lu, int mu, int L, int M);
113117

source/source_pw/module_pwdft/vnl_pw_grad.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ void pseudopot_cell_vnl::getgradq_vnl(const UnitCell& ucell,
9191

9292
ModuleBase::YlmReal::grad_Ylm_Real(x1, npw, gk, ylm, dylm[0], dylm[1], dylm[2]);
9393

94+
// GPU path skips vkb allocation in init(); allocate now if needed
95+
if (this->vkb.nc == 0 && this->nkb > 0 && this->vkbnc > 0) {
96+
this->vkb.create(this->nkb, this->vkbnc);
97+
}
98+
9499
int jkb = 0;
95100
for(int it = 0;it < ucell.ntype;it++)
96101
{

0 commit comments

Comments
 (0)