@@ -9559,13 +9559,13 @@ def forward(self, x):
9559
9559
[None , 1 , 3 ], # channels
9560
9560
[16 , 32 ], # n_fft
9561
9561
[5 , 9 ], # num_frames
9562
- [None , 4 , 5 ], # hop_length
9562
+ [None , 5 ], # hop_length
9563
9563
[None , 10 , 8 ], # win_length
9564
9564
[None , torch .hann_window ], # window
9565
9565
[False , True ], # center
9566
9566
[False , True ], # normalized
9567
9567
[None , False , True ], # onesided
9568
- [None , 30 , 40 ], # length
9568
+ [None , "shorter" , "larger" ], # length
9569
9569
[False , True ], # return_complex
9570
9570
)
9571
9571
)
@@ -9576,9 +9576,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
9576
9576
if hop_length is None and win_length is not None :
9577
9577
pytest .skip ("If win_length is set then we must set hop_length and 0 < hop_length <= win_length" )
9578
9578
9579
+ # Compute input_shape to generate test case
9579
9580
freq = n_fft // 2 + 1 if onesided else n_fft
9580
9581
input_shape = (channels , freq , num_frames ) if channels else (freq , num_frames )
9581
9582
9583
+ # If not set,c ompute hop_length for capturing errors
9584
+ if hop_length is None :
9585
+ hop_length = n_fft // 4
9586
+
9587
+ if length == "shorter" :
9588
+ length = n_fft // 2 + hop_length * (num_frames - 1 )
9589
+ elif length == "larger" :
9590
+ length = n_fft * 3 // 2 + hop_length * (num_frames - 1 )
9591
+
9582
9592
class ISTFTModel (torch .nn .Module ):
9583
9593
def forward (self , x ):
9584
9594
applied_window = window (win_length ) if window and win_length else None
@@ -9598,7 +9608,7 @@ def forward(self, x):
9598
9608
else :
9599
9609
return torch .real (x )
9600
9610
9601
- if win_length and center is False :
9611
+ if ( center is False and win_length ) or ( center and win_length and length ) :
9602
9612
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
9603
9613
with pytest .raises (RuntimeError , match = "istft\(.*\) window overlap add min: 1" ):
9604
9614
TorchBaseTest .run_compare_torch (
@@ -9607,7 +9617,7 @@ def forward(self, x):
9607
9617
backend = backend ,
9608
9618
compute_unit = compute_unit
9609
9619
)
9610
- elif length is not None and return_complex is True :
9620
+ elif length and return_complex :
9611
9621
with pytest .raises (ValueError , match = "New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`" ):
9612
9622
TorchBaseTest .run_compare_torch (
9613
9623
input_shape ,
0 commit comments