Skip to content

Commit 998adca

Browse files
committed
Start adding some quantization.
1 parent 1395d2f commit 998adca

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

moshi_mlx/moshi_mlx/modules/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Modules used for building the models."""
66

77
from .conv import Conv1d, ConvTranspose1d, StreamableConv1d, StreamableConvTranspose1d, NormConv1d, NormConvTranspose1d, ConvDownsample1d, ConvTrUpsample1d
8-
from .seanet import SeanetConfig, Seanet
8+
from .quantization import SplitResidualVectorQuantizer
9+
from .seanet import SeanetConfig, SeanetEncoder, SeanetDecoder
910
from .kv_cache import KVCache, RotatingKVCache
1011
from .transformer import Transformer, TransformerConfig
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Kyutai, all rights reserved.
2+
# This source code is licensed under the license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
import mlx.core as mx
6+
import mlx.nn as nn
7+
8+
class EuclideanCodebook(nn.Module):
9+
def __init__(self, dim: int, codebook_size: int):
10+
super().__init__()
11+
self._epsilon = 1e-5
12+
self._dim = dim
13+
14+
def __call__(self, xs: mx.array) -> mx.array:
15+
return xs
16+
17+
class VectorQuantization(nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def __call__(self, xs: mx.array) -> mx.array:
22+
return xs
23+
24+
class ResidualVectorQuantization(nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
28+
def __call__(self, xs: mx.array) -> mx.array:
29+
return xs
30+
31+
class SplitResidualVectorQuantizer(nn.Module):
32+
def __init__(self):
33+
super().__init__()
34+
35+
def __call__(self, xs: mx.array) -> mx.array:
36+
return xs

moshi_mlx/moshi_mlx/modules/seanet.py

+6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class SeanetConfig:
2727

2828
class SeanetResnetBlock(nn.Module):
2929
def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list):
30+
super().__init__()
3031
block = []
3132
hidden = dim // cfg.compress
3233
for i, (ksize, dilation) in enumerate(ksizes_and_dilations):
@@ -80,6 +81,7 @@ def __call__(self, xs: mx.array) -> mx.array:
8081

8182
class EncoderLayer(nn.Module):
8283
def __init__(self, cfg: SeanetConfig, ratio: int, mult: int):
84+
super().__init__()
8385
residuals = []
8486
dilation = 1
8587
for _ in range(cfg.nresidual_layers):
@@ -115,6 +117,7 @@ def __call__(self, xs: mx.array) -> mx.array:
115117

116118
class SeanetEncoder(nn.Module):
117119
def __init__(self, cfg: SeanetConfig):
120+
super().__init__()
118121
mult = 1
119122
self.init_conv1d = StreamableConv1d(
120123
in_channels=cfg.channels,
@@ -159,6 +162,7 @@ def __call__(self, xs: mx.array) -> mx.array:
159162

160163
class DecoderLayer(nn.Module):
161164
def __init__(self, cfg: SeanetConfig, ratio: int, mult: int):
165+
super().__init__()
162166
self.upsample = StreamableConvTranspose1d(
163167
in_channels=mult * cfg.nfilters,
164168
out_channels=mult * cfg.nfilters // 2,
@@ -183,6 +187,7 @@ def __call__(self, xs: mx.array) -> mx.array:
183187

184188
class SeanetDecoder(nn.Module):
185189
def __init__(self, cfg: SeanetConfig):
190+
super().__init__()
186191
mult = 1 << len(cfg.ratios)
187192
self.init_conv1d = StreamableConv1d(
188193
in_channels=cfg.dimension,
@@ -227,5 +232,6 @@ def __call__(self, xs: mx.array) -> mx.array:
227232

228233
class Seanet(nn.Module):
229234
def __init__(self, cfg: SeanetConfig):
235+
super().__init__()
230236
self.encoder = SeanetEncoder(cfg)
231237
self.decoder = SeanetDecoder(cfg)

0 commit comments

Comments
 (0)