From da3d12773fc515bd5c9e73f0654a92e80880f7f6 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 20 Feb 2025 10:58:04 +0100 Subject: [PATCH 01/22] Start adding mimi to the mlx codebase. --- moshi_mlx/moshi_mlx/modules/__init__.py | 1 + moshi_mlx/moshi_mlx/modules/conv.py | 219 ++++++++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 moshi_mlx/moshi_mlx/modules/conv.py diff --git a/moshi_mlx/moshi_mlx/modules/__init__.py b/moshi_mlx/moshi_mlx/modules/__init__.py index 10a4218e..e005b182 100644 --- a/moshi_mlx/moshi_mlx/modules/__init__.py +++ b/moshi_mlx/moshi_mlx/modules/__init__.py @@ -4,5 +4,6 @@ # flake8: noqa """Modules used for building the models.""" +from .conv import Conv1d, ConvTranspose1d from .kv_cache import KVCache, RotatingKVCache from .transformer import Transformer, TransformerConfig diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py new file mode 100644 index 00000000..904abfe3 --- /dev/null +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -0,0 +1,219 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import mlx.core as mx +import mlx.nn as nn + +class Conv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ksize: int, + stride: int = 1, + padding: int = 0, + groups: int = 1, + dilation: int = 1, + bias: bool = True + ): + super().__init__() + nn.Conv1d + scale = 1 / (in_channels * ksize) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, ksize, in_channels // groups), + ) + self.bias = None + if bias: + self.bias = mx.zeros(out_channels) + self._padding = padding + self._groups = groups + self._stride = stride + self._dilation = dilation + + def __call__(self, xs: mx.array) -> mx.array: + # MLX uses NLC whereas pytorch/candle use NCL + y = mx.conv1d( + xs.swapaxes(-1, -2), + self.weight, + stride=self._stride, + padding=self._padding, + dilation=self._dilation, + groups=self._groups + ) + if self.bias is not None: + y = y + self.bias + return y.swapaxes(-1, -2) + +class ConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ksize: int, + stride: int = 1, + padding: int = 0, + groups: int = 1, + bias: bool = True + ): + super().__init__() + nn.Conv1d + scale = 1 / (in_channels * ksize) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels // groups, ksize, in_channels), + ) + self.bias = None + if bias: + self.bias = mx.zeros(out_channels) + self._padding = padding + self._groups = groups + self._stride = stride + + def __call__(self, xs: mx.array) -> mx.array: + y = mx.conv_transpose1d( + xs.swapaxes(-1, -2), + self.weight, + stride=self._stride, + padding=self._padding, + groups=self._groups, + ) + if self.bias is not None: + y = y + self.bias + return y + +class NormConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ksize: int, + stride: int = 1, + padding: int = 0, + groups: int = 1, + dilation: int = 1, + bias: bool = True, + ): + self.conv = Conv1d( + in_channels, + out_channels, + ksize, + stride=stride, + padding=padding, + groups=groups, + dilation=dilation, + bias=bias + ) + + def __call__(self, xs: mx.array) -> mx.array: + return self.conv(xs) + +class NormConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ksize: int, + stride: int = 1, + padding: int = 0, + groups: int = 1, + bias: bool = True, + ): + self.convtr = ConvTranspose1d( + in_channels, + out_channels, + ksize, + stride=stride, + padding=padding, + groups=groups, + bias=bias + ) + + def __call__(self, xs: mx.array) -> mx.array: + return self.convtr(xs) + +def get_extra_padding_for_conv1d( + xs: mx.array, + ksize: int, + stride: int, + padding_total: int, +) -> int: + l = xs.shape[-1] + nframes = max(l + padding_total - ksize, 0) / stride + 1.0 + ideal_len = (int(math.ceil(nframes)) - 1) * stride + ksize - padding_total + return max(0, ideal_len - l) + +def unpad1d(xs: mx.array, unpad_l: int, unpad_r: int) -> mx.array: + left = unpad_l + right = xs.shape[-1] - unpad_r + return xs[..., left:right] + +# TODO(laurent): add a streaming module abstract class? +class StreamableConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ksize: int, + stride: int, + dilation: int, + groups: int, + bias: bool, + causal: bool, + pad_mode: str, + ): + self._causal = causal + self._pad_mode = pad_mode + self._ksize = ksize + self.conv = NormConv1d( + in_channels, + out_channels, + ksize, + stride=stride, + groups=groups, + dilation=dilation, + bias=bias, + ) + self._prev_xs = None + self._left_pad_applied = False + self._out_channels = out_channels + + def reset(self): + self._prev_xs = None + self._left_pad_applied = False + + def __call__(self, xs: mx.array) -> mx.array: + b, _, l = xs.shape + if l == 0: + return mx.zeros((b, self._out_channels, 0)) + stride = self.conv.conv._stride + dilation = self.conv.conv._dilation + ksize = (self._ksize - 1) * dilation + 1 + if not self._left_pad_applied: + self._left_pad_applied + padding_total = ksize - stride + xs = mx.pad( + xs, + pad_width=((0, 0), (0, 0), (padding_total, 0)), + mode=self._pad_mode + ) + if self._prev_xs is not None: + xs = mx.concat([self._prev_xs, xs], axis=-1) + l = xs.shape[-1] + nframes = max(l + stride - ksize, 0) // stride + if nframes > 0: + offset = nframes * stride + self._prev_xs = xs[..., offset:] + in_l = (nframes - 1) * stride + ksize + if in_l > 0: + xs = xs[..., 0:in_l] + return self.conv(xs) + else: + return mx.zeros((b, self._out_channels, 0)) + else: + self._prev_xs = xs + return mx.zeros((b, self._out_channels, 0)) From a0971b64be7209881697c2bda73d864ebecfd2a5 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 20 Feb 2025 11:17:32 +0100 Subject: [PATCH 02/22] Finish the conv module. --- moshi_mlx/moshi_mlx/modules/conv.py | 94 ++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py index 904abfe3..7d93adf8 100644 --- a/moshi_mlx/moshi_mlx/modules/conv.py +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -182,7 +182,7 @@ def __init__( self._left_pad_applied = False self._out_channels = out_channels - def reset(self): + def reset_state(self): self._prev_xs = None self._left_pad_applied = False @@ -217,3 +217,95 @@ def __call__(self, xs: mx.array) -> mx.array: else: self._prev_xs = xs return mx.zeros((b, self._out_channels, 0)) + +class StreamableConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ksize: int, + stride: int, + groups: int, + bias: bool, + causal: bool, + ): + self._causal = causal + self._ksize = ksize + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + ksize, + stride=stride, + groups=groups, + bias=bias, + ) + self._prev_ys = None + + def reset_state(self): + self._prev_ys = None + + def __call__(self, xs: mx.array) -> mx.array: + b, _, l = xs.shape + if l == 0: + return mx.zeros((b, self._out_channels, 0)) + stride = self.convtr.convtr._stride + ys = self.convtr(xs) + ot = ys.shape[-1] + if self._prev_ys is not None: + prev_ys = self._prev_ys + pt = prev_ys.shape[-1] + if self.convtr.convtr.bias is not None: + prev_ys = prev_ys - self.convtr.convtr.bias[None, :, None] + ys1, ys2 = ys[..., :pt] + prev_ys, ys[..., pt:] + ys = mx.concat([ys1, ys2], axis=-1) + invalid_steps = self._ksize - stride + ys, self._prev_ys = ys[..., :ot-invalid_steps], ys[..., ot-invalid_steps] + return ys + +class ConvDownsample1d(nn.Module): + def __init__( + self, + stride: int, + dim: int, + causal: bool + ): + self.conv = StreamableConv1d( + in_channels=dim, + out_channels=dim, + ksize=2*stride, + stride=stride, + dilation=1, + groups=1, + bias=False, + causal=causal, + pad_mode="edge", + ) + + def reset_state(self): + self.conv.reset_state() + + def __call__(self, xs: mx.array) -> mx.array: + return self.conv(xs) + +class ConvTrUpsample1d(nn.Module): + def __init__( + self, + stride: int, + dim: int, + causal: bool + ): + self.convtr = StreamableConvTranspose1d( + in_channels=dim, + out_channels=dim, + ksize=2*stride, + stride=stride, + groups=dim, # TODO: hopefully someday this will be fixed. + bias=False, + causal=causal, + ) + + def reset_state(self): + self.convtr.reset_state() + + def __call__(self, xs: mx.array) -> mx.array: + return self.convtr(xs) From 09fdef1b1e919f5fffe022f9c8a6b855363317bc Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 20 Feb 2025 11:41:53 +0100 Subject: [PATCH 03/22] Start adding the seanet module. --- moshi_mlx/moshi_mlx/modules/__init__.py | 3 +- moshi_mlx/moshi_mlx/modules/seanet.py | 156 ++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 moshi_mlx/moshi_mlx/modules/seanet.py diff --git a/moshi_mlx/moshi_mlx/modules/__init__.py b/moshi_mlx/moshi_mlx/modules/__init__.py index e005b182..3bd00ec7 100644 --- a/moshi_mlx/moshi_mlx/modules/__init__.py +++ b/moshi_mlx/moshi_mlx/modules/__init__.py @@ -4,6 +4,7 @@ # flake8: noqa """Modules used for building the models.""" -from .conv import Conv1d, ConvTranspose1d +from .conv import Conv1d, ConvTranspose1d, StreamableConv1d, StreamableConvTranspose1d, NormConv1d, NormConvTranspose1d, ConvDownsample1d, ConvTrUpsample1d +from .seanet import SeanetConfig, Seanet from .kv_cache import KVCache, RotatingKVCache from .transformer import Transformer, TransformerConfig diff --git a/moshi_mlx/moshi_mlx/modules/seanet.py b/moshi_mlx/moshi_mlx/modules/seanet.py new file mode 100644 index 00000000..7fbb8f45 --- /dev/null +++ b/moshi_mlx/moshi_mlx/modules/seanet.py @@ -0,0 +1,156 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from .conv import StreamableConv1d, StreamableConvTranspose1d + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class SeanetConfig: + dimension: int + channels: int + causal: bool + nfilters: int + nresidual_layers: int + ratios: list[int] + ksize: int + residual_ksize: int + last_ksize: int + dilation_base: int + pad_mode: str + true_skip: bool + compress: int + +class SeanetResnetBlock(nn.Module): + pass + + def reset_state(self): + pass + +class EncoderLayer(nn.Module): + def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): + pass + + def reset_state(self): + pass + +class SeanetEncoder(nn.Module): + def __init__(self, cfg: SeanetConfig): + mult = 1 + self.init_conv1d = StreamableConv1d( + in_channels=cfg.channels, + out_channels=mult * cfg.nfilters, + ksize=cfg.ksize, + stride=1, + dilation=1, + groups=1, + bias=True, + causal=cfg.causal, + pad_mode=cfg.pad_mode, + ) + layers = [] + for ratio in reversed(cfg.ratios): + layers.append(EncoderLayer(cfg, ratio=ratio, mult=mult)) + mult *= 2 + self.layers = layers + self.final_conv1d = StreamableConv1d( + in_channels=mult * cfg.nfilters, + out_channels=cfg.dimension, + ksize=cfg.last_ksize, + stride=1, + dilation=1, + groups=1, + bias=True, + causal=cfg.causal, + pad_mode=cfg.pad_mode, + ) + + def reset_state(self): + self.init_conv1d.reset_state() + self.final_conv1d.reset_state() + for layer in self.layers: + layer.reset_state() + + def __call__(self, xs: mx.array) -> mx.array: + xs = self.init_conv1d(xs) + for layer in self.layers: + xs = layer(xs) + xs = nn.elu(xs, alpha=1.0) + return self.final_conv1d(xs) + +class DecoderLayer(nn.Module): + def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): + self.upsample = StreamableConvTranspose1d( + in_channels=mult * cfg.nfilters, + out_channels=mult * cfg.nfilters // 2, + ksize=ratio * 2, + stride=ratio, + groups=1, + bias=True, + causal=cfg.causal, + ) + self.residuals = [] + + def reset_state(self): + self.upsample.reset_state() + for r in self.residuals: + r.reset_state() + + def __call__(self, xs: mx.array) -> mx.array: + xs = self.upsample(nn.elu(xs, alpha=1.0)) + for r in self.residuals: + xs = r(xs) + return xs + +class SeanetDecoder(nn.Module): + def __init__(self, cfg: SeanetConfig): + mult = 1 << len(cfg.ratios) + self.init_conv1d = StreamableConv1d( + in_channels=cfg.dimension, + out_channels=mult * cfg.nfilters, + ksize=cfg.ksize, + stride=1, + dilation=1, + groups=1, + bias=True, + causal=cfg.causal, + pad_mode=cfg.pad_mode, + ) + layers = [] + for ratio in cfg.ratios: + layers.append(DecoderLayer(cfg, ratio=ratio, mult=mult)) + mult //= 2 + self.layers = layers + self.final_conv1d = StreamableConv1d( + in_channels=cfg.nfilters, + out_channels=cfg.channels, + ksize=cfg.last_ksize, + stride=1, + dilation=1, + groups=1, + bias=True, + causal=cfg.causal, + pad_mode=cfg.pad_mode, + ) + + def reset_state(self): + self.init_conv1d.reset_state() + self.final_conv1d.reset_state() + for layer in self.layers: + layer.reset_state() + + def __call__(self, xs: mx.array) -> mx.array: + xs = self.init_conv1d(xs) + for layer in self.layers: + xs = layer(xs) + xs = nn.elu(xs, alpha=1.0) + return self.final_conv1d(xs) + +class Seanet(nn.Module): + def __init__(self, cfg: SeanetConfig): + self.encoder = SeanetEncoder(cfg) + self.decoder = SeanetDecoder(cfg) From fc7c728972328ec249e24f02a6c430fdefd43d16 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 20 Feb 2025 13:00:29 +0100 Subject: [PATCH 04/22] More seanet. --- moshi_mlx/moshi_mlx/modules/seanet.py | 54 ++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/seanet.py b/moshi_mlx/moshi_mlx/modules/seanet.py index 7fbb8f45..dac749c6 100644 --- a/moshi_mlx/moshi_mlx/modules/seanet.py +++ b/moshi_mlx/moshi_mlx/modules/seanet.py @@ -26,10 +26,57 @@ class SeanetConfig: compress: int class SeanetResnetBlock(nn.Module): - pass + def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list): + block = [] + hidden = dim // cfg.compress + for i, (ksize, dilation) in enumerate(ksizes_and_dilations): + in_channels = dim if i == 0 else hidden + out_channels = dim if i == len(ksizes_and_dilations) - 1 else hidden + c = StreamableConv1d( + in_channels=in_channels, + out_channels=out_channels, + ksize=ksize, + stride=1, + dilation=dilation, + groups=1, + bias=True, + causal=cfg.causal, + pad_mode=cfg.pad_mode, + ) + block.append(c) + self.block = block + + if cfg.true_skip: + self.shortcut = None + else: + self.shortcut = StreamableConv1d( + in_channels=dim, + out_channels=dim, + ksize=1, + stride=1, + dilation=1, + groups=1, + bias=True, + causal=cfg.causal, + pad_mode=cfg.pad_mode, + ) def reset_state(self): - pass + if self.shortcut is not None: + self.shortcut.reset_state() + for b in self.block: + b.reset_state() + + def __call__(self, xs: mx.array) -> mx.array: + residual = xs + for b in self.block: + xs = b(nn.elu(xs, alpha=1.0)) + # TODO(laurent): we might need some streaming additions below. + if self.shortcut is None: + xs = xs + residual + else: + xs = xs + self.shortcut(residual) + return xs class EncoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): @@ -38,6 +85,9 @@ def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): def reset_state(self): pass + def __call__(self, xs: mx.array) -> mx.array: + return xs + class SeanetEncoder(nn.Module): def __init__(self, cfg: SeanetConfig): mult = 1 From 1395d2f00e321bb0892b107a39214565ff2fa5de Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 20 Feb 2025 14:40:58 +0100 Subject: [PATCH 05/22] Add to seanet. --- moshi_mlx/moshi_mlx/modules/seanet.py | 31 ++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/seanet.py b/moshi_mlx/moshi_mlx/modules/seanet.py index dac749c6..74532881 100644 --- a/moshi_mlx/moshi_mlx/modules/seanet.py +++ b/moshi_mlx/moshi_mlx/modules/seanet.py @@ -80,13 +80,38 @@ def __call__(self, xs: mx.array) -> mx.array: class EncoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): - pass + residuals = [] + dilation = 1 + for _ in range(cfg.nresidual_layers): + b = SeanetResnetBlock( + cfg, + dim=mult * cfg.nfilters, + ksizes_and_dilations=[(cfg.residual_ksize, dilation), (1, 1)], + ) + residuals.append(b) + dilation *= cfg.dilation_base + self.residuals = residuals + self.downsample = StreamableConv1d( + in_channels=mult * cfg.nfilters, + out_channels=mult * cfg.nfilters * 2, + ksize=ratio * 2, + stride=ratio, + dilation=1, + groups=1, + bias=True, + causal=True, + pad_mode=cfg.pad_mode, + ) def reset_state(self): - pass + self.downsample.reset_state() + for r in self.residuals: + r.reset_state() def __call__(self, xs: mx.array) -> mx.array: - return xs + for r in self.residuals: + xs = r(xs) + return self.downsample(nn.elu(xs, alpha=1.0)) class SeanetEncoder(nn.Module): def __init__(self, cfg: SeanetConfig): From 998adca5c2746943f178491bf81878bddcdced3b Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 20 Feb 2025 14:46:10 +0100 Subject: [PATCH 06/22] Start adding some quantization. --- moshi_mlx/moshi_mlx/modules/__init__.py | 3 +- moshi_mlx/moshi_mlx/modules/quantization.py | 36 +++++++++++++++++++++ moshi_mlx/moshi_mlx/modules/seanet.py | 6 ++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 moshi_mlx/moshi_mlx/modules/quantization.py diff --git a/moshi_mlx/moshi_mlx/modules/__init__.py b/moshi_mlx/moshi_mlx/modules/__init__.py index 3bd00ec7..275e5564 100644 --- a/moshi_mlx/moshi_mlx/modules/__init__.py +++ b/moshi_mlx/moshi_mlx/modules/__init__.py @@ -5,6 +5,7 @@ """Modules used for building the models.""" from .conv import Conv1d, ConvTranspose1d, StreamableConv1d, StreamableConvTranspose1d, NormConv1d, NormConvTranspose1d, ConvDownsample1d, ConvTrUpsample1d -from .seanet import SeanetConfig, Seanet +from .quantization import SplitResidualVectorQuantizer +from .seanet import SeanetConfig, SeanetEncoder, SeanetDecoder from .kv_cache import KVCache, RotatingKVCache from .transformer import Transformer, TransformerConfig diff --git a/moshi_mlx/moshi_mlx/modules/quantization.py b/moshi_mlx/moshi_mlx/modules/quantization.py new file mode 100644 index 00000000..29b25cbd --- /dev/null +++ b/moshi_mlx/moshi_mlx/modules/quantization.py @@ -0,0 +1,36 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import mlx.core as mx +import mlx.nn as nn + +class EuclideanCodebook(nn.Module): + def __init__(self, dim: int, codebook_size: int): + super().__init__() + self._epsilon = 1e-5 + self._dim = dim + + def __call__(self, xs: mx.array) -> mx.array: + return xs + +class VectorQuantization(nn.Module): + def __init__(self): + super().__init__() + + def __call__(self, xs: mx.array) -> mx.array: + return xs + +class ResidualVectorQuantization(nn.Module): + def __init__(self): + super().__init__() + + def __call__(self, xs: mx.array) -> mx.array: + return xs + +class SplitResidualVectorQuantizer(nn.Module): + def __init__(self): + super().__init__() + + def __call__(self, xs: mx.array) -> mx.array: + return xs diff --git a/moshi_mlx/moshi_mlx/modules/seanet.py b/moshi_mlx/moshi_mlx/modules/seanet.py index 74532881..6356c01c 100644 --- a/moshi_mlx/moshi_mlx/modules/seanet.py +++ b/moshi_mlx/moshi_mlx/modules/seanet.py @@ -27,6 +27,7 @@ class SeanetConfig: class SeanetResnetBlock(nn.Module): def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list): + super().__init__() block = [] hidden = dim // cfg.compress for i, (ksize, dilation) in enumerate(ksizes_and_dilations): @@ -80,6 +81,7 @@ def __call__(self, xs: mx.array) -> mx.array: class EncoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): + super().__init__() residuals = [] dilation = 1 for _ in range(cfg.nresidual_layers): @@ -115,6 +117,7 @@ def __call__(self, xs: mx.array) -> mx.array: class SeanetEncoder(nn.Module): def __init__(self, cfg: SeanetConfig): + super().__init__() mult = 1 self.init_conv1d = StreamableConv1d( in_channels=cfg.channels, @@ -159,6 +162,7 @@ def __call__(self, xs: mx.array) -> mx.array: class DecoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): + super().__init__() self.upsample = StreamableConvTranspose1d( in_channels=mult * cfg.nfilters, out_channels=mult * cfg.nfilters // 2, @@ -183,6 +187,7 @@ def __call__(self, xs: mx.array) -> mx.array: class SeanetDecoder(nn.Module): def __init__(self, cfg: SeanetConfig): + super().__init__() mult = 1 << len(cfg.ratios) self.init_conv1d = StreamableConv1d( in_channels=cfg.dimension, @@ -227,5 +232,6 @@ def __call__(self, xs: mx.array) -> mx.array: class Seanet(nn.Module): def __init__(self, cfg: SeanetConfig): + super().__init__() self.encoder = SeanetEncoder(cfg) self.decoder = SeanetDecoder(cfg) From d0659a0d54dc53ff9565f93e07282ecb92a61cef Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 20 Feb 2025 23:59:16 +0100 Subject: [PATCH 07/22] More quantization support. --- moshi_mlx/moshi_mlx/modules/quantization.py | 158 ++++++++++++++++++-- 1 file changed, 147 insertions(+), 11 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/quantization.py b/moshi_mlx/moshi_mlx/modules/quantization.py index 29b25cbd..5fa433db 100644 --- a/moshi_mlx/moshi_mlx/modules/quantization.py +++ b/moshi_mlx/moshi_mlx/modules/quantization.py @@ -2,6 +2,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from .conv import Conv1d + import mlx.core as mx import mlx.nn as nn @@ -10,27 +12,161 @@ def __init__(self, dim: int, codebook_size: int): super().__init__() self._epsilon = 1e-5 self._dim = dim + self._initialized = mx.zeros([1], dtype=mx.float32) + self.embedding_sum = mx.zeros([codebook_size, dim], dtype=mx.float32) + self.cluster_usage = mx.zeros([codebook_size], dtype=mx.float32) + cluster_usage = mx.maximum(self.cluster_usage, self._epsilon)[:, None] + self.embedding = self.embedding_sum / cluster_usage + self.c2 = self.embedding.square().sum(axis=-1) / 2 + + def update(self, parameters: dict) -> nn.Module: + super().update(parameters) + cluster_usage = mx.maximum(self.cluster_usage, self._epsilon)[:, None] + self.embedding = self.embedding_sum / cluster_usage + self.c2 = self.embedding.square().sum(axis=-1) / 2 + return self + + def encode(self, xs: mx.array) -> mx.array: + target_shape = xs.shape[:-1] + xs = xs.flatten(start_axis=-2) + dot_prod = xs @ self.embedding.swapaxes(-1, -2) + return (self.c2 - dot_prod).min(axis=-1).reshape(target_shape) - def __call__(self, xs: mx.array) -> mx.array: - return xs + def decode(self, xs: mx.array) -> mx.array: + target_shape = list(xs.shape) + [self._dim] + return mx.take(self.embedding, xs.flatten()).reshape(target_shape) class VectorQuantization(nn.Module): - def __init__(self): + def __init__(self, dim: int, codebook_size: int, codebook_dim: int | None): super().__init__() + codebook_dim = dim if codebook_dim is None else codebook_dim + if dim == codebook_dim: + self.project_in = None + self.project_out = None + else: + self.project_in = nn.Linear(dim, codebook_dim) + self.project_out = nn.Linear(codebook_dim, dim) + self._codebook = EuclideanCodebook(dim=codebook_dim, codebook_size=codebook_size) - def __call__(self, xs: mx.array) -> mx.array: - return xs + def encode(self, xs: mx.array) -> mx.array: + xs = xs.swapaxes(-1, -2) + if self.project_in is not None: + xs = self.project_in(xs) + return self._codebook.encode(xs) + + def decode(self, xs: mx.array) -> mx.array: + xs = self._codebook.decode(xs) + if self.project_out is not None: + xs = self.project_out(xs) + return xs.swapaxes(-1, -2) class ResidualVectorQuantization(nn.Module): - def __init__(self): + def __init__(self, nq: int, dim: int, codebook_size: int, codebook_dim: int | None): + super().__init__() + layers = [] + for _ in range(nq): + vq = VectorQuantization( + dim=dim, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + ) + layers.append(vq) + self.layers = layers + + def encode(self, xs: mx.array) -> mx.array: + codes = [] + residual = xs + for layer in self.layers: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + codes.append(indices) + return mx.concat(codes, axis=0) + + def decode(self, xs: mx.array) -> mx.array: + seq_len = xs.shape[0] + quantized = self.layers[0].decode(xs[0]) + for i in range(1, seq_len): + quantized = quantized + self.layers[i].decode(xs[i]) + return quantized + +class ResidualVectorQuantizer(nn.Module): + def __init__( + self, + dim: int, + input_dim: int | None, + output_dim: int | None, + nq: int, + bins: int, + force_projection: bool, + ): super().__init__() + input_dim = dim if input_dim is None else input_dim + output_dim = dim if output_dim is None else output_dim + if input_dim == dim and not force_projection: + self.input_proj = None + else: + self.input_proj = Conv1d(input_dim, dim, 1, bias=False) + if output_dim == dim and not force_projection: + self.output_proj = None + else: + self.output_proj = Conv1d(dim, output_dim, 1, bias=False) + self.vq = ResidualVectorQuantization( + nq=nq, + dim=dim, + codebook_size=bins, + codebook_dim=None, + ) - def __call__(self, xs: mx.array) -> mx.array: - return xs + def encode(self, xs: mx.array) -> mx.array: + if self.input_proj is not None: + xs = self.input_proj(xs) + return self.vq.encode(xs).transpose(0, 1) + + def decode(self, xs: mx.array) -> mx.array: + xs = xs.swapaxes(0, 1) + quantized = self.vq.decode(xs) + if self.output_proj is not None: + quantized = self.output_proj(quantized) + return quantized class SplitResidualVectorQuantizer(nn.Module): - def __init__(self): + def __init__( + self, + dim: int, + input_dim: int | None, + output_dim: int | None, + nq: int, + bins: int + ): super().__init__() + self._nq = nq + self.rvq_first = ResidualVectorQuantizer( + dim=dim, + input_dim=input_dim, + output_dim=output_dim, + nq=1, + bins=bins, + force_projection=True, + ) + self.rvq_rest = ResidualVectorQuantizer( + dim=dim, + input_dim=input_dim, + output_dim=output_dim, + nq=nq-1, + bins=bins, + force_projection=True + ) + + def encode(self, xs: mx.array) -> mx.array: + codes = self.rvq_first.encode(xs) + if self._nq > 1: + rest_codes = self.rvq_rest(xs) + codes = mx.concat([codes, rest_codes], axis=1) + return codes - def __call__(self, xs: mx.array) -> mx.array: - return xs + def decode(self, xs: mx.array) -> mx.array: + quantized = self.rvq_first.decode(xs[:, :1]) + if self._nq > 1: + quantized = quantized + self.rvq_rest.decode(xs[:, 1:]) + return quantized From a8ef5308da2b28fb125800f42d688346a86ba2e8 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 10:13:44 +0100 Subject: [PATCH 08/22] Formatting. --- moshi_mlx/moshi_mlx/modules/conv.py | 21 ++++++++++++++++----- moshi_mlx/moshi_mlx/modules/quantization.py | 15 ++++++++++----- moshi_mlx/moshi_mlx/modules/seanet.py | 6 ++++++ 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py index 7d93adf8..e29ba812 100644 --- a/moshi_mlx/moshi_mlx/modules/conv.py +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -6,6 +6,7 @@ import mlx.core as mx import mlx.nn as nn + class Conv1d(nn.Module): def __init__( self, @@ -48,6 +49,7 @@ def __call__(self, xs: mx.array) -> mx.array: y = y + self.bias return y.swapaxes(-1, -2) + class ConvTranspose1d(nn.Module): def __init__( self, @@ -86,6 +88,7 @@ def __call__(self, xs: mx.array) -> mx.array: y = y + self.bias return y + class NormConv1d(nn.Module): def __init__( self, @@ -112,6 +115,7 @@ def __init__( def __call__(self, xs: mx.array) -> mx.array: return self.conv(xs) + class NormConvTranspose1d(nn.Module): def __init__( self, @@ -136,10 +140,11 @@ def __init__( def __call__(self, xs: mx.array) -> mx.array: return self.convtr(xs) + def get_extra_padding_for_conv1d( xs: mx.array, ksize: int, - stride: int, + stride: int, padding_total: int, ) -> int: l = xs.shape[-1] @@ -147,12 +152,15 @@ def get_extra_padding_for_conv1d( ideal_len = (int(math.ceil(nframes)) - 1) * stride + ksize - padding_total return max(0, ideal_len - l) + def unpad1d(xs: mx.array, unpad_l: int, unpad_r: int) -> mx.array: left = unpad_l right = xs.shape[-1] - unpad_r return xs[..., left:right] # TODO(laurent): add a streaming module abstract class? + + class StreamableConv1d(nn.Module): def __init__( self, @@ -218,6 +226,7 @@ def __call__(self, xs: mx.array) -> mx.array: self._prev_xs = xs return mx.zeros((b, self._out_channels, 0)) + class StreamableConvTranspose1d(nn.Module): def __init__( self, @@ -259,9 +268,10 @@ def __call__(self, xs: mx.array) -> mx.array: ys1, ys2 = ys[..., :pt] + prev_ys, ys[..., pt:] ys = mx.concat([ys1, ys2], axis=-1) invalid_steps = self._ksize - stride - ys, self._prev_ys = ys[..., :ot-invalid_steps], ys[..., ot-invalid_steps] + ys, self._prev_ys = ys[..., :ot - invalid_steps], ys[..., ot - invalid_steps] return ys + class ConvDownsample1d(nn.Module): def __init__( self, @@ -272,7 +282,7 @@ def __init__( self.conv = StreamableConv1d( in_channels=dim, out_channels=dim, - ksize=2*stride, + ksize=2 * stride, stride=stride, dilation=1, groups=1, @@ -287,6 +297,7 @@ def reset_state(self): def __call__(self, xs: mx.array) -> mx.array: return self.conv(xs) + class ConvTrUpsample1d(nn.Module): def __init__( self, @@ -297,9 +308,9 @@ def __init__( self.convtr = StreamableConvTranspose1d( in_channels=dim, out_channels=dim, - ksize=2*stride, + ksize=2 * stride, stride=stride, - groups=dim, # TODO: hopefully someday this will be fixed. + groups=dim, # TODO: hopefully someday this will be fixed. bias=False, causal=causal, ) diff --git a/moshi_mlx/moshi_mlx/modules/quantization.py b/moshi_mlx/moshi_mlx/modules/quantization.py index 5fa433db..a040fe75 100644 --- a/moshi_mlx/moshi_mlx/modules/quantization.py +++ b/moshi_mlx/moshi_mlx/modules/quantization.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn + class EuclideanCodebook(nn.Module): def __init__(self, dim: int, codebook_size: int): super().__init__() @@ -36,6 +37,7 @@ def decode(self, xs: mx.array) -> mx.array: target_shape = list(xs.shape) + [self._dim] return mx.take(self.embedding, xs.flatten()).reshape(target_shape) + class VectorQuantization(nn.Module): def __init__(self, dim: int, codebook_size: int, codebook_dim: int | None): super().__init__() @@ -60,6 +62,7 @@ def decode(self, xs: mx.array) -> mx.array: xs = self.project_out(xs) return xs.swapaxes(-1, -2) + class ResidualVectorQuantization(nn.Module): def __init__(self, nq: int, dim: int, codebook_size: int, codebook_dim: int | None): super().__init__() @@ -90,6 +93,7 @@ def decode(self, xs: mx.array) -> mx.array: quantized = quantized + self.layers[i].decode(xs[i]) return quantized + class ResidualVectorQuantizer(nn.Module): def __init__( self, @@ -112,10 +116,10 @@ def __init__( else: self.output_proj = Conv1d(dim, output_dim, 1, bias=False) self.vq = ResidualVectorQuantization( - nq=nq, - dim=dim, - codebook_size=bins, - codebook_dim=None, + nq=nq, + dim=dim, + codebook_size=bins, + codebook_dim=None, ) def encode(self, xs: mx.array) -> mx.array: @@ -130,6 +134,7 @@ def decode(self, xs: mx.array) -> mx.array: quantized = self.output_proj(quantized) return quantized + class SplitResidualVectorQuantizer(nn.Module): def __init__( self, @@ -153,7 +158,7 @@ def __init__( dim=dim, input_dim=input_dim, output_dim=output_dim, - nq=nq-1, + nq=nq - 1, bins=bins, force_projection=True ) diff --git a/moshi_mlx/moshi_mlx/modules/seanet.py b/moshi_mlx/moshi_mlx/modules/seanet.py index 6356c01c..3f7781c2 100644 --- a/moshi_mlx/moshi_mlx/modules/seanet.py +++ b/moshi_mlx/moshi_mlx/modules/seanet.py @@ -25,6 +25,7 @@ class SeanetConfig: true_skip: bool compress: int + class SeanetResnetBlock(nn.Module): def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list): super().__init__() @@ -79,6 +80,7 @@ def __call__(self, xs: mx.array) -> mx.array: xs = xs + self.shortcut(residual) return xs + class EncoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): super().__init__() @@ -115,6 +117,7 @@ def __call__(self, xs: mx.array) -> mx.array: xs = r(xs) return self.downsample(nn.elu(xs, alpha=1.0)) + class SeanetEncoder(nn.Module): def __init__(self, cfg: SeanetConfig): super().__init__() @@ -160,6 +163,7 @@ def __call__(self, xs: mx.array) -> mx.array: xs = nn.elu(xs, alpha=1.0) return self.final_conv1d(xs) + class DecoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): super().__init__() @@ -185,6 +189,7 @@ def __call__(self, xs: mx.array) -> mx.array: xs = r(xs) return xs + class SeanetDecoder(nn.Module): def __init__(self, cfg: SeanetConfig): super().__init__() @@ -230,6 +235,7 @@ def __call__(self, xs: mx.array) -> mx.array: xs = nn.elu(xs, alpha=1.0) return self.final_conv1d(xs) + class Seanet(nn.Module): def __init__(self, cfg: SeanetConfig): super().__init__() From 565ba717f213ada34daa48dec0c5cee6ed39a41e Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 10:17:30 +0100 Subject: [PATCH 09/22] Make pyright happy. --- moshi_mlx/moshi_mlx/modules/conv.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py index e29ba812..2afb55a0 100644 --- a/moshi_mlx/moshi_mlx/modules/conv.py +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -147,10 +147,10 @@ def get_extra_padding_for_conv1d( stride: int, padding_total: int, ) -> int: - l = xs.shape[-1] - nframes = max(l + padding_total - ksize, 0) / stride + 1.0 + len_ = xs.shape[-1] + nframes = max(len_ + padding_total - ksize, 0) / stride + 1.0 ideal_len = (int(math.ceil(nframes)) - 1) * stride + ksize - padding_total - return max(0, ideal_len - l) + return max(0, ideal_len - len_) def unpad1d(xs: mx.array, unpad_l: int, unpad_r: int) -> mx.array: @@ -195,8 +195,8 @@ def reset_state(self): self._left_pad_applied = False def __call__(self, xs: mx.array) -> mx.array: - b, _, l = xs.shape - if l == 0: + b, _, len_ = xs.shape + if len_ == 0: return mx.zeros((b, self._out_channels, 0)) stride = self.conv.conv._stride dilation = self.conv.conv._dilation @@ -211,8 +211,8 @@ def __call__(self, xs: mx.array) -> mx.array: ) if self._prev_xs is not None: xs = mx.concat([self._prev_xs, xs], axis=-1) - l = xs.shape[-1] - nframes = max(l + stride - ksize, 0) // stride + len_ = xs.shape[-1] + nframes = max(len_ + stride - ksize, 0) // stride if nframes > 0: offset = nframes * stride self._prev_xs = xs[..., offset:] @@ -254,8 +254,8 @@ def reset_state(self): self._prev_ys = None def __call__(self, xs: mx.array) -> mx.array: - b, _, l = xs.shape - if l == 0: + b, _, len_ = xs.shape + if len_ == 0: return mx.zeros((b, self._out_channels, 0)) stride = self.convtr.convtr._stride ys = self.convtr(xs) From 06ec20213f4f2de02fc71e7073f78b074e44e2c6 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 16:41:59 +0100 Subject: [PATCH 10/22] Add the mimi config. --- moshi_mlx/moshi_mlx/models/__init__.py | 1 + moshi_mlx/moshi_mlx/models/mimi.py | 71 ++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 moshi_mlx/moshi_mlx/models/mimi.py diff --git a/moshi_mlx/moshi_mlx/models/__init__.py b/moshi_mlx/moshi_mlx/models/__init__.py index db8d7105..5d24bc0a 100644 --- a/moshi_mlx/moshi_mlx/models/__init__.py +++ b/moshi_mlx/moshi_mlx/models/__init__.py @@ -15,3 +15,4 @@ config_helium_1_preview_2b, ) from .generate import LmGen +from .mimi import mimi_202407, MimiConfig diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py new file mode 100644 index 00000000..4d5bdd5a --- /dev/null +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -0,0 +1,71 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from ..modules import SeanetConfig, TransformerConfig + + +@dataclass +class MimiConfig: + channels: int + sample_rate: float + frame_rate: float + renormalize: bool + seanet: SeanetConfig + transformer: TransformerConfig + quantizer_nq: int + quantizer_bins: int + quantizer_dim: int + + +def mimi_202407(num_codebooks: int) -> MimiConfig: + seanet = SeanetConfig( + dimension=512, + channels=1, + causal=True, + nfilters=64, + nresidual_layers=1, + ratios=[8, 6, 5, 4], + ksize=7, + residual_ksize=3, + last_ksize=3, + dilation_base=2, + pad_mode="constant", + true_skip=True, + compress=2, + ) + transformer = TransformerConfig( + d_model=seanet.dimension, + num_heads=8, + num_layers=8, + causal=True, + norm_first=True, + bias_ff=False, + bias_attn=False, + layer_scale=0.01, + positional_embedding="rope", + use_conv_bias=True, + gating=False, + norm="layer_norm", + context=250, + max_period=10000, + max_seq_len=8192, + kv_repeat=1, + dim_feedforward=2048, + conv_layout=True, + use_conv_block=False, + cross_attention=False, + conv_kernel_size=3, + ) + return MimiConfig( + channels=1, + sample_rate=24000, + frame_rate=12.5, + renormalize=True, + seanet=seanet, + transformer=transformer, + quantizer_nq=num_codebooks, + quantizer_bins=2048, + quantizer_dim=256, + ) From 9a1b25fcfb17c30c9d328c7ad2063af92923bb71 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 17:16:02 +0100 Subject: [PATCH 11/22] More mimi implementation. --- moshi_mlx/moshi_mlx/models/mimi.py | 77 +++++++++++++++++++++- moshi_mlx/moshi_mlx/modules/__init__.py | 2 +- moshi_mlx/moshi_mlx/modules/transformer.py | 48 ++++++++++++++ 3 files changed, 125 insertions(+), 2 deletions(-) diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 4d5bdd5a..5c18ed9f 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -3,7 +3,20 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from ..modules import SeanetConfig, TransformerConfig +from ..modules import ( + SeanetConfig, + TransformerConfig, + SeanetEncoder, + SeanetDecoder, + SplitResidualVectorQuantizer, + ProjectedTransformer, + ConvDownsample1d, + ConvTrUpsample1d, +) +import math + +import mlx.core as mx +import mlx.nn as nn @dataclass @@ -69,3 +82,65 @@ def mimi_202407(num_codebooks: int) -> MimiConfig: quantizer_bins=2048, quantizer_dim=256, ) + + +class Mimi(nn.Module): + def __init__(self, cfg: MimiConfig): + super().__init__() + dim = cfg.seanet.dimension + self.cfg = cfg + encoder_frame_rate = cfg.sample_rate / math.prod(cfg.seanet.ratios) + downsample_stride = int(encoder_frame_rate / cfg.frame_rate) + self.encoder = SeanetEncoder(cfg.seanet) + self.decoder = SeanetDecoder(cfg.seanet) + self.quantizer = SplitResidualVectorQuantizer( + dim=cfg.quantizer_dim, + input_dim=dim, + output_dim=dim, + nq=cfg.quantizer_nq, + bins=cfg.quantizer_bins, + ) + self.encoder_transformer = ProjectedTransformer( + cfg.transformer, + input_dim=dim, + output_dims=[dim], + ) + self.decoder_transformer = ProjectedTransformer( + cfg.transformer, + input_dim=dim, + output_dims=[dim], + ) + self.downsample = ConvDownsample1d( + stride=downsample_stride, + dim=dim, + causal=True, + ) + self.upsample = ConvTrUpsample1d( + stride=downsample_stride, + dim=dim, + causal=True, + ) + self.encoder_cache = self.encoder_transformer.make_cache() + self.decoder_cache = self.decoder_transformer.make_cache() + + def reset_state(self): + self.encoder.reset_state() + self.decoder.reset_state() + + def encode(self, xs: mx.array) -> mx.array: + return xs + + def decode(self, xs: mx.array) -> mx.array: + return xs + + def encode_step(self, xs: mx.array) -> mx.array: + return xs + + def decode_step(self, xs: mx.array) -> mx.array: + return xs + + def warmup(self): + pcm = mx.zeros((1, 1, 1920 * 4)) + codes = self.encode(pcm) + pcm_out = self.decode(codes) + mx.eval(pcm_out) diff --git a/moshi_mlx/moshi_mlx/modules/__init__.py b/moshi_mlx/moshi_mlx/modules/__init__.py index 275e5564..157de6f1 100644 --- a/moshi_mlx/moshi_mlx/modules/__init__.py +++ b/moshi_mlx/moshi_mlx/modules/__init__.py @@ -8,4 +8,4 @@ from .quantization import SplitResidualVectorQuantizer from .seanet import SeanetConfig, SeanetEncoder, SeanetDecoder from .kv_cache import KVCache, RotatingKVCache -from .transformer import Transformer, TransformerConfig +from .transformer import Transformer, TransformerConfig, ProjectedTransformer diff --git a/moshi_mlx/moshi_mlx/modules/transformer.py b/moshi_mlx/moshi_mlx/modules/transformer.py index 31c1b060..5abc76e7 100644 --- a/moshi_mlx/moshi_mlx/modules/transformer.py +++ b/moshi_mlx/moshi_mlx/modules/transformer.py @@ -203,3 +203,51 @@ def make_rot_cache(self) -> list[RotatingKVCache]: ) for _ in self.layers ] + + +class ProjectedTransformer(nn.Module): + def __init__(self, cfg: TransformerConfig, input_dim: int, output_dims: list[int]): + super().__init__() + + self.conv_layout = cfg.conv_layout + self.transformer = Transformer(cfg) + if input_dim == cfg.d_model: + self.input_proj = None + else: + self.input_proj = nn.Linear(input_dim, cfg.d_model, bias=False) + + output_projs = [] + for output_dim in output_dims: + if output_dim == cfg.d_model: + p = None + else: + p = nn.Linear(cfg.d_model, output_dim, bias=False) + output_projs.append(p) + self.output_projs = output_projs + + def __call__( + self, + xs: mx.array, + cache: list[KVCache] | list[RotatingKVCache], + ) -> mx.array: + if self.conv_layout: + xs = xs.swapaxes(1, 2) + if self.input_proj is not None: + xs = self.input_proj(xs) + xs = self.transformer(xs, cache=cache) + outs = [] + for output_proj in self.output_projs: + if output_proj is None: + out = xs + else: + out = output_proj(xs) + if self.conv_layout: + out = out.swapaxes(1, 2) + outs.append(out) + return xs + + def make_cache(self) -> list[KVCache]: + return self.transformer.make_cache() + + def make_rot_cache(self) -> list[RotatingKVCache]: + return self.transformer.make_rot_cache() From 9e9a17e612142533a3326801b98a5f11f341dbde Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 17:40:02 +0100 Subject: [PATCH 12/22] Add a test script. --- moshi_mlx/moshi_mlx/models/mimi.py | 24 ++++++++++++++++++------ scripts/mimi_mlx.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 scripts/mimi_mlx.py diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 5c18ed9f..485c49ae 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -128,16 +128,28 @@ def reset_state(self): self.decoder.reset_state() def encode(self, xs: mx.array) -> mx.array: - return xs + self.encoder.reset_state() + for c in self.encoder_cache: + c.reset() + xs = self.encoder(xs) + xs = self.encoder_transformer(xs, cache=self.encoder_cache)[0] + xs = self.downsample(xs) + return self.quantizer.encode(xs) def decode(self, xs: mx.array) -> mx.array: - return xs + self.decoder.reset_state() + for c in self.decoder_cache: + c.reset() + xs = self.quantizer.decode(xs) + xs = self.upsample(xs) + xs = self.decoder_transformer(xs, cache=self.decoder_cache)[0] + return self.decoder(xs[0]) - def encode_step(self, xs: mx.array) -> mx.array: - return xs + def encode_step(self, _: mx.array) -> mx.array: + raise ValueError("TODO") - def decode_step(self, xs: mx.array) -> mx.array: - return xs + def decode_step(self, _: mx.array) -> mx.array: + raise ValueError("TODO") def warmup(self): pcm = mx.zeros((1, 1, 1920 * 4)) diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py new file mode 100644 index 00000000..80f0e795 --- /dev/null +++ b/scripts/mimi_mlx.py @@ -0,0 +1,28 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from huggingface_hub import hf_hub_download +import mlx.core as mx +import sphn +import moshi_mlx + + +def run(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--hf-repo", type=str, default="kyutai/moshiko-mlx-q4") + args = parser.parse_args() + + pcm_in, _ = sphn.read(args.input, sample_rate=24000) + pcm_in = mx.array(pcm_in[0]) + print(pcm_in.shape) + + weight_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") + cfg = moshi_mlx.models.mimi.mimi_202407(16) + mimi = moshi_mlx.models.mimi.Mimi(cfg) + mimi.encode(pcm_in) + +if __name__ == "__main__": + run() From 2664857be21d68729ea7a7e4550cedeca7ebf38a Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 17:43:36 +0100 Subject: [PATCH 13/22] Shape bugfixes. --- moshi_mlx/moshi_mlx/modules/transformer.py | 4 ++-- scripts/mimi_mlx.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/transformer.py b/moshi_mlx/moshi_mlx/modules/transformer.py index 5abc76e7..b1e3957e 100644 --- a/moshi_mlx/moshi_mlx/modules/transformer.py +++ b/moshi_mlx/moshi_mlx/modules/transformer.py @@ -229,7 +229,7 @@ def __call__( self, xs: mx.array, cache: list[KVCache] | list[RotatingKVCache], - ) -> mx.array: + ) -> list[mx.array]: if self.conv_layout: xs = xs.swapaxes(1, 2) if self.input_proj is not None: @@ -244,7 +244,7 @@ def __call__( if self.conv_layout: out = out.swapaxes(1, 2) outs.append(out) - return xs + return outs def make_cache(self) -> list[KVCache]: return self.transformer.make_cache() diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index 80f0e795..c8b25b39 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -16,7 +16,7 @@ def run(): args = parser.parse_args() pcm_in, _ = sphn.read(args.input, sample_rate=24000) - pcm_in = mx.array(pcm_in[0]) + pcm_in = mx.array(pcm_in[0])[None, None] print(pcm_in.shape) weight_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") From bfcb20d8ed01d9d06731175c931db79bd6e673be Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 17:53:57 +0100 Subject: [PATCH 14/22] Bugfixes... --- moshi_mlx/moshi_mlx/modules/quantization.py | 12 ++++++------ scripts/mimi_mlx.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/quantization.py b/moshi_mlx/moshi_mlx/modules/quantization.py index a040fe75..1dcac59f 100644 --- a/moshi_mlx/moshi_mlx/modules/quantization.py +++ b/moshi_mlx/moshi_mlx/modules/quantization.py @@ -29,13 +29,13 @@ def update(self, parameters: dict) -> nn.Module: def encode(self, xs: mx.array) -> mx.array: target_shape = xs.shape[:-1] - xs = xs.flatten(start_axis=-2) + xs = xs.flatten(end_axis=-2) dot_prod = xs @ self.embedding.swapaxes(-1, -2) - return (self.c2 - dot_prod).min(axis=-1).reshape(target_shape) + return (self.c2 - dot_prod).argmin(axis=-1).reshape(target_shape) def decode(self, xs: mx.array) -> mx.array: target_shape = list(xs.shape) + [self._dim] - return mx.take(self.embedding, xs.flatten()).reshape(target_shape) + return mx.take(self.embedding, xs.flatten(), axis=0).reshape(target_shape) class VectorQuantization(nn.Module): @@ -84,7 +84,7 @@ def encode(self, xs: mx.array) -> mx.array: quantized = layer.decode(indices) residual = residual - quantized codes.append(indices) - return mx.concat(codes, axis=0) + return mx.stack(codes, axis=0) def decode(self, xs: mx.array) -> mx.array: seq_len = xs.shape[0] @@ -125,7 +125,7 @@ def __init__( def encode(self, xs: mx.array) -> mx.array: if self.input_proj is not None: xs = self.input_proj(xs) - return self.vq.encode(xs).transpose(0, 1) + return self.vq.encode(xs).swapaxes(0, 1) def decode(self, xs: mx.array) -> mx.array: xs = xs.swapaxes(0, 1) @@ -166,7 +166,7 @@ def __init__( def encode(self, xs: mx.array) -> mx.array: codes = self.rvq_first.encode(xs) if self._nq > 1: - rest_codes = self.rvq_rest(xs) + rest_codes = self.rvq_rest.encode(xs) codes = mx.concat([codes, rest_codes], axis=1) return codes diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index c8b25b39..8825c2ec 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -22,7 +22,8 @@ def run(): weight_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") cfg = moshi_mlx.models.mimi.mimi_202407(16) mimi = moshi_mlx.models.mimi.Mimi(cfg) - mimi.encode(pcm_in) + codes = mimi.encode(pcm_in) + print(codes.shape) if __name__ == "__main__": run() From 61d36df973bfe04a9fa8a47f8db80e23785f24db Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 23:07:28 +0100 Subject: [PATCH 15/22] Again more mimi. --- moshi_mlx/moshi_mlx/modules/conv.py | 35 +++++++++++++++++++++++++++-- scripts/mimi_mlx.py | 2 ++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py index 2afb55a0..f23b70ec 100644 --- a/moshi_mlx/moshi_mlx/modules/conv.py +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -75,14 +75,45 @@ def __init__( self._padding = padding self._groups = groups self._stride = stride + self._ksize = ksize + self._in_channels = in_channels + self._out_channels = out_channels + if groups == in_channels and groups == out_channels: + eye = mx.eye(out_channels).astype(self.weight.dtype).reshape((out_channels, 1, out_channels)) + eye = mx.repeat(eye, repeats=ksize, axis=1) + self.expanded_weight = mx.repeat(self.weight, repeats=groups, axis=0) * eye + self.expanded_groups = 1 + elif groups > 1: + raise ValueError("groups are not supported in ConvTranspose1d") + else: + self.expanded_weight = self.weight + self.expanded_groups = groups + + def update(self, parameters: dict) -> nn.Module: + super().update(parameters) + groups = self._groups + in_channels = self._in_channels + out_channels = self._out_channels + ksize = self._ksize + if groups == in_channels and groups == out_channels: + eye = mx.eye(out_channels).astype(self.weight.dtype).reshape((out_channels, 1, out_channels)) + eye = mx.repeat(eye, repeats=ksize, axis=1) + self.expanded_weight = mx.repeat(self.weight, repeats=groups, axis=0) * eye + self.expanded_groups = 1 + elif groups > 1: + raise ValueError("groups are not supported in ConvTranspose1d") + else: + self.expanded_weight = self.weight + self.expanded_groups = groups + return self def __call__(self, xs: mx.array) -> mx.array: y = mx.conv_transpose1d( xs.swapaxes(-1, -2), - self.weight, + self.expanded_weight, stride=self._stride, padding=self._padding, - groups=self._groups, + groups=self.expanded_groups, ) if self.bias is not None: y = y + self.bias diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index 8825c2ec..a86944ae 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -24,6 +24,8 @@ def run(): mimi = moshi_mlx.models.mimi.Mimi(cfg) codes = mimi.encode(pcm_in) print(codes.shape) + pcm_out = mimi.decode(codes) + print(pcm_out.shape) if __name__ == "__main__": run() From 9db07c2c9fe6f54dd14a73e66b2500ec86964f23 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 21 Feb 2025 23:45:27 +0100 Subject: [PATCH 16/22] Get some roundtrip to work. --- moshi_mlx/moshi_mlx/models/mimi.py | 2 +- moshi_mlx/moshi_mlx/modules/conv.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 485c49ae..88cb9214 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -143,7 +143,7 @@ def decode(self, xs: mx.array) -> mx.array: xs = self.quantizer.decode(xs) xs = self.upsample(xs) xs = self.decoder_transformer(xs, cache=self.decoder_cache)[0] - return self.decoder(xs[0]) + return self.decoder(xs) def encode_step(self, _: mx.array) -> mx.array: raise ValueError("TODO") diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py index f23b70ec..d14c26d1 100644 --- a/moshi_mlx/moshi_mlx/modules/conv.py +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -117,7 +117,7 @@ def __call__(self, xs: mx.array) -> mx.array: ) if self.bias is not None: y = y + self.bias - return y + return y.swapaxes(-1, -2) class NormConv1d(nn.Module): @@ -350,4 +350,5 @@ def reset_state(self): self.convtr.reset_state() def __call__(self, xs: mx.array) -> mx.array: - return self.convtr(xs) + xs = self.convtr(xs) + return xs From 29d7fe8590a7ad2ad015535dd801f70c4deaf206 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 22 Feb 2025 08:39:48 +0100 Subject: [PATCH 17/22] Improve the weight loading. --- moshi_mlx/moshi_mlx/models/mimi.py | 9 +++++++++ moshi_mlx/moshi_mlx/modules/quantization.py | 22 ++++++++++----------- scripts/mimi_mlx.py | 12 +++++++---- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 88cb9214..8b6c1b23 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -156,3 +156,12 @@ def warmup(self): codes = self.encode(pcm) pcm_out = self.decode(codes) mx.eval(pcm_out) + + def load_weights(self, model_file: str, strict: bool) -> nn.Module: + weights = {} + for k, v in mx.load(model_file).items(): + clean_k = '.'.join([s.removeprefix('_') for s in k.split('.')]) + weights[clean_k] = v + return super().load_weights(weights, strict=strict) + + diff --git a/moshi_mlx/moshi_mlx/modules/quantization.py b/moshi_mlx/moshi_mlx/modules/quantization.py index 1dcac59f..597bf872 100644 --- a/moshi_mlx/moshi_mlx/modules/quantization.py +++ b/moshi_mlx/moshi_mlx/modules/quantization.py @@ -13,29 +13,29 @@ def __init__(self, dim: int, codebook_size: int): super().__init__() self._epsilon = 1e-5 self._dim = dim - self._initialized = mx.zeros([1], dtype=mx.float32) + self.initialized = mx.zeros([1], dtype=mx.float32) self.embedding_sum = mx.zeros([codebook_size, dim], dtype=mx.float32) self.cluster_usage = mx.zeros([codebook_size], dtype=mx.float32) cluster_usage = mx.maximum(self.cluster_usage, self._epsilon)[:, None] - self.embedding = self.embedding_sum / cluster_usage - self.c2 = self.embedding.square().sum(axis=-1) / 2 + self._embedding = self.embedding_sum / cluster_usage + self._c2 = self._embedding.square().sum(axis=-1) / 2 def update(self, parameters: dict) -> nn.Module: super().update(parameters) cluster_usage = mx.maximum(self.cluster_usage, self._epsilon)[:, None] - self.embedding = self.embedding_sum / cluster_usage - self.c2 = self.embedding.square().sum(axis=-1) / 2 + self._embedding = self.embedding_sum / cluster_usage + self._c2 = self._embedding.square().sum(axis=-1) / 2 return self def encode(self, xs: mx.array) -> mx.array: target_shape = xs.shape[:-1] xs = xs.flatten(end_axis=-2) - dot_prod = xs @ self.embedding.swapaxes(-1, -2) - return (self.c2 - dot_prod).argmin(axis=-1).reshape(target_shape) + dot_prod = xs @ self._embedding.swapaxes(-1, -2) + return (self._c2 - dot_prod).argmin(axis=-1).reshape(target_shape) def decode(self, xs: mx.array) -> mx.array: target_shape = list(xs.shape) + [self._dim] - return mx.take(self.embedding, xs.flatten(), axis=0).reshape(target_shape) + return mx.take(self._embedding, xs.flatten(), axis=0).reshape(target_shape) class VectorQuantization(nn.Module): @@ -48,16 +48,16 @@ def __init__(self, dim: int, codebook_size: int, codebook_dim: int | None): else: self.project_in = nn.Linear(dim, codebook_dim) self.project_out = nn.Linear(codebook_dim, dim) - self._codebook = EuclideanCodebook(dim=codebook_dim, codebook_size=codebook_size) + self.codebook = EuclideanCodebook(dim=codebook_dim, codebook_size=codebook_size) def encode(self, xs: mx.array) -> mx.array: xs = xs.swapaxes(-1, -2) if self.project_in is not None: xs = self.project_in(xs) - return self._codebook.encode(xs) + return self.codebook.encode(xs) def decode(self, xs: mx.array) -> mx.array: - xs = self._codebook.decode(xs) + xs = self.codebook.decode(xs) if self.project_out is not None: xs = self.project_out(xs) return xs.swapaxes(-1, -2) diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index a86944ae..292a9a94 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -19,12 +19,16 @@ def run(): pcm_in = mx.array(pcm_in[0])[None, None] print(pcm_in.shape) - weight_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") + model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") cfg = moshi_mlx.models.mimi.mimi_202407(16) - mimi = moshi_mlx.models.mimi.Mimi(cfg) - codes = mimi.encode(pcm_in) + model = moshi_mlx.models.mimi.Mimi(cfg) + print(f"loading weights {model_file}") + model.load_weights(model_file, strict=True) + print("weights loaded") + + codes = model.encode(pcm_in) print(codes.shape) - pcm_out = mimi.decode(codes) + pcm_out = model.decode(codes) print(pcm_out.shape) if __name__ == "__main__": From 84fff57b3cf6b8e2c314f56a62bfcec32648e08e Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 22 Feb 2025 09:10:58 +0100 Subject: [PATCH 18/22] Hacky import. --- moshi_mlx/moshi_mlx/models/mimi.py | 56 ++++++++++++++++++++++++++---- scripts/mimi_mlx.py | 2 +- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 8b6c1b23..7c849548 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -14,6 +14,7 @@ ConvTrUpsample1d, ) import math +import typing as tp import mlx.core as mx import mlx.nn as nn @@ -157,11 +158,52 @@ def warmup(self): pcm_out = self.decode(codes) mx.eval(pcm_out) - def load_weights(self, model_file: str, strict: bool) -> nn.Module: - weights = {} - for k, v in mx.load(model_file).items(): - clean_k = '.'.join([s.removeprefix('_') for s in k.split('.')]) - weights[clean_k] = v + def load_weights( + self, + file_or_weights: tp.Union[str, tp.List[tp.Tuple[str, mx.array]]], + strict: bool = True, + ) -> nn.Module: + if isinstance(file_or_weights, str): + weights = [] + for k, v in mx.load(file_or_weights).items(): + v: mx.array = v + k: str = '.'.join([s.removeprefix('_') for s in k.split('.')]) + if k.startswith("encoder.model."): + k = k.replace("encoder.model.", "encoder.") + if k.startswith("decoder.model."): + k = k.replace("decoder.model.", "decoder.") + if k.endswith(".in_proj_weight"): + k = k.replace(".in_proj_weight", ".in_proj.weight") + if k.endswith(".linear1.weight"): + k = k.replace(".linear1.weight", ".gating.linear1.weight") + if k.endswith(".linear2.weight"): + k = k.replace(".linear2.weight", ".gating.linear2.weight") + # Awfully hardcoded matching between the pytorch layers and their mlx equivalent :( + if k.startswith("decoder.6"): + print(k) + for layerIdx, decoderIdx in enumerate([2, 5, 8, 11]): + k = k.replace(f"decoder.{decoderIdx}.", f"decoder.layers.{layerIdx}.upsample.") + k = k.replace( + f"decoder.{decoderIdx + 1}.", f"decoder.layers.{layerIdx}.residuals.0.") + for (layerIdx, encoderIdx) in enumerate([1, 4, 7, 10]): + k = k.replace(f"encoder.{encoderIdx}.", f"encoder.layers.{layerIdx}.residuals.0.") + k = k.replace( + f"encoder.{encoderIdx + 2}.", f"encoder.layers.{layerIdx}.downsample.") + + k.replace("decoder.0.", "decoder.init_conv1d.") + k.replace("decoder.14.", "decoder.final_conv1d.") + k.replace("encoder.0.", "encoder.init_conv1d.") + k.replace("encoder.14.", "encoder.final_conv1d.") + k.replace(".block.1.", ".block.0.") + k.replace(".block.3.", ".block.1.") + + # PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC + if k.endswith(".conv.weight") or k.endswith(".output_proj.weight") or k.endswith(".input_proj.weight"): + v = v.swapaxes(-1, -2) + # PyTorch layout for conv-transposed weights is inC, outC, kSize, for MLX it's outC, kSize, inC + if k.endswith(".convtr.weight"): + v = v.transpose(1, 2, 0) + weights.append((k, v)) + else: + weights = file_or_weights return super().load_weights(weights, strict=strict) - - diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index 292a9a94..ebb0e7ae 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -20,7 +20,7 @@ def run(): print(pcm_in.shape) model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") - cfg = moshi_mlx.models.mimi.mimi_202407(16) + cfg = moshi_mlx.models.mimi.mimi_202407(32) model = moshi_mlx.models.mimi.Mimi(cfg) print(f"loading weights {model_file}") model.load_weights(model_file, strict=True) From d93739079f6222b5b6bc71f098261f78071d2368 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 22 Feb 2025 09:20:51 +0100 Subject: [PATCH 19/22] Get the weight loading to work. --- moshi_mlx/moshi_mlx/models/mimi.py | 14 ++++++-------- moshi_mlx/moshi_mlx/modules/conv.py | 20 ++++++++++---------- moshi_mlx/moshi_mlx/modules/seanet.py | 12 +++++++++++- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 7c849548..715331e1 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -179,8 +179,6 @@ def load_weights( if k.endswith(".linear2.weight"): k = k.replace(".linear2.weight", ".gating.linear2.weight") # Awfully hardcoded matching between the pytorch layers and their mlx equivalent :( - if k.startswith("decoder.6"): - print(k) for layerIdx, decoderIdx in enumerate([2, 5, 8, 11]): k = k.replace(f"decoder.{decoderIdx}.", f"decoder.layers.{layerIdx}.upsample.") k = k.replace( @@ -190,12 +188,12 @@ def load_weights( k = k.replace( f"encoder.{encoderIdx + 2}.", f"encoder.layers.{layerIdx}.downsample.") - k.replace("decoder.0.", "decoder.init_conv1d.") - k.replace("decoder.14.", "decoder.final_conv1d.") - k.replace("encoder.0.", "encoder.init_conv1d.") - k.replace("encoder.14.", "encoder.final_conv1d.") - k.replace(".block.1.", ".block.0.") - k.replace(".block.3.", ".block.1.") + k = k.replace("decoder.0.", "decoder.init_conv1d.") + k = k.replace("decoder.14.", "decoder.final_conv1d.") + k = k.replace("encoder.0.", "encoder.init_conv1d.") + k = k.replace("encoder.14.", "encoder.final_conv1d.") + k = k.replace(".block.1.", ".block.0.") + k = k.replace(".block.3.", ".block.1.") # PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC if k.endswith(".conv.weight") or k.endswith(".output_proj.weight") or k.endswith(".input_proj.weight"): diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py index d14c26d1..47b22d4b 100644 --- a/moshi_mlx/moshi_mlx/modules/conv.py +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -81,13 +81,13 @@ def __init__( if groups == in_channels and groups == out_channels: eye = mx.eye(out_channels).astype(self.weight.dtype).reshape((out_channels, 1, out_channels)) eye = mx.repeat(eye, repeats=ksize, axis=1) - self.expanded_weight = mx.repeat(self.weight, repeats=groups, axis=0) * eye - self.expanded_groups = 1 + self._expanded_weight = mx.repeat(self.weight, repeats=groups, axis=0) * eye + self._expanded_groups = 1 elif groups > 1: raise ValueError("groups are not supported in ConvTranspose1d") else: - self.expanded_weight = self.weight - self.expanded_groups = groups + self._expanded_weight = self.weight + self._expanded_groups = groups def update(self, parameters: dict) -> nn.Module: super().update(parameters) @@ -98,22 +98,22 @@ def update(self, parameters: dict) -> nn.Module: if groups == in_channels and groups == out_channels: eye = mx.eye(out_channels).astype(self.weight.dtype).reshape((out_channels, 1, out_channels)) eye = mx.repeat(eye, repeats=ksize, axis=1) - self.expanded_weight = mx.repeat(self.weight, repeats=groups, axis=0) * eye - self.expanded_groups = 1 + self._expanded_weight = mx.repeat(self.weight, repeats=groups, axis=0) * eye + self._expanded_groups = 1 elif groups > 1: raise ValueError("groups are not supported in ConvTranspose1d") else: - self.expanded_weight = self.weight - self.expanded_groups = groups + self._expanded_weight = self.weight + self._expanded_groups = groups return self def __call__(self, xs: mx.array) -> mx.array: y = mx.conv_transpose1d( xs.swapaxes(-1, -2), - self.expanded_weight, + self._expanded_weight, stride=self._stride, padding=self._padding, - groups=self.expanded_groups, + groups=self._expanded_groups, ) if self.bias is not None: y = y + self.bias diff --git a/moshi_mlx/moshi_mlx/modules/seanet.py b/moshi_mlx/moshi_mlx/modules/seanet.py index 3f7781c2..2a7f91b3 100644 --- a/moshi_mlx/moshi_mlx/modules/seanet.py +++ b/moshi_mlx/moshi_mlx/modules/seanet.py @@ -176,7 +176,17 @@ def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): bias=True, causal=cfg.causal, ) - self.residuals = [] + residuals = [] + dilation = 1 + for _ in range(cfg.nresidual_layers): + r = SeanetResnetBlock( + cfg, + dim=mult * cfg.nfilters // 2, + ksizes_and_dilations=[(cfg.residual_ksize, dilation), (1, 1)], + ) + residuals.append(r) + dilation *= cfg.dilation_base + self.residuals = residuals def reset_state(self): self.upsample.reset_state() From 746cc3d1a49a8e715b1a2bf39227d2bffab9529e Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 22 Feb 2025 09:26:50 +0100 Subject: [PATCH 20/22] Write the generated file. --- scripts/mimi_mlx.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index ebb0e7ae..7a58cd95 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -4,6 +4,7 @@ import argparse from huggingface_hub import hf_hub_download +import numpy as np import mlx.core as mx import sphn import moshi_mlx @@ -30,6 +31,7 @@ def run(): print(codes.shape) pcm_out = model.decode(codes) print(pcm_out.shape) + sphn.write_wav("out.wav", np.array(pcm_out[0]), sample_rate=24000) if __name__ == "__main__": run() From e9cfc7622c675a2a34c877af32c7a6c8c010e016 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 22 Feb 2025 09:34:58 +0100 Subject: [PATCH 21/22] Rename the hacky weight loader. --- moshi_mlx/moshi_mlx/models/mimi.py | 85 ++++++++++++++---------------- scripts/mimi_mlx.py | 2 +- 2 files changed, 42 insertions(+), 45 deletions(-) diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 715331e1..c7b3f2a9 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -158,50 +158,47 @@ def warmup(self): pcm_out = self.decode(codes) mx.eval(pcm_out) - def load_weights( + def load_pytorch_weights( self, - file_or_weights: tp.Union[str, tp.List[tp.Tuple[str, mx.array]]], + file: str, strict: bool = True, ) -> nn.Module: - if isinstance(file_or_weights, str): - weights = [] - for k, v in mx.load(file_or_weights).items(): - v: mx.array = v - k: str = '.'.join([s.removeprefix('_') for s in k.split('.')]) - if k.startswith("encoder.model."): - k = k.replace("encoder.model.", "encoder.") - if k.startswith("decoder.model."): - k = k.replace("decoder.model.", "decoder.") - if k.endswith(".in_proj_weight"): - k = k.replace(".in_proj_weight", ".in_proj.weight") - if k.endswith(".linear1.weight"): - k = k.replace(".linear1.weight", ".gating.linear1.weight") - if k.endswith(".linear2.weight"): - k = k.replace(".linear2.weight", ".gating.linear2.weight") - # Awfully hardcoded matching between the pytorch layers and their mlx equivalent :( - for layerIdx, decoderIdx in enumerate([2, 5, 8, 11]): - k = k.replace(f"decoder.{decoderIdx}.", f"decoder.layers.{layerIdx}.upsample.") - k = k.replace( - f"decoder.{decoderIdx + 1}.", f"decoder.layers.{layerIdx}.residuals.0.") - for (layerIdx, encoderIdx) in enumerate([1, 4, 7, 10]): - k = k.replace(f"encoder.{encoderIdx}.", f"encoder.layers.{layerIdx}.residuals.0.") - k = k.replace( - f"encoder.{encoderIdx + 2}.", f"encoder.layers.{layerIdx}.downsample.") - - k = k.replace("decoder.0.", "decoder.init_conv1d.") - k = k.replace("decoder.14.", "decoder.final_conv1d.") - k = k.replace("encoder.0.", "encoder.init_conv1d.") - k = k.replace("encoder.14.", "encoder.final_conv1d.") - k = k.replace(".block.1.", ".block.0.") - k = k.replace(".block.3.", ".block.1.") - - # PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC - if k.endswith(".conv.weight") or k.endswith(".output_proj.weight") or k.endswith(".input_proj.weight"): - v = v.swapaxes(-1, -2) - # PyTorch layout for conv-transposed weights is inC, outC, kSize, for MLX it's outC, kSize, inC - if k.endswith(".convtr.weight"): - v = v.transpose(1, 2, 0) - weights.append((k, v)) - else: - weights = file_or_weights - return super().load_weights(weights, strict=strict) + weights = [] + for k, v in mx.load(file).items(): + v: mx.array = v + k: str = '.'.join([s.removeprefix('_') for s in k.split('.')]) + if k.startswith("encoder.model."): + k = k.replace("encoder.model.", "encoder.") + if k.startswith("decoder.model."): + k = k.replace("decoder.model.", "decoder.") + if k.endswith(".in_proj_weight"): + k = k.replace(".in_proj_weight", ".in_proj.weight") + if k.endswith(".linear1.weight"): + k = k.replace(".linear1.weight", ".gating.linear1.weight") + if k.endswith(".linear2.weight"): + k = k.replace(".linear2.weight", ".gating.linear2.weight") + # Awfully hardcoded matching between the pytorch layers and their mlx equivalent :( + for layerIdx, decoderIdx in enumerate([2, 5, 8, 11]): + k = k.replace(f"decoder.{decoderIdx}.", f"decoder.layers.{layerIdx}.upsample.") + k = k.replace( + f"decoder.{decoderIdx + 1}.", f"decoder.layers.{layerIdx}.residuals.0.") + for (layerIdx, encoderIdx) in enumerate([1, 4, 7, 10]): + k = k.replace(f"encoder.{encoderIdx}.", f"encoder.layers.{layerIdx}.residuals.0.") + k = k.replace( + f"encoder.{encoderIdx + 2}.", f"encoder.layers.{layerIdx}.downsample.") + + k = k.replace("decoder.0.", "decoder.init_conv1d.") + k = k.replace("decoder.14.", "decoder.final_conv1d.") + k = k.replace("encoder.0.", "encoder.init_conv1d.") + k = k.replace("encoder.14.", "encoder.final_conv1d.") + k = k.replace(".block.1.", ".block.0.") + k = k.replace(".block.3.", ".block.1.") + + # PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC + if k.endswith(".conv.weight") or k.endswith(".output_proj.weight") or k.endswith(".input_proj.weight"): + v = v.swapaxes(-1, -2) + # PyTorch layout for conv-transposed weights is inC, outC, kSize, for MLX it's outC, kSize, inC + if k.endswith(".convtr.weight"): + v = v.transpose(1, 2, 0) + weights.append((k, v)) + return self.load_weights(weights, strict=strict) diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index 7a58cd95..5ebaaff6 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -24,7 +24,7 @@ def run(): cfg = moshi_mlx.models.mimi.mimi_202407(32) model = moshi_mlx.models.mimi.Mimi(cfg) print(f"loading weights {model_file}") - model.load_weights(model_file, strict=True) + model.load_pytorch_weights(model_file, strict=True) print("weights loaded") codes = model.encode(pcm_in) From 5d938b6dad6766d5e97e226b0e490934e78eee38 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 22 Feb 2025 09:37:33 +0100 Subject: [PATCH 22/22] Fix the pre-commit issue. --- moshi_mlx/moshi_mlx/models/mimi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index c7b3f2a9..6d9b7043 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -14,7 +14,6 @@ ConvTrUpsample1d, ) import math -import typing as tp import mlx.core as mx import mlx.nn as nn