@@ -27,6 +27,7 @@ class SeanetConfig:
27
27
28
28
class SeanetResnetBlock (nn .Module ):
29
29
def __init__ (self , cfg : SeanetConfig , dim : int , ksizes_and_dilations : list ):
30
+ super ().__init__ ()
30
31
block = []
31
32
hidden = dim // cfg .compress
32
33
for i , (ksize , dilation ) in enumerate (ksizes_and_dilations ):
@@ -80,6 +81,7 @@ def __call__(self, xs: mx.array) -> mx.array:
80
81
81
82
class EncoderLayer (nn .Module ):
82
83
def __init__ (self , cfg : SeanetConfig , ratio : int , mult : int ):
84
+ super ().__init__ ()
83
85
residuals = []
84
86
dilation = 1
85
87
for _ in range (cfg .nresidual_layers ):
@@ -115,6 +117,7 @@ def __call__(self, xs: mx.array) -> mx.array:
115
117
116
118
class SeanetEncoder (nn .Module ):
117
119
def __init__ (self , cfg : SeanetConfig ):
120
+ super ().__init__ ()
118
121
mult = 1
119
122
self .init_conv1d = StreamableConv1d (
120
123
in_channels = cfg .channels ,
@@ -159,6 +162,7 @@ def __call__(self, xs: mx.array) -> mx.array:
159
162
160
163
class DecoderLayer (nn .Module ):
161
164
def __init__ (self , cfg : SeanetConfig , ratio : int , mult : int ):
165
+ super ().__init__ ()
162
166
self .upsample = StreamableConvTranspose1d (
163
167
in_channels = mult * cfg .nfilters ,
164
168
out_channels = mult * cfg .nfilters // 2 ,
@@ -183,6 +187,7 @@ def __call__(self, xs: mx.array) -> mx.array:
183
187
184
188
class SeanetDecoder (nn .Module ):
185
189
def __init__ (self , cfg : SeanetConfig ):
190
+ super ().__init__ ()
186
191
mult = 1 << len (cfg .ratios )
187
192
self .init_conv1d = StreamableConv1d (
188
193
in_channels = cfg .dimension ,
@@ -227,5 +232,6 @@ def __call__(self, xs: mx.array) -> mx.array:
227
232
228
233
class Seanet (nn .Module ):
229
234
def __init__ (self , cfg : SeanetConfig ):
235
+ super ().__init__ ()
230
236
self .encoder = SeanetEncoder (cfg )
231
237
self .decoder = SeanetDecoder (cfg )
0 commit comments