Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 827ce6f

Browse files
committed
migrate
1 parent d259ce3 commit 827ce6f

File tree

5 files changed

+207
-178
lines changed

5 files changed

+207
-178
lines changed

src/Transform/wavelet_transform.jl

+13-124
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,19 @@
1-
export
2-
SparseKernel,
3-
SparseKernel1D,
4-
SparseKernel2D,
5-
SparseKernel3D
1+
export WaveletTransform
62

7-
8-
struct SparseKernel{N,T,S}
9-
conv_blk::T
10-
out_weight::S
11-
end
12-
13-
function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
14-
input_dim, emb_dim = ch
15-
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
16-
W_out = Dense(emb_dim, input_dim; init=init)
17-
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
18-
end
19-
20-
function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
21-
input_dim = c*k
22-
emb_dim = 128
23-
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
24-
end
25-
26-
function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
27-
input_dim = c*k^2
28-
emb_dim = α*k^2
29-
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
30-
end
31-
32-
function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
33-
input_dim = c*k^2
34-
emb_dim = α*k^2
35-
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
36-
W_out = Dense(emb_dim, input_dim; init=init)
37-
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
38-
end
39-
40-
Flux.@functor SparseKernel
41-
42-
function (l::SparseKernel)(X::AbstractArray)
43-
bch_sz, _, dims_r... = reverse(size(X))
44-
dims = reverse(dims_r)
45-
46-
X_ = l.conv_blk(X) # (dims..., emb_dims, B)
47-
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
48-
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
49-
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
50-
return collect(Y)
51-
end
52-
53-
54-
struct MWT_CZ1d{T,S,R,Q,P}
55-
k::Int
56-
L::Int
57-
A::T
58-
B::S
59-
C::R
60-
T0::Q
61-
ec_s::P
62-
ec_d::P
63-
rc_e::P
64-
rc_o::P
65-
end
66-
67-
function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
68-
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
69-
H0r = zero_out!(H0 * Φ0)
70-
G0r = zero_out!(G0 * Φ0)
71-
H1r = zero_out!(H1 * Φ1)
72-
G1r = zero_out!(G1 * Φ1)
73-
74-
dim = c*k
75-
A = SpectralConv(dim=>dim, (α,); init=init)
76-
B = SpectralConv(dim=>dim, (α,); init=init)
77-
C = SpectralConv(dim=>dim, (α,); init=init)
78-
T0 = Dense(k, k)
79-
80-
ec_s = vcat(H0', H1')
81-
ec_d = vcat(G0', G1')
82-
rc_e = vcat(H0r, G0r)
83-
rc_o = vcat(H1r, G1r)
84-
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
85-
end
86-
87-
function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
88-
N = size(X, 3)
89-
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
90-
d = NNlib.batched_mul(Xa, l.ec_d)
91-
s = NNlib.batched_mul(Xa, l.ec_s)
92-
return d, s
3+
struct WaveletTransform{N, S}<:AbstractTransform
4+
modes::NTuple{N, S} # N == ndims(x)
935
end
946

95-
function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
96-
bch_sz, N, dims_r... = reverse(size(X))
97-
dims = reverse(dims_r)
98-
@assert dims[1] == 2*l.k
99-
Xₑ = NNlib.batched_mul(X, l.rc_e)
100-
Xₒ = NNlib.batched_mul(X, l.rc_o)
101-
# x = torch.zeros(B, N*2, c, self.k,
102-
# device = x.device)
103-
# x[..., ::2, :, :] = x_e
104-
# x[..., 1::2, :, :] = x_o
105-
return X
106-
end
7+
Base.ndims(::WaveletTransform{N}) where {N} = N
1078

108-
function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
109-
bch_sz, N, dims_r... = reverse(size(X))
110-
ns = floor(log2(N))
111-
stop = ns - l.L
9+
# function transform(wt::WaveletTransform, 𝐱::AbstractArray)
10+
# return fft(Zygote.hook(real, 𝐱), 1:ndims(wt)) # [size(x)..., in_chs, batch]
11+
# end
11212

113-
# decompose
114-
Ud = T[]
115-
Us = T[]
116-
for i in 1:stop
117-
d, X = wavelet_transform(l, X)
118-
push!(Ud, l.A(d)+l.B(d))
119-
push!(Us, l.C(d))
120-
end
121-
X = l.T0(X)
13+
# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray)
14+
# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch]
15+
# end
12216

123-
# reconstruct
124-
for i in stop:-1:1
125-
X += Us[i]
126-
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
127-
X = even_odd(l, X)
128-
end
129-
return X
130-
end
17+
# function inverse(wt::WaveletTransform, 𝐱_fft::AbstractArray)
18+
# return real(ifft(𝐱_fft, 1:ndims(wt))) # [size(x_fft)..., out_chs, batch]
19+
# end

src/operator_kernel.jl

+140-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ export
22
OperatorConv,
33
SpectralConv,
44
OperatorKernel,
5-
GraphKernel
5+
GraphKernel,
6+
SparseKernel,
7+
SparseKernel1D,
8+
SparseKernel2D,
9+
SparseKernel3D
610

711
struct OperatorConv{P, T, S, TT}
812
weight::T
@@ -216,6 +220,141 @@ function Base.show(io::IO, l::GraphKernel)
216220
print(io, ")")
217221
end
218222

223+
"""
224+
SparseKernel(κ, ch, σ=identity)
225+
226+
Sparse kernel layer.
227+
228+
## Arguments
229+
230+
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
231+
* `ch`: Channel size for linear transform, e.g. `32`.
232+
* `σ`: Activation function.
233+
"""
234+
struct SparseKernel{N,T,S}
235+
conv_blk::T
236+
out_weight::S
237+
end
238+
239+
function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
240+
input_dim, emb_dim = ch
241+
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
242+
W_out = Dense(emb_dim, input_dim; init=init)
243+
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
244+
end
245+
246+
function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
247+
input_dim = c*k
248+
emb_dim = 128
249+
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
250+
end
251+
252+
function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
253+
input_dim = c*k^2
254+
emb_dim = α*k^2
255+
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
256+
end
257+
258+
function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
259+
input_dim = c*k^2
260+
emb_dim = α*k^2
261+
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
262+
W_out = Dense(emb_dim, input_dim; init=init)
263+
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
264+
end
265+
266+
Flux.@functor SparseKernel
267+
268+
function (l::SparseKernel)(X::AbstractArray)
269+
bch_sz, _, dims_r... = reverse(size(X))
270+
dims = reverse(dims_r)
271+
272+
X_ = l.conv_blk(X) # (dims..., emb_dims, B)
273+
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
274+
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
275+
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
276+
return collect(Y)
277+
end
278+
279+
280+
struct MWT_CZ1d{T,S,R,Q,P}
281+
k::Int
282+
L::Int
283+
A::T
284+
B::S
285+
C::R
286+
T0::Q
287+
ec_s::P
288+
ec_d::P
289+
rc_e::P
290+
rc_o::P
291+
end
292+
293+
function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
294+
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
295+
H0r = zero_out!(H0 * Φ0)
296+
G0r = zero_out!(G0 * Φ0)
297+
H1r = zero_out!(H1 * Φ1)
298+
G1r = zero_out!(G1 * Φ1)
299+
300+
dim = c*k
301+
A = SpectralConv(dim=>dim, (α,); init=init)
302+
B = SpectralConv(dim=>dim, (α,); init=init)
303+
C = SpectralConv(dim=>dim, (α,); init=init)
304+
T0 = Dense(k, k)
305+
306+
ec_s = vcat(H0', H1')
307+
ec_d = vcat(G0', G1')
308+
rc_e = vcat(H0r, G0r)
309+
rc_o = vcat(H1r, G1r)
310+
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
311+
end
312+
313+
function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
314+
N = size(X, 3)
315+
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
316+
d = NNlib.batched_mul(Xa, l.ec_d)
317+
s = NNlib.batched_mul(Xa, l.ec_s)
318+
return d, s
319+
end
320+
321+
function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
322+
bch_sz, N, dims_r... = reverse(size(X))
323+
dims = reverse(dims_r)
324+
@assert dims[1] == 2*l.k
325+
Xₑ = NNlib.batched_mul(X, l.rc_e)
326+
Xₒ = NNlib.batched_mul(X, l.rc_o)
327+
# x = torch.zeros(B, N*2, c, self.k,
328+
# device = x.device)
329+
# x[..., ::2, :, :] = x_e
330+
# x[..., 1::2, :, :] = x_o
331+
return X
332+
end
333+
334+
function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
335+
bch_sz, N, dims_r... = reverse(size(X))
336+
ns = floor(log2(N))
337+
stop = ns - l.L
338+
339+
# decompose
340+
Ud = T[]
341+
Us = T[]
342+
for i in 1:stop
343+
d, X = wavelet_transform(l, X)
344+
push!(Ud, l.A(d)+l.B(d))
345+
push!(Us, l.C(d))
346+
end
347+
X = l.T0(X)
348+
349+
# reconstruct
350+
for i in stop:-1:1
351+
X += Us[i]
352+
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
353+
X = even_odd(l, X)
354+
end
355+
return X
356+
end
357+
219358

220359
#########
221360
# utils #
File renamed without changes.

test/Transform/wavelet_transform.jl

-53
Original file line numberDiff line numberDiff line change
@@ -1,53 +0,0 @@
1-
@testset "SparseKernel" begin
2-
T = Float32
3-
k = 3
4-
batch_size = 32
5-
6-
@testset "1D SparseKernel" begin
7-
α = 4
8-
c = 1
9-
in_chs = 20
10-
X = rand(T, in_chs, c*k, batch_size)
11-
12-
l1 = SparseKernel1D(k, α, c)
13-
Y = l1(X)
14-
@test l1 isa SparseKernel{1}
15-
@test size(Y) == size(X)
16-
17-
gs = gradient(()->sum(l1(X)), Flux.params(l1))
18-
@test length(gs.grads) == 4
19-
end
20-
21-
@testset "2D SparseKernel" begin
22-
α = 4
23-
c = 3
24-
Nx = 5
25-
Ny = 7
26-
X = rand(T, Nx, Ny, c*k^2, batch_size)
27-
28-
l2 = SparseKernel2D(k, α, c)
29-
Y = l2(X)
30-
@test l2 isa SparseKernel{2}
31-
@test size(Y) == size(X)
32-
33-
gs = gradient(()->sum(l2(X)), Flux.params(l2))
34-
@test length(gs.grads) == 4
35-
end
36-
37-
@testset "3D SparseKernel" begin
38-
α = 4
39-
c = 3
40-
Nx = 5
41-
Ny = 7
42-
Nz = 13
43-
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)
44-
45-
l3 = SparseKernel3D(k, α, c)
46-
Y = l3(X)
47-
@test l3 isa SparseKernel{3}
48-
@test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size)
49-
50-
gs = gradient(()->sum(l3(X)), Flux.params(l3))
51-
@test length(gs.grads) == 4
52-
end
53-
end

0 commit comments

Comments
 (0)