Skip to content

Commit aaf35c8

Browse files
committed
Tidy utils.jl
1 parent 6908101 commit aaf35c8

File tree

2 files changed

+119
-54
lines changed

2 files changed

+119
-54
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

+13-46
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ function _eval_hessian_inner(
105105
@inbounds input_ϵ[idx] = zero(T)
106106
end
107107
end
108-
# TODO(odow): consider reverting to a view.
109-
output_slice = _VectorView(nzcount, length(ex.hess_I), pointer(H))
108+
output_slice = _UnsafeVectorView(nzcount, length(ex.hess_I), pointer(H))
110109
Coloring.recover_from_matmat!(
111110
output_slice,
112111
ex.seed_matrix,
@@ -177,41 +176,6 @@ function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
177176
return
178177
end
179178

180-
struct HessianView <: AbstractMatrix{Float64}
181-
x::Vector{Float64}
182-
N::Int
183-
function HessianView(x, N)
184-
z = div(N * (N + 1), 2)
185-
if length(x) < z
186-
resize!(x, z)
187-
end
188-
for i in 1:z
189-
x[i] = 0.0
190-
end
191-
return new(x, N)
192-
end
193-
end
194-
195-
Base.size(x::HessianView) = (x.N, x.N)
196-
_linear_index(i, j) = i < j ? _linear_index(j, i) : div((i - 1) * i, 2) + j
197-
Base.getindex(x::HessianView, i, j) = x.x[_linear_index(i, j)]
198-
Base.setindex!(x::HessianView, v, i, j) = (x.x[_linear_index(i, j)] = v)
199-
200-
struct VectorView <: AbstractVector{Float64}
201-
x::Vector{Float64}
202-
N::Int
203-
function VectorView(x, N)
204-
if length(x) < N
205-
resize!(x, N)
206-
end
207-
return new(x, N)
208-
end
209-
end
210-
211-
Base.size(x::VectorView) = (x.N,)
212-
Base.getindex(x::VectorView, i) = x.x[i]
213-
Base.setindex!(x::VectorView, v, i) = (x.x[i] = v)
214-
215179
"""
216180
_forward_eval_ϵ(
217181
d,
@@ -273,29 +237,32 @@ function _forward_eval_ϵ(
273237
end
274238
if node.type == Nonlinear.NODE_CALL_MULTIVARIATE
275239
nn = length(children_indices)
276-
f_input = VectorView(d.jac_storage, nn)
240+
f_input = _UnsafeVectorView(d.jac_storage, nn)
277241
for (i, c) in enumerate(children_indices)
278242
f_input[i] = ex.forward_storage[children_arr[c]]
279243
end
280-
H = HessianView(d.user_output_buffer, nn)
244+
H = _UnsafeHessianView(d.user_output_buffer, nn)
281245
has_hessian = Nonlinear.eval_multivariate_hessian(
282246
user_operators,
283247
user_operators.multivariate_operators[node.index],
284248
LinearAlgebra.UpperTriangular(H),
285249
f_input,
286250
)
287251
if has_hessian
288-
for (row, c) in enumerate(children_indices)
289-
ix = children_arr[c]
290-
dual = ntuple(N) do j
252+
for (col, c) in enumerate(children_indices)
253+
dual = ntuple(Val(N)) do j
291254
y = 0.0
292-
for (col, ck) in enumerate(children_indices)
293-
ε = storage_ϵ[children_arr[ck]]
294-
y += H[row, col] * ε[j]
255+
for (row, ck) in enumerate(children_indices)
256+
h = H[row, col]
257+
if !iszero(h)
258+
ε = storage_ϵ[children_arr[ck]]
259+
y += h * ε[j]
260+
end
295261
end
296262
return y
297263
end
298-
partials_storage_ϵ[ix] = ForwardDiff.Partials(dual)
264+
ix = children_arr[c]
265+
partials_storage_ϵ[ix] = ForwardDiff.Partials{N,T}(dual)
299266
end
300267
end
301268
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE

src/Nonlinear/ReverseAD/utils.jl

+106-8
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,125 @@
33
# License, v. 2.0. If a copy of the MPL was not distributed with this
44
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
55

6-
# Lightweight unsafe view for vectors. It seems that the only way to avoid
7-
# triggering allocations is to have only bitstype fields, so we store a pointer.
8-
struct _VectorView{T} <: DenseVector{T}
6+
"""
7+
_UnsafeVectorView(offset, len, ptr)
8+
9+
Lightweight unsafe view for vectors.
10+
11+
## Motivation
12+
13+
`_UnsafeVectorView` is needed as an allocation-free equivalent of `view`. Other
14+
alternatives, like `view(x, 1:len)` or a struct like
15+
```
16+
struct _SafeView{T}
17+
x::Vector{T}
18+
len::Int
19+
end
20+
```
21+
will allocate so that `x` can be tracked by Julia's GC.
22+
23+
`_UnsafeVectorView` relies on the fact that the use-cases of `_UnsafeVectorView`
24+
only temporarily wrap a long-lived vector like `d.jac_storage` so that we don't
25+
have to worry about the GC removing `d.jac_storage` while `_UnsafeVectorView`
26+
exists. This lets us use a `Ptr{T}` and create a struct that is `isbitstype` and
27+
therefore does not allocate.
28+
29+
## Example
30+
31+
Instead of `view(x, 1:10)`, use `_UnsafeVectorView(0, 10, pointer(x))`.
32+
33+
## Unsafe behavior
34+
35+
`_UnsafeVectorView` is unsafe because it assumes that the vector `x` that the
36+
pointer `ptr` refers to remains valid during the usage of `_UnsafeVectorView`.
37+
"""
38+
struct _UnsafeVectorView{T} <: DenseVector{T}
939
offset::Int
1040
len::Int
1141
ptr::Ptr{T}
1242
end
1343

14-
Base.getindex(v::_VectorView, idx::Integer) = unsafe_load(v.ptr, idx + v.offset)
44+
Base.getindex(x::_UnsafeVectorView, i) = unsafe_load(x.ptr, i + x.offset)
1545

16-
function Base.setindex!(v::_VectorView, value, idx::Integer)
17-
unsafe_store!(v.ptr, value, idx + v.offset)
46+
function Base.setindex!(x::_UnsafeVectorView, value, i)
47+
unsafe_store!(x.ptr, value, i + x.offset)
1848
return value
1949
end
2050

21-
Base.length(v::_VectorView) = v.len
51+
Base.length(x::_UnsafeVectorView) = x.len
52+
53+
function _UnsafeVectorView(x::Vector, N::Int)
54+
if length(x) < N
55+
resize!(x, N)
56+
end
57+
return _UnsafeVectorView(0, N, pointer(x))
58+
end
59+
60+
"""
61+
_UnsafeHessianView(x, N)
62+
63+
Lightweight unsafe view that converts a vector `x` into the upper-triangular
64+
component of a symmetric `N`-by-`N` matrix.
65+
66+
## Motivation
67+
68+
`_UnsafeHessianView` is needed as an allocation-free equivalent of `view`. Other
69+
alternatives, like `reshape(view(x, 1:N^2), N, N)` or a struct like
70+
```
71+
struct _SafeView{T}
72+
x::Vector{T}
73+
len::Int
74+
end
75+
```
76+
will allocate so that `x` can be tracked by Julia's GC.
77+
78+
`_UnsafeHessianView` relies on the fact that the use-cases of
79+
`_UnsafeHessianView` only temporarily wrap a long-lived vector like
80+
`d.jac_storage` so that we don't have to worry about the GC removing
81+
`d.jac_storage` while `_UnsafeHessianView` exists. This lets us use a `Ptr{T}`
82+
and create a struct that is `isbitstype` and therefore does not allocate.
83+
84+
## Unsafe behavior
85+
86+
`_UnsafeHessianView` is unsafe because it assumes that the vector `x` remains
87+
valid during the usage of `_UnsafeHessianView`.
88+
"""
89+
struct _UnsafeHessianView <: AbstractMatrix{Float64}
90+
N::Int
91+
x::_UnsafeVectorView{T}
92+
end
93+
94+
Base.size(x::_UnsafeHessianView) = (x.N, x.N)
95+
96+
_linear_index(i, j) = i >= j ? div((i - 1) * i, 2) + j : _linear_index(j, i)
97+
98+
Base.getindex(x::_UnsafeHessianView, i, j) = x.x[_linear_index(i, j)]
99+
100+
function Base.setindex!(x::_UnsafeHessianView, value, i, j)
101+
return x[_linear_index(i, j)] = value
102+
end
103+
104+
function _UnsafeHessianView(x::Vector, N::Int)
105+
z = div(N * (N + 1), 2)
106+
if length(x) < z
107+
resize!(x, z)
108+
end
109+
for i in 1:z
110+
x[i] = 0.0
111+
end
112+
return _UnsafeHessianView(N, _UnsafeVectorView(0, z, pointer(x)))
113+
end
114+
115+
"""
116+
_reinterpret_unsafe(::Type{T}, x::Vector{R}) where {T,R}
22117
118+
Return an `_UnsafeVectorView` that is the result of re-interpreting the vector
119+
`x` as having the element type `T`.
120+
"""
23121
function _reinterpret_unsafe(::Type{T}, x::Vector{R}) where {T,R}
24122
# how many T's fit into x?
25123
@assert isbitstype(T) && isbitstype(R)
26124
len = length(x) * sizeof(R)
27125
p = reinterpret(Ptr{T}, pointer(x))
28-
return _VectorView(0, div(len, sizeof(T)), p)
126+
return _UnsafeVectorView(0, div(len, sizeof(T)), p)
29127
end

0 commit comments

Comments
 (0)