Skip to content

Commit 186110c

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Fixes
1 parent 866d61e commit 186110c

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9597,7 +9597,7 @@ def forward(self, x):
95979597
[16, 32], # n_fft
95989598
[5, 9], # num_frames
95999599
[None, 4, 5], # hop_length
9600-
[None, 16, 9], # win_length
9600+
[None, 10, 8], # win_length
96019601
[None, torch.hann_window], # window
96029602
[False, True], # center
96039603
[False, True], # normalized
@@ -9610,6 +9610,9 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
96109610
if return_complex and onesided:
96119611
pytest.skip("Complex output is incompatible with onesided")
96129612

9613+
if hop_length is None and win_length is not None:
9614+
pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length")
9615+
96139616
freq = n_fft//2+1 if onesided else n_fft
96149617
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
96159618

@@ -9628,24 +9631,13 @@ def forward(self, x):
96289631
length=length,
96299632
return_complex=return_complex)
96309633
if return_complex:
9631-
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
9632-
return x
9634+
return torch.stack([torch.real(x), torch.imag(x)], dim=0)
9635+
else:
9636+
return torch.real(x)
96339637

9634-
if length is not None or center is False:
9638+
if win_length and center is False:
96359639
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
9636-
with pytest.raises(
9637-
RuntimeError, match="istft\(.*\) window overlap add min: 1"
9638-
):
9639-
TorchBaseTest.run_compare_torch(
9640-
input_shape,
9641-
ISTFTModel(),
9642-
backend=backend,
9643-
compute_unit=compute_unit
9644-
)
9645-
elif return_complex is False:
9646-
with pytest.raises(
9647-
ValueError, match="MIL doesn't support complex data as model's output"
9648-
):
9640+
with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"):
96499641
TorchBaseTest.run_compare_torch(
96509642
input_shape,
96519643
ISTFTModel(),

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def _istft(
427427

428428
expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
429429

430-
is_onesided = onesided.val if onesided else fft_size != n_fft
430+
is_onesided = True if fft_size != n_fft.val else onesided and onesided.val
431431
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)
432432

433433
# create a window of centered 1s of the requested size
@@ -481,20 +481,18 @@ def _istft(
481481
window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
482482
real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op)
483483
imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op)
484-
485484
# We need to adapt last dimension
486485
if length is not None:
487486
if length.val > expected_output_signal_len:
487+
real_result = mb.pad(x=real_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
488+
imag_result = mb.pad(x=imag_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
489+
elif length.val < expected_output_signal_len:
488490
if channels:
489-
right_pad = mb.fill(shape=(channels, length.val - expected_output_signal_len ), value=0., before_op=before_op)
491+
real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
492+
imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
490493
else:
491-
right_pad = mb.fill(shape=(length.val - expected_output_signal_len,), value=0., before_op=before_op)
492-
493-
real_result = mb.stack(values=(real_result, right_pad), axis=1, before_op=before_op)
494-
imag_result = mb.stack(values=(imag_result, right_pad), axis=1, before_op=before_op)
495-
elif length.val < expected_output_signal_len:
496-
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
497-
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)
494+
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
495+
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)
498496

499497
return real_result, imag_result
500498

0 commit comments

Comments
 (0)