Skip to content

Commit c500428

Browse files
committed
Fix(PW): correct DFT+U force/stress spinor formula for npol=2
The vu array stores coefficients in order [H_↑↑, H_↓↑, H_↑↓, H_↓↓], different from deeq_nc's [H_↑↑, H_↑↓, H_↓↑, H_↓↓]. The force/stress formula must match: ps[2]*dbb1 + ps[1]*dbb2 for off-diagonal terms. Fixed in CPU, CUDA, and ROCm kernels for both force and stress.
1 parent f7ce7d0 commit c500428

6 files changed

Lines changed: 6 additions & 6 deletions

File tree

source/source_pw/module_pwdft/kernels/cuda/force_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ __global__ void cal_force_onsite(int wg_nc,
375375
const thrust::complex<FPTYPE> dbb1 = conj(dbecp[inkb0]) * becp[inkb2 + nkb];
376376
const thrust::complex<FPTYPE> dbb2 = conj(dbecp[inkb0 + nkb]) * becp[inkb2];
377377
const thrust::complex<FPTYPE> dbb3 = conj(dbecp[inkb0 + nkb]) * becp[inkb2 + nkb];
378-
const FPTYPE tmp = -fac * (ps[0] * dbb0 + ps[1] * dbb1 + ps[2] * dbb2 + ps[3] * dbb3).real();
378+
const FPTYPE tmp = -fac * (ps[0] * dbb0 + ps[2] * dbb1 + ps[1] * dbb2 + ps[3] * dbb3).real();
379379
atomicAdd(force + iat * forcenl_nc + ipol, tmp);
380380
}
381381
}

source/source_pw/module_pwdft/kernels/cuda/stress_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ __global__ void cal_stress_onsite(
977977
const thrust::complex<FPTYPE> dbb1 = conj(dbecp[inkb1]) * becp[inkb2 + nkb];
978978
const thrust::complex<FPTYPE> dbb2 = conj(dbecp[inkb1 + nkb]) * becp[inkb2];
979979
const thrust::complex<FPTYPE> dbb3 = conj(dbecp[inkb1 + nkb]) * becp[inkb2 + nkb];
980-
stress_var -= fac * (ps[0] * dbb0 + ps[1] * dbb1 + ps[2] * dbb2 + ps[3] * dbb3).real();
980+
stress_var -= fac * (ps[0] * dbb0 + ps[2] * dbb1 + ps[1] * dbb2 + ps[3] * dbb3).real();
981981
}
982982
++iat;
983983
sum+=nprojs;

source/source_pw/module_pwdft/kernels/force_op.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ struct cal_force_nl_op<FPTYPE, base_device::DEVICE_CPU>
352352
const std::complex<FPTYPE> dbb2 = conj(dbecp[index0 + nkb]) * becp[index1];
353353
const std::complex<FPTYPE> dbb3 = conj(dbecp[index0 + nkb]) * becp[index1 + nkb];
354354

355-
local_force[iforce] -= fac * (ps[0] * dbb0 + ps[1] * dbb1 + ps[2] * dbb2 + ps[3] * dbb3).real();
355+
local_force[iforce] -= fac * (ps[0] * dbb0 + ps[2] * dbb1 + ps[1] * dbb2 + ps[3] * dbb3).real();
356356
}
357357
}
358358
else if(npol == 1)

source/source_pw/module_pwdft/kernels/rocm/force_op.hip.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ __global__ void cal_force_onsite(int wg_nc,
361361
const thrust::complex<FPTYPE> dbb1 = conj(dbecp[inkb0]) * becp[inkb2 + nkb];
362362
const thrust::complex<FPTYPE> dbb2 = conj(dbecp[inkb0 + nkb]) * becp[inkb2];
363363
const thrust::complex<FPTYPE> dbb3 = conj(dbecp[inkb0 + nkb]) * becp[inkb2 + nkb];
364-
const FPTYPE tmp = -fac * (ps[0] * dbb0 + ps[1] * dbb1 + ps[2] * dbb2 + ps[3] * dbb3).real();
364+
const FPTYPE tmp = -fac * (ps[0] * dbb0 + ps[2] * dbb1 + ps[1] * dbb2 + ps[3] * dbb3).real();
365365
atomicAdd(force + iat * forcenl_nc + ipol, tmp);
366366
}
367367
}

source/source_pw/module_pwdft/kernels/rocm/stress_op.hip.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ __global__ void cal_stress_onsite(
966966
const thrust::complex<FPTYPE> dbb1 = conj(dbecp[inkb1]) * becp[inkb2 + nkb];
967967
const thrust::complex<FPTYPE> dbb2 = conj(dbecp[inkb1 + nkb]) * becp[inkb2];
968968
const thrust::complex<FPTYPE> dbb3 = conj(dbecp[inkb1 + nkb]) * becp[inkb2 + nkb];
969-
stress_var -= fac * (ps[0] * dbb0 + ps[1] * dbb1 + ps[2] * dbb2 + ps[3] * dbb3).real();
969+
stress_var -= fac * (ps[0] * dbb0 + ps[2] * dbb1 + ps[1] * dbb2 + ps[3] * dbb3).real();
970970
}
971971
++iat;
972972
sum+=nprojs;

source/source_pw/module_pwdft/kernels/stress_op.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ struct cal_stress_nl_op<FPTYPE, base_device::DEVICE_CPU>
318318
const std::complex<FPTYPE> dbb1 = conj(dbecp[inkb1]) * becp[nkb + inkb2];
319319
const std::complex<FPTYPE> dbb2 = conj(dbecp[nkb + inkb1]) * becp[inkb2];
320320
const std::complex<FPTYPE> dbb3 = conj(dbecp[nkb + inkb1]) * becp[nkb + inkb2];
321-
local_stress -= fac * (ps[0] * dbb0 + ps[1] * dbb1 + ps[2] * dbb2 + ps[3] * dbb3).real();
321+
local_stress -= fac * (ps[0] * dbb0 + ps[2] * dbb1 + ps[1] * dbb2 + ps[3] * dbb3).real();
322322
}
323323
} // end ip
324324
break;

0 commit comments

Comments
 (0)