@@ -325,7 +325,7 @@ def _stft(
325
325
We can write STFT in terms of convolutions with a DFT kernel.
326
326
At the end:
327
327
* The real part output is: cos_base * input_real + sin_base * input_imag
328
- * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
328
+ * The imaginary part output is: cos_base * input_imag - sin_base * input_real
329
329
Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
330
330
"""
331
331
hop_length = hop_length or mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
@@ -342,14 +342,13 @@ def _stft(
342
342
343
343
# create a window of centered 1s of the requested size
344
344
if win_length :
345
- window = _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
345
+ window = _get_window (win_length = win_length , n_fft = n_fft , window = window , before_op = before_op )
346
346
347
347
# apply time window
348
348
if window :
349
349
cos_base = mb .mul (x = window , y = cos_base , before_op = before_op )
350
350
sin_base = mb .mul (x = window , y = sin_base , before_op = before_op )
351
351
352
-
353
352
# Expand
354
353
cos_base = mb .expand_dims (x = cos_base , axes = (1 ,), before_op = before_op )
355
354
sin_base = mb .expand_dims (x = sin_base , axes = (1 ,), before_op = before_op )
@@ -358,12 +357,13 @@ def _stft(
358
357
if input_imaginary :
359
358
signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
360
359
361
- # conv with DFT kernel across the input signal
362
- # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
363
- # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
364
- # If x is complex then x[n]=(a+i*b)
365
- # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
366
- # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
360
+ # Convolve the DFT kernel with the input signal
361
+ # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n])
362
+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k))
363
+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k))
364
+ # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k):
365
+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k))
366
+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k))
367
367
cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
368
368
sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
369
369
if input_imaginary :
@@ -372,11 +372,11 @@ def _stft(
372
372
373
373
# add everything together
374
374
if input_imaginary :
375
- real_result = mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376
- imag_result = mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
375
+ real_result = mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376
+ imag_result = mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
377
377
else :
378
378
real_result = cos_windows_real
379
- imag_result = mb . sub ( x = 0. , y = sin_windows_real , before_op = before_op )
379
+ imag_result = sin_windows_real
380
380
381
381
# reduce the rank of the output
382
382
if should_increase_rank :
@@ -417,17 +417,18 @@ def _istft(
417
417
# By default, use the entire frame
418
418
win_length = win_length or n_fft
419
419
420
- input_shape = mb .shape (x = x , before_op = before_op )
421
- n_frames = input_shape .val [- 1 ]
422
- fft_size = input_shape .val [- 2 ]
423
- # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
420
+ input_shape = mb .shape (x = input_real , before_op = before_op )
421
+ channels = input_shape .val [0 ]
422
+ fft_size = input_shape .val [1 ]
423
+ n_frames = input_shape .val [2 ]
424
+ expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
424
425
425
426
is_onesided = onesided .val if onesided else fft_size != n_fft
426
427
cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
427
428
428
429
# create a window of centered 1s of the requested size
429
430
if win_length :
430
- window = _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
431
+ window = _get_window (win_length = win_length , n_fft = n_fft , window = window , before_op = before_op )
431
432
432
433
# apply time window
433
434
if window :
@@ -447,14 +448,13 @@ def _istft(
447
448
signal_real = mb .mul (x = signal_real , y = multiplier , before_op = before_op )
448
449
signal_imaginary = mb .mul (x = signal_imaginary , y = multiplier , before_op = before_op )
449
450
450
- # Conv with DFT kernel across the input signal
451
- # We can describe the IDFT in terms of DFT just by swapping the input and output
451
+ # Convolve the DFT kernel with the input signal
452
+ # We can describe the IDFT in terms of DFT just by swapping the input and output.
452
453
# ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
453
- # So IDFT(x) = (1/N) * swap(DFT(swap(x)))
454
- # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i)
455
- # If x is complex then x[n]=(a+i*b)
456
- # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
457
- # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
454
+ # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N
455
+ # So using the definition in stft function, we get:
456
+ # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
457
+ # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
458
458
cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459
459
sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
460
460
cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
@@ -519,6 +519,7 @@ def _overlap_add(
519
519
def _get_window (
520
520
win_length : Var ,
521
521
n_fft : Var ,
522
+ window : Optional [Var ],
522
523
before_op : Operation ,
523
524
) -> Var :
524
525
n_left = (n_fft .val - win_length .val ) // 2
@@ -750,17 +751,21 @@ def _lower_complex_istft(op: Operation):
750
751
is_complex = types .is_complex (op .input .dtype )
751
752
752
753
# check parameters for validity
754
+ if is_complex :
755
+ raise ValueError ("Only complex inputs are allowed" )
753
756
if op .win_length and op .win_length .val > op .n_fft .val :
754
757
raise ValueError ("Window length must be less than or equal to n_fft" )
755
- if is_complex and op .onesided and op .onesided .val :
756
- raise ValueError ("Onesided is only valid for real inputs " )
758
+ if op . return_complex and op .onesided and op .onesided .val :
759
+ raise ValueError ("Complex output is not compatible with onesided " )
757
760
758
761
real , imag = _istft (
759
- op .input .real if is_complex else op .input ,
760
- op .input .imag if is_complex else None ,
761
- op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , before_op = op )
762
+ op .input .real , op .input .imag ,
763
+ op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , op .length , before_op = op )
762
764
763
- return _wrap_complex_output (op .outputs [0 ], real , imag )
765
+ if op .return_complex :
766
+ return _wrap_complex_output (op .outputs [0 ], real , imag )
767
+ else
768
+ return real
764
769
765
770
766
771
@LowerComplex .register_lower_func (op_type = "complex_shape" )
0 commit comments