-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
237 lines (200 loc) · 6.97 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import warnings
import torch
import torch.nn as nn
import torchinfo
import yaml
# ! CNN
class CNN(nn.Module):
def __init__(
self,
n_in_channel,
activation="Relu",
conv_dropout=0,
kernel_size=[3, 3, 3],
padding=[1, 1, 1],
stride=[1, 1, 1],
nb_filters=[64, 64, 64],
pooling=[(1, 4), (1, 4), (1, 4)],
normalization="batch",
**kwargs,
):
"""
Initialization of CNN network
Args:
n_in_channel: int, number of input channel
activation: str, activation function
conv_dropout: float, dropout
kernel_size: kernel size
padding: padding
stride: list, stride
nb_filters: number of filters
pooling: list of tuples, time and frequency pooling
normalization: choose between "batch" for BatchNormalization and "layer" for LayerNormalization.
"""
super(CNN, self).__init__()
self.nb_filters = nb_filters
cnn = nn.Sequential()
def conv(i, normalization="batch", dropout=None, activ="relu"):
nIn = n_in_channel if i == 0 else nb_filters[i - 1]
nOut = nb_filters[i]
cnn.add_module(
"conv{0}".format(i),
nn.Conv2d(nIn, nOut, kernel_size[i], stride[i], padding[i]),
)
if normalization == "batch":
cnn.add_module(
"batchnorm{0}".format(i),
nn.BatchNorm2d(nOut, eps=0.001, momentum=0.99),
)
elif normalization == "layer":
cnn.add_module("layernorm{0}".format(i), nn.GroupNorm(1, nOut))
if activ.lower() == "leakyrelu":
cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2))
elif activ.lower() == "relu":
cnn.add_module("relu{0}".format(i), nn.ReLU())
if dropout is not None:
cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout))
for i in range(len(nb_filters)):
conv(i, normalization=normalization, dropout=conv_dropout, activ=activation)
cnn.add_module(
"pooling{0}".format(i), nn.AvgPool2d(pooling[i])
) # bs x tframe x mels
self.cnn = cnn
def forward(self, x):
"""
Forward step of the CNN module
Args:
x (Tensor): input batch of size (batch_size, n_channels, n_frames, n_freq)
Returns:
Tensor: batch embedded
"""
# conv features
x = self.cnn(x)
return x
# ! RNN
class BidirectionalGRU(nn.Module):
def __init__(self, n_in, n_hidden, dropout=0, num_layers=1):
"""
Initialization of BidirectionalGRU instance
Args:
n_in: int, number of input
n_hidden: int, number of hidden layers
dropout: flat, dropout
num_layers: int, number of layers
"""
super(BidirectionalGRU, self).__init__()
self.rnn = nn.GRU(
n_in,
n_hidden,
bidirectional=True,
dropout=dropout,
batch_first=True,
num_layers=num_layers,
)
def forward(self, input_feat):
recurrent, _ = self.rnn(input_feat)
return recurrent
# ! CRNN
class CRNN(nn.Module):
def __init__(
self,
n_in_channel=1,
nclass=10,
activation="Relu",
dropout=0.5,
rnn_type="BGRU",
n_RNN_cell=128,
n_layers_RNN=2,
dropout_recurrent=0,
attention=True,
**kwargs,
):
"""
Initialization of CRNN model
Args:
n_in_channel: int, number of input channel
n_class: int, number of classes
activation: str, activation function
dropout: float, dropout
train_cnn: bool, training cnn layers
rnn_type: str, rnn type
n_RNN_cell: int, RNN nodes
n_layer_RNN: int, number of RNN layers
dropout_recurrent: float, recurrent layers dropout
cnn_integration: bool, integration of cnn
freeze_bn:
**kwargs: keywords arguments for CNN.
"""
super(CRNN, self).__init__()
self.n_in_channel = n_in_channel
n_in_cnn = n_in_channel
self.cnn = CNN(
n_in_channel=n_in_cnn,
activation=activation,
conv_dropout=int(dropout),
**kwargs,
)
self.attention = attention
# n_in_channel,
# activation="Relu",
# conv_dropout=0,
# kernel_size=[3, 3, 3],
# padding=[1, 1, 1],
# stride=[1, 1, 1],
# nb_filters=[64, 64, 64],
# pooling=[(1, 4), (1, 4), (1, 4)],
# normalization="batch"
if rnn_type == "BGRU":
nb_in = self.cnn.nb_filters[-1]
self.rnn = BidirectionalGRU(
n_in=nb_in,
n_hidden=n_RNN_cell,
dropout=dropout_recurrent,
num_layers=n_layers_RNN,
)
else:
NotImplementedError("Only BGRU supported for CRNN for now")
self.dropout = nn.Dropout(dropout)
self.dense = nn.Linear(n_RNN_cell * 2, nclass)
self.sigmoid = nn.Sigmoid()
if self.attention:
self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, pad_mask=None, embeddings=None):
# [batch_size, n_freq, n_frames]
x = x.transpose(1, 2).unsqueeze(1)
# [batch_size, n_channels, n_frames, n_freq]
x = self.cnn(x)
bs, chan, frames, freq = x.size()
if freq != 1:
warnings.warn(
f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq"
)
x = x.permute(0, 2, 1, 3)
x = x.contiguous().view(bs, frames, chan * freq)
else:
x = x.squeeze(-1) # [batch_size, n_channels, n_frames]
x = x.permute(0, 2, 1) # [batch_size, n_frames, n_channels]
# print(f'x.shape pre rnn = {x.shape}')
x = self.rnn(x) # [batch_size, n_frames, n_channels]
x = self.dropout(x)
strong = self.dense(x) # [bs, frames, nclass]
strong = self.sigmoid(strong)
if self.attention:
sof = self.dense_softmax(x) # [bs, frames, nclass]
sof = self.softmax(sof)
sof = torch.clamp(sof, min=1e-7, max=1)
weak = (strong * sof).sum(1) / sof.sum(1) # [bs, nclass]
else:
weak = strong.mean(1)
return strong.transpose(1, 2), weak
if __name__ == "__main__":
with open("params.yaml", "r") as f:
configs = yaml.safe_load(f)
model = CRNN(**configs["net"])
torchinfo.summary(
model,
(64, 128, 618),
verbose=1,
col_names=["input_size", "output_size", "num_params", "kernel_size"],
)