Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming implementation for mimi on mlx #231

Merged
merged 5 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions moshi_mlx/moshi_mlx/models/mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def __init__(self, cfg: MimiConfig):
def reset_state(self):
self.encoder.reset_state()
self.decoder.reset_state()
for c in self.decoder_cache:
c.reset()
for c in self.encoder_cache:
c.reset()

def encode(self, xs: mx.array) -> mx.array:
self.encoder.reset_state()
Expand All @@ -145,11 +149,19 @@ def decode(self, xs: mx.array) -> mx.array:
xs = self.decoder_transformer(xs, cache=self.decoder_cache)[0]
return self.decoder(xs)

def encode_step(self, _: mx.array) -> mx.array:
raise ValueError("TODO")
def encode_step(self, xs: mx.array) -> mx.array:
xs = self.encoder.step(xs)
xs = self.encoder_transformer(xs, cache=self.encoder_cache)[0]
xs = self.downsample.step(xs)
xs = self.quantizer.encode(xs)
return xs

def decode_step(self, _: mx.array) -> mx.array:
raise ValueError("TODO")
def decode_step(self, xs: mx.array) -> mx.array:
xs = self.quantizer.decode(xs)
xs = self.upsample.step(xs)
xs = self.decoder_transformer(xs, cache=self.decoder_cache)[0]
xs = self.decoder.step(xs)
return xs

def warmup(self):
pcm = mx.zeros((1, 1, 1920 * 4))
Expand Down
44 changes: 42 additions & 2 deletions moshi_mlx/moshi_mlx/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,35 @@ def reset_state(self):
self._left_pad_applied = False

def __call__(self, xs: mx.array) -> mx.array:
ksize = self._ksize
ksize = (ksize - 1) * self.conv.conv._dilation + 1
padding_total = ksize - self.conv.conv._stride
extra_padding = get_extra_padding_for_conv1d(
xs,
ksize=ksize,
stride=self.conv.conv._stride,
padding_total=padding_total,
)
z = 0, 0
if self._causal:
padding_left = padding_total
padding_right = 0
else:
padding_right = padding_total // 2
padding_left = padding_total - padding_right
widths = [z, z, (padding_left, padding_right + extra_padding)]
pd = mx.pad(xs, pad_width=widths, mode=self._pad_mode)
return self.conv(pd)

def step(self, xs: mx.array) -> mx.array:
b, _, len_ = xs.shape
if len_ == 0:
return mx.zeros((b, self._out_channels, 0))
stride = self.conv.conv._stride
dilation = self.conv.conv._dilation
ksize = (self._ksize - 1) * dilation + 1
if not self._left_pad_applied:
self._left_pad_applied
self._left_pad_applied = True
padding_total = ksize - stride
xs = mx.pad(
xs,
Expand Down Expand Up @@ -285,6 +306,18 @@ def reset_state(self):
self._prev_ys = None

def __call__(self, xs: mx.array) -> mx.array:
stride = self.convtr.convtr._stride
padding_total = max(self._ksize - stride, 0)
xs = self.convtr(xs)
if self._causal:
unpad_l = 0
unpad_r = padding_total
else:
unpad_r = padding_total // 2
unpad_l = padding_total - unpad_r
return unpad1d(xs, unpad_l=unpad_l, unpad_r=unpad_r)

def step(self, xs: mx.array) -> mx.array:
b, _, len_ = xs.shape
if len_ == 0:
return mx.zeros((b, self._out_channels, 0))
Expand All @@ -299,7 +332,7 @@ def __call__(self, xs: mx.array) -> mx.array:
ys1, ys2 = ys[..., :pt] + prev_ys, ys[..., pt:]
ys = mx.concat([ys1, ys2], axis=-1)
invalid_steps = self._ksize - stride
ys, self._prev_ys = ys[..., :ot - invalid_steps], ys[..., ot - invalid_steps]
ys, self._prev_ys = ys[..., :ot - invalid_steps], ys[..., ot - invalid_steps:]
return ys


Expand Down Expand Up @@ -328,6 +361,9 @@ def reset_state(self):
def __call__(self, xs: mx.array) -> mx.array:
return self.conv(xs)

def step(self, xs: mx.array) -> mx.array:
return self.conv.step(xs)


class ConvTrUpsample1d(nn.Module):
def __init__(
Expand All @@ -352,3 +388,7 @@ def reset_state(self):
def __call__(self, xs: mx.array) -> mx.array:
xs = self.convtr(xs)
return xs

def step(self, xs: mx.array) -> mx.array:
xs = self.convtr.step(xs)
return xs
62 changes: 61 additions & 1 deletion moshi_mlx/moshi_mlx/modules/seanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ class SeanetConfig:
compress: int


class StreamingAdd(nn.Module):
def __init__(self):
super().__init__()
self._lhs = None
self._rhs = None

def step(self, lhs: mx.array, rhs: mx.array) -> mx.array:
if self._lhs is not None:
lhs = mx.concat([self._lhs, lhs], axis=-1)
self._lhs = None
if self._rhs is not None:
rhs = mx.concat([self._rhs, rhs], axis=-1)
self._rhs = None
lhs_l = lhs.shape[-1]
rhs_l = rhs.shape[-1]
if lhs_l == rhs_l:
return lhs + rhs
elif lhs_l < rhs_l:
self._rhs = rhs[..., lhs_l:]
return lhs + rhs[..., :lhs_l]
else:
self._lhs = lhs[..., rhs_l:]
return lhs[..., :rhs_l] + rhs


class SeanetResnetBlock(nn.Module):
def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list):
super().__init__()
Expand All @@ -47,6 +72,7 @@ def __init__(self, cfg: SeanetConfig, dim: int, ksizes_and_dilations: list):
)
block.append(c)
self.block = block
self.streaming_add = StreamingAdd()

if cfg.true_skip:
self.shortcut = None
Expand All @@ -73,13 +99,22 @@ def __call__(self, xs: mx.array) -> mx.array:
residual = xs
for b in self.block:
xs = b(nn.elu(xs, alpha=1.0))
# TODO(laurent): we might need some streaming additions below.
if self.shortcut is None:
xs = xs + residual
else:
xs = xs + self.shortcut(residual)
return xs

def step(self, xs: mx.array) -> mx.array:
residual = xs
for b in self.block:
xs = b.step(nn.elu(xs, alpha=1.0))
if self.shortcut is None:
xs = self.streaming_add.step(xs, residual)
else:
xs = self.streaming_add.step(xs, self.shortcut.step(residual))
return xs


class EncoderLayer(nn.Module):
def __init__(self, cfg: SeanetConfig, ratio: int, mult: int):
Expand Down Expand Up @@ -117,6 +152,11 @@ def __call__(self, xs: mx.array) -> mx.array:
xs = r(xs)
return self.downsample(nn.elu(xs, alpha=1.0))

def step(self, xs: mx.array) -> mx.array:
for r in self.residuals:
xs = r.step(xs)
return self.downsample.step(nn.elu(xs, alpha=1.0))


class SeanetEncoder(nn.Module):
def __init__(self, cfg: SeanetConfig):
Expand Down Expand Up @@ -163,6 +203,13 @@ def __call__(self, xs: mx.array) -> mx.array:
xs = nn.elu(xs, alpha=1.0)
return self.final_conv1d(xs)

def step(self, xs: mx.array) -> mx.array:
xs = self.init_conv1d.step(xs)
for layer in self.layers:
xs = layer.step(xs)
xs = nn.elu(xs, alpha=1.0)
return self.final_conv1d.step(xs)


class DecoderLayer(nn.Module):
def __init__(self, cfg: SeanetConfig, ratio: int, mult: int):
Expand Down Expand Up @@ -199,6 +246,12 @@ def __call__(self, xs: mx.array) -> mx.array:
xs = r(xs)
return xs

def step(self, xs: mx.array) -> mx.array:
xs = self.upsample.step(nn.elu(xs, alpha=1.0))
for r in self.residuals:
xs = r.step(xs)
return xs


class SeanetDecoder(nn.Module):
def __init__(self, cfg: SeanetConfig):
Expand Down Expand Up @@ -245,6 +298,13 @@ def __call__(self, xs: mx.array) -> mx.array:
xs = nn.elu(xs, alpha=1.0)
return self.final_conv1d(xs)

def step(self, xs: mx.array) -> mx.array:
xs = self.init_conv1d.step(xs)
for layer in self.layers:
xs = layer.step(xs)
xs = nn.elu(xs, alpha=1.0)
return self.final_conv1d.step(xs)


class Seanet(nn.Module):
def __init__(self, cfg: SeanetConfig):
Expand Down
36 changes: 30 additions & 6 deletions scripts/mimi_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,48 @@
def run():
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str)
parser.add_argument("--model-file", type=str)
parser.add_argument("--hf-repo", type=str, default="kyutai/moshiko-mlx-q4")
parser.add_argument("--streaming", action="store_true")
args = parser.parse_args()

pcm_in, _ = sphn.read(args.input, sample_rate=24000)
pcm_in = mx.array(pcm_in[0])[None, None]
print(pcm_in.shape)

model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors")
if args.model_file is None:
model_file = hf_hub_download(args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors")
else:
model_file = args.model_file
cfg = moshi_mlx.models.mimi.mimi_202407(32)
print("building model", flush=True)
model = moshi_mlx.models.mimi.Mimi(cfg)
print(f"loading weights {model_file}")
print(f"loading weights {model_file}", flush=True)
model.load_pytorch_weights(model_file, strict=True)
print("weights loaded")

codes = model.encode(pcm_in)
print(codes.shape)
pcm_out = model.decode(codes)
print(pcm_out.shape)
if args.streaming:
chunk_size = 1920
pcm_out = []
len_ = pcm_in.shape[-1]
print("starting streaming conversion")
for start_idx in range(0, len_, chunk_size):
end_idx = start_idx + chunk_size
if end_idx >= len_:
break
_pcm_in = pcm_in[..., start_idx:end_idx]
codes = model.encode_step(_pcm_in)
_pcm_out = model.decode_step(codes)
pcm_out.append(_pcm_out)
pct = int(100 * start_idx / len_)
print(f"{pct}%", end="\r", flush=True)
print()
pcm_out = mx.concat(pcm_out, axis=-1)
else:
codes = model.encode(pcm_in)
print(codes.shape)
pcm_out = model.decode(codes)
print("writing output file with audio shape", pcm_out.shape)
sphn.write_wav("out.wav", np.array(pcm_out[0]), sample_rate=24000)

if __name__ == "__main__":
Expand Down
Loading