|
1 | 1 | module NonlinearSolveBaseForwardDiffExt
|
2 | 2 |
|
| 3 | +using CommonSolve: solve |
| 4 | +using FastClosures: @closure |
3 | 5 | using ForwardDiff: ForwardDiff, Dual
|
4 |
| -using NonlinearSolveBase: Utils |
| 6 | +using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearProblem, |
| 7 | + NonlinearLeastSquaresProblem, remake |
| 8 | + |
| 9 | +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils |
5 | 10 |
|
6 | 11 | Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
|
7 | 12 | Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
|
8 | 13 |
|
| 14 | +function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( |
| 15 | + prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem}, |
| 16 | + alg, args...; kwargs...) |
| 17 | + p = Utils.value(prob.p) |
| 18 | + if prob isa IntervalNonlinearProblem |
| 19 | + tspan = Utils.value.(prob.tspan) |
| 20 | + newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...) |
| 21 | + else |
| 22 | + newprob = remake(prob; p, u0 = Utils.value(prob.u0)) |
| 23 | + end |
| 24 | + |
| 25 | + sol = solve(newprob, alg, args...; kwargs...) |
| 26 | + |
| 27 | + uu = sol.u |
| 28 | + Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, p) |
| 29 | + Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, p) |
| 30 | + z = -Jᵤ \ Jₚ |
| 31 | + pp = prob.p |
| 32 | + sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z) |
| 33 | + |
| 34 | + if uu isa Number |
| 35 | + partials = sum(sumfun, zip(z, pp)) |
| 36 | + elseif p isa Number |
| 37 | + partials = sumfun((z, pp)) |
| 38 | + else |
| 39 | + partials = sum(sumfun, zip(eachcol(z), pp)) |
| 40 | + end |
| 41 | + |
| 42 | + return sol, partials |
| 43 | +end |
| 44 | + |
| 45 | +function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F} |
| 46 | + if isinplace(prob) |
| 47 | + f = @closure p -> begin |
| 48 | + du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p))) |
| 49 | + f(du, u, p) |
| 50 | + return du |
| 51 | + end |
| 52 | + else |
| 53 | + f = Base.Fix1(f, u) |
| 54 | + end |
| 55 | + if p isa Number |
| 56 | + return Utils.safe_reshape(ForwardDiff.derivative(f, p), :, 1) |
| 57 | + elseif u isa Number |
| 58 | + return Utils.safe_reshape(ForwardDiff.gradient(f, p), 1, :) |
| 59 | + else |
| 60 | + return ForwardDiff.jacobian(f, p) |
| 61 | + end |
| 62 | +end |
| 63 | + |
| 64 | +function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F} |
| 65 | + if isinplace(prob) |
| 66 | + return ForwardDiff.jacobian( |
| 67 | + @closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u) |
| 68 | + end |
| 69 | + return ForwardDiff.jacobian(Base.Fix2(f, p), u) |
| 70 | +end |
| 71 | + |
| 72 | +function NonlinearSolveBase.nonlinearsolve_dual_solution(u::Number, partials, |
| 73 | + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} |
| 74 | + return Dual{T, V, P}(u, partials) |
| 75 | +end |
| 76 | + |
| 77 | +function NonlinearSolveBase.nonlinearsolve_dual_solution(u::AbstractArray, partials, |
| 78 | + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} |
| 79 | + return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials))) |
| 80 | +end |
| 81 | + |
9 | 82 | end
|
0 commit comments