forked from deepmodeling/abacus-develop
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhsolver_pw.h
More file actions
114 lines (90 loc) · 3.77 KB
/
hsolver_pw.h
File metadata and controls
114 lines (90 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#ifndef HSOLVERPW_H
#define HSOLVERPW_H
#include "source_estate/elecstate.h"
#include "source_hamilt/hamilt.h"
#include "source_base/macros.h"
#include "source_basis/module_pw/pw_basis_k.h"
#include <unordered_map>
#include "source_base/memory.h"
namespace hsolver
{
template <typename T, typename Device = base_device::DEVICE_CPU>
class HSolverPW
{
protected:
// Note GetTypeReal<T>::type will
// return T if T is real type(float, double),
// otherwise return the real type of T(complex<float>, std::complex<double>)
using Real = typename GetTypeReal<T>::type;
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
public:
HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
const std::string calculation_type_in,
const std::string basis_type_in,
const std::string method_in,
const bool use_uspp_in,
const int nspin_in,
const int scf_iter_in,
const int diag_iter_max_in,
const double diag_thr_in,
const bool need_subspace_in,
const bool use_k_continuity_in = false)
: wfc_basis(wfc_basis_in), calculation_type(calculation_type_in), basis_type(basis_type_in), method(method_in),
use_uspp(use_uspp_in), nspin(nspin_in), scf_iter(scf_iter_in),
diag_iter_max(diag_iter_max_in), diag_thr(diag_thr_in), need_subspace(need_subspace_in),
use_k_continuity(use_k_continuity_in) {};
/// @brief solve function for pw
/// @param pHamilt interface to hamilt
/// @param psi reference to psi
/// @param pes interface to elecstate
/// @param method_in dav or cg
/// @param skip_charge
void solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
double* out_eigenvalues,
const int rank_in_pool_in,
const int nproc_in_pool_in,
const bool skip_charge,
const double tpiba,
const int nat);
protected:
// diago caller
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi::Psi<T, Device>& psi,
std::vector<Real>& pre_condition,
Real* eigenvalue,
const int& nk_nums);
// calculate the precondition array for diagonalization in PW base
void update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0);
void output_iterInfo();
ModulePW::PW_Basis_K* wfc_basis = nullptr;
const std::string calculation_type;
const std::string basis_type;
const std::string method;
const bool use_uspp;
const int nspin;
const int scf_iter; // Start from 1
const int diag_iter_max; // max iter times for diagonalization
const double diag_thr; // threshold for diagonalization
const bool need_subspace; // for cg or dav_subspace
const bool use_k_continuity;
protected:
Device* ctx = {};
int rank_in_pool = 0;
int nproc_in_pool = 1;
std::vector<double> ethr_band;
private:
/// @brief calculate the threshold for iterative-diagonalization for each band
void cal_smooth_ethr(const double& wk, const double* wg, const double& ethr, std::vector<double>& ethrs);
// K-point continuity related members
std::vector<int> k_order;
std::unordered_map<int, int> k_parent;
std::vector<ModuleBase::Vector3<double>> kvecs_c;
void build_k_neighbors();
void propagate_psi(psi::Psi<T, Device>& psi, const int from_ik, const int to_ik);
};
} // namespace hsolver
#endif