-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathcfm.py
520 lines (409 loc) · 14.9 KB
/
cfm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from pathlib import Path
from typing import Callable, Literal
import mlx.core as mx
import mlx.nn as nn
from einops.array_api import rearrange, repeat
from vocos_mlx import Vocos
from f5_tts_mlx.audio import MelSpec
from f5_tts_mlx.duration import DurationPredictor, DurationTransformer
from f5_tts_mlx.dit import DiT
from f5_tts_mlx.utils import (
exists,
fetch_from_hub,
default,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
mask_from_frac_lengths,
pad_sequence,
)
# ode solvers
def odeint_euler(func, y0, t):
"""
Solves ODE using the Euler method.
Parameters:
- func: Function representing the ODE, with signature func(t, y).
- y0: Initial state, an MLX array of any shape.
- t: Array of time steps, an MLX array.
"""
ys = [y0]
y_current = y0
for i in range(len(t) - 1):
t_current = t[i]
dt = t[i + 1] - t_current
# compute the next value
k = func(t_current, y_current)
y_next = y_current + dt * k
ys.append(y_next)
y_current = y_next
return mx.stack(ys)
def odeint_midpoint(func, y0, t):
"""
Solves ODE using the midpoint method.
Parameters:
- func: Function representing the ODE, with signature func(t, y).
- y0: Initial state, an MLX array of any shape.
- t: Array of time steps, an MLX array.
"""
ys = [y0]
y_current = y0
for i in range(len(t) - 1):
t_current = t[i]
dt = t[i + 1] - t_current
# midpoint approximation
k1 = func(t_current, y_current)
mid = y_current + 0.5 * dt * k1
# compute the next value
k2 = func(t_current + 0.5 * dt, mid)
y_next = y_current + dt * k2
ys.append(y_next)
y_current = y_next
return mx.stack(ys)
def odeint_rk4(func, y0, t):
"""
Solves ODE using the Runge-Kutta 4th-order (RK4) method.
Parameters:
- func: Function representing the ODE, with signature func(t, y).
- y0: Initial state, an MLX array of any shape.
- t: Array of time steps, an MLX array.
"""
ys = [y0]
y_current = y0
for i in range(len(t) - 1):
t_current = t[i]
dt = t[i + 1] - t_current
# rk4 steps
k1 = func(t_current, y_current)
k2 = func(t_current + 0.5 * dt, y_current + 0.5 * dt * k1)
k3 = func(t_current + 0.5 * dt, y_current + 0.5 * dt * k2)
k4 = func(t_current + dt, y_current + dt * k3)
# compute the next value
y_next = y_current + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)
ys.append(y_next)
y_current = y_next
return mx.stack(ys)
# conditional flow matching
class F5TTS(nn.Module):
def __init__(
self,
transformer: nn.Module,
audio_drop_prob=0.3,
cond_drop_prob=0.2,
num_channels=None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
vocab_char_map: dict[str, int] | None = None,
vocoder: Callable[[mx.array["b d n"]], mx.array["b nw"]] | None = None,
duration_predictor: DurationPredictor | None = None,
):
super().__init__()
self.frac_lengths_mask = frac_lengths_mask
# mel spec
self._mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = default(num_channels, self._mel_spec.n_mels)
self.num_channels = num_channels
# classifier-free guidance
self.audio_drop_prob = audio_drop_prob
self.cond_drop_prob = cond_drop_prob
# transformer
self.transformer = transformer
dim = transformer.dim
self.dim = dim
# vocab map for tokenization
self._vocab_char_map = vocab_char_map
# vocoder (optional)
self._vocoder = vocoder
# duration predictor (optional)
self._duration_predictor = duration_predictor
def __call__(
self,
inp: mx.array["b n d"] | mx.array["b nw"], # mel or raw wave
text: mx.array["b nt"] | list[str],
*,
lens: mx.array["b"] | None = None,
) -> mx.array:
# handle raw wave
if inp.ndim == 2:
inp = self._mel_spec(inp)
inp = rearrange(inp, "b d n -> b n d")
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype = *inp.shape[:2], inp.dtype
# handle text as string
if isinstance(text, list):
if exists(self._vocab_char_map):
text = list_str_to_idx(text, self._vocab_char_map)
else:
text = list_str_to_tensor(text)
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = mx.full((batch,), seq_len)
mask = lens_to_mask(lens, length=seq_len)
# get a random span to mask out for training conditionally
frac_lengths = mx.random.uniform(*self.frac_lengths_mask, (batch,))
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length=seq_len)
if exists(mask):
rand_span_mask = rand_span_mask & mask
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = mx.random.normal(x1.shape)
# time step
time = mx.random.uniform(0, 1, (batch,), dtype=dtype)
# sample xt (φ_t(x) in the paper)
t = rearrange(time, "b -> b 1 1")
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# only predict what is within the random mask span for infilling
cond = mx.where(
rand_span_mask[..., None],
mx.zeros_like(x1),
x1,
)
# transformer and cfg training with a drop rate
rand_audio_drop = mx.random.uniform(0, 1, (1,))
rand_cond_drop = mx.random.uniform(0, 1, (1,))
drop_audio_cond = rand_audio_drop < self.audio_drop_prob
drop_text = rand_cond_drop < self.cond_drop_prob
drop_audio_cond = drop_audio_cond | drop_text
pred = self.transformer(
x=φ,
cond=cond,
text=text,
time=time,
drop_audio_cond=drop_audio_cond,
drop_text=drop_text,
)
# flow matching loss
loss = nn.losses.mse_loss(pred, flow, reduction="none")
rand_span_mask = repeat(rand_span_mask, "b n -> b n d", d=self.num_channels)
masked_loss = mx.where(rand_span_mask, loss, mx.zeros_like(loss))
loss = mx.sum(masked_loss) / mx.maximum(mx.sum(rand_span_mask), 1e-6)
return loss.mean()
def predict_duration(
self,
cond: mx.array["b n d"],
text: mx.array["b nt"],
speed: float = 1.0,
) -> int:
duration_in_sec = self._duration_predictor(cond, text)
frame_rate = self._mel_spec.sample_rate // self._mel_spec.hop_length
duration = (duration_in_sec * frame_rate / speed).astype(mx.int32)
return duration
def sample(
self,
cond: mx.array["b n d"] | mx.array["b nw"],
text: mx.array["b nt"] | list[str],
duration: int | mx.array["b"] | None = None,
*,
lens: mx.array["b"] | None = None,
steps=8,
method: Literal["euler", "midpoint", "rk4"] = "rk4",
cfg_strength=2.0,
speed=1.0,
sway_sampling_coef=-1.0,
seed: int | None = None,
max_duration=4096,
) -> tuple[mx.array, mx.array]:
self.eval()
# raw wave
if cond.ndim == 2:
cond = rearrange(cond, "1 n -> n")
cond = self._mel_spec(cond)
assert cond.shape[-1] == self.num_channels
batch, cond_seq_len, dtype = *cond.shape[:2], cond.dtype
if not exists(lens):
lens = mx.full((batch,), cond_seq_len, dtype=dtype)
# text
if isinstance(text, list):
if exists(self._vocab_char_map):
text = list_str_to_idx(text, self._vocab_char_map)
else:
text = list_str_to_tensor(text)
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(axis=-1)
lens = mx.maximum(text_lens, lens)
# duration
if duration is None and self._duration_predictor is not None:
duration = self.predict_duration(cond, text, speed)
elif duration is None:
raise ValueError("Duration must be provided or a duration predictor must be set.")
cond_mask = lens_to_mask(lens)
if isinstance(duration, int):
duration = mx.full((batch,), duration, dtype=dtype)
duration = mx.maximum(lens + 1, duration)
duration = mx.clip(duration, 0, max_duration)
max_duration = int(duration.max().item())
cond = mx.pad(cond, [(0, 0), (0, max_duration - cond_seq_len), (0, 0)])
cond_mask = mx.pad(
cond_mask,
[(0, 0), (0, max_duration - cond_mask.shape[-1])],
constant_values=False,
)
cond_mask = rearrange(cond_mask, "... -> ... 1")
# at each step, conditioning is fixed
step_cond = mx.where(cond_mask, cond, mx.zeros_like(cond))
if batch > 1:
mask = lens_to_mask(duration)
else:
mask = None
# neural ode
def fn(t, x):
# predict flow
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=True,
drop_text=True,
)
output = pred + (pred - null_pred) * cfg_strength
return output
# noise input
y0 = []
for dur in duration:
if exists(seed):
mx.random.seed(seed)
y0.append(mx.random.normal((self.num_channels, dur)))
y0 = pad_sequence(y0, padding_value=0)
y0 = rearrange(y0, "b d n -> b n d")
t_start = 0
t = mx.linspace(t_start, 1, steps)
if exists(sway_sampling_coef):
t = t + sway_sampling_coef * (mx.cos(mx.pi / 2 * t) - 1 + t)
if method == "midpoint":
ode_step_fn = odeint_midpoint
elif method == "euler":
ode_step_fn = odeint_euler
elif method == "rk4":
ode_step_fn = odeint_rk4
else:
raise ValueError(f"Unknown method: {method}")
fn = mx.compile(fn)
trajectory = ode_step_fn(fn, y0, t)
sampled = trajectory[-1]
out = sampled
out = mx.where(cond_mask, cond, out)
if exists(self._vocoder):
out = self._vocoder(out)
return out, trajectory
@classmethod
def from_pretrained(
cls,
hf_model_name_or_path: str,
convert_weights=None,
quantization_bits: int | None = None,
) -> F5TTS:
path = fetch_from_hub(hf_model_name_or_path, quantization_bits=quantization_bits)
if path is None:
raise ValueError(f"Could not find model {hf_model_name_or_path}")
# vocab
vocab_path = path / "vocab.txt"
vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}
if len(vocab) == 0:
raise ValueError(f"Could not load vocab from {vocab_path}")
# duration predictor
duration_model_path = path / "duration_v2.safetensors"
duration_predictor = None
if duration_model_path.exists():
duration_predictor = DurationPredictor(
transformer=DurationTransformer(
dim=512,
depth=8,
heads=8,
text_dim=512,
ff_mult=2,
conv_layers=2,
text_num_embeds=len(vocab) - 1,
),
vocab_char_map=vocab,
)
weights = mx.load(duration_model_path.as_posix(), format="safetensors")
duration_predictor.load_weights(list(weights.items()))
# vocoder
vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")
# model
model_filename = "model_v1.safetensors"
if exists(quantization_bits):
model_filename = f"model_v1_{quantization_bits}b.safetensors"
convert_weights = False
else:
convert_weights = default(convert_weights, True)
model_path = path / model_filename
f5tts = F5TTS(
transformer=DiT(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_num_embeds=len(vocab) - 1,
text_mask_padding=True,
),
vocab_char_map=vocab,
vocoder=vocos.decode,
duration_predictor=duration_predictor,
)
weights = mx.load(model_path.as_posix(), format="safetensors")
if convert_weights:
new_weights = {}
for k, v in weights.items():
k = k.replace("ema_model.", "")
# rename layers
if len(k) < 1 or "mel_spec." in k or k in ("initted", "step"):
continue
elif ".to_out" in k:
k = k.replace(".to_out", ".to_out.layers")
elif ".text_blocks" in k:
k = k.replace(".text_blocks", ".text_blocks.layers")
elif ".ff.ff.0.0" in k:
k = k.replace(".ff.ff.0.0", ".ff.ff.layers.0.layers.0")
elif ".ff.ff.2" in k:
k = k.replace(".ff.ff.2", ".ff.ff.layers.2")
elif ".time_mlp" in k:
k = k.replace(".time_mlp", ".time_mlp.layers")
elif ".conv1d" in k:
k = k.replace(".conv1d", ".conv1d.layers")
# reshape weights
if ".dwconv.weight" in k:
v = v.swapaxes(1, 2)
elif ".conv1d.layers.0.weight" in k:
v = v.swapaxes(1, 2)
elif ".conv1d.layers.2.weight" in k:
v = v.swapaxes(1, 2)
new_weights[k] = v
weights = new_weights
if quantization_bits is not None:
nn.quantize(
f5tts,
bits=quantization_bits,
class_predicate=lambda p, m: (isinstance(m, nn.Linear) and m.weight.shape[1] % 64 == 0),
)
f5tts.load_weights(list(weights.items()))
mx.eval(f5tts.parameters())
return f5tts