Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 849fde2

Browse files
committed
draft for SparseKernel1d
1 parent 678bd2b commit 849fde2

File tree

3 files changed

+117
-0
lines changed

3 files changed

+117
-0
lines changed

src/NeuralOperators.jl

+1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ module NeuralOperators
99
using ChainRulesCore
1010

1111
include("fourier.jl")
12+
include("wavelet.jl")
1213
include("model.jl")
1314
end

src/wavelet.jl

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
struct SparseKernel1d{T,S}
2+
k::Int
3+
conv_blk::S
4+
out_weight::T
5+
end
6+
7+
function SparseKernel1d(k::Int, c::Int=1; init=Flux.glorot_uniform)
8+
input_dim = c*k
9+
emb_dim = 128
10+
conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
11+
W_out = Dense(emb_dim, input_dim; init=init)
12+
return SparseKernel1d(k, conv, W_out)
13+
end
14+
15+
function (l::SparseKernel1d)(X::AbstractArray)
16+
X_ = l.conv_blk(batched_transpose(X))
17+
Y = l.out_weight(batched_transpose(X_))
18+
return Y
19+
end
20+
21+
22+
# class MWT_CZ1d(nn.Module):
23+
# def __init__(self,
24+
# k = 3, alpha = 5,
25+
# L = 0, c = 1,
26+
# base = 'legendre',
27+
# initializer = None,
28+
# **kwargs):
29+
# super(MWT_CZ1d, self).__init__()
30+
31+
# self.k = k
32+
# self.L = L
33+
# H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
34+
# H0r = H0@PHI0
35+
# G0r = G0@PHI0
36+
# H1r = H1@PHI1
37+
# G1r = G1@PHI1
38+
39+
# H0r[np.abs(H0r)<1e-8]=0
40+
# H1r[np.abs(H1r)<1e-8]=0
41+
# G0r[np.abs(G0r)<1e-8]=0
42+
# G1r[np.abs(G1r)<1e-8]=0
43+
44+
# self.A = sparseKernelFT1d(k, alpha, c)
45+
# self.B = sparseKernelFT1d(k, alpha, c)
46+
# self.C = sparseKernelFT1d(k, alpha, c)
47+
48+
# self.T0 = nn.Linear(k, k)
49+
50+
# self.register_buffer('ec_s', torch.Tensor(
51+
# np.concatenate((H0.T, H1.T), axis=0)))
52+
# self.register_buffer('ec_d', torch.Tensor(
53+
# np.concatenate((G0.T, G1.T), axis=0)))
54+
55+
# self.register_buffer('rc_e', torch.Tensor(
56+
# np.concatenate((H0r, G0r), axis=0)))
57+
# self.register_buffer('rc_o', torch.Tensor(
58+
# np.concatenate((H1r, G1r), axis=0)))
59+
60+
61+
# def forward(self, x):
62+
63+
# B, N, c, ich = x.shape # (B, N, k)
64+
# ns = math.floor(np.log2(N))
65+
66+
# Ud = torch.jit.annotate(List[Tensor], [])
67+
# Us = torch.jit.annotate(List[Tensor], [])
68+
# # decompose
69+
# for i in range(ns-self.L):
70+
# d, x = self.wavelet_transform(x)
71+
# Ud += [self.A(d) + self.B(x)]
72+
# Us += [self.C(d)]
73+
# x = self.T0(x) # coarsest scale transform
74+
75+
# # reconstruct
76+
# for i in range(ns-1-self.L,-1,-1):
77+
# x = x + Us[i]
78+
# x = torch.cat((x, Ud[i]), -1)
79+
# x = self.evenOdd(x)
80+
# return x
81+
82+
83+
# def wavelet_transform(self, x):
84+
# xa = torch.cat([x[:, ::2, :, :],
85+
# x[:, 1::2, :, :],
86+
# ], -1)
87+
# d = torch.matmul(xa, self.ec_d)
88+
# s = torch.matmul(xa, self.ec_s)
89+
# return d, s
90+
91+
92+
# def evenOdd(self, x):
93+
94+
# B, N, c, ich = x.shape # (B, N, c, k)
95+
# assert ich == 2*self.k
96+
# x_e = torch.matmul(x, self.rc_e)
97+
# x_o = torch.matmul(x, self.rc_o)
98+
99+
# x = torch.zeros(B, N*2, c, self.k,
100+
# device = x.device)
101+
# x[..., ::2, :, :] = x_e
102+
# x[..., 1::2, :, :] = x_o
103+
# return x

test/wavelet.jl

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using NeuralOperators
2+
3+
T = Float32
4+
k = 10
5+
c = 1
6+
in_chs = 20
7+
batch_size = 32
8+
9+
10+
l = NeuralOperators.SparseKernel1d(k, c)
11+
12+
X = rand(T, c*k, in_chs, batch_size)
13+
Y = l(X)

0 commit comments

Comments
 (0)