Skip to content

Commit 5481e82

Browse files
committed
feat: add ImmutableNonlinearProblem
1 parent a12ddbe commit 5481e82

File tree

4 files changed

+77
-1
lines changed

4 files changed

+77
-1
lines changed

lib/NonlinearSolveBase/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1616
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
1717

1818
[weakdeps]
19+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1920
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2021
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2122

2223
[extensions]
24+
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
2325
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
2426
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
2527

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module NonlinearSolveBaseDiffEqBaseExt
2+
3+
using DiffEqBase: DiffEqBase
4+
using NonlinearSolveBase: ImmutableNonlinearProblem
5+
6+
function DiffEqBase.get_concrete_problem(
7+
prob::ImmutableNonlinearProblem, isadapt; kwargs...)
8+
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
9+
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
10+
p = DiffEqBase.get_concrete_p(prob, kwargs)
11+
return SciMLBase.remake(prob; u0, p)
12+
end
13+
14+
end

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ using FastClosures: @closure
77
using LinearAlgebra: norm
88
using Markdown: @doc_str
99
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
10-
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator
10+
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
11+
AbstractNonlinearFunction, @add_kwonly, StandardNonlinearProblem,
12+
NullParameters, NonlinearProblem, isinplace
1113
using StaticArraysCore: StaticArray
1214

1315
include("public.jl")
@@ -25,4 +27,6 @@ export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTermi
2527
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
2628
RelNormSafeNormTerminationMode, AbsNormSafeNormTerminationMode
2729

30+
export ImmutableNonlinearProblem
31+
2832
end
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,57 @@
1+
struct ImmutableNonlinearProblem{uType, iip, P, F, K, PT} <:
2+
AbstractNonlinearProblem{uType, iip}
3+
f::F
4+
u0::uType
5+
p::P
6+
problem_type::PT
7+
kwargs::K
18

9+
@add_kwonly function ImmutableNonlinearProblem{iip}(
10+
f::AbstractNonlinearFunction{iip}, u0, p = NullParameters(),
11+
problem_type = StandardNonlinearProblem(); kwargs...) where {iip}
12+
if haskey(kwargs, :p)
13+
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to \
14+
`NonlinearProblem`. This is not supported.")
15+
end
16+
warn_paramtype(p)
17+
return new{
18+
typeof(u0), iip, typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(
19+
f, u0, p, problem_type, kwargs)
20+
end
21+
22+
"""
23+
Define a steady state problem using the given function.
24+
`isinplace` optionally sets whether the function is inplace or not.
25+
This is determined automatically, but not inferred.
26+
"""
27+
function ImmutableNonlinearProblem{iip}(
28+
f, u0, p = NullParameters(); kwargs...) where {iip}
29+
return ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
30+
end
31+
end
32+
33+
"""
34+
Define a nonlinear problem using an instance of
35+
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
36+
"""
37+
function ImmutableNonlinearProblem(
38+
f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
39+
return ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
40+
end
41+
42+
function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
43+
return ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
44+
end
45+
46+
"""
47+
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
48+
"""
49+
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
50+
return ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
51+
end
52+
53+
function Base.convert(
54+
::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
55+
return ImmutableNonlinearProblem{isinplace(prob)}(
56+
prob.f, prob.u0, prob.p, prob.problem_type; prob.kwargs...)
57+
end

0 commit comments

Comments
 (0)