Skip to content

Commit 8cb63ed

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
More fixes
1 parent 186110c commit 8cb63ed

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9602,7 +9602,7 @@ def forward(self, x):
96029602
[False, True], # center
96039603
[False, True], # normalized
96049604
[None, False, True], # onesided
9605-
[None, 60], # length
9605+
[None, 30, 40], # length
96069606
[False, True], # return_complex
96079607
)
96089608
)
@@ -9644,6 +9644,14 @@ def forward(self, x):
96449644
backend=backend,
96459645
compute_unit=compute_unit
96469646
)
9647+
elif length is not None and return_complex is True:
9648+
with pytest.raises(ValueError, match="New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`"):
9649+
TorchBaseTest.run_compare_torch(
9650+
input_shape,
9651+
ISTFTModel(),
9652+
backend=backend,
9653+
compute_unit=compute_unit
9654+
)
96479655
else:
96489656
TorchBaseTest.run_compare_torch(
96499657
input_shape,

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ def _istft(
479479
window_mtx = mb.stack(values=[window_square] * n_frames, axis=0, before_op=before_op)
480480
window_mtx = mb.expand_dims(x=window_mtx, axes=(0,), before_op=before_op)
481481
window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
482+
483+
# After this operation if it didn't have any channels dimention it adds one
482484
real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op)
483485
imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op)
484486
# We need to adapt last dimension
@@ -487,12 +489,8 @@ def _istft(
487489
real_result = mb.pad(x=real_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
488490
imag_result = mb.pad(x=imag_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
489491
elif length.val < expected_output_signal_len:
490-
if channels:
491-
real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
492-
imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
493-
else:
494-
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
495-
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)
492+
real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
493+
imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
496494

497495
return real_result, imag_result
498496

0 commit comments

Comments
 (0)