1
+ struct SparseKernel1d{T,S}
2
+ k:: Int
3
+ conv_blk:: S
4
+ out_weight:: T
5
+ end
6
+
7
+ function SparseKernel1d (k:: Int , c:: Int = 1 ; init= Flux. glorot_uniform)
8
+ input_dim = c* k
9
+ emb_dim = 128
10
+ conv = Conv ((3 ,), input_dim=> emb_dim, relu; stride= 1 , pad= 1 , init= init)
11
+ W_out = Dense (emb_dim, input_dim; init= init)
12
+ return SparseKernel1d (k, conv, W_out)
13
+ end
14
+
15
+ function (l:: SparseKernel1d )(X:: AbstractArray )
16
+ X_ = l. conv_blk (batched_transpose (X))
17
+ Y = l. out_weight (batched_transpose (X_))
18
+ return Y
19
+ end
20
+
21
+
22
+ # class MWT_CZ1d(nn.Module):
23
+ # def __init__(self,
24
+ # k = 3, alpha = 5,
25
+ # L = 0, c = 1,
26
+ # base = 'legendre',
27
+ # initializer = None,
28
+ # **kwargs):
29
+ # super(MWT_CZ1d, self).__init__()
30
+
31
+ # self.k = k
32
+ # self.L = L
33
+ # H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
34
+ # H0r = H0@PHI0
35
+ # G0r = G0@PHI0
36
+ # H1r = H1@PHI1
37
+ # G1r = G1@PHI1
38
+
39
+ # H0r[np.abs(H0r)<1e-8]=0
40
+ # H1r[np.abs(H1r)<1e-8]=0
41
+ # G0r[np.abs(G0r)<1e-8]=0
42
+ # G1r[np.abs(G1r)<1e-8]=0
43
+
44
+ # self.A = sparseKernelFT1d(k, alpha, c)
45
+ # self.B = sparseKernelFT1d(k, alpha, c)
46
+ # self.C = sparseKernelFT1d(k, alpha, c)
47
+
48
+ # self.T0 = nn.Linear(k, k)
49
+
50
+ # self.register_buffer('ec_s', torch.Tensor(
51
+ # np.concatenate((H0.T, H1.T), axis=0)))
52
+ # self.register_buffer('ec_d', torch.Tensor(
53
+ # np.concatenate((G0.T, G1.T), axis=0)))
54
+
55
+ # self.register_buffer('rc_e', torch.Tensor(
56
+ # np.concatenate((H0r, G0r), axis=0)))
57
+ # self.register_buffer('rc_o', torch.Tensor(
58
+ # np.concatenate((H1r, G1r), axis=0)))
59
+
60
+
61
+ # def forward(self, x):
62
+
63
+ # B, N, c, ich = x.shape # (B, N, k)
64
+ # ns = math.floor(np.log2(N))
65
+
66
+ # Ud = torch.jit.annotate(List[Tensor], [])
67
+ # Us = torch.jit.annotate(List[Tensor], [])
68
+ # # decompose
69
+ # for i in range(ns-self.L):
70
+ # d, x = self.wavelet_transform(x)
71
+ # Ud += [self.A(d) + self.B(x)]
72
+ # Us += [self.C(d)]
73
+ # x = self.T0(x) # coarsest scale transform
74
+
75
+ # # reconstruct
76
+ # for i in range(ns-1-self.L,-1,-1):
77
+ # x = x + Us[i]
78
+ # x = torch.cat((x, Ud[i]), -1)
79
+ # x = self.evenOdd(x)
80
+ # return x
81
+
82
+
83
+ # def wavelet_transform(self, x):
84
+ # xa = torch.cat([x[:, ::2, :, :],
85
+ # x[:, 1::2, :, :],
86
+ # ], -1)
87
+ # d = torch.matmul(xa, self.ec_d)
88
+ # s = torch.matmul(xa, self.ec_s)
89
+ # return d, s
90
+
91
+
92
+ # def evenOdd(self, x):
93
+
94
+ # B, N, c, ich = x.shape # (B, N, c, k)
95
+ # assert ich == 2*self.k
96
+ # x_e = torch.matmul(x, self.rc_e)
97
+ # x_o = torch.matmul(x, self.rc_o)
98
+
99
+ # x = torch.zeros(B, N*2, c, self.k,
100
+ # device = x.device)
101
+ # x[..., ::2, :, :] = x_e
102
+ # x[..., 1::2, :, :] = x_o
103
+ # return x
0 commit comments