@@ -403,25 +403,25 @@ def _istft(
403
403
) -> Tuple [Var , Var ]:
404
404
"""
405
405
We can write ISTFT in terms of convolutions with a DFT kernel.
406
- At the end:
407
- * The real part output is: cos_base * input_real + sin_base * input_imag
408
- * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
409
- Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
406
+
407
+ The input has shape (channels, fft_size, n_frames)
408
+
409
+ References:
410
+ H. Zhivomirov, “On the Development of STFT-analysis and ISTFT-synthesis Routines and their Practical Implementation,” TEM Journal, vol. 8, no. 1, pp. 56–64, 2019.
411
+ https://en.wikipedia.org/wiki/Discrete_Fourier_transform
410
412
"""
411
413
# Set the default hop, if it's not already specified
412
414
hop_length = hop_length or mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
413
415
414
416
# By default, use the entire frame
415
417
win_length = win_length or n_fft
416
418
417
- # input should always be 2D
418
- should_increase_rank = input_real .rank == 1
419
- if should_increase_rank :
420
- input_real = mb .expand_dims (x = input_real , axes = (0 ,), before_op = before_op )
421
- if input_imaginary :
422
- input_imaginary = mb .expand_dims (x = input_imaginary , axes = (0 ,), before_op = before_op )
419
+ input_shape = mb .shape (x = x , before_op = before_op )
420
+ n_frames = input_shape .val [- 1 ]
421
+ fft_size = input_shape .val [- 2 ]
422
+ expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
423
423
424
- is_onesided = onesided and onesided . val
424
+ is_onesided = onesided . val if onesided else fft_size != n_fft
425
425
cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
426
426
427
427
# create a window of centered 1s of the requested size
@@ -433,32 +433,34 @@ def _istft(
433
433
cos_base = mb .mul (x = window , y = cos_base , before_op = before_op )
434
434
sin_base = mb .mul (x = window , y = sin_base , before_op = before_op )
435
435
436
- # The DFT matrix is obtained with the equation e^(2pi/N i), which is what we want but we actually need the conjuate => e^(-2pi/N i)
437
- # or in terms of cos and sin => cos+i*sin cos-i*sin
438
- sin_base = mb .sub (x = 0. , y = sin_base , before_op = before_op )
439
-
440
436
cos_base = mb .expand_dims (x = cos_base , axes = (1 ,), before_op = before_op )
441
437
sin_base = mb .expand_dims (x = sin_base , axes = (1 ,), before_op = before_op )
442
438
hop_size = mb .expand_dims (x = hop_length , axes = (0 ,), before_op = before_op )
443
439
444
440
signal_real = mb .expand_dims (x = input_real , axes = (1 ,), before_op = before_op )
445
441
signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
446
442
443
+ # De-normalized signal before applying the IFT
444
+ if normalized and normalized .val :
445
+ multiplier = mb .sqrt (x = mb .cast (x = n_fft , dtype = "fp32" , before_op = before_op ), before_op = before_op )
446
+ signal_real = mb .mul (x = signal_real , y = multiplier , before_op = before_op )
447
+ signal_imaginary = mb .mul (x = signal_imaginary , y = multiplier , before_op = before_op )
448
+
447
449
# Conv with DFT kernel across the input signal
448
450
# We can describe the IDFT in terms of DFT just by swapping the input and output
449
451
# ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
450
452
# So IDFT(x) = (1/N) * swap(DFT(swap(x)))
451
- # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
453
+ # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^( 2kpi/N i)
452
454
# If x is complex then x[n]=(a+i*b)
453
- # So the real part = (1/N)*Σ(a*cos(2kpi/N)- b*sin(2kpi/N))
454
- # So the imag part = (1/N)*Σ(b*cos(2kpi/N)+ a*sin(2kpi/N))
455
+ # then real part = (1/N)*Σ(a*cos(2kpi/N)+ b*sin(2kpi/N))
456
+ # then imag part = (1/N)*Σ(b*cos(2kpi/N)- a*sin(2kpi/N))
455
457
cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
456
458
sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
457
459
cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
458
460
sin_windows_imag = mb .conv (x = signal_imaginary , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459
461
460
- real_result = mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
461
- imag_result = mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
462
+ real_result = mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
463
+ imag_result = mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
462
464
463
465
# Divide by N
464
466
real_result = mb .real_div (x = real_result , y = n_fft , before_op = before_op )
@@ -472,10 +474,9 @@ def _istft(
472
474
n_frames = mb .shape (x = real_result , before_op = before_op )[1 ]
473
475
window_square = mb .mul (x = window , y = window , before_op = before_op )
474
476
window_mtx = mb .stack (values = [window_square ] * n_frames , axis = 1 )
475
- normalization_factor = _overlap_add (x = window_mtx , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
476
-
477
- real_result = mb .real_div (x = real_result , y = normalization_factor , before_op = before_op )
478
- imag_result = mb .real_div (x = imag_result , y = normalization_factor , before_op = before_op )
477
+ window_envelope = _overlap_add (x = window_mtx , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
478
+ real_result = mb .real_div (x = real_result , y = window_envelope , before_op = before_op )
479
+ imag_result = mb .real_div (x = imag_result , y = window_envelope , before_op = before_op )
479
480
480
481
# reduce the rank of the output
481
482
if should_increase_rank :
@@ -490,13 +491,21 @@ def _overlap_add(
490
491
hop_length : Var ,
491
492
before_op : Operation ,
492
493
) -> Var :
493
- n_frames = mb .shape (x = x , before_op = before_op )[1 ]
494
- output = mb .fill (shape = (n_fft .val + hop_length .val * (n_frames - 1 )), value = 0. , before_op = before_op )
495
- signal_frames = mb .split (x = x , num_splits = n_frames , axis = 1 , before_op = before_op )
494
+ """
495
+ The input has shape (channels, fft_size, n_frames)
496
+ """
497
+ input_shape = mb .shape (x = x , before_op = before_op )
498
+ channels = input_shape .val [0 ]
499
+ n_frames = input_shape .val [2 ]
500
+
501
+ output = mb .fill (shape = (channels , n_fft .val + hop_length .val * (n_frames - 1 )), value = 0. , before_op = before_op )
502
+ signal_frames = mb .split (x = x , num_splits = n_frames , axis = 2 , before_op = before_op )
496
503
local_idx = mb .range_1d (start = 0 , end = n_fft , step = 1 , before_op = before_op )
497
504
498
505
for frame_num , frame in enumerate (signal_frames ):
499
506
global_idx = mb .add (x = local_idx , y = frame_num * hop_length .val , before_op = before_op )
507
+ global_idx = mb .expand_dims (x = global_idx , axes = (0 ,), before_op = before_op )
508
+ global_idx = mb .stack (values = [global_idx ] * channels , axis = 0 )
500
509
output = mb .scatter_nd (data = output , indices = global_idx , updates = frame , before_op = before_op )
501
510
502
511
return output
0 commit comments