@@ -399,6 +399,7 @@ def _istft(
399
399
window : Optional [Var ],
400
400
normalized : Optional [Var ],
401
401
onesided : Optional [Var ],
402
+ length : Optional [Var ],
402
403
before_op : Operation ,
403
404
) -> Tuple [Var , Var ]:
404
405
"""
@@ -419,7 +420,7 @@ def _istft(
419
420
input_shape = mb .shape (x = x , before_op = before_op )
420
421
n_frames = input_shape .val [- 1 ]
421
422
fft_size = input_shape .val [- 2 ]
422
- expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
423
+ # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
423
424
424
425
is_onesided = onesided .val if onesided else fft_size != n_fft
425
426
cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
@@ -478,10 +479,14 @@ def _istft(
478
479
real_result = mb .real_div (x = real_result , y = window_envelope , before_op = before_op )
479
480
imag_result = mb .real_div (x = imag_result , y = window_envelope , before_op = before_op )
480
481
481
- # reduce the rank of the output
482
- if should_increase_rank :
483
- real_result = mb .squeeze (x = real_result , axes = (0 ,), before_op = before_op )
484
- imag_result = mb .squeeze (x = imag_result , axes = (0 ,), before_op = before_op )
482
+ # We need to adapt last dimension
483
+ if length is not None :
484
+ if length > expected_output_signal_len :
485
+ real_result = mb .pad (x = real_result , pad = , mode = "constant" , constant_val = 0 , before_op = before_op )
486
+ imag_result = mb .pad (x = imag_result , pad = , mode = "constant" , constant_val = 0 , before_op = before_op )
487
+ elif length < expected_output_signal_len :
488
+ real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length ], before_op = before_op )
489
+ imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length ], before_op = before_op )
485
490
486
491
return real_result , imag_result
487
492
0 commit comments