|
| 1 | +""" |
| 2 | +`DofManager`: |
| 3 | +
|
| 4 | +Constructor: |
| 5 | +```julia |
| 6 | +DofManager(sys; variablecell = false, r0 =..., free=..., clamp=..., mask=...) |
| 7 | +``` |
| 8 | +* `variablecell` determines whether the cell is fixed or allowed to change |
| 9 | + during optimization |
| 10 | +* `r0` is a reference length-scale, default set to one (in the unit of sys), |
| 11 | + this is used to non-dimensionalize the degrees of freedom. |
| 12 | +
|
| 13 | +In addition set at most one of the kwargs: |
| 14 | +* no kwarg: all atoms are free |
| 15 | +* `free` : list of free atom indices (not dof indices) |
| 16 | +* `clamp` : list of clamped atom indices (not dof indices) |
| 17 | +* `mask` : 3 x N Bool array to specify individual coordinates to be clamped |
| 18 | +
|
| 19 | +### Meaning of dofs |
| 20 | +
|
| 21 | +On call to the constructor, `DofManager` stores positions and cell |
| 22 | +`X0, C0`, dofs are understood *relative* to this initial configuration. |
| 23 | +`get_dofs(sys, dm::DofManager)` returns a vector that represents the |
| 24 | +non-dimensional displacement and a deformation matrix `(U, F)`. The new configuration extracted from a dof vector |
| 25 | +is understood as |
| 26 | +* The new cell: `C = F * C0` |
| 27 | +* The new positions: `𝐫[i] = F * (X0[i] + U[i] * r0)` |
| 28 | +One aspect of this definition is that clamped atom positions still change via |
| 29 | +the deformation `F`. This is natural in the context of optimizing the |
| 30 | +cell shape. |
| 31 | +""" |
| 32 | +mutable struct DofManager{D,T} |
| 33 | + variablecell::Bool |
| 34 | + ifree::Vector{Int} # extract the free position dofs |
| 35 | + r0::T |
| 36 | + X0::Vector{SVector{D,T}} # reference positions |
| 37 | + C0::NTuple{D, SVector{D, T}} # reference cell |
| 38 | +end |
| 39 | + |
| 40 | +# NOTES: |
| 41 | +# - length units are implicitly given by the units in X0, C0, r0. |
| 42 | +# there should be no explicit length-unit stripping but this should be |
| 43 | +# implicitly through the reference length-scale r0 |
| 44 | +# - at the moment energy-nondimensionalization is achieved simply by |
| 45 | +# stripping. A better approach would be to enforce this to happen in |
| 46 | +# the preconditioner, which could simply be a rescaling operation. |
| 47 | + |
| 48 | +# ======================================================================== |
| 49 | +# Constructors |
| 50 | + |
| 51 | +function DofManager(sys::AbstractSystem{D}; |
| 52 | + variablecell=false, |
| 53 | + r0=_auto_r0(sys), |
| 54 | + free=nothing, |
| 55 | + clamp=nothing, |
| 56 | + mask=nothing) where {D} |
| 57 | + if D != 3 |
| 58 | + error("this package assumes D = 3; please file an issue if you neeed a different use case.") |
| 59 | + end |
| 60 | + X0 = copy(position(sys, :)) |
| 61 | + C0 = cell_vectors(sys) |
| 62 | + ifree = analyze_mask(sys, free, clamp, mask) |
| 63 | + DofManager(variablecell, ifree, r0, X0, C0) |
| 64 | +end |
| 65 | +function _auto_r0(sys) |
| 66 | + r = position(sys, 1)[1] |
| 67 | + one(ustrip(r)) * unit(r) |
| 68 | +end |
| 69 | + |
| 70 | + |
| 71 | +""" |
| 72 | +`analyze_mask` : helper function to generate list of dof indices from |
| 73 | +lists of atom indices indicating free and clamped atoms |
| 74 | +""" |
| 75 | +function analyze_mask(sys, free, clamp, mask) |
| 76 | + if sum(!isnothing, (free, clamp, mask)) > 1 |
| 77 | + error("DofManager: At most one of `free`, `clamp`, `mask` may be provided") |
| 78 | + end |
| 79 | + if all(isnothing, (free, clamp, mask)) |
| 80 | + # in this case (default) all atoms are free |
| 81 | + return collect(1:3length(sys)) |
| 82 | + end |
| 83 | + |
| 84 | + # determine free dof indices |
| 85 | + n_atom = length(sys) |
| 86 | + if !isnothing(clamp) # revert to setting free |
| 87 | + free = setdiff(1:n_atom, clamp) |
| 88 | + end |
| 89 | + if !isnothing(free) # revert to setting mask |
| 90 | + mask = fill(false, 3, n_atom) |
| 91 | + if !isempty(free) |
| 92 | + mask[:, free] .= true |
| 93 | + end |
| 94 | + end |
| 95 | + return findall(mask[:]) |
| 96 | +end |
| 97 | + |
| 98 | +# ======================================================================== |
| 99 | +# DOF Conversions |
| 100 | + |
| 101 | +length_unit(dm::DofManager) = unit(dm.r0) |
| 102 | +length_unit(sys::AbstractSystem) = unit(position(sys, 1)[1]) |
| 103 | +function check_length_units(sys, dm::DofManager) |
| 104 | + if length_unit(dm) != length_unit(sys) |
| 105 | + error("System `sys` and DofManager have inconsistent units.") |
| 106 | + end |
| 107 | + if length(sys) != length(dm.X0) |
| 108 | + error("System `sys` and DofManager have inconsistent size.") |
| 109 | + end |
| 110 | +end |
| 111 | + |
| 112 | +variablecell(dofmgr::DofManager) = dofmgr.variablecell |
| 113 | +fixedcell(dofmgr::DofManager) = !variablecell(dofmgr) |
| 114 | + |
| 115 | +# there is a type-instability here!! |
| 116 | +_posdofs(x, dofmgr::DofManager) = dofmgr.variablecell ? (@view x[1:end-9]) : x |
| 117 | + |
| 118 | +function _pos2dofs(U::AbstractVector{SVector{3, T}}, dofmgr) where {T} |
| 119 | + @view(reinterpret(T, U)[dofmgr.ifree]) |
| 120 | +end |
| 121 | + |
| 122 | +function _dofs2pos(x::AbstractVector{T}, dofmgr) where {T} |
| 123 | + u = zeros(T, 3 * length(dofmgr.X0)) |
| 124 | + u[dofmgr.ifree] .= _posdofs(x, dofmgr) |
| 125 | + return reinterpret(SVector{3, T}, u) |
| 126 | +end |
| 127 | +_defm2dofs(F, dofmgr) = Matrix(F)[:] |
| 128 | + |
| 129 | +function _dofs2defm(x::AbstractVector{T}, dofmgr) where {T} |
| 130 | + if dofmgr.variablecell |
| 131 | + SMatrix{3, 3, T}(x[end-8:end]) |
| 132 | + else |
| 133 | + SMatrix{3, 3, T}([1 0 0; 0 1 0; 0 0 1]) |
| 134 | + end |
| 135 | +end |
| 136 | + |
| 137 | +function get_dofs(sys::AbstractSystem, dofmgr::DofManager) |
| 138 | + check_length_units(sys, dofmgr) |
| 139 | + |
| 140 | + # obtain the positions and their underlying floating point type |
| 141 | + X = position(sys, :) |
| 142 | + if fixedcell(dofmgr) |
| 143 | + # there are allocations here that could maybe be avoided |
| 144 | + return collect(_pos2dofs((X - dofmgr.X0)/dofmgr.r0, dofmgr)) |
| 145 | + else |
| 146 | + # variable cell case: note we already checked units and can strip |
| 147 | + # (otherwise we would have problems inverting) |
| 148 | + bb = cell_vectors(sys) |
| 149 | + F = ustrip.(hcat(bb...)) / ustrip.(hcat(dofmgr.C0...)) |
| 150 | + # Xi = F * (X0i + Ui * r0) => Ui = (F \ Xi - X0i) / r0 |
| 151 | + U = [ (F \ X[i] - dofmgr.X0[i]) / dofmgr.r0 for i = 1:length(X) ] |
| 152 | + return [ _pos2dofs(U, dofmgr); |
| 153 | + _defm2dofs(F, dofmgr) ] |
| 154 | + end |
| 155 | +end |
| 156 | + |
| 157 | + |
| 158 | +function set_dofs(system::AbstractSystem, dofmgr::DofManager, |
| 159 | + x::AbstractVector{T} ) where {T <: AbstractFloat} |
| 160 | + check_length_units(system, dofmgr) |
| 161 | + |
| 162 | + # get the displacement from the dof vector |
| 163 | + U = _dofs2pos(x, dofmgr) |
| 164 | + F = _dofs2defm(x, dofmgr) |
| 165 | + |
| 166 | + # convert the displacements to positions |
| 167 | + positions = [ F * (dofmgr.X0[i] + U[i] * dofmgr.r0) for i = 1:length(U) ] |
| 168 | + bb_old = dofmgr.C0 |
| 169 | + bb_new = ntuple(i -> F * bb_old[i], 3) |
| 170 | + |
| 171 | + # and update the system |
| 172 | + particles = [Atom(atom; position) for (atom, position) in zip(system, positions)] |
| 173 | + AbstractSystem(system; particles, cell_vectors=bb_new) |
| 174 | +end |
| 175 | + |
| 176 | + |
| 177 | + |
| 178 | +# ======================================================================== |
| 179 | +# Compute the gradient with respect to dofs |
| 180 | +# from forces and virials |
| 181 | + |
| 182 | +function energy_dofs(system, calculator, dofmgr, x::AbstractVector, ps, state) |
| 183 | + res = calculate(Energy(), set_dofs(system, dofmgr, x), calculator, ps, state) |
| 184 | + (; energy_unitless=ustrip(res.energy), res...) |
| 185 | +end |
| 186 | + |
| 187 | +function gradient_dofs(system, calculator, dofmgr, x::AbstractVector{T}, ps, state) where {T} |
| 188 | + # Compute and transform forces and virial into a gradient w.r.t. x |
| 189 | + if fixedcell(dofmgr) |
| 190 | + # fixed cell version |
| 191 | + # fi = - ∇_𝐫i E [eV/A] |
| 192 | + # 𝐫i = X0[i] + r0 * U[i] |
| 193 | + # g_iα = - fiα * r0 [eV] => same unit as E so can strip |
| 194 | + |
| 195 | + res = calculate(Forces(), set_dofs(system, dofmgr, x), calculator, ps, state) |
| 196 | + g_pos = [ ustrip( - dofmgr.r0 * f ) for f in res.forces ] |
| 197 | + grad = collect(_pos2dofs(g_pos, dofmgr))::Vector{T} |
| 198 | + else |
| 199 | + # variable cell version |
| 200 | + # fi = - ∇_𝐫i E [eV/A] 𝐫i = F * (X0[i] + r0 * U[i]) |
| 201 | + # ∇_𝐮i' = - fi' * ∂𝐫i/∂𝐮i = - fi' * (r0 * F) => ∇_𝐮i = - F' * r0 * fi |
| 202 | + # ∂F E |_{F = I} = - virial => ∂F E = - virial / F' |
| 203 | + |
| 204 | + res = calculate((Forces(), Virial()), set_dofs(system, dofmgr, x), |
| 205 | + calculator, ps, state) |
| 206 | + |
| 207 | + F = _dofs2defm(x, dofmgr) |
| 208 | + g_pos = [ - ustrip(dofmgr.r0 * F' * f) for f in res.forces ] |
| 209 | + |
| 210 | + grad = [ _pos2dofs(g_pos, dofmgr); |
| 211 | + ( - ustrip.(res.virial) / F' )[:] ]::Vector{T} |
| 212 | + end |
| 213 | + |
| 214 | + (; grad, res...) |
| 215 | +end |
0 commit comments