Skip to content

Commit 7719e98

Browse files
authored
Merge pull request #348 from getzze/regularize
Add regularization for BSpline
2 parents 13da9c4 + 955d75e commit 7719e98

File tree

4 files changed

+156
-3
lines changed

4 files changed

+156
-3
lines changed

src/b-splines/b-splines.jl

+42
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ function interpolate(::Type{TWeights}, ::Type{TC}, A, it::IT) where {TWeights,TC
161161
BSplineInterpolation(TWeights, Apad, it, axes(A))
162162
end
163163

164+
function interpolate(::Type{TWeights}, ::Type{TC}, A, it::IT, λ::Real, k::Int) where {TWeights,TC,IT<:DimSpec{BSpline}}
165+
Apad = prefilter(TWeights, TC, A, it, λ, k)
166+
BSplineInterpolation(TWeights, Apad, it, axes(A))
167+
end
168+
164169
"""
165170
itp = interpolate(A, interpmode)
166171
@@ -179,6 +184,33 @@ function interpolate(A::AbstractArray, it::IT) where {IT<:DimSpec{BSpline}}
179184
interpolate(tweight(A), tcoef(A), A, it)
180185
end
181186

187+
"""
188+
itp = interpolate(A, interpmode, gridstyle, λ, k)
189+
190+
Interpolate an array `A` in the mode determined by `interpmode` and `gridstyle`
191+
with regularization following [1], of order `k` and constant `λ`.
192+
`interpmode` may be one of
193+
194+
- `BSpline(NoInterp())`
195+
- `BSpline(Linear())`
196+
- `BSpline(Quadratic(BC()))` (see [`BoundaryCondition`](@ref))
197+
- `BSpline(Cubic(BC()))`
198+
199+
It may also be a tuple of such values, if you want to use different interpolation schemes along each axis.
200+
201+
`gridstyle` should be one of `OnGrid()` or `OnCell()`.
202+
203+
`k` corresponds to the derivative to penalize. In the limit λ->∞, the interpolation function is
204+
a polynomial of order `k-1`. A value of 2 is the most common.
205+
206+
`λ` is non-negative. If its value is zero, it falls back to non-regularized interpolation.
207+
208+
[1] https://projecteuclid.org/euclid.ss/1038425655.
209+
"""
210+
function interpolate(A::AbstractArray, it::IT, λ::Real, k::Int) where {IT<:DimSpec{BSpline}}
211+
interpolate(tweight(A), tcoef(A), A, it, λ, k)
212+
end
213+
182214
# We can't just return a tuple-of-types due to julia #12500
183215
tweight(A::AbstractArray) = Float64
184216
tweight(A::AbstractArray{T}) where T<:AbstractFloat = T
@@ -200,6 +232,16 @@ function interpolate!(A::AbstractArray, it::IT) where {IT<:DimSpec{BSpline}}
200232
interpolate!(tweight(A), A, it)
201233
end
202234

235+
function interpolate!(::Type{TWeights}, A::AbstractArray, it::IT, λ::Real, k::Int) where {TWeights,IT<:DimSpec{BSpline}}
236+
# Set the bounds of the interpolant inward, if necessary
237+
axsA = axes(A)
238+
axspad = padded_axes(axsA, it)
239+
BSplineInterpolation(TWeights, prefilter!(TWeights, A, it, λ, k), it, fix_axis.(padinset.(axsA, axspad)))
240+
end
241+
function interpolate!(A::AbstractArray, it::IT, λ::Real, k::Int) where {IT<:DimSpec{BSpline}}
242+
interpolate!(tweight(A), A, it, λ, k)
243+
end
244+
203245
lut!(dl, d, du) = lu!(Tridiagonal(dl, d, du), Val(false))
204246

205247
include("constant.jl")

src/b-splines/prefiltering.jl

+56-3
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,18 @@ function prefilter(
4040
prefilter!(TWeights, ret, it)
4141
end
4242

43+
function prefilter(
44+
::Type{TWeights}, ::Type{TC}, A::AbstractArray,
45+
it::Union{BSpline,Tuple{Vararg{Union{BSpline,NoInterp}}}},
46+
λ::Real, k::Int
47+
) where {TWeights,TC}
48+
ret = copy_with_padding(TC, A, it)
49+
prefilter!(TWeights, ret, it, λ, k)
50+
end
51+
4352
function prefilter!(
4453
::Type{TWeights}, ret::TCoefs, it::BSpline
4554
) where {TWeights,TCoefs<:AbstractArray}
46-
local buf, shape, retrs
4755
sz = size(ret)
4856
first = true
4957
for dim in 1:ndims(ret)
@@ -53,12 +61,57 @@ function prefilter!(
5361
ret
5462
end
5563

64+
function diffop(::Type{TWeights}, n::Int, k::Int) where {TWeights}
65+
D = spdiagm(0 => ones(TWeights,n))
66+
for i in 1:k
67+
D = diff(D; dims=1)
68+
end
69+
## TODO: Normalize by n?
70+
D' * D
71+
end
72+
diffop(n::Int, k::Int) = diffop(Float64, n, k)
73+
### TODO: add compiled constructor for most common operators of order k=1,2
74+
75+
76+
function prefilter!(
77+
::Type{TWeights}, ret::TCoefs, it::BSpline, λ::Real, k::Int
78+
) where {TWeights,TCoefs<:AbstractArray}
79+
80+
sz = size(ret)
81+
82+
# Test if regularizing is allowed
83+
fallback = if ndims(ret) > 1
84+
@warn "Smooth BSpline only available for Vectors, fallback to non-smoothed"
85+
true
86+
elseif λ < 0
87+
@warn "Smooth BSpline require a non-negative λ, fallback to non-smoothed: $(λ)"
88+
true
89+
elseif λ == 0
90+
true
91+
else
92+
false
93+
end
94+
95+
if fallback
96+
return prefilter!(TWeights, ret, it)
97+
end
98+
99+
M, b = prefiltering_system(TWeights, eltype(TCoefs), sz[1], degree(it))
100+
## Solve with regularization
101+
# Convert to dense Matrix because `*(Woodbury{Adjoint}, Woodbury)` is not defined
102+
Mt = Matrix(M)'
103+
Q = diffop(TWeights, sz[1], k)
104+
K = cholesky(Mt * Matrix(M) + λ * Q)
105+
B = Mt * popwrapper(ret)
106+
ldiv!(popwrapper(ret), K, B)
107+
108+
ret
109+
end
110+
56111
function prefilter!(
57112
::Type{TWeights}, ret::TCoefs, its::Tuple{Vararg{Union{BSpline,NoInterp}}}
58113
) where {TWeights,TCoefs<:AbstractArray}
59-
local buf, shape, retrs
60114
sz = size(ret)
61-
first = true
62115
for dim in 1:ndims(ret)
63116
it = iextract(its, dim)
64117
if it != NoInterp

test/b-splines/regularization.jl

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
@testset "Regularization" begin
2+
for (constructor, copier) in ((interpolate, identity), (interpolate!, copy))
3+
isinplace = constructor == interpolate!
4+
f0(x) = sin((x-3)*2pi/9 - 1)
5+
f1(x) = 1.0 + 0.1*x + 0.01*x^2 + 0.001*x^3
6+
7+
xmax = 10
8+
xi = range(0, stop=xmax, length=11)
9+
xfine = range(0, stop=xmax, length=101)
10+
11+
A0 = f0.(xi)
12+
A1 = f1.(xi)
13+
14+
for (yi, f) in ((A0, f0), (A1, f1)), λ in (0.01, 100), k in (1, 2, 3, 4)
15+
it = @inferred(interpolate(yi, BSpline(Cubic(Line(OnCell()))), λ, k))
16+
itt = @inferred(scale(it, xi))
17+
18+
if f === f0
19+
if λ == 0.01
20+
# test that interpolation of the interpolating points is close to expected data
21+
for x in [2, 3, 6, 7]
22+
@test f(x) itt(x) atol=0.01
23+
end
24+
elseif λ == 100
25+
# test that interpolation of the interpolating points is far from expected data
26+
for x in [2, 3, 6, 7]
27+
@test !(isapprox(f(x), itt(x); atol=0.1))
28+
end
29+
end
30+
31+
elseif f === f1
32+
if k==4
33+
for x in xfine
34+
# test that interpolation is close to expected data when
35+
# the smoothing order is higher than the polynomial order
36+
@test f(x) itt(x) atol=0.1
37+
end
38+
end
39+
end
40+
end
41+
42+
# Check fallback to no regularization for λ=0
43+
itp1 = @inferred(constructor(copier(A0), BSpline(Cubic(Line(OnGrid()))), 0, 1))
44+
itp1p = @inferred(constructor(copier(A0), BSpline(Cubic(Line(OnGrid())))))
45+
@test itp1 == itp1p
46+
47+
48+
f2(x, y) = sin(x/10)*cos(y/6)
49+
xmax2, ymax2 = 30, 10
50+
A2 = Float64[f2(x, y) for x in 1:xmax2, y in 1:ymax2]
51+
52+
# Check fallback to no regularization for non Vectors
53+
itp2 = @inferred(constructor(copier(A2), BSpline(Cubic(Line(OnGrid()))), 1, 1))
54+
itp2p = @inferred(constructor(copier(A2), BSpline(Cubic(Line(OnGrid())))))
55+
@test itp2 == itp2p
56+
end
57+
end

test/b-splines/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
include("mixed.jl")
77
include("multivalued.jl")
88
include("non1.jl")
9+
include("regularization.jl")
910

1011
@test eltype(@inferred(interpolate(rand(Float16, 3, 3), BSpline(Linear())))) == Float16 # issue #308
1112
end

0 commit comments

Comments
 (0)