Skip to content

Commit a8413b3

Browse files
authored
Feature: Moving spatial gauge for RT-TDDFT Ehrenfest dynamics (#7300)
1 parent 4b847b6 commit a8413b3

11 files changed

Lines changed: 591 additions & 99 deletions

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ ESolver_KS_LCAO_TDDFT<TR, Device>::~ESolver_KS_LCAO_TDDFT()
5454
delete td_p;
5555
}
5656
TD_info::td_vel_op = nullptr;
57+
58+
if (td_mg_ != nullptr)
59+
{
60+
delete td_mg_;
61+
td_mg_ = nullptr;
62+
}
5763
}
5864

5965
template <typename TR, typename Device>
@@ -94,6 +100,16 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::runner(UnitCell& ucell, const int istep)
94100
// 1) before_scf (electronic iteration loops)
95101
//----------------------------------------------------------------
96102
this->before_scf(ucell, istep); // From ESolver_KS_LCAO
103+
104+
// Initialize the moving spatial gauge
105+
if (use_td_moving_gauge && this->td_mg_ == nullptr)
106+
{
107+
this->td_mg_ = new module_rt::TD_MovingGauge();
108+
auto* hamilt_lcao = dynamic_cast<hamilt::HamiltLCAO<std::complex<double>, TR>*>(this->p_hamilt);
109+
const hamilt::HContainer<TR>* sR_template = hamilt_lcao->getSR();
110+
this->td_mg_->init_DR(sR_template, &ucell, &this->pv, this->two_center_bundle_.overlap_orb.get());
111+
}
112+
97113
if (PARAM.inp.td_stype == 2)
98114
{
99115
this->dmat.dm->cal_DMR_td(ucell, TD_info::cart_At);
@@ -242,6 +258,14 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
242258
const int iter,
243259
const double ethr)
244260
{
261+
// Update the moving spatial gauge
262+
if (use_td_moving_gauge)
263+
{
264+
auto* hamilt_lcao = dynamic_cast<hamilt::HamiltLCAO<std::complex<double>, TR>*>(this->p_hamilt);
265+
const hamilt::HContainer<TR>* sR_template = hamilt_lcao->getSR();
266+
this->td_mg_->update_DR(sR_template, &ucell, &this->pv, this->two_center_bundle_.overlap_orb.get());
267+
}
268+
245269
if (PARAM.inp.init_wfc == "file")
246270
{
247271
if (istep >= TD_info::estep_shift + 1)
@@ -261,7 +285,11 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
261285
GlobalV::ofs_running,
262286
PARAM.inp.propagator,
263287
use_tensor,
264-
use_lapack);
288+
use_lapack,
289+
this->td_mg_,
290+
&ucell,
291+
this->kv.kvec_d,
292+
use_td_moving_gauge);
265293
}
266294
this->weight_dm_rho(ucell);
267295
}
@@ -281,7 +309,11 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
281309
GlobalV::ofs_running,
282310
PARAM.inp.propagator,
283311
use_tensor,
284-
use_lapack);
312+
use_lapack,
313+
this->td_mg_,
314+
&ucell,
315+
this->kv.kvec_d,
316+
use_td_moving_gauge);
285317
this->weight_dm_rho(ucell);
286318
}
287319
else

source/source_esolver/esolver_ks_lcao_tddft.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
88
#include "source_lcao/module_rt/kernels/cublasmp_context.h"
99
#include "source_lcao/module_rt/td_info.h"
10+
#include "source_lcao/module_rt/td_moving_gauge.h"
1011
#include "source_lcao/module_rt/velocity_op.h"
1112

1213
namespace ModuleESolver
@@ -66,6 +67,10 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, TR>
6667

6768
TD_info* td_p = nullptr;
6869

70+
//! Moving spatial gauge for Ehrenfest dynamics, to calculate the correction term arising from the movement of basis
71+
bool use_td_moving_gauge = false;
72+
module_rt::TD_MovingGauge* td_mg_ = nullptr;
73+
6974
//! Restart flag
7075
bool restart_done = false;
7176

source/source_lcao/module_rt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if(ENABLE_LCAO)
1616
td_folding.cpp
1717
solve_propagation.cpp
1818
boundary_fix.cpp
19+
td_moving_gauge.cpp
1920
)
2021

2122
if(USE_CUDA)

source/source_lcao/module_rt/evolve_elec.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
namespace module_rt
1111
{
1212
template <typename Device>
13-
Evolve_elec<Device>::Evolve_elec(){};
13+
Evolve_elec<Device>::Evolve_elec() {};
1414
template <typename Device>
15-
Evolve_elec<Device>::~Evolve_elec(){};
15+
Evolve_elec<Device>::~Evolve_elec() {};
1616

1717
template <typename Device>
1818
ct::DeviceType Evolve_elec<Device>::ct_device_type = ct::DeviceTypeToEnum<Device>::value;
@@ -33,7 +33,11 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
3333
std::ofstream& ofs_running,
3434
const int propagator,
3535
const bool use_tensor,
36-
const bool use_lapack)
36+
const bool use_lapack,
37+
module_rt::TD_MovingGauge* td_mg,
38+
const UnitCell* ucell,
39+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
40+
const bool use_td_moving_gauge)
3741
{
3842
ModuleBase::TITLE("Evolve_elec", "solve_psi");
3943
ModuleBase::timer::start("Evolve_elec", "solve_psi");
@@ -57,6 +61,13 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
5761

5862
if (!use_tensor)
5963
{
64+
// Construct the local P_k matrix for moving spatial gauge, CPU only for now
65+
std::vector<std::complex<double>> P_k_local(para_orb.nloc, {0.0, 0.0});
66+
if (use_td_moving_gauge && td_mg != nullptr)
67+
{
68+
td_mg->get_P_k(ucell, kvec_d[ik], P_k_local.data(), para_orb.nloc, para_orb.ncol);
69+
}
70+
6071
const int len_HS_laststep = use_lapack ? nlocal * nlocal : para_orb.nloc;
6172
evolve_psi(nband,
6273
nlocal,
@@ -66,6 +77,8 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
6677
psi_laststep[0].get_pointer(),
6778
Hk_laststep.data<std::complex<double>>() + ik * len_HS_laststep,
6879
Sk_laststep.data<std::complex<double>>() + ik * len_HS_laststep,
80+
P_k_local.data(),
81+
use_td_moving_gauge,
6982
&(ekb(ik, 0)),
7083
propagator,
7184
ofs_running,

source/source_lcao/module_rt/evolve_elec.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "source_lcao/hamilt_lcao.h"
1414
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
1515
#include "source_lcao/module_rt/kernels/cublasmp_context.h"
16+
#include "source_lcao/module_rt/td_moving_gauge.h"
1617
#include "source_psi/psi.h"
1718

1819
//-----------------------------------------------------------
@@ -158,7 +159,11 @@ class Evolve_elec
158159
std::ofstream& ofs_running,
159160
const int propagator,
160161
const bool use_tensor,
161-
const bool use_lapack);
162+
const bool use_lapack,
163+
module_rt::TD_MovingGauge* td_mg,
164+
const UnitCell* ucell,
165+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
166+
const bool use_td_moving_gauge);
162167

163168
// ct_device_type = ct::DeviceType::CpuDevice or ct::DeviceType::GpuDevice
164169
static ct::DeviceType ct_device_type;

source/source_lcao/module_rt/evolve_psi.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ void evolve_psi(const int nband,
2424
std::complex<double>* psi_k_laststep,
2525
std::complex<double>* H_laststep,
2626
std::complex<double>* S_laststep,
27+
std::complex<double>* P_k,
28+
const bool use_td_moving_gauge,
2729
double* ekb,
2830
int propagator,
2931
std::ofstream& ofs_running,
@@ -85,8 +87,15 @@ void evolve_psi(const int nband,
8587
{
8688
/// @brief solve the propagation equation
8789
/// @input Stmp, Htmp, psi_k_laststep
88-
/// @output psi_k
89-
solve_propagation(pv, nband, nlocal, PARAM.inp.td_dt, Stmp, Htmp, psi_k_laststep, psi_k);
90+
/// @output psi_k
91+
if (use_td_moving_gauge)
92+
{
93+
solve_propagation(pv, nband, nlocal, PARAM.inp.td_dt, Stmp, Htmp, P_k, psi_k_laststep, psi_k);
94+
}
95+
else
96+
{
97+
solve_propagation(pv, nband, nlocal, PARAM.inp.td_dt, Stmp, Htmp, psi_k_laststep, psi_k);
98+
}
9099
}
91100

92101
// (4)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

source/source_lcao/module_rt/evolve_psi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ void evolve_psi(const int nband,
2323
std::complex<double>* psi_k_laststep,
2424
std::complex<double>* H_laststep,
2525
std::complex<double>* S_laststep,
26+
std::complex<double>* P_k,
27+
const bool use_td_moving_gauge,
2628
double* ekb,
2729
int propagator,
2830
std::ofstream& ofs_running,

0 commit comments

Comments
 (0)