Skip to content

Commit 3858b63

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
updates
1 parent c956bdb commit 3858b63

File tree

1 file changed

+36
-27
lines changed

1 file changed

+36
-27
lines changed

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -403,25 +403,25 @@ def _istft(
403403
) -> Tuple[Var, Var]:
404404
"""
405405
We can write ISTFT in terms of convolutions with a DFT kernel.
406-
At the end:
407-
* The real part output is: cos_base * input_real + sin_base * input_imag
408-
* The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
409-
Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
406+
407+
The input has shape (channels, fft_size, n_frames)
408+
409+
References:
410+
H. Zhivomirov, “On the Development of STFT-analysis and ISTFT-synthesis Routines and their Practical Implementation,” TEM Journal, vol. 8, no. 1, pp. 56–64, 2019.
411+
https://en.wikipedia.org/wiki/Discrete_Fourier_transform
410412
"""
411413
# Set the default hop, if it's not already specified
412414
hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op)
413415

414416
# By default, use the entire frame
415417
win_length = win_length or n_fft
416418

417-
# input should always be 2D
418-
should_increase_rank = input_real.rank == 1
419-
if should_increase_rank:
420-
input_real = mb.expand_dims(x=input_real, axes=(0,), before_op=before_op)
421-
if input_imaginary:
422-
input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op)
419+
input_shape = mb.shape(x=x, before_op=before_op)
420+
n_frames = input_shape.val[-1]
421+
fft_size = input_shape.val[-2]
422+
expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
423423

424-
is_onesided = onesided and onesided.val
424+
is_onesided = onesided.val if onesided else fft_size != n_fft
425425
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)
426426

427427
# create a window of centered 1s of the requested size
@@ -433,32 +433,34 @@ def _istft(
433433
cos_base = mb.mul(x=window, y=cos_base, before_op=before_op)
434434
sin_base = mb.mul(x=window, y=sin_base, before_op=before_op)
435435

436-
# The DFT matrix is obtained with the equation e^(2pi/N i), which is what we want but we actually need the conjuate => e^(-2pi/N i)
437-
# or in terms of cos and sin => cos+i*sin cos-i*sin
438-
sin_base = mb.sub(x=0., y=sin_base, before_op=before_op)
439-
440436
cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op)
441437
sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op)
442438
hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op)
443439

444440
signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op)
445441
signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op)
446442

443+
# De-normalized signal before applying the IFT
444+
if normalized and normalized.val:
445+
multiplier = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op)
446+
signal_real = mb.mul(x=signal_real, y=multiplier, before_op=before_op)
447+
signal_imaginary = mb.mul(x=signal_imaginary, y=multiplier, before_op=before_op)
448+
447449
# Conv with DFT kernel across the input signal
448450
# We can describe the IDFT in terms of DFT just by swapping the input and output
449451
# ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
450452
# So IDFT(x) = (1/N) * swap(DFT(swap(x)))
451-
# DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
453+
# and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i)
452454
# If x is complex then x[n]=(a+i*b)
453-
# So the real part = (1/N)*Σ(a*cos(2kpi/N)-b*sin(2kpi/N))
454-
# So the imag part = (1/N)*Σ(b*cos(2kpi/N)+a*sin(2kpi/N))
455+
# then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
456+
# then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
455457
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
456458
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
457459
cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
458460
sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
459461

460-
real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
461-
imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
462+
real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
463+
imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
462464

463465
# Divide by N
464466
real_result = mb.real_div(x=real_result, y=n_fft, before_op=before_op)
@@ -472,10 +474,9 @@ def _istft(
472474
n_frames = mb.shape(x=real_result, before_op=before_op)[1]
473475
window_square = mb.mul(x=window, y=window, before_op=before_op)
474476
window_mtx = mb.stack(values=[window_square] * n_frames, axis=1)
475-
normalization_factor = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
476-
477-
real_result = mb.real_div(x=real_result, y=normalization_factor, before_op=before_op)
478-
imag_result = mb.real_div(x=imag_result, y=normalization_factor, before_op=before_op)
477+
window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
478+
real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op)
479+
imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op)
479480

480481
# reduce the rank of the output
481482
if should_increase_rank:
@@ -490,13 +491,21 @@ def _overlap_add(
490491
hop_length: Var,
491492
before_op: Operation,
492493
) -> Var:
493-
n_frames = mb.shape(x=x, before_op=before_op)[1]
494-
output = mb.fill(shape=(n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op)
495-
signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op)
494+
"""
495+
The input has shape (channels, fft_size, n_frames)
496+
"""
497+
input_shape = mb.shape(x=x, before_op=before_op)
498+
channels = input_shape.val[0]
499+
n_frames = input_shape.val[2]
500+
501+
output = mb.fill(shape=(channels, n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op)
502+
signal_frames = mb.split(x=x, num_splits=n_frames, axis=2, before_op=before_op)
496503
local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op)
497504

498505
for frame_num, frame in enumerate(signal_frames):
499506
global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op)
507+
global_idx = mb.expand_dims(x=global_idx, axes=(0,), before_op=before_op)
508+
global_idx = mb.stack(values=[global_idx] * channels, axis=0)
500509
output = mb.scatter_nd(data=output, indices=global_idx, updates=frame, before_op=before_op)
501510

502511
return output

0 commit comments

Comments
 (0)