diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 66b55c277..153a5d4c0 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9596,13 +9596,13 @@ def forward(self, x): [None, 1, 3], # channels [16, 32], # n_fft [5, 9], # num_frames - [None, 4, 5], # hop_length + [None, 5], # hop_length [None, 10, 8], # win_length [None, torch.hann_window], # window [False, True], # center [False, True], # normalized [None, False, True], # onesided - [None, 30, 40], # length + [None, "shorter", "larger"], # length [False, True], # return_complex ) ) @@ -9613,9 +9613,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len if hop_length is None and win_length is not None: pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length") + # Compute input_shape to generate test case freq = n_fft//2+1 if onesided else n_fft input_shape = (channels, freq, num_frames) if channels else (freq, num_frames) + # If not set,c ompute hop_length for capturing errors + if hop_length is None: + hop_length = n_fft // 4 + + if length == "shorter": + length = n_fft//2 + hop_length * (num_frames - 1) + elif length == "larger": + length = n_fft*3//2 + hop_length * (num_frames - 1) + class ISTFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None @@ -9635,7 +9645,7 @@ def forward(self, x): else: return torch.real(x) - if win_length and center is False: + if (center is False and win_length) or (center and win_length and length): # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033 with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"): TorchBaseTest.run_compare_torch( @@ -9644,7 +9654,7 @@ def forward(self, x): backend=backend, compute_unit=compute_unit ) - elif length is not None and return_complex is True: + elif length and return_complex: with pytest.raises(ValueError, match="New var type `.tensor'>` not a subtype of existing var type `.tensor'>`"): TorchBaseTest.run_compare_torch( input_shape,