From 90df7d240c12c7ab380de16f0e97062ed113b6f0 Mon Sep 17 00:00:00 2001 From: fwcd Date: Tue, 6 Jun 2023 21:10:18 -0700 Subject: [PATCH 1/6] Update to coremltools 7.0b1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d8383a5..34bacec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "Spleeter implementation in PyTorch" # and fail during model conversions e.g. noting that BlobWriter is not available. requires-python = "<3.11" dependencies = [ - "coremltools >= 6.3, < 7", + "coremltools == 7.0b1", "numpy >= 1.24, < 2", "tensorflow >= 2.13.0rc0", "torch >= 2.0, < 3", From 67c99ad8fddf645f9d40234e4426844c951b1677 Mon Sep 17 00:00:00 2001 From: fwcd Date: Tue, 6 Jun 2023 21:12:40 -0700 Subject: [PATCH 2/6] Trace entire model --- convert-to-coreml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/convert-to-coreml b/convert-to-coreml index 27bfa55..4305237 100755 --- a/convert-to-coreml +++ b/convert-to-coreml @@ -32,12 +32,8 @@ def main(): # Create sample 'audio' for tracing wav = torch.zeros(2, int(args.length * samplerate)) - # Reproduce the STFT step (which we cannot convert to Core ML, unfortunately) - _, stft_mag = estimator.compute_stft(wav) - print('==> Tracing model') - traced_model = torch.jit.trace(estimator.separator, stft_mag) - out = traced_model(stft_mag) + traced_model = torch.jit.trace(estimator, wav) print('==> Converting to Core ML') mlmodel = ct.convert( @@ -45,7 +41,7 @@ def main(): convert_to='mlprogram', # TODO: Investigate whether we'd want to make the input shape flexible # See https://coremltools.readme.io/docs/flexible-inputs - inputs=[ct.TensorType(shape=stft_mag.shape)] + inputs=[ct.TensorType(shape=wav.shape)] ) output_dir: Path = args.output From 5621ccdc4beabe02ebfd0ba7d08720cfc9e7f15d Mon Sep 17 00:00:00 2001 From: fwcd Date: Tue, 6 Jun 2023 21:25:34 -0700 Subject: [PATCH 3/6] Implement mag manually --- spleeter_pytorch/estimator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/spleeter_pytorch/estimator.py b/spleeter_pytorch/estimator.py index a6c042b..ecb25a6 100644 --- a/spleeter_pytorch/estimator.py +++ b/spleeter_pytorch/estimator.py @@ -32,12 +32,14 @@ def compute_stft(self, wav): stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win, center=True, return_complex=True, pad_mode='constant') + stft = torch.view_as_real(stft) # only keep freqs smaller than self.F - stft = stft[:, :self.F, :] - mag = stft.abs() + stft = stft[:, :self.F] - return torch.view_as_real(stft), mag + mag = torch.hypot(stft[:, :, :, 0], stft[:, :, :, 1]) + + return stft, mag def inverse_stft(self, stft): """Inverses stft to wave form""" From 92c7985f49437f85a57db4bf4dfb7ae0c24f74c7 Mon Sep 17 00:00:00 2001 From: fwcd Date: Tue, 6 Jun 2023 21:35:18 -0700 Subject: [PATCH 4/6] Implement torch.view_as_real and hypot manually --- spleeter_pytorch/estimator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spleeter_pytorch/estimator.py b/spleeter_pytorch/estimator.py index ecb25a6..b836193 100644 --- a/spleeter_pytorch/estimator.py +++ b/spleeter_pytorch/estimator.py @@ -32,12 +32,15 @@ def compute_stft(self, wav): stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win, center=True, return_complex=True, pad_mode='constant') - stft = torch.view_as_real(stft) + + # implement torch.view_as_real(stft) manually since coremltools doesn't support it + stft = torch.stack((torch.real(stft), torch.imag(stft)), axis=-1) # only keep freqs smaller than self.F stft = stft[:, :self.F] - mag = torch.hypot(stft[:, :, :, 0], stft[:, :, :, 1]) + # implement torch.hypot manually since coremltools doesn't support it + mag = torch.sqrt(stft[:, :, :, 0] ** 2 + stft[:, :, :, 1] ** 2) return stft, mag From 13e3a2419c8c94bfd25b08f82116e1d4c295aaa7 Mon Sep 17 00:00:00 2001 From: fwcd Date: Tue, 6 Jun 2023 21:38:26 -0700 Subject: [PATCH 5/6] Implement torch.view_as_complex manually --- spleeter_pytorch/estimator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spleeter_pytorch/estimator.py b/spleeter_pytorch/estimator.py index b836193..7de8f52 100644 --- a/spleeter_pytorch/estimator.py +++ b/spleeter_pytorch/estimator.py @@ -40,7 +40,7 @@ def compute_stft(self, wav): stft = stft[:, :self.F] # implement torch.hypot manually since coremltools doesn't support it - mag = torch.sqrt(stft[:, :, :, 0] ** 2 + stft[:, :, :, 1] ** 2) + mag = torch.sqrt(stft[..., 0] ** 2 + stft[..., 1] ** 2) return stft, mag @@ -49,7 +49,10 @@ def inverse_stft(self, stft): pad = self.win_length // 2 + 1 - stft.size(1) stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) - stft = torch.view_as_complex(stft) + + # implement torch.view_as_complex(stft) manually since coremltools doesn't support it + stft = stft[..., 0] + stft[..., 1] * 1j + wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True, window=self.win) return wav.detach() From 565a0878f3420c1dc40a18bd76cdacb6de8da017 Mon Sep 17 00:00:00 2001 From: fwcd Date: Tue, 6 Jun 2023 21:53:37 -0700 Subject: [PATCH 6/6] Implement torch.view_as_complex using torch.complex --- spleeter_pytorch/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spleeter_pytorch/estimator.py b/spleeter_pytorch/estimator.py index 7de8f52..fcb302a 100644 --- a/spleeter_pytorch/estimator.py +++ b/spleeter_pytorch/estimator.py @@ -51,7 +51,7 @@ def inverse_stft(self, stft): stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) # implement torch.view_as_complex(stft) manually since coremltools doesn't support it - stft = stft[..., 0] + stft[..., 1] * 1j + stft = torch.complex(stft[..., 0], stft[..., 1]) wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True, window=self.win)