Skip to content

Commit cc4c6de

Browse files
committed
Fix gradient for NoInterp
1 parent b160f08 commit cc4c6de

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

src/b-splines/indexing.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ function expand(coefs::AbstractArray{T,N}, vweights::Tuple{}, ixs::Tuple{}, iexp
174174
@inbounds coefs[iexpanded...] # @inbounds is safe because we checked in the original call
175175
end
176176

177+
const HasNoInterp{N} = NTuple{N,Tuple{Vararg{<:Union{Number,NoInterp}}}}
178+
expand(coefs::AbstractArray, vweights::HasNoInterp, ixs::Indexes, iexpanded::Vararg{Integer,M}) where {M} = NoInterp()
179+
177180
# _expand1 handles the expansion of a single dimension weight list (of length L)
178181
@inline _expand1(coefs, w1, ix1, wrest, ixrest, iexpanded) =
179182
w1[1] * expand(coefs, wrest, ixrest, iexpanded..., ix1[1]) +
@@ -182,14 +185,17 @@ end
182185
w1[1] * expand(coefs, wrest, ixrest, iexpanded..., ix1[1])
183186

184187
# Expansion of the gradient
185-
function expand(coefs, (vweights, gweights)::Tuple{Weights{N},Weights{N}}, ixs::Indexes{N}) where N
188+
function expand(coefs, (vweights, gweights)::Tuple{HasNoInterp{N},HasNoInterp{N}}, ixs::Indexes{N}) where N
186189
# We swap in one gradient dimension per call to expand
187-
SVector(ntuple(d->expand(coefs, substitute(vweights, d, gweights), ixs), Val(N)))
190+
SVector(skip_nointerp(ntuple(d->expand(coefs, substitute(vweights, d, gweights), ixs), Val(N))...))
188191
end
189-
function expand!(dest, coefs, (vweights, gweights)::Tuple{Weights{N},Weights{N}}, ixs::Indexes{N}) where N
192+
function expand!(dest, coefs, (vweights, gweights)::Tuple{HasNoInterp{N},HasNoInterp{N}}, ixs::Indexes{N}) where N
190193
# We swap in one gradient dimension per call to expand
194+
i = 0
191195
for d = 1:N
192-
dest[d] = expand(coefs, substitute(vweights, d, gweights), ixs)
196+
w = substitute(vweights, d, gweights)
197+
w isa Weights || continue # must have a NoInterp in it
198+
dest[i+=1] = expand(coefs, w, ixs)
193199
end
194200
dest
195201
end

src/nointerp/nointerp.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ base_rem(::NoInterp, bounds, x::Number) = Int(x), 0
1717
expand_index(::NoInterp, xi::Number, ax::AbstractUnitRange, δx) = (xi,)
1818

1919
value_weights(::NoInterp, δx) = (oneunit(δx),)
20-
gradient_weights(::NoInterp, δx) = (zero(δx),)
21-
hessian_weights(::NoInterp, δx) = (zero(δx),)
20+
gradient_weights(::NoInterp, δx) = (NoInterp(),)
21+
hessian_weights(::NoInterp, δx) = (NoInterp(),)
2222

2323
padded_axis(ax::AbstractUnitRange, ::NoInterp) = ax

src/utils.jl

+4
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ end
2020
function substitute(default::NTuple{N,Any}, d::Integer, val) where N
2121
ntuple(i->ifelse(i==d, val, default[i]), Val(N))
2222
end
23+
24+
@inline skip_nointerp(x, rest...) = (x, skip_nointerp(rest...)...)
25+
@inline skip_nointerp(::NoInterp, rest...) = skip_nointerp(rest...)
26+
skip_nointerp() = ()

test/gradient.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
106106
@test g[2] 2 * (4 - 1.75) ^ 2 * (3 - 1.75)
107107

108108
A2 = rand(Float64, nx, nx) * 100
109+
gni = [1.0]
109110
for BC in (Flat,Line,Free,Periodic,Reflect,Natural), GT in (OnGrid, OnCell)
110111
itp_a = interpolate(A2, (BSpline(Linear()), BSpline(Quadratic(BC()))), GT())
111112
itp_b = interpolate(A2, (BSpline(Quadratic(BC())), BSpline(Linear())), GT())
@@ -127,11 +128,17 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
127128
@test epsilon(itp_b(x,yd)) gtmp[2]
128129
ix, iy = round(Int, x), round(Int, y)
129130
gtmp = Interpolations.gradient(itp_c, ix, y)
130-
@test_broken length(gtmp) == 1
131-
@test_broken epsilon(itp_c(ix,yd)) gtmp[1]
131+
@test length(gtmp) == 1
132+
@test epsilon(itp_c(ix,yd)) gtmp[1]
133+
gni[1] = NaN
134+
Interpolations.gradient!(gni, itp_c, ix, y)
135+
@test gni[1] gtmp[1]
132136
gtmp = Interpolations.gradient(itp_d, x, iy)
133-
@test_broken length(gtmp) == 1
137+
@test length(gtmp) == 1
134138
@test epsilon(itp_d(xd,iy)) gtmp[1]
139+
gni[1] = NaN
140+
Interpolations.gradient!(gni, itp_d, x, iy)
141+
@test gni[1] gtmp[1]
135142
end
136143
end
137144

0 commit comments

Comments
 (0)