Skip to content

Commit 44a98ab

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
More fixes
1 parent 1de46b2 commit 44a98ab

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

Diff for: coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -9559,13 +9559,13 @@ def forward(self, x):
95599559
[None, 1, 3], # channels
95609560
[16, 32], # n_fft
95619561
[5, 9], # num_frames
9562-
[None, 4, 5], # hop_length
9562+
[None, 5], # hop_length
95639563
[None, 10, 8], # win_length
95649564
[None, torch.hann_window], # window
95659565
[False, True], # center
95669566
[False, True], # normalized
95679567
[None, False, True], # onesided
9568-
[None, 30, 40], # length
9568+
[None, "shorter", "larger"], # length
95699569
[False, True], # return_complex
95709570
)
95719571
)
@@ -9576,9 +9576,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
95769576
if hop_length is None and win_length is not None:
95779577
pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length")
95789578

9579+
# Compute input_shape to generate test case
95799580
freq = n_fft//2+1 if onesided else n_fft
95809581
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
95819582

9583+
# If not set,c ompute hop_length for capturing errors
9584+
if hop_length is None:
9585+
hop_length = n_fft // 4
9586+
9587+
if length == "shorter":
9588+
length = n_fft//2 + hop_length * (num_frames - 1)
9589+
elif length == "larger":
9590+
length = n_fft*3//2 + hop_length * (num_frames - 1)
9591+
95829592
class ISTFTModel(torch.nn.Module):
95839593
def forward(self, x):
95849594
applied_window = window(win_length) if window and win_length else None
@@ -9598,7 +9608,7 @@ def forward(self, x):
95989608
else:
95999609
return torch.real(x)
96009610

9601-
if win_length and center is False:
9611+
if (center is False and win_length) or (center and win_length and length):
96029612
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
96039613
with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"):
96049614
TorchBaseTest.run_compare_torch(
@@ -9607,7 +9617,7 @@ def forward(self, x):
96079617
backend=backend,
96089618
compute_unit=compute_unit
96099619
)
9610-
elif length is not None and return_complex is True:
9620+
elif length and return_complex:
96119621
with pytest.raises(ValueError, match="New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`"):
96129622
TorchBaseTest.run_compare_torch(
96139623
input_shape,

0 commit comments

Comments
 (0)