Skip to content

Commit 5f16205

Browse files
authored
Julia: Update memory management (#123)
1 parent 871e587 commit 5f16205

File tree

7 files changed

+143
-90
lines changed

7 files changed

+143
-90
lines changed

LICENSE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 sphericart contributors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

julia/Project.toml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
name = "SpheriCart"
22
uuid = "5caf2b29-02d9-47a3-9434-5931c85ba645"
33
authors = ["Christoph Ortner <[email protected]> and contributors"]
4-
version = "0.0.3"
4+
version = "0.1.1"
55

66
[deps]
7+
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9-
ObjectPools = "658cac36-ff0f-48ad-967c-110375d98c9d"
1010
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1111
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
12+
StrideArrays = "d1fa6d79-ef01-42a6-86c9-f7c551f8593b"
1213

13-
[compat]
14-
ForwardDiff = "0.10"
15-
ObjectPools = "0.3.0"
16-
OffsetArrays = "1.13"
17-
StaticArrays = "1.9"
14+
[compat]
15+
Bumper = "0.6.0"
16+
StrideArrays = "0.1.28"
17+
ForwardDiff = "0.10"
1818
LinearAlgebra = "1.8"
19+
OffsetArrays = "1.13"
20+
StaticArrays = "1.9"
1921
julia = "1.8.0, 1.9.0, 1.10.0"
2022

2123
[extras]

julia/src/SpheriCart.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module SpheriCart
22

3-
using StaticArrays, OffsetArrays, ObjectPools
3+
using StaticArrays, OffsetArrays, Bumper, StrideArrays
44

55
export SolidHarmonics,
66
compute,

julia/src/api.jl

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
"""
23
`struct SolidHarmonics` : datatype representing a solid harmonics basis.
34
@@ -36,7 +37,6 @@ See documentation for more details.
3637
"""
3738
struct SolidHarmonics{L, NORM, STATIC, T1}
3839
Flm::OffsetMatrix{T1, Matrix{T1}}
39-
cache::TSafe{ArrayPool{FlexArrayCache}}
4040
end
4141

4242
function SolidHarmonics(L::Integer;
@@ -45,7 +45,7 @@ function SolidHarmonics(L::Integer;
4545
T = Float64)
4646
Flm = generate_Flms(L; normalisation = normalisation, T = T)
4747
@assert eltype(Flm) == T
48-
SolidHarmonics{L, normalisation, static, T}(Flm, TSafe(ArrayPool(FlexArrayCache)))
48+
SolidHarmonics{L, normalisation, static, T}(Flm)
4949
end
5050

5151
@inline (basis::SolidHarmonics)(args...) = compute(basis, args...)
@@ -82,28 +82,23 @@ function compute!(Z::AbstractMatrix,
8282
nX = length(Rs)
8383
T = promote_type(T1, T2)
8484

85-
# allocate temporary arrays from an array cache
86-
temps = (x = acquire!(basis.cache, :x, (nX, ), T),
87-
y = acquire!(basis.cache, :y, (nX, ), T),
88-
z = acquire!(basis.cache, :z, (nX, ), T),
89-
= acquire!(basis.cache, :r2, (nX, ), T),
90-
s = acquire!(basis.cache, :s, (nX, L+1), T),
91-
c = acquire!(basis.cache, :c, (nX, L+1), T),
92-
Q = acquire!(basis.cache, :Q, (nX, sizeY(L)), T),
93-
Flm = basis.Flm )
94-
95-
# the actual evaluation kernel
96-
solid_harmonics!(Z, Val{L}(), Rs, temps)
97-
98-
# release the temporary arrays back into the cache
99-
# (don't release Flm!!)
100-
release!(temps.x)
101-
release!(temps.y)
102-
release!(temps.z)
103-
release!(temps.r²)
104-
release!(temps.s)
105-
release!(temps.c)
106-
release!(temps.Q)
85+
@no_escape begin
86+
87+
# allocate temporary arrays from an array cache
88+
temps = (x = @alloc(T, nX),
89+
y = @alloc(T, nX),
90+
z = @alloc(T, nX),
91+
= @alloc(T, nX),
92+
s = @alloc(T, nX, L+1),
93+
c = @alloc(T, nX, L+1),
94+
Q = @alloc(T, nX, sizeY(L)),
95+
Flm = basis.Flm )
96+
97+
# the actual evaluation kernel
98+
solid_harmonics!(Z, Val{L}(), Rs, temps)
99+
100+
nothing
101+
end # @no_escape
107102

108103
return Z
109104
end
@@ -152,28 +147,22 @@ function compute_with_gradients!(
152147
nX = length(Rs)
153148
T = promote_type(T1, T2)
154149

155-
# allocate temporary arrays from an array cache
156-
temps = (x = acquire!(basis.cache, :x, (nX, ), T),
157-
y = acquire!(basis.cache, :y, (nX, ), T),
158-
z = acquire!(basis.cache, :z, (nX, ), T),
159-
= acquire!(basis.cache, :r2, (nX, ), T),
160-
s = acquire!(basis.cache, :s, (nX, L+1), T),
161-
c = acquire!(basis.cache, :c, (nX, L+1), T),
162-
Q = acquire!(basis.cache, :Q, (nX, sizeY(L)), T),
163-
Flm = basis.Flm )
164-
165-
# the actual evaluation kernel
166-
solid_harmonics_with_grad!(Z, dZ, Val{L}(), Rs, temps)
167-
168-
# release the temporary arrays back into the cache
169-
# (don't release Flm!!)
170-
release!(temps.x)
171-
release!(temps.y)
172-
release!(temps.z)
173-
release!(temps.r²)
174-
release!(temps.s)
175-
release!(temps.c)
176-
release!(temps.Q)
150+
@no_escape begin
151+
152+
# allocate temporary arrays from an array cache
153+
temps = (x = @alloc(T, nX),
154+
y = @alloc(T, nX),
155+
z = @alloc(T, nX),
156+
= @alloc(T, nX),
157+
s = @alloc(T, nX, L+1),
158+
c = @alloc(T, nX, L+1),
159+
Q = @alloc(T, nX, sizeY(L)),
160+
Flm = basis.Flm )
161+
162+
# the actual evaluation kernel
163+
solid_harmonics_with_grad!(Z, dZ, Val{L}(), Rs, temps)
164+
nothing
165+
end
177166

178167
return Z
179168
end

julia/src/normalisations.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
# Generates the `F[l, m]` values exactly as described in the
5-
# `sphericart` publication.
5+
# `sphericart` publication. This gives L2-orthonormality.
66
function _generate_Flms(L::Integer, ::Union{Val{:sphericart}, Val{:L2}}, T=Float64)
77
Flm = OffsetMatrix(zeros(L+1, L+1), (-1, -1))
88
for l = 0:L
@@ -14,6 +14,15 @@ function _generate_Flms(L::Integer, ::Union{Val{:sphericart}, Val{:L2}}, T=Float
1414
return Flm
1515
end
1616

17+
function _generate_Flms(L::Integer, ::Val{:racah}, T=Float64)
18+
Flm = _generate_Flms(L, Val(:L2), T)
19+
for l = 0:L
20+
for m = 0:l
21+
Flm[l, m] = Flm[l, m] * sqrt(4*pi/(2*l+1))
22+
end
23+
end
24+
return Flm
25+
end
1726

1827

1928
"""
@@ -23,8 +32,8 @@ generate_Flms(L; normalisation = :L2, T = Float64)
2332
generate the `F[l,m]` prefactors in the definitions of the solid harmonics;
2433
see `sphericart` publication for details. The default normalisation generates
2534
a basis that is L2-orthonormal on the unit sphere. Other normalisations:
26-
- `:sphericart` the same as `:L2`
27-
- `:p4ml` same normalisation as used in `Polynomials4ML.jl`
35+
- `:sphericart` the same as `:L2`, gives L2-orthonormality, i.e. ∫ |Ylm|² = 1
36+
- `:racah` gives Racah normalization for which ∫ |Ylm|² = 4π/(2l+1).
2837
"""
2938
generate_Flms(L::Integer; normalisation = :L2, T = Float64) =
3039
_generate_Flms(L, Val(normalisation), T)

julia/src/spherical.jl

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,10 @@ See documentation for more details.
4242
struct SphericalHarmonics{L, NORM, STATIC, T1}
4343
solids::SolidHarmonics{L, NORM, STATIC, T1}
4444
# Flm::OffsetMatrix{T1, Matrix{T1}}
45-
cache::TSafe{ArrayPool{FlexArrayCache}}
4645
end
4746

4847
SphericalHarmonics(L::Integer; kwargs...) =
49-
SphericalHarmonics(SolidHarmonics(L; kwargs...),
50-
TSafe(ArrayPool(FlexArrayCache)))
48+
SphericalHarmonics(SolidHarmonics(L; kwargs...))
5149

5250

5351
@inline (basis::SphericalHarmonics)(args...) = compute(basis, args...)
@@ -75,16 +73,16 @@ end
7573
# ---------------------
7674
# batched api
7775

78-
function _normalise_Rs!(basis::SphericalHarmonics,
76+
function _normalise_Rs!(rs, Rs_norm,
77+
basis::SphericalHarmonics,
7978
Rs::AbstractVector{SVector{3, T1}}) where {T1}
8079
nX = length(Rs)
81-
rs = acquire!(basis.cache, :rs, (nX, ), T1)
82-
Rs_norm = acquire!(basis.cache, :Rs_norm, (nX, ), SVector{3, T1})
80+
@assert length(rs) == length(Rs_norm) == nX
8381
@inbounds @simd ivdep for i = 1:nX
8482
rs[i] = norm(Rs[i])
8583
Rs_norm[i] = Rs[i] / rs[i]
8684
end
87-
return rs, Rs_norm
85+
return nothing
8886
end
8987

9088
function _rescale_∇Z2∇Y!(∇Z::AbstractMatrix, Rs_norm, rs)
@@ -99,44 +97,59 @@ function _rescale_∇Z2∇Y!(∇Z::AbstractMatrix, Rs_norm, rs)
9997
end
10098

10199
function compute(basis::SphericalHarmonics,
102-
Rs::AbstractVector{<: SVector{3}})
103-
rs, Rs_norm = _normalise_Rs!(basis, Rs)
104-
Y = compute(basis.solids, Rs_norm)
105-
release!(Rs_norm)
106-
release!(rs)
100+
Rs::AbstractVector{<: SVector{3, T1}}
101+
) where {T1}
102+
@no_escape begin
103+
nX = length(Rs)
104+
rs = @alloc(T1, nX)
105+
Rs_norm = @alloc(SVector{3, T1}, nX)
106+
_normalise_Rs!(rs, Rs_norm, basis, Rs)
107+
Y = compute(basis.solids, Rs_norm)
108+
end
107109
return Y
108110
end
109111

110112
function compute!(Y, basis::SphericalHarmonics,
111-
Rs::AbstractVector{<: SVector{3}})
112-
rs, Rs_norm = _normalise_Rs!(basis, Rs)
113-
compute!(Y, basis.solids, Rs_norm)
114-
release!(Rs_norm)
115-
release!(rs)
113+
Rs::AbstractVector{<: SVector{3, T1}}
114+
) where {T1}
115+
@no_escape begin
116+
nX = length(Rs)
117+
rs = @alloc(T1, nX)
118+
Rs_norm = @alloc(SVector{3, T1}, nX)
119+
_normalise_Rs!(rs, Rs_norm, basis, Rs)
120+
compute!(Y, basis.solids, Rs_norm)
121+
nothing
122+
end
116123
return Y
117124
end
118125

119126

120127
function compute_with_gradients(basis::SphericalHarmonics,
121-
Rs::AbstractVector{<: SVector{3}})
122-
rs, Rs_norm = _normalise_Rs!(basis, Rs)
123-
Y, ∇Z = compute_with_gradients(basis.solids, Rs_norm)
124-
_rescale_∇Z2∇Y!(∇Z, Rs_norm, rs)
125-
126-
release!(Rs_norm)
127-
release!(rs)
128-
128+
Rs::AbstractVector{<: SVector{3, T1}}
129+
) where {T1}
130+
@no_escape begin
131+
nX = length(Rs)
132+
rs = @alloc(T1, nX)
133+
Rs_norm = @alloc(SVector{3, T1}, nX)
134+
_normalise_Rs!(rs, Rs_norm, basis, Rs)
135+
Y, ∇Z = compute_with_gradients(basis.solids, Rs_norm)
136+
_rescale_∇Z2∇Y!(∇Z, Rs_norm, rs)
137+
nothing
138+
end
129139
return Y, ∇Z
130140
end
131141

132142
function compute_with_gradients!(Y, ∇Y, basis::SphericalHarmonics,
133-
Rs::AbstractVector{<: SVector{3}})
134-
rs, Rs_norm = _normalise_Rs!(basis, Rs)
135-
compute_with_gradients!(Y, ∇Y, basis.solids, Rs_norm)
136-
_rescale_∇Z2∇Y!(∇Y, Rs_norm, rs)
137-
138-
release!(Rs_norm)
139-
release!(rs)
140-
143+
Rs::AbstractVector{<: SVector{3, T1}}
144+
) where {T1}
145+
@no_escape begin
146+
nX = length(Rs)
147+
rs = @alloc(T1, nX)
148+
Rs_norm = @alloc(SVector{3, T1}, nX)
149+
_normalise_Rs!(rs, Rs_norm, basis, Rs)
150+
compute_with_gradients!(Y, ∇Y, basis.solids, Rs_norm)
151+
_rescale_∇Z2∇Y!(∇Y, Rs_norm, rs)
152+
nothing
153+
end
141154
return Y, ∇Y
142155
end

julia/test/test_solidharmonics.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,25 @@ for ntest = 1:10
9999
@test cond(G) < 1.5
100100
end
101101

102+
##
103+
104+
@info("check Racah scaling")
105+
106+
L = 3
107+
racah = SolidHarmonics(L; normalisation=:racah)
108+
109+
for ntest = 1:10
110+
local Z
111+
Rs = [ rand_sphere() for _ = 1:10_000 ]
112+
Z = compute(racah, Rs)
113+
G = (Z' * Z) / length(Rs) * 4 * π
114+
dd = [ 4*π/(2*l+1) for l = 0:L for m = -l:l ]
115+
D = Diagonal(dd)
116+
Drtinv = Diagonal(dd.^(-0.5))
117+
@test norm(G - D) < 1.0
118+
@test cond( Drtinv * G * Drtinv ) < 1.5
119+
end
120+
102121

103122
##
104123

0 commit comments

Comments
 (0)