Skip to content

Commit f9e4ee3

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Try adding length
1 parent 3858b63 commit f9e4ee3

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def _istft(
399399
window: Optional[Var],
400400
normalized: Optional[Var],
401401
onesided: Optional[Var],
402+
length: Optional[Var],
402403
before_op: Operation,
403404
) -> Tuple[Var, Var]:
404405
"""
@@ -419,7 +420,7 @@ def _istft(
419420
input_shape = mb.shape(x=x, before_op=before_op)
420421
n_frames = input_shape.val[-1]
421422
fft_size = input_shape.val[-2]
422-
expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
423+
# expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
423424

424425
is_onesided = onesided.val if onesided else fft_size != n_fft
425426
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)
@@ -478,10 +479,14 @@ def _istft(
478479
real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op)
479480
imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op)
480481

481-
# reduce the rank of the output
482-
if should_increase_rank:
483-
real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op)
484-
imag_result = mb.squeeze(x=imag_result, axes=(0,), before_op=before_op)
482+
# We need to adapt last dimension
483+
if length is not None:
484+
if length > expected_output_signal_len:
485+
real_result = mb.pad(x=real_result, pad=, mode="constant", constant_val=0, before_op=before_op)
486+
imag_result = mb.pad(x=imag_result, pad=, mode="constant", constant_val=0, before_op=before_op)
487+
elif length < expected_output_signal_len:
488+
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length], before_op=before_op)
489+
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length], before_op=before_op)
485490

486491
return real_result, imag_result
487492

0 commit comments

Comments
 (0)