diff --git a/moshi_mlx/moshi_mlx/models/mimi.py b/moshi_mlx/moshi_mlx/models/mimi.py index 6d9b704..0a9400d 100644 --- a/moshi_mlx/moshi_mlx/models/mimi.py +++ b/moshi_mlx/moshi_mlx/models/mimi.py @@ -126,6 +126,10 @@ def __init__(self, cfg: MimiConfig): def reset_state(self): self.encoder.reset_state() self.decoder.reset_state() + for c in self.decoder_cache: + c.reset() + for c in self.encoder_cache: + c.reset() def encode(self, xs: mx.array) -> mx.array: self.encoder.reset_state() @@ -145,11 +149,19 @@ def decode(self, xs: mx.array) -> mx.array: xs = self.decoder_transformer(xs, cache=self.decoder_cache)[0] return self.decoder(xs) - def encode_step(self, _: mx.array) -> mx.array: - raise ValueError("TODO") + def encode_step(self, xs: mx.array) -> mx.array: + xs = self.encoder.step(xs) + xs = self.encoder_transformer(xs, cache=self.encoder_cache)[0] + xs = self.downsample.step(xs) + xs = self.quantizer.encode(xs) + return xs - def decode_step(self, _: mx.array) -> mx.array: - raise ValueError("TODO") + def decode_step(self, xs: mx.array) -> mx.array: + xs = self.quantizer.decode(xs) + xs = self.upsample.step(xs) + xs = self.decoder_transformer(xs, cache=self.decoder_cache)[0] + xs = self.decoder.step(xs) + return xs def warmup(self): pcm = mx.zeros((1, 1, 1920 * 4)) diff --git a/moshi_mlx/moshi_mlx/modules/conv.py b/moshi_mlx/moshi_mlx/modules/conv.py index 47b22d4..7833160 100644 --- a/moshi_mlx/moshi_mlx/modules/conv.py +++ b/moshi_mlx/moshi_mlx/modules/conv.py @@ -226,6 +226,27 @@ def reset_state(self): self._left_pad_applied = False def __call__(self, xs: mx.array) -> mx.array: + ksize = self._ksize + ksize = (ksize - 1) * self.conv.conv._dilation + 1 + padding_total = ksize - self.conv.conv._stride + extra_padding = get_extra_padding_for_conv1d( + xs, + ksize=ksize, + stride=self.conv.conv._stride, + padding_total=padding_total, + ) + z = 0, 0 + if self._causal: + padding_left = padding_total + padding_right = 0 + else: + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + widths = [z, z, (padding_left, padding_right + extra_padding)] + pd = mx.pad(xs, pad_width=widths, mode=self._pad_mode) + return self.conv(pd) + + def step(self, xs: mx.array) -> mx.array: b, _, len_ = xs.shape if len_ == 0: return mx.zeros((b, self._out_channels, 0)) @@ -233,7 +254,7 @@ def __call__(self, xs: mx.array) -> mx.array: dilation = self.conv.conv._dilation ksize = (self._ksize - 1) * dilation + 1 if not self._left_pad_applied: - self._left_pad_applied + self._left_pad_applied = True padding_total = ksize - stride xs = mx.pad( xs, @@ -285,6 +306,18 @@ def reset_state(self): self._prev_ys = None def __call__(self, xs: mx.array) -> mx.array: + stride = self.convtr.convtr._stride + padding_total = max(self._ksize - stride, 0) + xs = self.convtr(xs) + if self._causal: + unpad_l = 0 + unpad_r = padding_total + else: + unpad_r = padding_total // 2 + unpad_l = padding_total - unpad_r + return unpad1d(xs, unpad_l=unpad_l, unpad_r=unpad_r) + + def step(self, xs: mx.array) -> mx.array: b, _, len_ = xs.shape if len_ == 0: return mx.zeros((b, self._out_channels, 0)) @@ -299,7 +332,7 @@ def __call__(self, xs: mx.array) -> mx.array: ys1, ys2 = ys[..., :pt] + prev_ys, ys[..., pt:] ys = mx.concat([ys1, ys2], axis=-1) invalid_steps = self._ksize - stride - ys, self._prev_ys = ys[..., :ot - invalid_steps], ys[..., ot - invalid_steps] + ys, self._prev_ys = ys[..., :ot - invalid_steps], ys[..., ot - invalid_steps:] return ys @@ -328,6 +361,9 @@ def reset_state(self): def __call__(self, xs: mx.array) -> mx.array: return self.conv(xs) + def step(self, xs: mx.array) -> mx.array: + return self.conv.step(xs) + class ConvTrUpsample1d(nn.Module): def __init__( @@ -352,3 +388,7 @@ def reset_state(self): def __call__(self, xs: mx.array) -> mx.array: xs = self.convtr(xs) return xs + + def step(self, xs: mx.array) -> mx.array: + xs = self.convtr.step(xs) + return xs diff --git a/moshi_mlx/moshi_mlx/modules/seanet.py b/moshi_mlx/moshi_mlx/modules/seanet.py index 2a7f91b..9b71b97 100644 --- a/moshi_mlx/moshi_mlx/modules/seanet.py +++ b/moshi_mlx/moshi_mlx/modules/seanet.py @@ -26,6 +26,31 @@ class SeanetConfig: compress: int +class StreamingAdd(nn.Module): + def __init__(self): + super().__init__() + self._lhs = None + self._rhs = None + + def step(self, lhs: mx.array, rhs: mx.array) -> mx.array: + if self._lhs is not None: + lhs = mx.concat([self._lhs, lhs], axis=-1) + self._lhs = None + if self._rhs is not None: + rhs = mx.concat([self._rhs, rhs], axis=-1) + self._rhs = None + lhs_l = lhs.shape[-1] + rhs_l = rhs.shape[-1] + if lhs_l == rhs_l: + return lhs + rhs + elif lhs_l < rhs_l: + self._rhs = rhs[..., lhs_l:] + return lhs + rhs[..., :lhs_l] + else: + self._lhs = lhs[..., rhs_l:] + return lhs[..., :rhs_l] + rhs + + class SeanetResnetBlock(nn.Module): def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list): super().__init__() @@ -47,6 +72,7 @@ def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list): ) block.append(c) self.block = block + self.streaming_add = StreamingAdd() if cfg.true_skip: self.shortcut = None @@ -73,13 +99,22 @@ def __call__(self, xs: mx.array) -> mx.array: residual = xs for b in self.block: xs = b(nn.elu(xs, alpha=1.0)) - # TODO(laurent): we might need some streaming additions below. if self.shortcut is None: xs = xs + residual else: xs = xs + self.shortcut(residual) return xs + def step(self, xs: mx.array) -> mx.array: + residual = xs + for b in self.block: + xs = b.step(nn.elu(xs, alpha=1.0)) + if self.shortcut is None: + xs = self.streaming_add.step(xs, residual) + else: + xs = self.streaming_add.step(xs, self.shortcut.step(residual)) + return xs + class EncoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): @@ -117,6 +152,11 @@ def __call__(self, xs: mx.array) -> mx.array: xs = r(xs) return self.downsample(nn.elu(xs, alpha=1.0)) + def step(self, xs: mx.array) -> mx.array: + for r in self.residuals: + xs = r.step(xs) + return self.downsample.step(nn.elu(xs, alpha=1.0)) + class SeanetEncoder(nn.Module): def __init__(self, cfg: SeanetConfig): @@ -163,6 +203,13 @@ def __call__(self, xs: mx.array) -> mx.array: xs = nn.elu(xs, alpha=1.0) return self.final_conv1d(xs) + def step(self, xs: mx.array) -> mx.array: + xs = self.init_conv1d.step(xs) + for layer in self.layers: + xs = layer.step(xs) + xs = nn.elu(xs, alpha=1.0) + return self.final_conv1d.step(xs) + class DecoderLayer(nn.Module): def __init__(self, cfg: SeanetConfig, ratio: int, mult: int): @@ -199,6 +246,12 @@ def __call__(self, xs: mx.array) -> mx.array: xs = r(xs) return xs + def step(self, xs: mx.array) -> mx.array: + xs = self.upsample.step(nn.elu(xs, alpha=1.0)) + for r in self.residuals: + xs = r.step(xs) + return xs + class SeanetDecoder(nn.Module): def __init__(self, cfg: SeanetConfig): @@ -245,6 +298,13 @@ def __call__(self, xs: mx.array) -> mx.array: xs = nn.elu(xs, alpha=1.0) return self.final_conv1d(xs) + def step(self, xs: mx.array) -> mx.array: + xs = self.init_conv1d.step(xs) + for layer in self.layers: + xs = layer.step(xs) + xs = nn.elu(xs, alpha=1.0) + return self.final_conv1d.step(xs) + class Seanet(nn.Module): def __init__(self, cfg: SeanetConfig): diff --git a/scripts/mimi_mlx.py b/scripts/mimi_mlx.py index 5ebaaff..7edd388 100644 --- a/scripts/mimi_mlx.py +++ b/scripts/mimi_mlx.py @@ -13,24 +13,48 @@ def run(): parser = argparse.ArgumentParser() parser.add_argument("--input", type=str) + parser.add_argument("--model-file", type=str) parser.add_argument("--hf-repo", type=str, default="kyutai/moshiko-mlx-q4") + parser.add_argument("--streaming", action="store_true") args = parser.parse_args() pcm_in, _ = sphn.read(args.input, sample_rate=24000) pcm_in = mx.array(pcm_in[0])[None, None] print(pcm_in.shape) - model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") + if args.model_file is None: + model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors") + else: + model_file = args.model_file cfg = moshi_mlx.models.mimi.mimi_202407(32) + print("building model", flush=True) model = moshi_mlx.models.mimi.Mimi(cfg) - print(f"loading weights {model_file}") + print(f"loading weights {model_file}", flush=True) model.load_pytorch_weights(model_file, strict=True) print("weights loaded") - codes = model.encode(pcm_in) - print(codes.shape) - pcm_out = model.decode(codes) - print(pcm_out.shape) + if args.streaming: + chunk_size = 1920 + pcm_out = [] + len_ = pcm_in.shape[-1] + print("starting streaming conversion") + for start_idx in range(0, len_, chunk_size): + end_idx = start_idx + chunk_size + if end_idx >= len_: + break + _pcm_in = pcm_in[..., start_idx:end_idx] + codes = model.encode_step(_pcm_in) + _pcm_out = model.decode_step(codes) + pcm_out.append(_pcm_out) + pct = int(100 * start_idx / len_) + print(f"{pct}%", end="\r", flush=True) + print() + pcm_out = mx.concat(pcm_out, axis=-1) + else: + codes = model.encode(pcm_in) + print(codes.shape) + pcm_out = model.decode(codes) + print("writing output file with audio shape", pcm_out.shape) sphn.write_wav("out.wav", np.array(pcm_out[0]), sample_rate=24000) if __name__ == "__main__":