diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 0755d4d9d..5781e9689 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -6362,6 +6362,7 @@ def stft(context, node): Lowers torch.stft with the dialect op `complex_stft` from complex_dialect_ops.py """ input_data, n_fft, hop_length, win_length, window, normalized, onesided, _ = _get_inputs(context, node, min_expected=2) + if types.is_complex(input_data.dtype): onesided = False # pytorch defaults onesided to False for complex inputs stft_res = mb.complex_stft( @@ -6371,9 +6372,32 @@ def stft(context, node): win_length=win_length, window=window, normalized=normalized, - onesided=onesided) + onesided=onesided + ) context.add(stft_res, node.name) +@register_torch_op +def istft(context, node): + """ + Lowers torch.istft with the dialect op `complex_istft` from complex_dialect_ops.py + """ + input_data, n_fft, hop_length, win_length, window, center, normalized, onesided, length, _ = _get_inputs(context, node, min_expected=2) + + if types.is_complex(input_data.dtype): + onesided = False # pytorch defaults onesided to False for complex inputs + istft_res = mb.complex_istft( + input=input_data, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + ) + context.add(istft_res, node.name) + @register_torch_op(torch_alias=["torchvision::nms"]) def torchvision_nms(context, node): inputs = _get_inputs(context, node, expected=3) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 242ec8740..279846c2a 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9540,8 +9540,8 @@ def forward(self, x): (2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit ) + class TestSTFT(TorchBaseTest): - @pytest.mark.slow @pytest.mark.parametrize( "compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided", itertools.product( @@ -9566,9 +9566,8 @@ def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_leng class STFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None - x = torch.complex(x, x) if complex else x x = torch.stft( - x, + torch.complex(x, x) if complex else x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, @@ -9588,6 +9587,87 @@ def forward(self, x): compute_unit=compute_unit ) + @pytest.mark.parametrize( + "compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex", + itertools.product( + compute_units, + backends, + [None, 1, 3], # channels + [16, 32], # n_fft + [5, 9], # num_frames + [None, 5], # hop_length + [None, 10, 8], # win_length + [None, torch.hann_window], # window + [False, True], # center + [False, True], # normalized + [None, False, True], # onesided + [None, "shorter", "larger"], # length + [False, True], # return_complex + ) + ) + def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex): + if return_complex and onesided: + pytest.skip("Complex output is incompatible with onesided") + + if hop_length is None and win_length is not None: + pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length") + + # Compute input_shape to generate test case + freq = n_fft//2+1 if onesided else n_fft + input_shape = (channels, freq, num_frames) if channels else (freq, num_frames) + + # If not set,c ompute hop_length for capturing errors + if hop_length is None: + hop_length = n_fft // 4 + + if length == "shorter": + length = n_fft//2 + hop_length * (num_frames - 1) + elif length == "larger": + length = n_fft*3//2 + hop_length * (num_frames - 1) + + class ISTFTModel(torch.nn.Module): + def forward(self, x): + applied_window = window(win_length) if window and win_length else None + x = torch.istft( + torch.complex(x, x), + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=applied_window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + return_complex=return_complex) + if return_complex: + return torch.stack([torch.real(x), torch.imag(x)], dim=0) + else: + return torch.real(x) + + if (center is False and win_length) or (center and win_length and length): + # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033 + with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"): + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) + elif length and return_complex: + with pytest.raises(ValueError, match="New var type `.tensor'>` not a subtype of existing var type `.tensor'>`"): + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) + else: + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) if _HAS_TORCH_AUDIO: diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index e19bf1757..75c935bc7 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -861,3 +861,80 @@ def type_inference(self): return types.tensor(output_type, tuple(output_shape)) +@register_op(namespace="complex") +class complex_istft(Operation): + """ + Dialect op for 1-D ISTFT. + + Parameters + ---------- + input: tensor<\*V, complex64> (Required) + * A complex tensor where real and imag parts have the same shape. + n_fft: const i32 (Required) + * Size of the fourier transform. + hop_length: const i32 (Optional) + * Stride between window frames of the input tensor. + win_length: const i32 (optional) + * The size of the window frame. + window: tensor<1, win_length> (optional) + * The window to apply to the input signal before performing the fourier transform. + normalized: const bool (optional, Default=``false``) + * Whether to normalize the results of the STFT + onesided: const bool (optional, Default=``true``) + * Whether the STFT was onesieded + length: const i32 (Required) + * Output fixed length, which will be zeropadded + + + Returns + ------- + tensor<\*D, T> + * The output tensor + + Attributes + ---------- + T: fp32, complex64 + + References + ---------- + See `torch.istft `_. + """ + + input_spec = InputSpec( + input=TensorInputType(type_domain=types.complex), + n_fft=TensorInputType(const=True, type_domain=types.int32), + hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32), + win_length=TensorInputType(const=True, optional=True, type_domain=types.int32), + window=TensorInputType(const=True, optional=True, type_domain=types.fp32), + center=TensorInputType(const=True, type_domain=types.bool), + normalized=TensorInputType(const=True, optional=False, type_domain=types.bool), + onesided=TensorInputType(const=True, optional=True, type_domain=types.bool), + length=TensorInputType(const=True, optional=True, type_domain=types.int32), + return_complex=TensorInputType(const=True, optional=True, type_domain=types.bool), + ) + + def default_inputs(self): + return DefaultInputs( + hop_length = None, + win_length = None, + window = None, + normalized = False, + onesided = True, + length = None, + return_complex = True, + ) + + def type_inference(self): + output_type = (types.complex64) if self.return_complex else (types.fp32) + + # add batch size if given + output_shape = [self.input.shape[0] if self.input.rank == 3 else 1] + + if self.length: + output_shape += [self.length] + else: + n_frames = self.input.shape[-1] + hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4 + output_shape += [self.n_fft.val + hop_length * (n_frames - 1)] + + return types.tensor(output_type, tuple(output_shape)) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index ed36d87f3..b3ab30bf6 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -325,7 +325,7 @@ def _stft( We can write STFT in terms of convolutions with a DFT kernel. At the end: * The real part output is: cos_base * input_real + sin_base * input_imag - * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag) + * The imaginary part output is: cos_base * input_imag - sin_base * input_real Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py """ hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op) @@ -338,52 +338,42 @@ def _stft( input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op) is_onesided = onesided and onesided.val - cos_base, sin_base = _calculate_dft_matrix( - n_fft, - onesided=is_onesided, - before_op=before_op) + cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) # create a window of centered 1s of the requested size if win_length: - n_left = (n_fft.val - win_length.val) // 2 - n_right = n_fft.val - win_length.val - n_left - - left = mb.fill(shape=(n_left,), value=0., before_op=before_op) - if not window: - window = mb.fill(shape=(win_length.val,), value=1., before_op=before_op) - right = mb.fill(shape=(n_right,), value=0., before_op=before_op) - - # concatenate - window = mb.concat(values=(left, window, right), axis=0, before_op=before_op) + window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op) # apply time window if window: cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) - # conv with DFT kernel across the input signal - sin_base = mb.sub(x=0., y=sin_base, before_op=before_op) + # Expand cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op) - signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op) + if input_imaginary: + signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) + + # Convolve the DFT kernel with the input signal + # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n]) + # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k)) + # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) - if input_imaginary: - signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) # add everything together if input_imaginary: - # sin base is already negative so subtract - real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) - imag_result = mb.add(x=sin_windows_real, y=cos_windows_imag, before_op=before_op) + real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) else: real_result = cos_windows_real - imag_result = sin_windows_real + imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op) # reduce the rank of the output if should_increase_rank: @@ -397,6 +387,168 @@ def _stft( return real_result, imag_result +def _istft( + input_real: Var, + input_imaginary: Var, + n_fft: Var, + hop_length: Optional[Var], + win_length: Optional[Var], + window: Optional[Var], + center: Optional[Var], + normalized: Optional[Var], + onesided: Optional[Var], + length: Optional[Var], + before_op: Operation, +) -> Tuple[Var, Var]: + """ + We can write ISTFT in terms of convolutions with a DFT kernel. + + The input has shape (channels, fft_size, n_frames) + + References: + H. Zhivomirov, “On the Development of STFT-analysis and ISTFT-synthesis Routines and their Practical Implementation,” TEM Journal, vol. 8, no. 1, pp. 56–64, 2019. + https://en.wikipedia.org/wiki/Discrete_Fourier_transform + """ + # Set the default hop, if it's not already specified + hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op) + + # By default, use the entire frame + win_length = win_length or n_fft + + input_shape = mb.shape(x=input_real, before_op=before_op) + if input_real.rank == 3: + channels, fft_size, n_frames = input_shape.val + else: + channels = None + fft_size, n_frames = input_shape.val + + expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) + + is_onesided = True if fft_size != n_fft.val else onesided and onesided.val + cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) + + # create a window of centered 1s of the requested size + if win_length: + window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op) + + # apply time window + if window: + cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) + sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) + + hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op) + + signal_real = input_real + signal_imaginary = input_imaginary + + # De-normalized signal before applying the IFT + if normalized and normalized.val: + multiplier = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op) + signal_real = mb.mul(x=signal_real, y=multiplier, before_op=before_op) + signal_imaginary = mb.mul(x=signal_imaginary, y=multiplier, before_op=before_op) + + # Convolve the DFT kernel with the input signal + # We can describe the IDFT in terms of DFT just by swapping the input and output. + # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT + # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N + # So using the definition in stft function, we get: + # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n)) + # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n)) + cos_windows_real = mb.matmul(x=signal_real, y=cos_base, transpose_x=True, before_op=before_op) + sin_windows_real = mb.matmul(x=signal_real, y=sin_base, transpose_x=True, before_op=before_op) + cos_windows_imag = mb.matmul(x=signal_imaginary, y=cos_base, transpose_x=True, before_op=before_op) + sin_windows_imag = mb.matmul(x=signal_imaginary, y=sin_base, transpose_x=True, before_op=before_op) + + real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) + + # Divide by N + n_fft = mb.cast(x=n_fft, dtype="fp32", before_op=before_op) + real_result = mb.real_div(x=real_result, y=n_fft, before_op=before_op) + imag_result = mb.real_div(x=imag_result, y=n_fft, before_op=before_op) + + # Overlap-add + real_result = _overlap_add(x=real_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + + # Normalize by the window square + window_square = mb.mul(x=window, y=window, before_op=before_op) + window_mtx = mb.stack(values=[window_square] * n_frames, axis=0, before_op=before_op) + window_mtx = mb.expand_dims(x=window_mtx, axes=(0,), before_op=before_op) + window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + + # After this operation if it didn't have any channels dimention it adds one + real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op) + imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op) + # We need to adapt last dimension + if length is not None: + if length.val > expected_output_signal_len: + real_result = mb.pad(x=real_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op) + imag_result = mb.pad(x=imag_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op) + elif length.val < expected_output_signal_len: + real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op) + imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op) + + return real_result, imag_result + +def _overlap_add( + x: Var, + n_fft: Var, + hop_length: Var, + before_op: Operation, +) -> Var: + """ + The input has shape (channels, n_frames, fft_size) + """ + input_shape = mb.shape(x=x, before_op=before_op) + + # Create empty output with final shape + if x.rank == 3: + channels, n_frames, _= input_shape.val + output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op) + else: + channels = None + n_frames, _ = input_shape.val + output = mb.fill(shape=(int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op) + + # Create an index used later on overlap add + n_fft = mb.cast(x=n_fft, dtype="int32", before_op=before_op) + local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op) + + # Split data into frames and iterate + signal_frames = mb.split(x=x, num_splits=n_frames, axis=1 if channels else 0, before_op=before_op) + + for frame_num, frame in enumerate(signal_frames): + frame = mb.squeeze(x=frame, axes=[1] if channels else [0], before_op=before_op) + + # Create index to align data frames + global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op) + if channels: + global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op) + + # Add data frame + output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1 if channels else 0, mode="add", before_op=before_op) + + return output + +def _get_window( + win_length: Var, + n_fft: Var, + window: Optional[Var], + before_op: Operation, +) -> Var: + n_left = (n_fft.val - win_length.val) // 2 + n_right = n_fft.val - win_length.val - n_left + + left = mb.fill(shape=(n_left,), value=0., before_op=before_op) + if not window: + window = mb.fill(shape=(win_length.val,), value=1., before_op=before_op) + right = mb.fill(shape=(n_right,), value=0., before_op=before_op) + + # concatenate + return mb.concat(values=(left, window, right), axis=0, before_op=before_op) + + def _wrap_complex_output(original_output: Var, real_data: Var, imag_data: Var) -> ComplexVar: return ComplexVar( name=original_output.name + "_lowered", @@ -609,6 +761,27 @@ def _lower_complex_stft(op: Operation): return _wrap_complex_output(op.outputs[0], real, imag) +@LowerComplex.register_lower_func(op_type="complex_istft") +def _lower_complex_istft(op: Operation): + + # check parameters for validity + if not types.is_complex(op.input.dtype): + raise TypeError("Input type must be complex") + if op.win_length and op.win_length.val > op.n_fft.val: + raise ValueError("Window length must be less than or equal to n_fft") + if op.return_complex and op.onesided and op.onesided.val: + raise ValueError("Complex output is not compatible with onesided") + + real, imag = _istft( + op.input.real, op.input.imag, + op.n_fft, op.hop_length, op.win_length, op.window, op.center, op.normalized, op.onesided, op.length, before_op=op) + + if op.return_complex: + return _wrap_complex_output(op.outputs[0], real, imag) + else: + return real + + @LowerComplex.register_lower_func(op_type="complex_shape") def _lower_complex_shape(op: Operation): return mb.shape(x=op.data.real, before_op=op)