From 8d015e751037856d2e19f41b0910f210dfdb5091 Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 10 Jan 2020 17:53:59 +0000 Subject: [PATCH 1/4] Add regularization for BSpline --- src/b-splines/b-splines.jl | 32 ++++++++++++++++++++ src/b-splines/prefiltering.jl | 56 +++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/src/b-splines/b-splines.jl b/src/b-splines/b-splines.jl index e066cce4..834a5d01 100644 --- a/src/b-splines/b-splines.jl +++ b/src/b-splines/b-splines.jl @@ -133,6 +133,11 @@ function interpolate(::Type{TWeights}, ::Type{TC}, A, it::IT) where {TWeights,TC BSplineInterpolation(TWeights, Apad, it, axes(A)) end +function interpolate(::Type{TWeights}, ::Type{TC}, A, it::IT, λ::Real, k::Int) where {TWeights,TC,IT<:DimSpec{BSpline}} + Apad = prefilter(TWeights, TC, A, it, λ, k) + BSplineInterpolation(TWeights, Apad, it, axes(A)) +end + """ itp = interpolate(A, interpmode, gridstyle) @@ -152,6 +157,33 @@ function interpolate(A::AbstractArray, it::IT) where {IT<:DimSpec{BSpline}} interpolate(tweight(A), tcoef(A), A, it) end +""" + itp = interpolate(A, interpmode, gridstyle, λ, k) + +Interpolate an array `A` in the mode determined by `interpmode` and `gridstyle` +with regularization following [1], of order `k` and constant `λ`. +`interpmode` may be one of + +- `BSpline(NoInterp())` +- `BSpline(Linear())` +- `BSpline(Quadratic(BC()))` (see [`BoundaryCondition`](@ref)) +- `BSpline(Cubic(BC()))` + +It may also be a tuple of such values, if you want to use different interpolation schemes along each axis. + +`gridstyle` should be one of `OnGrid()` or `OnCell()`. + +`k` corresponds to the derivative to penalize. It also corresponds to the order of the polynomial +least squares estimate that arises from the limit λ->∞. A value of 2 is the most common. + +`λ` is non-negative. If its value is zero, it falls back to non-regularized interpolation. + +[1] https://projecteuclid.org/euclid.ss/1038425655. +""" +function interpolate(A::AbstractArray, it::IT, λ::Real, k::Int) where {IT<:DimSpec{BSpline}} + interpolate(tweight(A), tcoef(A), A, it, λ, k) +end + # We can't just return a tuple-of-types due to julia #12500 tweight(A::AbstractArray) = Float64 tweight(A::AbstractArray{T}) where T<:AbstractFloat = T diff --git a/src/b-splines/prefiltering.jl b/src/b-splines/prefiltering.jl index cd6ad9e0..cecbc8d1 100644 --- a/src/b-splines/prefiltering.jl +++ b/src/b-splines/prefiltering.jl @@ -40,6 +40,15 @@ function prefilter( prefilter!(TWeights, ret, it) end +function prefilter( + ::Type{TWeights}, ::Type{TC}, A::AbstractArray, + it::Union{BSpline,Tuple{Vararg{Union{BSpline,NoInterp}}}}, + λ::Real, k::Int + ) where {TWeights,TC} + ret = copy_with_padding(TC, A, it) + prefilter!(TWeights, ret, it, λ, k) +end + function prefilter!( ::Type{TWeights}, ret::TCoefs, it::BSpline ) where {TWeights,TCoefs<:AbstractArray} @@ -53,6 +62,53 @@ function prefilter!( ret end +function diffop(::Type{TWeights}, n::Int, k::Int) where {TWeights} + D = spdiagm(0 => ones(TWeights,n)) + for i in 1:k + D = diff(D; dims=1) + end + ## TODO: Normalize by n? + D' * D +end +diffop(n::Int, k::Int) = diffop(Float64, n, k) +### TODO: add compiled constructor for most common operators of order k=1,2 +# +#function diffop(::Type{TWeights}, n::Int, k::Int) where {TWeights} +# D1 = spdiagm(0 => -ones(TWeights,n-1), 1 => ones(TWeights,n-1))[1:end-1, :] +# D = D1 +# for i in 1:k-1 +# D = diff(D; dims=1) +## D = D1[1+i:end, 1+i:end] * D +## lmul!(D1[1+i:end, 1+i:end], D) +# end +# +# D' * D +#end + +function prefilter!( + ::Type{TWeights}, ret::TCoefs, it::BSpline, λ::Real, k::Int + ) where {TWeights,TCoefs<:AbstractArray} + local buf, shape, retrs + sz = size(ret) + first = true + if ndims(ret) > 1 + @warn "Smooth BSpline only available for Vectors, fallback to non-smooth" + prefilter!(TWeights, ret, it) + end + if λ <= 0 + prefilter!(TWeights, ret, it) + end + M, b = prefiltering_system(TWeights, eltype(TCoefs), sz[1], degree(it)) + ### TEST REGULARIZATION + n = sz[1] + Q = Matrix(diffop(TWeights, n, k)) + K = cholesky(Matrix(M)' * Matrix(M) + λ * Q) + B = Matrix(M)' * popwrapper(ret) + ldiv!(popwrapper(ret), K, B) + + ret +end + function prefilter!( ::Type{TWeights}, ret::TCoefs, its::Tuple{Vararg{Union{BSpline,NoInterp}}} ) where {TWeights,TCoefs<:AbstractArray} From 4227f901a87a94e00848ba4d2f38365d6315621d Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 6 Mar 2020 16:37:50 +0000 Subject: [PATCH 2/4] add tests --- src/b-splines/b-splines.jl | 14 ++++++-- src/b-splines/prefiltering.jl | 25 +++++++++----- test/b-splines/regularization.jl | 57 ++++++++++++++++++++++++++++++++ test/b-splines/runtests.jl | 1 + 4 files changed, 87 insertions(+), 10 deletions(-) create mode 100644 test/b-splines/regularization.jl diff --git a/src/b-splines/b-splines.jl b/src/b-splines/b-splines.jl index 834a5d01..21211013 100644 --- a/src/b-splines/b-splines.jl +++ b/src/b-splines/b-splines.jl @@ -173,8 +173,8 @@ It may also be a tuple of such values, if you want to use different interpolatio `gridstyle` should be one of `OnGrid()` or `OnCell()`. -`k` corresponds to the derivative to penalize. It also corresponds to the order of the polynomial -least squares estimate that arises from the limit λ->∞. A value of 2 is the most common. +`k` corresponds to the derivative to penalize. In the limit λ->∞, the interpolation function is +a polynomial of order `k-1`. A value of 2 is the most common. `λ` is non-negative. If its value is zero, it falls back to non-regularized interpolation. @@ -205,6 +205,16 @@ function interpolate!(A::AbstractArray, it::IT) where {IT<:DimSpec{BSpline}} interpolate!(tweight(A), A, it) end +function interpolate!(::Type{TWeights}, A::AbstractArray, it::IT, λ::Real, k::Int) where {TWeights,IT<:DimSpec{BSpline}} + # Set the bounds of the interpolant inward, if necessary + axsA = axes(A) + axspad = padded_axes(axsA, it) + BSplineInterpolation(TWeights, prefilter!(TWeights, A, it, λ, k), it, fix_axis.(padinset.(axsA, axspad))) +end +function interpolate!(A::AbstractArray, it::IT, λ::Real, k::Int) where {IT<:DimSpec{BSpline}} + interpolate!(tweight(A), A, it, λ, k) +end + lut!(dl, d, du) = lu!(Tridiagonal(dl, d, du), Val(false)) include("constant.jl") diff --git a/src/b-splines/prefiltering.jl b/src/b-splines/prefiltering.jl index cecbc8d1..77262a41 100644 --- a/src/b-splines/prefiltering.jl +++ b/src/b-splines/prefiltering.jl @@ -88,16 +88,26 @@ diffop(n::Int, k::Int) = diffop(Float64, n, k) function prefilter!( ::Type{TWeights}, ret::TCoefs, it::BSpline, λ::Real, k::Int ) where {TWeights,TCoefs<:AbstractArray} - local buf, shape, retrs + sz = size(ret) - first = true - if ndims(ret) > 1 - @warn "Smooth BSpline only available for Vectors, fallback to non-smooth" - prefilter!(TWeights, ret, it) + + # Test if regularizing is allowed + fallback = if ndims(ret) > 1 + @warn "Smooth BSpline only available for Vectors, fallback to non-smoothed" + true + elseif λ < 0 + @warn "Smooth BSpline require a non-negative λ, fallback to non-smoothed: $(λ)" + true + elseif λ == 0 + true + else + false end - if λ <= 0 - prefilter!(TWeights, ret, it) + + if fallback + return prefilter!(TWeights, ret, it) end + M, b = prefiltering_system(TWeights, eltype(TCoefs), sz[1], degree(it)) ### TEST REGULARIZATION n = sz[1] @@ -114,7 +124,6 @@ function prefilter!( ) where {TWeights,TCoefs<:AbstractArray} local buf, shape, retrs sz = size(ret) - first = true for dim in 1:ndims(ret) it = iextract(its, dim) if it != NoInterp diff --git a/test/b-splines/regularization.jl b/test/b-splines/regularization.jl new file mode 100644 index 00000000..560c5896 --- /dev/null +++ b/test/b-splines/regularization.jl @@ -0,0 +1,57 @@ +@testset "Regularization" begin + for (constructor, copier) in ((interpolate, identity), (interpolate!, copy)) + isinplace = constructor == interpolate! + f0(x) = sin((x-3)*2pi/9 - 1) + f1(x) = 1.0 + 0.1*x + 0.01*x^2 + 0.001*x^3 + + xmax = 10 + xi = range(0, xmax, length=11) + xfine = range(0, xmax, length=101) + + A0 = f0.(xi) + A1 = f1.(xi) + + for (yi, f) in ((A0, f0), (A1, f1)), λ in (0.01, 100), k in (1, 2, 3, 4) + it = @inferred(interpolate(yi, BSpline(Cubic(Line(OnCell()))), λ, k)) + itt = @inferred(scale(it, xi)) + + if f === f0 + if λ == 0.01 + # test that interpolation of the interpolating points is close to expected data + for x in [2, 3, 6, 7] + @test f(x) ≈ itt(x) atol=0.01 + end + elseif λ == 100 + # test that interpolation of the interpolating points is far from expected data + for x in [2, 3, 6, 7] + @test !(isapprox(f(x), itt(x); atol=0.1)) + end + end + + elseif f === f1 + if k==4 + for x in xfine + # test that interpolation is close to expected data when + # the smoothing order is higher than the polynomial order + @test f(x) ≈ itt(x) atol=0.1 + end + end + end + end + + # Check fallback to no regularization for λ=0 + itp1 = @inferred(constructor(copier(A0), BSpline(Cubic(Line(OnGrid()))), 0, 1)) + itp1p = @inferred(constructor(copier(A0), BSpline(Cubic(Line(OnGrid()))))) + @test itp1 == itp1p + + + f2(x, y) = sin(x/10)*cos(y/6) + xmax2, ymax2 = 30, 10 + A2 = Float64[f2(x, y) for x in 1:xmax2, y in 1:ymax2] + + # Check fallback to no regularization for non Vectors + itp2 = @inferred(constructor(copier(A2), BSpline(Cubic(Line(OnGrid()))), 1, 1)) + itp2p = @inferred(constructor(copier(A2), BSpline(Cubic(Line(OnGrid()))))) + @test itp2 == itp2p + end +end diff --git a/test/b-splines/runtests.jl b/test/b-splines/runtests.jl index 13316d7a..f7c8a3ff 100644 --- a/test/b-splines/runtests.jl +++ b/test/b-splines/runtests.jl @@ -6,6 +6,7 @@ include("mixed.jl") include("multivalued.jl") include("non1.jl") + include("regularization.jl") @test eltype(@inferred(interpolate(rand(Float16, 3, 3), BSpline(Linear())))) == Float16 # issue #308 end From 24d44683a20ed3432cba9baf300d55e1adda1a4f Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 23 Mar 2020 22:18:39 +0000 Subject: [PATCH 3/4] Fix, compat julia 1.0 --- test/b-splines/regularization.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/b-splines/regularization.jl b/test/b-splines/regularization.jl index 560c5896..d8fe4e17 100644 --- a/test/b-splines/regularization.jl +++ b/test/b-splines/regularization.jl @@ -5,8 +5,8 @@ f1(x) = 1.0 + 0.1*x + 0.01*x^2 + 0.001*x^3 xmax = 10 - xi = range(0, xmax, length=11) - xfine = range(0, xmax, length=101) + xi = range(0, stop=xmax, length=11) + xfine = range(0, stop=xmax, length=101) A0 = f0.(xi) A1 = f1.(xi) From 37b0ad90b48e0c421f87f57f662855ae7604510b Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 19 Oct 2020 12:04:55 +0100 Subject: [PATCH 4/4] do not convert Q to dense Matrix --- src/b-splines/prefiltering.jl | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/src/b-splines/prefiltering.jl b/src/b-splines/prefiltering.jl index 77262a41..cf0c2927 100644 --- a/src/b-splines/prefiltering.jl +++ b/src/b-splines/prefiltering.jl @@ -52,7 +52,6 @@ end function prefilter!( ::Type{TWeights}, ret::TCoefs, it::BSpline ) where {TWeights,TCoefs<:AbstractArray} - local buf, shape, retrs sz = size(ret) first = true for dim in 1:ndims(ret) @@ -72,18 +71,7 @@ function diffop(::Type{TWeights}, n::Int, k::Int) where {TWeights} end diffop(n::Int, k::Int) = diffop(Float64, n, k) ### TODO: add compiled constructor for most common operators of order k=1,2 -# -#function diffop(::Type{TWeights}, n::Int, k::Int) where {TWeights} -# D1 = spdiagm(0 => -ones(TWeights,n-1), 1 => ones(TWeights,n-1))[1:end-1, :] -# D = D1 -# for i in 1:k-1 -# D = diff(D; dims=1) -## D = D1[1+i:end, 1+i:end] * D -## lmul!(D1[1+i:end, 1+i:end], D) -# end -# -# D' * D -#end + function prefilter!( ::Type{TWeights}, ret::TCoefs, it::BSpline, λ::Real, k::Int @@ -109,11 +97,12 @@ function prefilter!( end M, b = prefiltering_system(TWeights, eltype(TCoefs), sz[1], degree(it)) - ### TEST REGULARIZATION - n = sz[1] - Q = Matrix(diffop(TWeights, n, k)) - K = cholesky(Matrix(M)' * Matrix(M) + λ * Q) - B = Matrix(M)' * popwrapper(ret) + ## Solve with regularization + # Convert to dense Matrix because `*(Woodbury{Adjoint}, Woodbury)` is not defined + Mt = Matrix(M)' + Q = diffop(TWeights, sz[1], k) + K = cholesky(Mt * Matrix(M) + λ * Q) + B = Mt * popwrapper(ret) ldiv!(popwrapper(ret), K, B) ret @@ -122,7 +111,6 @@ end function prefilter!( ::Type{TWeights}, ret::TCoefs, its::Tuple{Vararg{Union{BSpline,NoInterp}}} ) where {TWeights,TCoefs<:AbstractArray} - local buf, shape, retrs sz = size(ret) for dim in 1:ndims(ret) it = iextract(its, dim)