Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f374e0a

Browse files
committed
add WaveletTransform
1 parent 3f0864f commit f374e0a

File tree

5 files changed

+247
-171
lines changed

5 files changed

+247
-171
lines changed

src/Transform/wavelet_transform.jl

+28-125
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,33 @@
1-
export
2-
SparseKernel,
3-
SparseKernel1D,
4-
SparseKernel2D,
5-
SparseKernel3D
6-
7-
8-
struct SparseKernel{N,T,S}
9-
conv_blk::T
10-
out_weight::S
11-
end
12-
13-
function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
14-
input_dim, emb_dim = ch
15-
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
16-
W_out = Dense(emb_dim, input_dim; init=init)
17-
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
18-
end
19-
20-
function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
21-
input_dim = c*k
22-
emb_dim = 128
23-
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
24-
end
25-
26-
function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
27-
input_dim = c*k^2
28-
emb_dim = α*k^2
29-
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
30-
end
31-
32-
function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
33-
input_dim = c*k^2
34-
emb_dim = α*k^2
35-
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
36-
W_out = Dense(emb_dim, input_dim; init=init)
37-
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
38-
end
39-
40-
Flux.@functor SparseKernel
41-
42-
function (l::SparseKernel)(X::AbstractArray)
43-
bch_sz, _, dims_r... = reverse(size(X))
44-
dims = reverse(dims_r)
45-
46-
X_ = l.conv_blk(X) # (dims..., emb_dims, B)
47-
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
48-
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
49-
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
50-
return collect(Y)
51-
end
52-
53-
54-
struct MWT_CZ1d{T,S,R,Q,P}
55-
k::Int
56-
L::Int
57-
A::T
58-
B::S
59-
C::R
60-
T0::Q
61-
ec_s::P
62-
ec_d::P
63-
rc_e::P
64-
rc_o::P
65-
end
66-
67-
function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
68-
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
69-
H0r = zero_out!(H0 * Φ0)
70-
G0r = zero_out!(G0 * Φ0)
71-
H1r = zero_out!(H1 * Φ1)
72-
G1r = zero_out!(G1 * Φ1)
73-
74-
dim = c*k
75-
A = SpectralConv(dim=>dim, (α,); init=init)
76-
B = SpectralConv(dim=>dim, (α,); init=init)
77-
C = SpectralConv(dim=>dim, (α,); init=init)
78-
T0 = Dense(k, k)
79-
80-
ec_s = vcat(H0', H1')
81-
ec_d = vcat(G0', G1')
82-
rc_e = vcat(H0r, G0r)
83-
rc_o = vcat(H1r, G1r)
84-
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
85-
end
86-
87-
function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
88-
N = size(X, 3)
89-
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
90-
d = NNlib.batched_mul(Xa, l.ec_d)
91-
s = NNlib.batched_mul(Xa, l.ec_s)
1+
export WaveletTransform
2+
3+
struct WaveletTransform{N, S}<:AbstractTransform
4+
ec_d
5+
ec_s
6+
modes::NTuple{N, S} # N == ndims(x)
7+
end
8+
9+
Base.ndims(::WaveletTransform{N}) where {N} = N
10+
11+
function transform(wt::WaveletTransform, 𝐱::AbstractArray)
12+
N = size(X, ndims(wt)-1)
13+
# 1d
14+
Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :))
15+
# 2d
16+
# Xa = vcat(
17+
# view(𝐱, :, :, 1:2:N, 1:2:N, :),
18+
# view(𝐱, :, :, 1:2:N, 2:2:N, :),
19+
# view(𝐱, :, :, 2:2:N, 1:2:N, :),
20+
# view(𝐱, :, :, 2:2:N, 2:2:N, :),
21+
# )
22+
d = NNlib.batched_mul(Xa, wt.ec_d)
23+
s = NNlib.batched_mul(Xa, wt.ec_s)
9224
return d, s
9325
end
9426

95-
function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
96-
bch_sz, N, dims_r... = reverse(size(X))
97-
dims = reverse(dims_r)
98-
@assert dims[1] == 2*l.k
99-
Xₑ = NNlib.batched_mul(X, l.rc_e)
100-
Xₒ = NNlib.batched_mul(X, l.rc_o)
101-
# x = torch.zeros(B, N*2, c, self.k,
102-
# device = x.device)
103-
# x[..., ::2, :, :] = x_e
104-
# x[..., 1::2, :, :] = x_o
105-
return X
27+
function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray)
28+
10629
end
10730

108-
function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
109-
bch_sz, N, dims_r... = reverse(size(X))
110-
ns = floor(log2(N))
111-
stop = ns - l.L
112-
113-
# decompose
114-
Ud = T[]
115-
Us = T[]
116-
for i in 1:stop
117-
d, X = wavelet_transform(l, X)
118-
push!(Ud, l.A(d)+l.B(d))
119-
push!(Us, l.C(d))
120-
end
121-
X = l.T0(X)
122-
123-
# reconstruct
124-
for i in stop:-1:1
125-
X += Us[i]
126-
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
127-
X = even_odd(l, X)
128-
end
129-
return X
130-
end
31+
# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray)
32+
# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch]
33+
# end

src/operator_kernel.jl

+145-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
export
2-
OperatorConv,
3-
SpectralConv,
4-
OperatorKernel
2+
OperatorConv,
3+
SpectralConv,
4+
OperatorKernel,
5+
SparseKernel,
6+
SparseKernel1D,
7+
SparseKernel2D,
8+
SparseKernel3D,
9+
MWT_CZ1d
510

611
struct OperatorConv{P, T, S, TT}
712
weight::T
@@ -180,6 +185,143 @@ function (m::OperatorKernel)(𝐱)
180185
return m.σ.(m.linear(𝐱) + m.conv(𝐱))
181186
end
182187

188+
"""
189+
SparseKernel(κ, ch, σ=identity)
190+
191+
Sparse kernel layer.
192+
193+
## Arguments
194+
195+
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
196+
* `ch`: Channel size for linear transform, e.g. `32`.
197+
* `σ`: Activation function.
198+
"""
199+
struct SparseKernel{N,T,S}
200+
conv_blk::T
201+
out_weight::S
202+
end
203+
204+
function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
205+
input_dim, emb_dim = ch
206+
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
207+
W_out = Dense(emb_dim, input_dim; init=init)
208+
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
209+
end
210+
211+
function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
212+
input_dim = c*k
213+
emb_dim = 128
214+
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
215+
end
216+
217+
function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
218+
input_dim = c*k^2
219+
emb_dim = α*k^2
220+
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
221+
end
222+
223+
function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
224+
input_dim = c*k^2
225+
emb_dim = α*k^2
226+
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
227+
W_out = Dense(emb_dim, input_dim; init=init)
228+
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
229+
end
230+
231+
Flux.@functor SparseKernel
232+
233+
function (l::SparseKernel)(X::AbstractArray)
234+
bch_sz, _, dims_r... = reverse(size(X))
235+
dims = reverse(dims_r)
236+
237+
X_ = l.conv_blk(X) # (dims..., emb_dims, B)
238+
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
239+
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
240+
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
241+
return collect(Y)
242+
end
243+
244+
245+
struct MWT_CZ1d{T,S,R,Q,P}
246+
k::Int
247+
L::Int
248+
A::T
249+
B::S
250+
C::R
251+
T0::Q
252+
ec_s::P
253+
ec_d::P
254+
rc_e::P
255+
rc_o::P
256+
end
257+
258+
function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
259+
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
260+
H0r = zero_out!(H0 * Φ0)
261+
G0r = zero_out!(G0 * Φ0)
262+
H1r = zero_out!(H1 * Φ1)
263+
G1r = zero_out!(G1 * Φ1)
264+
265+
dim = c*k
266+
A = SpectralConv(dim=>dim, (α,); init=init)
267+
B = SpectralConv(dim=>dim, (α,); init=init)
268+
C = SpectralConv(dim=>dim, (α,); init=init)
269+
T0 = Dense(k, k)
270+
271+
ec_s = vcat(H0', H1')
272+
ec_d = vcat(G0', G1')
273+
rc_e = vcat(H0r, G0r)
274+
rc_o = vcat(H1r, G1r)
275+
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
276+
end
277+
278+
function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
279+
N = size(X, 3)
280+
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
281+
d = NNlib.batched_mul(Xa, l.ec_d)
282+
s = NNlib.batched_mul(Xa, l.ec_s)
283+
return d, s
284+
end
285+
286+
function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
287+
bch_sz, N, dims_r... = reverse(size(X))
288+
dims = reverse(dims_r)
289+
@assert dims[1] == 2*l.k
290+
Y = similar(X, bch_sz, 2N, l.c, l.k)
291+
view(Y, :, :, 1:2:N, :) .= NNlib.batched_mul(X, l.rc_e)
292+
view(Y, :, :, 2:2:N, :) .= NNlib.batched_mul(X, l.rc_o)
293+
return Y
294+
end
295+
296+
function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
297+
bch_sz, N, dims_r... = reverse(size(X))
298+
ns = floor(log2(N))
299+
stop = ns - l.L
300+
301+
# decompose
302+
Ud = T[]
303+
Us = T[]
304+
for i in 1:stop
305+
d, X = wavelet_transform(l, X)
306+
push!(Ud, l.A(d)+l.B(d))
307+
push!(Us, l.C(d))
308+
end
309+
X = l.T0(X)
310+
311+
# reconstruct
312+
for i in stop:-1:1
313+
X += Us[i]
314+
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
315+
X = even_odd(l, X)
316+
end
317+
return X
318+
end
319+
320+
# function Base.show(io::IO, l::MWT_CZ1d)
321+
# print(io, "MWT_CZ($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)")
322+
# end
323+
324+
183325
#########
184326
# utils #
185327
#########
File renamed without changes.

test/Transform/wavelet_transform.jl

+20-43
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,30 @@
1-
@testset "SparseKernel" begin
1+
@testset "wavelet transform" begin
2+
𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7
3+
4+
wt = WaveletTransform((3, 4, 5))
5+
6+
@test size(transform(wt, 𝐱)) == (30, 40, 50, 6, 7)
7+
@test size(truncate_modes(wt, transform(wt, 𝐱))) == (3, 4, 5, 6, 7)
8+
@test size(inverse(wt, truncate_modes(wt, transform(wt, 𝐱)))) == (3, 4, 5, 6, 7)
9+
end
10+
11+
@testset "MWT_CZ" begin
212
T = Float32
313
k = 3
414
batch_size = 32
515

6-
@testset "1D SparseKernel" begin
7-
α = 4
8-
c = 1
9-
in_chs = 20
10-
X = rand(T, in_chs, c*k, batch_size)
16+
@testset "MWT_CZ1d" begin
17+
mwt = MWT_CZ1d()
1118

12-
l1 = SparseKernel1D(k, α, c)
13-
Y = l1(X)
14-
@test l1 isa SparseKernel{1}
15-
@test size(Y) == size(X)
19+
# base functions
20+
wavelet_transform(mwt, )
21+
even_odd(mwt, )
1622

17-
gs = gradient(()->sum(l1(X)), Flux.params(l1))
18-
@test length(gs.grads) == 4
19-
end
23+
# forward
24+
Y = mwt(X)
2025

21-
@testset "2D SparseKernel" begin
22-
α = 4
23-
c = 3
24-
Nx = 5
25-
Ny = 7
26-
X = rand(T, Nx, Ny, c*k^2, batch_size)
27-
28-
l2 = SparseKernel2D(k, α, c)
29-
Y = l2(X)
30-
@test l2 isa SparseKernel{2}
31-
@test size(Y) == size(X)
32-
33-
gs = gradient(()->sum(l2(X)), Flux.params(l2))
34-
@test length(gs.grads) == 4
26+
# backward
27+
g = gradient()
3528
end
3629

37-
@testset "3D SparseKernel" begin
38-
α = 4
39-
c = 3
40-
Nx = 5
41-
Ny = 7
42-
Nz = 13
43-
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)
44-
45-
l3 = SparseKernel3D(k, α, c)
46-
Y = l3(X)
47-
@test l3 isa SparseKernel{3}
48-
@test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size)
49-
50-
gs = gradient(()->sum(l3(X)), Flux.params(l3))
51-
@test length(gs.grads) == 4
52-
end
5330
end

0 commit comments

Comments
 (0)