@@ -427,7 +427,7 @@ def _istft(
427
427
428
428
expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
429
429
430
- is_onesided = onesided . val if onesided else fft_size != n_fft
430
+ is_onesided = True if fft_size != n_fft . val else onesided and onesided . val
431
431
cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
432
432
433
433
# create a window of centered 1s of the requested size
@@ -481,20 +481,18 @@ def _istft(
481
481
window_envelope = _overlap_add (x = window_mtx , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
482
482
real_result = mb .real_div (x = real_result , y = window_envelope , before_op = before_op )
483
483
imag_result = mb .real_div (x = imag_result , y = window_envelope , before_op = before_op )
484
-
485
484
# We need to adapt last dimension
486
485
if length is not None :
487
486
if length .val > expected_output_signal_len :
487
+ real_result = mb .pad (x = real_result , pad = (0 , length .val - expected_output_signal_len ), before_op = before_op )
488
+ imag_result = mb .pad (x = imag_result , pad = (0 , length .val - expected_output_signal_len ), before_op = before_op )
489
+ elif length .val < expected_output_signal_len :
488
490
if channels :
489
- right_pad = mb .fill (shape = (channels , length .val - expected_output_signal_len ), value = 0. , before_op = before_op )
491
+ real_result = mb .slice_by_size (x = real_result , begin = [0 ,0 ], size = [- 1 , length .val ], before_op = before_op )
492
+ imag_result = mb .slice_by_size (x = imag_result , begin = [0 ,0 ], size = [- 1 , length .val ], before_op = before_op )
490
493
else :
491
- right_pad = mb .fill (shape = (length .val - expected_output_signal_len ,), value = 0. , before_op = before_op )
492
-
493
- real_result = mb .stack (values = (real_result , right_pad ), axis = 1 , before_op = before_op )
494
- imag_result = mb .stack (values = (imag_result , right_pad ), axis = 1 , before_op = before_op )
495
- elif length .val < expected_output_signal_len :
496
- real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length .val ], before_op = before_op )
497
- imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length .val ], before_op = before_op )
494
+ real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length .val ], before_op = before_op )
495
+ imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length .val ], before_op = before_op )
498
496
499
497
return real_result , imag_result
500
498
0 commit comments