Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,26 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
WoodburyMatrices = "efce3f68-66dc-5838-9240-27a6d6f5f9b6"

[weakdeps]
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
InterpolationsUnitfulExt = "Unitful"
InterpolationsForwardDiffExt = "ForwardDiff"

[compat]
Adapt = "2, 3, 4.0"
AxisAlgorithms = "0.3, 1"
ChainRulesCore = "0.10, 1.0, 1.2, 1.3"
ForwardDiff = "0.10, 1.0"
OffsetArrays = "0.10, 0.11, 1.0.1"
Ratios = "0.3, 0.4"
Requires = "1.1"
StaticArrays = "0.12, 1"
Unitful = "1"
WoodburyMatrices = "0.4, 0.5, 1.0"
julia = "1.6"

[extensions]
InterpolationsUnitfulExt = "Unitful"
julia = "1.9"

[extras]
ColorVectorSpace = "c3611d14-8923-5661-9e6a-0046d554d3a4"
Expand All @@ -46,6 +52,3 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["OffsetArrays", "Unitful", "SharedArrays", "ForwardDiff", "LinearAlgebra", "DualNumbers", "Random", "Pkg", "Test", "Zygote", "ColorVectorSpace"]

[weakdeps]
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
23 changes: 23 additions & 0 deletions ext/InterpolationsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module InterpolationsForwardDiffExt

import Interpolations
using ForwardDiff

# this strips arbitrary layers of ForwardDiff.Dual, returning the innermost value
Interpolations.just_dual_value(x::ForwardDiff.Dual) = Interpolations.just_dual_value(ForwardDiff.value(x))

function Interpolations.maybe_clamp(::Interpolations.NeedsCheck, itp, xs::Tuple{Vararg{ForwardDiff.Dual}})
xs_values = Interpolations.just_dual_value.(xs)
clamped_vals = Interpolations.maybe_clamp(Interpolations.NeedsCheck(), itp, xs_values)
apply_partials.(xs, clamped_vals)
end

# apply partials from arbitrarily nested ForwardDiff.Dual to a value
# used in maybe_clamp, above
function apply_partials(x_dual::D, val::Number) where D <: ForwardDiff.Dual
∂s = ForwardDiff.partials(x_dual)
apply_partials(ForwardDiff.value(x_dual), D(val, ∂s))
end
apply_partials(_::Number, val::Number) = val

end
4 changes: 4 additions & 0 deletions src/Interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ maybe_clamp(itp, xs) = maybe_clamp(BoundsCheckStyle(itp), itp, xs)
maybe_clamp(::NeedsCheck, itp, xs) = map(clamp, xs, lbounds(itp), ubounds(itp))
maybe_clamp(::CheckWillPass, itp, xs) = xs

# this strips arbitrary layers of ForwardDiff.Dual, returning the innermost value
# it's other methods are defined in InterpolationsForwardDiffExt.jl
just_dual_value(x::Number) = x

Base.hash(x::AbstractInterpolation, h::UInt) = Base.hash_uint(3h - objectid(x))
Base.hash(x::AbstractExtrapolation, h::UInt) = Base.hash_uint(3h - objectid(x))

Expand Down
18 changes: 12 additions & 6 deletions src/b-splines/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,36 @@ to `A[ceil(Int,x)]` without scaling.
Constant

function positions(c::Constant{Previous}, ax, x) # discontinuity occurs at integer locations
xm = floorbounds(x, ax)
x_value = just_dual_value.(x)
xm = floorbounds(x_value, ax)
δx = x - xm
fast_trunc(Int, xm), δx
end
function positions(c::Constant{Next}, ax, x) # discontinuity occurs at integer locations
xm = ceilbounds(x, ax)
x_value = just_dual_value.(x)
xm = ceilbounds(x_value, ax)
δx = x - xm
fast_trunc(Int, xm), δx
end
function positions(c::Constant{Nearest}, ax, x) # discontinuity occurs at half-integer locations
xm = roundbounds(x, ax)
x_value = just_dual_value.(x)
xm = roundbounds(x_value, ax)
δx = x - xm
fast_trunc(Int, xm), δx
i = fast_trunc(Int, xm)
i, δx
end

function positions(c::Constant{Previous,Periodic{OnCell}}, ax, x)
x_value = just_dual_value.(x)
# We do not use floorbounds because we do not want to add a half at
# the lowerbound to round up.
xm = floor(x)
xm = floor(x_value)
δx = x - xm
modrange(fast_trunc(Int, xm), ax), δx
end
function positions(c::Constant{Next,Periodic{OnCell}}, ax, x) # discontinuity occurs at integer locations
xm = ceilbounds(x, ax)
x_value = just_dual_value.(x)
xm = ceilbounds(x_value, ax)
δx = x - xm
modrange(fast_trunc(Int, xm), ax), δx
end
Expand Down
2 changes: 1 addition & 1 deletion src/b-splines/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
itpinfo(itp) = (tcollect(itpflag, itp), axes(itp))

@inline function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,N}) where {T,N}
@boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x))
@boundscheck (checkbounds(Bool, itp, just_dual_value.(x)...) || Base.throw_boundserror(itp, x))
wis = weightedindexes((value_weights,), itpinfo(itp)..., x)
InterpGetindex(itp)[wis...]
end
Expand Down
8 changes: 5 additions & 3 deletions src/b-splines/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ a piecewise linear function connecting each pair of neighboring data points.
Linear

function positions(deg::Linear, ax::AbstractUnitRange{<:Integer}, x)
f = floor(x)
x_value = just_dual_value.(x)
f = floor(x_value)
# When x == last(ax) we want to use the x-1, x pair
f = ifelse(x == last(ax), f - oneunit(f), f)
f = ifelse(x_value == last(ax), f - oneunit(f), f)
fi = fast_trunc(Int, f)
expand_index(deg, fi, ax), x-f

expand_index(deg, fi, ax), x - f # for this δ, we want x, not x_value
end
expand_index(::Linear{Throw{OnGrid}}, fi::Number, ax::AbstractUnitRange) = fi
expand_index(::Linear{Periodic{OnCell}}, fi::Number, ax::AbstractUnitRange) =
Expand Down
5 changes: 3 additions & 2 deletions src/monotonic/monotonic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ function interpolate(
end

function (itp::MonotonicInterpolation)(x::Number)
@boundscheck (checkbounds(Bool, itp, x) || Base.throw_boundserror(itp, (x,)))
k = searchsortedfirst(itp.knots, x)
x_value = just_dual_value.(x)
@boundscheck (checkbounds(Bool, itp, x_value) || Base.throw_boundserror(itp, (x_value,)))
k = searchsortedfirst(itp.knots, x_value)
if k > 1
k -= 1
end
Expand Down
3 changes: 2 additions & 1 deletion src/scaling/scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ ubound(ax::AbstractRange, ::DegreeBC, ::OnGrid) = last(ax)

# For (), we scale the evaluation point
@propagate_inbounds function (sitp::ScaledInterpolation{T,N})(xs::Vararg{Number,N}) where {T,N}
@boundscheck (checkbounds(Bool, sitp, xs...) || Base.throw_boundserror(sitp, xs))
xs_values = just_dual_value.(xs)
@boundscheck (checkbounds(Bool, sitp, xs_values...) || Base.throw_boundserror(sitp, xs_values))
xl = maybe_clamp(sitp.itp, coordslookup(itpflag(sitp.itp), sitp.ranges, xs))
@inbounds sitp.itp(xl...)
end
Expand Down
15 changes: 9 additions & 6 deletions test/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ using Test, Interpolations, DualNumbers, LinearAlgebra, ColorVectorSpace
using ColorVectorSpace: RGB, Gray, N0f8, Colorant

@testset "Gradients" begin
# array of values of the function f1 and vector to store gradient
nx = 10
f1(x) = sin((x-3)*2pi/(nx-1) - 1)
g1gt(x) = 2pi/(nx-1) * cos((x-3)*2pi/(nx-1) - 1)
f1(x) = sin((x - 3) * 2pi / (nx - 1) - 1)
g1gt(x) = 2pi / (nx - 1) * cos((x - 3) * 2pi / (nx - 1) - 1) # analytic gradient of f1
A1 = Float64[f1(x) for x in 1:nx]
g1 = Array{Float64}(undef, 1)
A2 = rand(Float64, nx, nx) * 100

# random array and vector to store gradient
A2 = rand(Float64, 3, 3) * 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this hard coded at 3 now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make the test less noisy if it fails. There's no particular reason that it had to be nx and I don't there there are any conditions that are hit with 10 points that aren't with 3. That said, I'm happy to change back. It is not an important change.

g2 = Array{Float64}(undef, 2)

for (A, g) in ((A1, g1), (A2, g2))
# Gradient of Constant should always be 0
for (A, g) in [(A1, g1)]#((A1, g1), (A2, g2))
# Gradient of Constant interpolation should always be 0
itp = interpolate(A, BSpline(Constant()))
for x in InterpolationTestUtils.thirds(axes(A))
@test all(iszero, @inferred(Interpolations.gradient(itp, x...)))
Expand All @@ -23,7 +26,7 @@ using ColorVectorSpace: RGB, Gray, N0f8, Colorant
i = first(eachindex(itp))
@test Interpolations.gradient(itp, i) == Interpolations.gradient(itp, Tuple(i)...)

for BC in (Flat,Line,Free,Periodic,Reflect,Natural), GT in (OnGrid, OnCell)
for BC in (Flat, Line, Free, Periodic, Reflect, Natural), GT in (OnGrid, OnCell)
itp = interpolate(A, BSpline(Quadratic(BC(GT()))))
check_gradient(itp, g)
i = first(eachindex(itp))
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ using Interpolations
const isci = get(ENV, "CI", "") in ("true", "True")

@testset "Interpolations" begin
@test isempty(detect_ambiguities(Interpolations))
@testset "method ambiguities" begin
@test isempty(detect_ambiguities(Interpolations))
end

include("core.jl")
# Hermite interpolation tests
Expand Down
Loading