Skip to content

Commit a218ee9

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Fixes
1 parent 4966daf commit a218ee9

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

Diff for: coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -9504,9 +9504,8 @@ def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_leng
95049504
class STFTModel(torch.nn.Module):
95059505
def forward(self, x):
95069506
applied_window = window(win_length) if window and win_length else None
9507-
x = torch.complex(x, x) if complex else x
95089507
x = torch.stft(
9509-
x,
9508+
torch.complex(x, x) if complex else x,
95109509
n_fft=n_fft,
95119510
hop_length=hop_length,
95129511
win_length=win_length,
@@ -9534,28 +9533,26 @@ class TestISTFT(TorchBaseTest):
95349533
compute_units,
95359534
backends,
95369535
[(1, 32, 9), (32, 9), (3, 32, 9)], # input shape
9537-
[False, True], # complex
95389536
[16], # n_fft
95399537
[None, 4, 5], # hop_length
95409538
[None, 16, 9], # win_length
95419539
[None, torch.hann_window], # window
95429540
[None, False, True], # center
9543-
["constant", "reflect", "replicate"], # pad mode
95449541
[False, True], # normalized
95459542
[None, False, True], # onesided
95469543
[None, 60], # length
9544+
[False, True], # return_complex
95479545
)
95489546
)
9549-
def test_istft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
9550-
if complex and onesided:
9551-
pytest.skip("Onesided stft not possible for complex inputs")
9547+
def test_istft(self, compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex):
9548+
if return_complex and onesided:
9549+
pytest.skip("Complex output is incompatible with onesided")
95529550

95539551
class ISTFTModel(torch.nn.Module):
95549552
def forward(self, x):
95559553
applied_window = window(win_length) if window and win_length else None
9556-
x = torch.complex(x, x)
95579554
x = torch.istft(
9558-
x,
9555+
torch.complex(x, x),
95599556
n_fft=n_fft,
95609557
hop_length=hop_length,
95619558
win_length=win_length,
@@ -9564,8 +9561,9 @@ def forward(self, x):
95649561
normalized=normalized,
95659562
onesided=onesided,
95669563
length=length,
9567-
return_complex=True)
9568-
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
9564+
return_complex=return_complex)
9565+
if return_complex:
9566+
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
95699567
return x
95709568

95719569
TorchBaseTest.run_compare_torch(

Diff for: coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,7 @@ class complex_istft(Operation):
893893
894894
Attributes
895895
----------
896+
V: complex64
896897
T: fp32, complex64
897898
898899
References
@@ -901,7 +902,7 @@ class complex_istft(Operation):
901902
"""
902903

903904
input_spec = InputSpec(
904-
input=TensorInputType(type_domain="T"),
905+
input=TensorInputType(type_domain="V"),
905906
n_fft=TensorInputType(const=True, type_domain=types.int32),
906907
hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
907908
win_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
@@ -912,7 +913,7 @@ class complex_istft(Operation):
912913
)
913914

914915
type_domains = {
915-
"T": (types.fp32, types.complex64),
916+
"V": types.complex64,
916917
}
917918

918919
def default_inputs(self):
@@ -937,7 +938,6 @@ def type_inference(self):
937938
output_shape += [self.length]
938939
return types.tensor(output_type, tuple(output_shape))
939940

940-
941941
n_frames = self.input.shape[-1]
942942
output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1)
943943

Diff for: coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

+35-30
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _stft(
325325
We can write STFT in terms of convolutions with a DFT kernel.
326326
At the end:
327327
* The real part output is: cos_base * input_real + sin_base * input_imag
328-
* The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
328+
* The imaginary part output is: cos_base * input_imag - sin_base * input_real
329329
Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
330330
"""
331331
hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op)
@@ -342,14 +342,13 @@ def _stft(
342342

343343
# create a window of centered 1s of the requested size
344344
if win_length:
345-
window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op)
345+
window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op)
346346

347347
# apply time window
348348
if window:
349349
cos_base = mb.mul(x=window, y=cos_base, before_op=before_op)
350350
sin_base = mb.mul(x=window, y=sin_base, before_op=before_op)
351351

352-
353352
# Expand
354353
cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op)
355354
sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op)
@@ -358,12 +357,13 @@ def _stft(
358357
if input_imaginary:
359358
signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op)
360359

361-
# conv with DFT kernel across the input signal
362-
# The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
363-
# DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
364-
# If x is complex then x[n]=(a+i*b)
365-
# So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
366-
# So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
360+
# Convolve the DFT kernel with the input signal
361+
# DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n])
362+
# real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k))
363+
# imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k))
364+
# But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k):
365+
# real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k))
366+
# imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k))
367367
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
368368
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
369369
if input_imaginary:
@@ -372,11 +372,11 @@ def _stft(
372372

373373
# add everything together
374374
if input_imaginary:
375-
real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
376-
imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
375+
real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
376+
imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
377377
else:
378378
real_result = cos_windows_real
379-
imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op)
379+
imag_result = sin_windows_real
380380

381381
# reduce the rank of the output
382382
if should_increase_rank:
@@ -417,17 +417,18 @@ def _istft(
417417
# By default, use the entire frame
418418
win_length = win_length or n_fft
419419

420-
input_shape = mb.shape(x=x, before_op=before_op)
421-
n_frames = input_shape.val[-1]
422-
fft_size = input_shape.val[-2]
423-
# expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
420+
input_shape = mb.shape(x=input_real, before_op=before_op)
421+
channels = input_shape.val[0]
422+
fft_size = input_shape.val[1]
423+
n_frames = input_shape.val[2]
424+
expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
424425

425426
is_onesided = onesided.val if onesided else fft_size != n_fft
426427
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)
427428

428429
# create a window of centered 1s of the requested size
429430
if win_length:
430-
window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op)
431+
window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op)
431432

432433
# apply time window
433434
if window:
@@ -447,14 +448,13 @@ def _istft(
447448
signal_real = mb.mul(x=signal_real, y=multiplier, before_op=before_op)
448449
signal_imaginary = mb.mul(x=signal_imaginary, y=multiplier, before_op=before_op)
449450

450-
# Conv with DFT kernel across the input signal
451-
# We can describe the IDFT in terms of DFT just by swapping the input and output
451+
# Convolve the DFT kernel with the input signal
452+
# We can describe the IDFT in terms of DFT just by swapping the input and output.
452453
# ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
453-
# So IDFT(x) = (1/N) * swap(DFT(swap(x)))
454-
# and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i)
455-
# If x is complex then x[n]=(a+i*b)
456-
# then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
457-
# then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
454+
# IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N
455+
# So using the definition in stft function, we get:
456+
# real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
457+
# imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
458458
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
459459
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
460460
cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
@@ -519,6 +519,7 @@ def _overlap_add(
519519
def _get_window(
520520
win_length: Var,
521521
n_fft: Var,
522+
window: Optional[Var],
522523
before_op: Operation,
523524
) -> Var:
524525
n_left = (n_fft.val - win_length.val) // 2
@@ -750,17 +751,21 @@ def _lower_complex_istft(op: Operation):
750751
is_complex = types.is_complex(op.input.dtype)
751752

752753
# check parameters for validity
754+
if is_complex:
755+
raise ValueError("Only complex inputs are allowed")
753756
if op.win_length and op.win_length.val > op.n_fft.val:
754757
raise ValueError("Window length must be less than or equal to n_fft")
755-
if is_complex and op.onesided and op.onesided.val:
756-
raise ValueError("Onesided is only valid for real inputs")
758+
if op.return_complex and op.onesided and op.onesided.val:
759+
raise ValueError("Complex output is not compatible with onesided")
757760

758761
real, imag = _istft(
759-
op.input.real if is_complex else op.input,
760-
op.input.imag if is_complex else None,
761-
op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, before_op=op)
762+
op.input.real, op.input.imag,
763+
op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, op.length, before_op=op)
762764

763-
return _wrap_complex_output(op.outputs[0], real, imag)
765+
if op.return_complex:
766+
return _wrap_complex_output(op.outputs[0], real, imag)
767+
else
768+
return real
764769

765770

766771
@LowerComplex.register_lower_func(op_type="complex_shape")

0 commit comments

Comments
 (0)