Skip to content

Commit bd707dc

Browse files
committed
Integrate DofManager
1 parent cd16ffa commit bd707dc

5 files changed

+275
-89
lines changed

src/GeometryOptimization.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
module GeometryOptimization
22

3+
using AtomsBase
4+
using AtomsCalculators
35
using DocStringExtensions
46
using LinearAlgebra
5-
using StaticArrays
67
using Optimization
7-
using AtomsBase
8-
using AtomsCalculators
8+
using StaticArrays
99
using Unitful
1010
using UnitfulAtomic
1111

1212
# Make sure Optim is always available
1313
using OptimizationOptimJL
1414
using LineSearches
1515

16+
# Useful shortcuts
17+
using AtomsCalculators: Energy, Forces, Virial, calculate
1618
AC = AtomsCalculators
1719

1820
@template METHODS =
@@ -22,7 +24,7 @@ $(TYPEDSIGNATURES)
2224
$(DOCSTRING)
2325
"""
2426

25-
include("clamping_updating_positions.jl")
27+
include("dof_management.jl")
2628
include("optimization.jl")
2729
include("callbacks.jl")
2830

src/clamping_updating_positions.jl

-55
This file was deleted.

src/dof_management.jl

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""
2+
`DofManager`:
3+
4+
Constructor:
5+
```julia
6+
DofManager(sys; variablecell = false, r0 =..., free=..., clamp=..., mask=...)
7+
```
8+
* `variablecell` determines whether the cell is fixed or allowed to change
9+
during optimization
10+
* `r0` is a reference length-scale, default set to one (in the unit of sys),
11+
this is used to non-dimensionalize the degrees of freedom.
12+
13+
In addition set at most one of the kwargs:
14+
* no kwarg: all atoms are free
15+
* `free` : list of free atom indices (not dof indices)
16+
* `clamp` : list of clamped atom indices (not dof indices)
17+
* `mask` : 3 x N Bool array to specify individual coordinates to be clamped
18+
19+
### Meaning of dofs
20+
21+
On call to the constructor, `DofManager` stores positions and cell
22+
`X0, C0`, dofs are understood *relative* to this initial configuration.
23+
`get_dofs(sys, dm::DofManager)` returns a vector that represents the
24+
non-dimensional displacement and a deformation matrix `(U, F)`. The new configuration extracted from a dof vector
25+
is understood as
26+
* The new cell: `C = F * C0`
27+
* The new positions: `𝐫[i] = F * (X0[i] + U[i] * r0)`
28+
One aspect of this definition is that clamped atom positions still change via
29+
the deformation `F`. This is natural in the context of optimizing the
30+
cell shape.
31+
"""
32+
mutable struct DofManager{D,T}
33+
variablecell::Bool
34+
ifree::Vector{Int} # extract the free position dofs
35+
r0::T
36+
X0::Vector{SVector{D,T}} # reference positions
37+
C0::NTuple{D, SVector{D, T}} # reference cell
38+
end
39+
40+
# NOTES:
41+
# - length units are implicitly given by the units in X0, C0, r0.
42+
# there should be no explicit length-unit stripping but this should be
43+
# implicitly through the reference length-scale r0
44+
# - at the moment energy-nondimensionalization is achieved simply by
45+
# stripping. A better approach would be to enforce this to happen in
46+
# the preconditioner, which could simply be a rescaling operation.
47+
48+
# ========================================================================
49+
# Constructors
50+
51+
function DofManager(sys::AbstractSystem{D};
52+
variablecell=false,
53+
r0=_auto_r0(sys),
54+
free=nothing,
55+
clamp=nothing,
56+
mask=nothing) where {D}
57+
if D != 3
58+
error("this package assumes D = 3; please file an issue if you neeed a different use case.")
59+
end
60+
X0 = copy(position(sys, :))
61+
C0 = cell_vectors(sys)
62+
ifree = analyze_mask(sys, free, clamp, mask)
63+
DofManager(variablecell, ifree, r0, X0, C0)
64+
end
65+
function _auto_r0(sys)
66+
r = position(sys, 1)[1]
67+
one(ustrip(r)) * unit(r)
68+
end
69+
70+
71+
"""
72+
`analyze_mask` : helper function to generate list of dof indices from
73+
lists of atom indices indicating free and clamped atoms
74+
"""
75+
function analyze_mask(sys, free, clamp, mask)
76+
if sum(!isnothing, (free, clamp, mask)) > 1
77+
error("DofManager: At most one of `free`, `clamp`, `mask` may be provided")
78+
end
79+
if all(isnothing, (free, clamp, mask))
80+
# in this case (default) all atoms are free
81+
return collect(1:3length(sys))
82+
end
83+
84+
# determine free dof indices
85+
n_atom = length(sys)
86+
if !isnothing(clamp) # revert to setting free
87+
free = setdiff(1:n_atom, clamp)
88+
end
89+
if !isnothing(free) # revert to setting mask
90+
mask = fill(false, 3, n_atom)
91+
if !isempty(free)
92+
mask[:, free] .= true
93+
end
94+
end
95+
return findall(mask[:])
96+
end
97+
98+
# ========================================================================
99+
# DOF Conversions
100+
101+
length_unit(dm::DofManager) = unit(dm.r0)
102+
length_unit(sys::AbstractSystem) = unit(position(sys, 1)[1])
103+
function check_length_units(sys, dm::DofManager)
104+
if length_unit(dm) != length_unit(sys)
105+
error("System `sys` and DofManager have inconsistent units.")
106+
end
107+
if length(sys) != length(dm.X0)
108+
error("System `sys` and DofManager have inconsistent size.")
109+
end
110+
end
111+
112+
variablecell(dofmgr::DofManager) = dofmgr.variablecell
113+
fixedcell(dofmgr::DofManager) = !variablecell(dofmgr)
114+
115+
# there is a type-instability here!!
116+
_posdofs(x, dofmgr::DofManager) = dofmgr.variablecell ? (@view x[1:end-9]) : x
117+
118+
function _pos2dofs(U::AbstractVector{SVector{3, T}}, dofmgr) where {T}
119+
@view(reinterpret(T, U)[dofmgr.ifree])
120+
end
121+
122+
function _dofs2pos(x::AbstractVector{T}, dofmgr) where {T}
123+
u = zeros(T, 3 * length(dofmgr.X0))
124+
u[dofmgr.ifree] .= _posdofs(x, dofmgr)
125+
return reinterpret(SVector{3, T}, u)
126+
end
127+
_defm2dofs(F, dofmgr) = Matrix(F)[:]
128+
129+
function _dofs2defm(x::AbstractVector{T}, dofmgr) where {T}
130+
if dofmgr.variablecell
131+
SMatrix{3, 3, T}(x[end-8:end])
132+
else
133+
SMatrix{3, 3, T}([1 0 0; 0 1 0; 0 0 1])
134+
end
135+
end
136+
137+
function get_dofs(sys::AbstractSystem, dofmgr::DofManager)
138+
check_length_units(sys, dofmgr)
139+
140+
# obtain the positions and their underlying floating point type
141+
X = position(sys, :)
142+
if fixedcell(dofmgr)
143+
# there are allocations here that could maybe be avoided
144+
return collect(_pos2dofs((X - dofmgr.X0)/dofmgr.r0, dofmgr))
145+
else
146+
# variable cell case: note we already checked units and can strip
147+
# (otherwise we would have problems inverting)
148+
bb = cell_vectors(sys)
149+
F = ustrip.(hcat(bb...)) / ustrip.(hcat(dofmgr.C0...))
150+
# Xi = F * (X0i + Ui * r0) => Ui = (F \ Xi - X0i) / r0
151+
U = [ (F \ X[i] - dofmgr.X0[i]) / dofmgr.r0 for i = 1:length(X) ]
152+
return [ _pos2dofs(U, dofmgr);
153+
_defm2dofs(F, dofmgr) ]
154+
end
155+
end
156+
157+
158+
function set_dofs(system::AbstractSystem, dofmgr::DofManager,
159+
x::AbstractVector{T} ) where {T <: AbstractFloat}
160+
check_length_units(system, dofmgr)
161+
162+
# get the displacement from the dof vector
163+
U = _dofs2pos(x, dofmgr)
164+
F = _dofs2defm(x, dofmgr)
165+
166+
# convert the displacements to positions
167+
positions = [ F * (dofmgr.X0[i] + U[i] * dofmgr.r0) for i = 1:length(U) ]
168+
bb_old = dofmgr.C0
169+
bb_new = ntuple(i -> F * bb_old[i], 3)
170+
171+
# and update the system
172+
particles = [Atom(atom; position) for (atom, position) in zip(system, positions)]
173+
AbstractSystem(system; particles, cell_vectors=bb_new)
174+
end
175+
176+
177+
178+
# ========================================================================
179+
# Compute the gradient with respect to dofs
180+
# from forces and virials
181+
182+
function energy_dofs(system, calculator, dofmgr, x::AbstractVector, ps, state)
183+
res = calculate(Energy(), set_dofs(system, dofmgr, x), calculator, ps, state)
184+
(; energy_unitless=ustrip(res.energy), res...)
185+
end
186+
187+
function gradient_dofs(system, calculator, dofmgr, x::AbstractVector{T}, ps, state) where {T}
188+
# Compute and transform forces and virial into a gradient w.r.t. x
189+
if fixedcell(dofmgr)
190+
# fixed cell version
191+
# fi = - ∇_𝐫i E [eV/A]
192+
# 𝐫i = X0[i] + r0 * U[i]
193+
# g_iα = - fiα * r0 [eV] => same unit as E so can strip
194+
195+
res = calculate(Forces(), set_dofs(system, dofmgr, x), calculator, ps, state)
196+
g_pos = [ ustrip( - dofmgr.r0 * f ) for f in res.forces ]
197+
grad = collect(_pos2dofs(g_pos, dofmgr))::Vector{T}
198+
else
199+
# variable cell version
200+
# fi = - ∇_𝐫i E [eV/A] 𝐫i = F * (X0[i] + r0 * U[i])
201+
# ∇_𝐮i' = - fi' * ∂𝐫i/∂𝐮i = - fi' * (r0 * F) => ∇_𝐮i = - F' * r0 * fi
202+
# ∂F E |_{F = I} = - virial => ∂F E = - virial / F'
203+
204+
res = calculate((Forces(), Virial()), set_dofs(system, dofmgr, x),
205+
calculator, ps, state)
206+
207+
F = _dofs2defm(x, dofmgr)
208+
g_pos = [ - ustrip(dofmgr.r0 * F' * f) for f in res.forces ]
209+
210+
grad = [ _pos2dofs(g_pos, dofmgr);
211+
( - ustrip.(res.virial) / F' )[:] ]::Vector{T}
212+
end
213+
214+
(; grad, res...)
215+
end

0 commit comments

Comments
 (0)