Skip to content

Implemented rfft!, irfft! and brfft! #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ end

include("fft.jl")
include("dct.jl")
include("rfft!.jl")

include("precompile.jl")
_precompile_()
Expand Down
266 changes: 266 additions & 0 deletions src/rfft!.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import Base: IndexStyle, getindex, setindex!, eltype, \, similar, copy, real, read!

export PaddedRFFTArray, plan_rfft!, rfft!, plan_irfft!, plan_brfft!, brfft!, irfft!


# This struct reinterprets the `data` array to a Complex or Float array, depending on `eltype(data)`
# It is used internally with the PaddedRFFTArray in place of `Base.ReinterpretArray`
# ReinterpretArray has some performance issues when reinterprreting a Complex array to Real
struct ComplexOrRealReinterpretArray{T<:fftwNumber,N,A<:DenseArray{<:fftwNumber,N},B<:Ptr} <: DenseArray{T,N}
data::A # Either a real or complex array
_unsafe_pointer::B # Pointer to the `data` array, but converted to a different type representation.

function ComplexOrRealReinterpretArray(rarray::DenseArray{T,N}) where {T<:fftwReal,N}
ptr = unsafe_convert(Ptr{Complex{T}}, pointer(rarray))
return new{Complex{T},N,typeof(rarray),typeof(ptr)}(rarray,ptr)
end

function ComplexOrRealReinterpretArray(carray::DenseArray{T,N}) where {T<:fftwComplex,N}
FT = T === ComplexF64 ? Float64 : Float32
ptr = unsafe_convert(Ptr{FT}, pointer(carray))
return new{FT,N,typeof(carray),typeof(ptr)}(carray,ptr)
end
end

const RealReinterpretArray{N} = ComplexOrRealReinterpretArray{<:fftwReal,N,<:DenseArray{<:fftwComplex,N}}
const ComplexReinterpretArray{N} = ComplexOrRealReinterpretArray{<:fftwComplex,N,<:DenseArray{<:fftwReal,N}}

@inline size_convertion(::RealReinterpretArray,i::Integer) = 2i
@inline size_convertion(::ComplexReinterpretArray,i::Integer) = i÷2

IndexStyle(::Type{T}) where {T<:ComplexOrRealReinterpretArray} = IndexLinear()

Base.size(a::ComplexOrRealReinterpretArray) =
ntuple(i->(i == 1 ? size_convertion(a,size(a.data)[i]) : size(a.data)[i]),Val(ndims(a.data)))

@inline function getindex(a::ComplexOrRealReinterpretArray,i::Integer)
data = a.data
@boundscheck checkbounds(a,i)
GC.@preserve data r = unsafe_load(a._unsafe_pointer, i)
return r
end

@inline function setindex!(a::ComplexOrRealReinterpretArray,v,i::Integer)
data = a.data
@boundscheck checkbounds(a,i)
GC.@preserve data unsafe_store!(a._unsafe_pointer,v, i)
return a
end

Base.unsafe_convert(p::Type{Ptr{T}}, a::ComplexOrRealReinterpretArray{T,N}) where {T,N} = Base.unsafe_convert(p,a.data)

Base.elsize(::Type{<:ComplexOrRealReinterpretArray{T,N}}) where {T,N} = sizeof(T)

complex_or_real_reinterpret(a::AbstractArray) = ComplexOrRealReinterpretArray(a)
complex_or_real_reinterpret(a::ComplexOrRealReinterpretArray) = a.data

# At the time this code was written the new `ReinterpretArray` in Base had some performace issues.
# Those issues were bypassed with the usage of our simplified version of ReinterpretArray above.
# Hopefully, once the performance issues with ReinterpretArray
# are solved we can just use Base.ReinterpretArray directly.

struct PaddedRFFTArray{T<:fftwReal,N,R,C,L,Nm1} <: DenseArray{Complex{T},N}
data::R
r::SubArray{T,N,R,Tuple{Base.OneTo{Int},Vararg{Base.Slice{Base.OneTo{Int}},Nm1}},L} # Real view skipping padding
c::C

function PaddedRFFTArray{T,N}(rr::DenseArray{T,N},nx::Int) where {T<:fftwReal,N}
fsize = size(rr)[1]
iseven(fsize) || throw(
ArgumentError("First dimension of allocated array must have even number of elements"))
(nx == fsize-2 || nx == fsize-1) || throw(
ArgumentError("Number of elements on the first dimension of array must be either 1 or 2 less than the number of elements on the first dimension of the allocated array"))
c = complex_or_real_reinterpret(rr)
r = view(rr, Base.OneTo(nx), ntuple(i->Colon(),Val(N-1))...)
return new{T, N, typeof(rr), typeof(c), N===1, N-1}(rr,r,c)
end # function

function PaddedRFFTArray{T,N}(c::DenseArray{Complex{T},N},nx::Int) where {T<:fftwReal,N}
rr = complex_or_real_reinterpret(c)
fsize = size(rr)[1]
(nx == fsize-2 || nx == fsize-1) || throw(
ArgumentError("Number of elements on the first dimension of array must be either 1 or 2 less than the number of elements on the first dimension of the allocated array"))
r = view(rr, Base.OneTo(nx), ntuple(i->Colon(),Val(N-1))...)
return new{T, N, typeof(rr), typeof(c), N===1, N-1}(rr,r,c)
end # function

end # struct

PaddedRFFTArray(a::DenseArray{<:Union{T,Complex{T}},N},nx::Int) where {T<:fftwReal,N} =
PaddedRFFTArray{T,N}(a,nx)

function PaddedRFFTArray{T}(ndims::Vararg{Integer,N}) where {T,N}
fsize = (ndims[1]÷2 + 1)*2
a = zeros(T,(fsize, ndims[2:end]...))
PaddedRFFTArray{T,N}(a, ndims[1])
end

PaddedRFFTArray{T}(ndims::NTuple{N,Integer}) where {T,N} =
PaddedRFFTArray{T}(ndims...)

PaddedRFFTArray(ndims::Vararg{Integer,N}) where N =
PaddedRFFTArray{Float64}(ndims...)

PaddedRFFTArray(ndims::NTuple{N,Integer}) where N =
PaddedRFFTArray{Float64}(ndims...)

function PaddedRFFTArray(a::AbstractArray{T,N}) where {T<:fftwReal,N}
t = PaddedRFFTArray{T}(size(a))
@inbounds copyto!(t.r, a)
return t
end

copy(S::PaddedRFFTArray) = PaddedRFFTArray(copy(S.data),size(S.r,1))

similar(f::PaddedRFFTArray,::Type{T},dims::Tuple{Vararg{Int,N}}) where {T, N} =
PaddedRFFTArray{T}(dims)

similar(f::PaddedRFFTArray{T,N,L},dims::NTuple{N2,Int}) where {T,N,L,N2} =
PaddedRFFTArray{T}(dims)

similar(f::PaddedRFFTArray,::Type{T}) where {T} =
PaddedRFFTArray{T}(size(f.r))

similar(f::PaddedRFFTArray{T,N}) where {T,N} =
PaddedRFFTArray{T,N}(similar(f.data), size(f.r,1))

size(S::PaddedRFFTArray) =
size(S.c)

IndexStyle(::Type{T}) where {T<:PaddedRFFTArray} =
IndexLinear()

Base.@propagate_inbounds getindex(A::PaddedRFFTArray,i::Integer) =
getindex(A.c,i)

Base.@propagate_inbounds setindex!(A::PaddedRFFTArray,x, i::Integer) =
setindex!(A.c,x,i)

Base.unsafe_convert(p::Type{Ptr{Complex{T}}}, a::PaddedRFFTArray{T,N}) where {T,N} = Base.unsafe_convert(p,a.c)

Base.elsize(::Type{<:PaddedRFFTArray{T,N}}) where {T,N} = sizeof(Complex{T})


function PaddedRFFTArray(stream::IO, dims)
field = PaddedRFFTArray(dims)
return read!(stream,field)
end

function PaddedRFFTArray{T}(stream::IO, dims) where T
field = PaddedRFFTArray{T}(dims)
return read!(stream,field)
end

function read!(file::AbstractString, field::PaddedRFFTArray)
open(file) do io
return read!(io,field)
end
end

# Read a binary file of an unpaded array directly to a PaddedRFFT array, without the need
# of the creation of a intermediary Array. If the data is already padded then the user
# should just use PaddedRFFTArray{T}(read("file",unpaddeddim),nx)
function read!(stream::IO, field::PaddedRFFTArray{T,N,L}) where {T,N,L}
rr = field.data
dims = size(field.r)
nx = dims[1]
nb = sizeof(T)*nx
npencils = prod(dims)÷nx
npad = iseven(nx) ? 2 : 1
for i=0:(npencils-1)
unsafe_read(stream,Ref(rr,Int((nx+npad)*i+1)),nb)
end
return field
end


###########################################################################################
# Foward plans

function plan_rfft!(X::PaddedRFFTArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N}

(1 in region) || throw(ArgumentError("The first dimension must always be transformed"))
return rFFTWPlan{T,FORWARD,true,N}(X.r, X.c, region, flags, timelimit)
end

plan_rfft!(f::PaddedRFFTArray;kws...) = plan_rfft!(f, 1:ndims(f); kws...)

*(p::rFFTWPlan{T,FORWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N} =
(mul!(f.c, p, f.r); f)

rfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_rfft!(f, region) * f

function rfft!(r::SubArray{<:fftwReal}, region=1:ndims(r))
f = PaddedRFFTArray(parent(r),size(r,1))
plan_rfft!(f, region) * f
end

function rfft!(r::DenseArray{<:fftwReal}, region=1:ndims(r))
f = PaddedRFFTArray(r)
plan_rfft!(f, region) * f
end

function \(p::rFFTWPlan{T,FORWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N}
isdefined(p,:pinv) || (p.pinv = plan_irfft!(f,p.region))
return p.pinv * f
end


##########################################################################################
# Inverse plans

function plan_brfft!(X::PaddedRFFTArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N}
(1 in region) || throw(ArgumentError("The first dimension must always be transformed"))
return rFFTWPlan{Complex{T},BACKWARD,true,N}(X.c, X.r, region, flags,timelimit)
end

plan_brfft!(f::PaddedRFFTArray;kws...) = plan_brfft!(f,1:ndims(f);kws...)

*(p::rFFTWPlan{Complex{T},BACKWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N} =
(mul!(f.r, p, f.c); f.r)

brfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_brfft!(f, region) * f

function brfft!(f::PaddedRFFTArray, i::Integer)
if i == size(f.r,1) # Assume `i` is the same as `d` in the brfft!(c::DenseArray{<:fftComplex}, d::Integer, region) defined below
return brfft!(f,1:ndims(f))
else # Assume `i` is specifying the region. `plan_brfft!` will throw an error if i != 1
return brfft!(f,(i,))
end
end

function brfft!(c::DenseArray{<:fftwComplex}, d::Integer, region=1:ndims(c))
f = PaddedRFFTArray(c,d)
plan_brfft!(f, region) * f
end

function plan_irfft!(x::PaddedRFFTArray{T,N}, region; kws...) where {T,N}
ScaledPlan(plan_brfft!(x, region; kws...),normalization(T, size(x.r), region))
end

plan_irfft!(f::PaddedRFFTArray;kws...) = plan_irfft!(f,1:ndims(f);kws...)

*(p::ScaledPlan{Complex{T},rFFTWPlan{Complex{T},BACKWARD,true,N}},f::PaddedRFFTArray{T,N}) where {T,N} = begin
p.p * f
rmul!(f.data, p.scale)
f.r
end

irfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_irfft!(f,region) * f

function irfft!(f::PaddedRFFTArray, i::Integer)
if i == size(f.r,1) # Assume `i` is the same as `d` in the irfft!(c::DenseArray{<:fftComplex}, d::Integer, region) defined below
return irfft!(f,1:ndims(f))
else # Assume `i` is specifying the region. `plan_irfft!` will throw an error if i != 1
return irfft!(f,(i,))
end
end

function irfft!(c::DenseArray{<:fftwComplex}, d::Integer, region=1:ndims(c))
f = PaddedRFFTArray(c,d)
plan_irfft!(f, region) * f
end
74 changes: 74 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,77 @@ end
@test occursin("dft-thr", string(p2))
end
end

let a = rand(Float64,(8,4,4)), b = PaddedRFFTArray(a), c = copy(b)

@testset "PaddedRFFTArray creation" begin
@test a == b.r
@test c == b
@test c.r == b.r
@test typeof(similar(b)) === typeof(b)
@test size(similar(b,Float32)) === size(b)
@test size(similar(b,Float32).r) === size(b.r)
@test size(similar(b,(4,4,4)).r) === (4,4,4)
@test size(similar(b,Float32,(4,4,4)).r) === (4,4,4)
end

@testset "rfft! and irfft!" begin
@test rfft(a) ≈ rfft!(b)
@test a ≈ irfft!(b)
@test rfft(a,1:2) ≈ rfft!(b,1:2)
@test a ≈ irfft!(b,1:2)
@test rfft(a,(1,3)) ≈ rfft!(b,(1,3))
@test a ≈ irfft!(b,(1,3))

p = plan_rfft!(c)
@test p*c ≈ rfft!(b)
@test p\c ≈ irfft!(b)

aa = rand(Float64,(9,4,4))
bb = PaddedRFFTArray(aa)
@test aa == bb.r
@test rfft(aa) ≈ rfft!(bb)
@test aa ≈ irfft!(bb)
@test rfft(aa,1:2) ≈ rfft!(bb,1:2)
@test aa ≈ irfft!(bb,1:2)
@test rfft(aa,(1,3)) ≈ rfft!(bb,(1,3))
@test aa ≈ irfft!(bb,(1,3))
end

@testset "Read binary file to PaddedRFFTArray" begin
for s in ((8,4,4),(9,4,4),(8,),(9,))
aa = rand(Float64,s)
f = IOBuffer()
write(f,aa)
@test aa == (PaddedRFFTArray(seekstart(f),s)).r
aa = rand(Float32,s)
f = IOBuffer()
write(f,aa)
@test aa == PaddedRFFTArray{Float32}(seekstart(f),s).r
end
end

@testset "brfft!" begin
a = rand(Float64,(4,4))
b = PaddedRFFTArray(a)
rfft!(b)
@test (brfft!(b) ./ 16) ≈ a
end

@testset "FFTW MEASURE flag" begin
c = similar(b)
p = plan_rfft!(c,flags=FFTW.MEASURE)
p.pinv = plan_irfft!(c,flags=FFTW.MEASURE)
c .= b
@test c == b
@test p*c ≈ rfft!(b)
@test p\c ≈ irfft!(b)
end

@testset "irfft! and brfft! of complex Array and rfft! of SubArray" begin
r = rand(8,6)
@test brfft!(rfft!(irfft!(rfft(r),8)),8)./48 ≈ r
r2 = rand(Float32,9,3,2)
@test brfft!(rfft!(irfft!(rfft(r2,(1,3)),9,(1,3))),9)./54 ≈ r2
end
end #let block