Skip to content

Add regularization for BSpline #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 19, 2020
42 changes: 42 additions & 0 deletions src/b-splines/b-splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ function interpolate(::Type{TWeights}, ::Type{TC}, A, it::IT) where {TWeights,TC
BSplineInterpolation(TWeights, Apad, it, axes(A))
end

function interpolate(::Type{TWeights}, ::Type{TC}, A, it::IT, λ::Real, k::Int) where {TWeights,TC,IT<:DimSpec{BSpline}}
Apad = prefilter(TWeights, TC, A, it, λ, k)
BSplineInterpolation(TWeights, Apad, it, axes(A))
end

"""
itp = interpolate(A, interpmode)

Expand All @@ -179,6 +184,33 @@ function interpolate(A::AbstractArray, it::IT) where {IT<:DimSpec{BSpline}}
interpolate(tweight(A), tcoef(A), A, it)
end

"""
itp = interpolate(A, interpmode, gridstyle, λ, k)

Interpolate an array `A` in the mode determined by `interpmode` and `gridstyle`
with regularization following [1], of order `k` and constant `λ`.
`interpmode` may be one of

- `BSpline(NoInterp())`
- `BSpline(Linear())`
- `BSpline(Quadratic(BC()))` (see [`BoundaryCondition`](@ref))
- `BSpline(Cubic(BC()))`

It may also be a tuple of such values, if you want to use different interpolation schemes along each axis.

`gridstyle` should be one of `OnGrid()` or `OnCell()`.

`k` corresponds to the derivative to penalize. In the limit λ->∞, the interpolation function is
a polynomial of order `k-1`. A value of 2 is the most common.

`λ` is non-negative. If its value is zero, it falls back to non-regularized interpolation.

[1] https://projecteuclid.org/euclid.ss/1038425655.
"""
function interpolate(A::AbstractArray, it::IT, λ::Real, k::Int) where {IT<:DimSpec{BSpline}}
interpolate(tweight(A), tcoef(A), A, it, λ, k)
end

# We can't just return a tuple-of-types due to julia #12500
tweight(A::AbstractArray) = Float64
tweight(A::AbstractArray{T}) where T<:AbstractFloat = T
Expand All @@ -200,6 +232,16 @@ function interpolate!(A::AbstractArray, it::IT) where {IT<:DimSpec{BSpline}}
interpolate!(tweight(A), A, it)
end

function interpolate!(::Type{TWeights}, A::AbstractArray, it::IT, λ::Real, k::Int) where {TWeights,IT<:DimSpec{BSpline}}
# Set the bounds of the interpolant inward, if necessary
axsA = axes(A)
axspad = padded_axes(axsA, it)
BSplineInterpolation(TWeights, prefilter!(TWeights, A, it, λ, k), it, fix_axis.(padinset.(axsA, axspad)))
end
function interpolate!(A::AbstractArray, it::IT, λ::Real, k::Int) where {IT<:DimSpec{BSpline}}
interpolate!(tweight(A), A, it, λ, k)
end

lut!(dl, d, du) = lu!(Tridiagonal(dl, d, du), Val(false))

include("constant.jl")
Expand Down
59 changes: 56 additions & 3 deletions src/b-splines/prefiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ function prefilter(
prefilter!(TWeights, ret, it)
end

function prefilter(
::Type{TWeights}, ::Type{TC}, A::AbstractArray,
it::Union{BSpline,Tuple{Vararg{Union{BSpline,NoInterp}}}},
λ::Real, k::Int
) where {TWeights,TC}
ret = copy_with_padding(TC, A, it)
prefilter!(TWeights, ret, it, λ, k)
end

function prefilter!(
::Type{TWeights}, ret::TCoefs, it::BSpline
) where {TWeights,TCoefs<:AbstractArray}
local buf, shape, retrs
sz = size(ret)
first = true
for dim in 1:ndims(ret)
Expand All @@ -53,12 +61,57 @@ function prefilter!(
ret
end

function diffop(::Type{TWeights}, n::Int, k::Int) where {TWeights}
D = spdiagm(0 => ones(TWeights,n))
for i in 1:k
D = diff(D; dims=1)
end
## TODO: Normalize by n?
D' * D
end
diffop(n::Int, k::Int) = diffop(Float64, n, k)
### TODO: add compiled constructor for most common operators of order k=1,2


function prefilter!(
::Type{TWeights}, ret::TCoefs, it::BSpline, λ::Real, k::Int
) where {TWeights,TCoefs<:AbstractArray}

sz = size(ret)

# Test if regularizing is allowed
fallback = if ndims(ret) > 1
@warn "Smooth BSpline only available for Vectors, fallback to non-smoothed"
true
elseif λ < 0
@warn "Smooth BSpline require a non-negative λ, fallback to non-smoothed: $(λ)"
true
elseif λ == 0
true
else
false
end

if fallback
return prefilter!(TWeights, ret, it)
end

M, b = prefiltering_system(TWeights, eltype(TCoefs), sz[1], degree(it))
## Solve with regularization
# Convert to dense Matrix because `*(Woodbury{Adjoint}, Woodbury)` is not defined
Mt = Matrix(M)'
Q = diffop(TWeights, sz[1], k)
K = cholesky(Mt * Matrix(M) + λ * Q)
B = Mt * popwrapper(ret)
ldiv!(popwrapper(ret), K, B)

ret
end

function prefilter!(
::Type{TWeights}, ret::TCoefs, its::Tuple{Vararg{Union{BSpline,NoInterp}}}
) where {TWeights,TCoefs<:AbstractArray}
local buf, shape, retrs
sz = size(ret)
first = true
for dim in 1:ndims(ret)
it = iextract(its, dim)
if it != NoInterp
Expand Down
57 changes: 57 additions & 0 deletions test/b-splines/regularization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
@testset "Regularization" begin
for (constructor, copier) in ((interpolate, identity), (interpolate!, copy))
isinplace = constructor == interpolate!
f0(x) = sin((x-3)*2pi/9 - 1)
f1(x) = 1.0 + 0.1*x + 0.01*x^2 + 0.001*x^3

xmax = 10
xi = range(0, stop=xmax, length=11)
xfine = range(0, stop=xmax, length=101)

A0 = f0.(xi)
A1 = f1.(xi)

for (yi, f) in ((A0, f0), (A1, f1)), λ in (0.01, 100), k in (1, 2, 3, 4)
it = @inferred(interpolate(yi, BSpline(Cubic(Line(OnCell()))), λ, k))
itt = @inferred(scale(it, xi))

if f === f0
if λ == 0.01
# test that interpolation of the interpolating points is close to expected data
for x in [2, 3, 6, 7]
@test f(x) ≈ itt(x) atol=0.01
end
elseif λ == 100
# test that interpolation of the interpolating points is far from expected data
for x in [2, 3, 6, 7]
@test !(isapprox(f(x), itt(x); atol=0.1))
end
end

elseif f === f1
if k==4
for x in xfine
# test that interpolation is close to expected data when
# the smoothing order is higher than the polynomial order
@test f(x) ≈ itt(x) atol=0.1
end
end
end
end

# Check fallback to no regularization for λ=0
itp1 = @inferred(constructor(copier(A0), BSpline(Cubic(Line(OnGrid()))), 0, 1))
itp1p = @inferred(constructor(copier(A0), BSpline(Cubic(Line(OnGrid())))))
@test itp1 == itp1p


f2(x, y) = sin(x/10)*cos(y/6)
xmax2, ymax2 = 30, 10
A2 = Float64[f2(x, y) for x in 1:xmax2, y in 1:ymax2]

# Check fallback to no regularization for non Vectors
itp2 = @inferred(constructor(copier(A2), BSpline(Cubic(Line(OnGrid()))), 1, 1))
itp2p = @inferred(constructor(copier(A2), BSpline(Cubic(Line(OnGrid())))))
@test itp2 == itp2p
end
end
1 change: 1 addition & 0 deletions test/b-splines/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
include("mixed.jl")
include("multivalued.jl")
include("non1.jl")
include("regularization.jl")

@test eltype(@inferred(interpolate(rand(Float16, 3, 3), BSpline(Linear())))) == Float16 # issue #308
end