Skip to content

Commit a8ef530

Browse files
committed
Formatting.
1 parent d0659a0 commit a8ef530

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

moshi_mlx/moshi_mlx/modules/conv.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import mlx.core as mx
77
import mlx.nn as nn
88

9+
910
class Conv1d(nn.Module):
1011
def __init__(
1112
self,
@@ -48,6 +49,7 @@ def __call__(self, xs: mx.array) -> mx.array:
4849
y = y + self.bias
4950
return y.swapaxes(-1, -2)
5051

52+
5153
class ConvTranspose1d(nn.Module):
5254
def __init__(
5355
self,
@@ -86,6 +88,7 @@ def __call__(self, xs: mx.array) -> mx.array:
8688
y = y + self.bias
8789
return y
8890

91+
8992
class NormConv1d(nn.Module):
9093
def __init__(
9194
self,
@@ -112,6 +115,7 @@ def __init__(
112115
def __call__(self, xs: mx.array) -> mx.array:
113116
return self.conv(xs)
114117

118+
115119
class NormConvTranspose1d(nn.Module):
116120
def __init__(
117121
self,
@@ -136,23 +140,27 @@ def __init__(
136140
def __call__(self, xs: mx.array) -> mx.array:
137141
return self.convtr(xs)
138142

143+
139144
def get_extra_padding_for_conv1d(
140145
xs: mx.array,
141146
ksize: int,
142-
stride: int,
147+
stride: int,
143148
padding_total: int,
144149
) -> int:
145150
l = xs.shape[-1]
146151
nframes = max(l + padding_total - ksize, 0) / stride + 1.0
147152
ideal_len = (int(math.ceil(nframes)) - 1) * stride + ksize - padding_total
148153
return max(0, ideal_len - l)
149154

155+
150156
def unpad1d(xs: mx.array, unpad_l: int, unpad_r: int) -> mx.array:
151157
left = unpad_l
152158
right = xs.shape[-1] - unpad_r
153159
return xs[..., left:right]
154160

155161
# TODO(laurent): add a streaming module abstract class?
162+
163+
156164
class StreamableConv1d(nn.Module):
157165
def __init__(
158166
self,
@@ -218,6 +226,7 @@ def __call__(self, xs: mx.array) -> mx.array:
218226
self._prev_xs = xs
219227
return mx.zeros((b, self._out_channels, 0))
220228

229+
221230
class StreamableConvTranspose1d(nn.Module):
222231
def __init__(
223232
self,
@@ -259,9 +268,10 @@ def __call__(self, xs: mx.array) -> mx.array:
259268
ys1, ys2 = ys[..., :pt] + prev_ys, ys[..., pt:]
260269
ys = mx.concat([ys1, ys2], axis=-1)
261270
invalid_steps = self._ksize - stride
262-
ys, self._prev_ys = ys[..., :ot-invalid_steps], ys[..., ot-invalid_steps]
271+
ys, self._prev_ys = ys[..., :ot - invalid_steps], ys[..., ot - invalid_steps]
263272
return ys
264273

274+
265275
class ConvDownsample1d(nn.Module):
266276
def __init__(
267277
self,
@@ -272,7 +282,7 @@ def __init__(
272282
self.conv = StreamableConv1d(
273283
in_channels=dim,
274284
out_channels=dim,
275-
ksize=2*stride,
285+
ksize=2 * stride,
276286
stride=stride,
277287
dilation=1,
278288
groups=1,
@@ -287,6 +297,7 @@ def reset_state(self):
287297
def __call__(self, xs: mx.array) -> mx.array:
288298
return self.conv(xs)
289299

300+
290301
class ConvTrUpsample1d(nn.Module):
291302
def __init__(
292303
self,
@@ -297,9 +308,9 @@ def __init__(
297308
self.convtr = StreamableConvTranspose1d(
298309
in_channels=dim,
299310
out_channels=dim,
300-
ksize=2*stride,
311+
ksize=2 * stride,
301312
stride=stride,
302-
groups=dim, # TODO: hopefully someday this will be fixed.
313+
groups=dim, # TODO: hopefully someday this will be fixed.
303314
bias=False,
304315
causal=causal,
305316
)

moshi_mlx/moshi_mlx/modules/quantization.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import mlx.core as mx
88
import mlx.nn as nn
99

10+
1011
class EuclideanCodebook(nn.Module):
1112
def __init__(self, dim: int, codebook_size: int):
1213
super().__init__()
@@ -36,6 +37,7 @@ def decode(self, xs: mx.array) -> mx.array:
3637
target_shape = list(xs.shape) + [self._dim]
3738
return mx.take(self.embedding, xs.flatten()).reshape(target_shape)
3839

40+
3941
class VectorQuantization(nn.Module):
4042
def __init__(self, dim: int, codebook_size: int, codebook_dim: int | None):
4143
super().__init__()
@@ -60,6 +62,7 @@ def decode(self, xs: mx.array) -> mx.array:
6062
xs = self.project_out(xs)
6163
return xs.swapaxes(-1, -2)
6264

65+
6366
class ResidualVectorQuantization(nn.Module):
6467
def __init__(self, nq: int, dim: int, codebook_size: int, codebook_dim: int | None):
6568
super().__init__()
@@ -90,6 +93,7 @@ def decode(self, xs: mx.array) -> mx.array:
9093
quantized = quantized + self.layers[i].decode(xs[i])
9194
return quantized
9295

96+
9397
class ResidualVectorQuantizer(nn.Module):
9498
def __init__(
9599
self,
@@ -112,10 +116,10 @@ def __init__(
112116
else:
113117
self.output_proj = Conv1d(dim, output_dim, 1, bias=False)
114118
self.vq = ResidualVectorQuantization(
115-
nq=nq,
116-
dim=dim,
117-
codebook_size=bins,
118-
codebook_dim=None,
119+
nq=nq,
120+
dim=dim,
121+
codebook_size=bins,
122+
codebook_dim=None,
119123
)
120124

121125
def encode(self, xs: mx.array) -> mx.array:
@@ -130,6 +134,7 @@ def decode(self, xs: mx.array) -> mx.array:
130134
quantized = self.output_proj(quantized)
131135
return quantized
132136

137+
133138
class SplitResidualVectorQuantizer(nn.Module):
134139
def __init__(
135140
self,
@@ -153,7 +158,7 @@ def __init__(
153158
dim=dim,
154159
input_dim=input_dim,
155160
output_dim=output_dim,
156-
nq=nq-1,
161+
nq=nq - 1,
157162
bins=bins,
158163
force_projection=True
159164
)

moshi_mlx/moshi_mlx/modules/seanet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class SeanetConfig:
2525
true_skip: bool
2626
compress: int
2727

28+
2829
class SeanetResnetBlock(nn.Module):
2930
def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list):
3031
super().__init__()
@@ -79,6 +80,7 @@ def __call__(self, xs: mx.array) -> mx.array:
7980
xs = xs + self.shortcut(residual)
8081
return xs
8182

83+
8284
class EncoderLayer(nn.Module):
8385
def __init__(self, cfg: SeanetConfig, ratio: int, mult: int):
8486
super().__init__()
@@ -115,6 +117,7 @@ def __call__(self, xs: mx.array) -> mx.array:
115117
xs = r(xs)
116118
return self.downsample(nn.elu(xs, alpha=1.0))
117119

120+
118121
class SeanetEncoder(nn.Module):
119122
def __init__(self, cfg: SeanetConfig):
120123
super().__init__()
@@ -160,6 +163,7 @@ def __call__(self, xs: mx.array) -> mx.array:
160163
xs = nn.elu(xs, alpha=1.0)
161164
return self.final_conv1d(xs)
162165

166+
163167
class DecoderLayer(nn.Module):
164168
def __init__(self, cfg: SeanetConfig, ratio: int, mult: int):
165169
super().__init__()
@@ -185,6 +189,7 @@ def __call__(self, xs: mx.array) -> mx.array:
185189
xs = r(xs)
186190
return xs
187191

192+
188193
class SeanetDecoder(nn.Module):
189194
def __init__(self, cfg: SeanetConfig):
190195
super().__init__()
@@ -230,6 +235,7 @@ def __call__(self, xs: mx.array) -> mx.array:
230235
xs = nn.elu(xs, alpha=1.0)
231236
return self.final_conv1d(xs)
232237

238+
233239
class Seanet(nn.Module):
234240
def __init__(self, cfg: SeanetConfig):
235241
super().__init__()

0 commit comments

Comments
 (0)