-
-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathforward_diff.jl
79 lines (68 loc) · 2.83 KB
/
forward_diff.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Not part of public API but helps reduce code duplication
import SimpleNonlinearSolve: __nlsolve_ad, __nlsolve_dual_soln, __nlsolve_∂f_∂p,
__nlsolve_∂f_∂u
for (uType, pType) in [
(Union{<:Number, <:AbstractArray}, Union{<:Dual, <:AbstractArray{<:Dual}}),
(Union{<:Dual, <:AbstractArray{<:Dual}}, Union{<:Dual, <:AbstractArray{<:Dual}}),
(Union{<:Dual, <:AbstractArray{<:Dual}}, Union{<:Number, <:AbstractArray})
]
@eval begin
function SciMLBase.solve(
prob::NonlinearProblem{<:$(uType), iip, <:$(pType)},
alg::Union{Nothing, AbstractNonlinearAlgorithm},
args...; kwargs...) where {iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
function SciMLBase.init(
prob::NonlinearProblem{<:$(uType), iip, <:$(pType)},
alg::Union{Nothing, AbstractNonlinearAlgorithm},
args...; kwargs...) where {iip}
p = __value(prob.p)
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
end
end
end
@concrete mutable struct NonlinearSolveForwardDiffCache
cache
prob
alg
p
values_p
partials_p
end
@internal_caches NonlinearSolveForwardDiffCache :cache
function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
p = cache.p, u0 = get_u(cache.cache), kwargs...)
inner_cache = reinit_cache!(cache.cache; p = __value(p), u0 = __value(u0), kwargs...)
cache.cache = inner_cache
cache.p = p
cache.values_p = __value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end
function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob
uu = sol.u
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
z_arr = -f_x \ f_p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if cache.p isa Number
partials = sumfun((z_arr, cache.p))
else
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end
dual_soln = __nlsolve_dual_soln(sol.u, partials, cache.p)
return SciMLBase.build_solution(
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
@inline __value(x) = x
@inline __value(x::Dual) = ForwardDiff.value(x)
@inline __value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)