Skip to content

Commit 295a3a3

Browse files
committed
exclude nested duals
1 parent ed77d30 commit 295a3a3

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,45 @@ using ForwardDiff: Dual, Partials
77
using SciMLBase
88
using RecursiveArrayTools
99

10-
const DualLinearProblem = LinearProblem{
10+
11+
# Define type for non-nested dual numbers
12+
const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Float64 , P}
13+
14+
# Define type for nested dual numbers
15+
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P}
16+
17+
const SingleDualLinearProblem = LinearProblem{
18+
<:Union{Number, <:AbstractArray, Nothing}, iip,
19+
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
20+
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
21+
<:Any
22+
} where {iip}
23+
24+
const NestedDualLinearProblem = LinearProblem{
1125
<:Union{Number, <:AbstractArray, Nothing}, iip,
12-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
13-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
26+
<:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
27+
<:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
1428
<:Any
15-
} where {iip, T, V, P}
29+
} where {iip}
1630

1731
const DualALinearProblem = LinearProblem{
1832
<:Union{Number, <:AbstractArray, Nothing},
1933
iip,
20-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
34+
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
2135
<:Union{Number, <:AbstractArray},
2236
<:Any
23-
} where {iip, T, V, P}
37+
} where {iip}
2438

2539
const DualBLinearProblem = LinearProblem{
2640
<:Union{Number, <:AbstractArray, Nothing},
2741
iip,
2842
<:Union{Number, <:AbstractArray},
29-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
43+
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
3044
<:Any
31-
} where {iip, T, V, P}
45+
} where {iip}
3246

3347
const DualAbstractLinearProblem = Union{
34-
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
48+
SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}
3549

3650
LinearSolve.@concrete mutable struct DualLinearCache
3751
linear_cache

test/forwarddiff_overloads.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ cache.b = new_b
7979
x_p = solve!(cache)
8080
backslash_x_p = A \ new_b
8181

82-
@test (x_p, backslash_x_p, rtol = 1e-9)
82+
@test (x_p, backslash_x_p, rtol = 1e-9)

0 commit comments

Comments
 (0)