Skip to content

Commit 9c63bcc

Browse files
authored
add rmvpe support
add rmvpe support
1 parent 9b78902 commit 9c63bcc

File tree

3 files changed

+355
-3
lines changed

3 files changed

+355
-3
lines changed

infer-web.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,7 @@ def export_onnx(ModelPath, ExportedPath):
13401340
label=i18n(
13411341
"选择音高提取算法,输入歌声可用pm提速,harvest低音好但巨慢无比,crepe效果好但吃GPU"
13421342
),
1343-
choices=["pm", "harvest", "crepe"],
1343+
choices=["pm", "harvest", "crepe", "rmvpe"],
13441344
value="pm",
13451345
interactive=True,
13461346
)
@@ -1442,7 +1442,7 @@ def export_onnx(ModelPath, ExportedPath):
14421442
label=i18n(
14431443
"选择音高提取算法,输入歌声可用pm提速,harvest低音好但巨慢无比,crepe效果好但吃GPU"
14441444
),
1445-
choices=["pm", "harvest", "crepe"],
1445+
choices=["pm", "harvest", "crepe", "rmvpe"],
14461446
value="pm",
14471447
interactive=True,
14481448
)

rmvpe.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
import sys,torch,numpy as np,traceback,pdb
2+
import torch.nn as nn
3+
from time import time as ttime
4+
import torch.nn.functional as F
5+
6+
class BiGRU(nn.Module):
7+
def __init__(self, input_features, hidden_features, num_layers):
8+
super(BiGRU, self).__init__()
9+
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
10+
11+
def forward(self, x):
12+
return self.gru(x)[0]
13+
class ConvBlockRes(nn.Module):
14+
def __init__(self, in_channels, out_channels, momentum=0.01):
15+
super(ConvBlockRes, self).__init__()
16+
self.conv = nn.Sequential(
17+
nn.Conv2d(in_channels=in_channels,
18+
out_channels=out_channels,
19+
kernel_size=(3, 3),
20+
stride=(1, 1),
21+
padding=(1, 1),
22+
bias=False),
23+
nn.BatchNorm2d(out_channels, momentum=momentum),
24+
nn.ReLU(),
25+
26+
nn.Conv2d(in_channels=out_channels,
27+
out_channels=out_channels,
28+
kernel_size=(3, 3),
29+
stride=(1, 1),
30+
padding=(1, 1),
31+
bias=False),
32+
nn.BatchNorm2d(out_channels, momentum=momentum),
33+
nn.ReLU(),
34+
)
35+
if in_channels != out_channels:
36+
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
37+
self.is_shortcut = True
38+
else:
39+
self.is_shortcut = False
40+
41+
def forward(self, x):
42+
if self.is_shortcut:
43+
return self.conv(x) + self.shortcut(x)
44+
else:
45+
return self.conv(x) + x
46+
47+
class Encoder(nn.Module):
48+
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
49+
super(Encoder, self).__init__()
50+
self.n_encoders = n_encoders
51+
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
52+
self.layers = nn.ModuleList()
53+
self.latent_channels = []
54+
for i in range(self.n_encoders):
55+
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
56+
self.latent_channels.append([out_channels, in_size])
57+
in_channels = out_channels
58+
out_channels *= 2
59+
in_size //= 2
60+
self.out_size = in_size
61+
self.out_channel = out_channels
62+
63+
def forward(self, x):
64+
concat_tensors = []
65+
x = self.bn(x)
66+
for i in range(self.n_encoders):
67+
_, x = self.layers[i](x)
68+
concat_tensors.append(_)
69+
return x, concat_tensors
70+
class ResEncoderBlock(nn.Module):
71+
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
72+
super(ResEncoderBlock, self).__init__()
73+
self.n_blocks = n_blocks
74+
self.conv = nn.ModuleList()
75+
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
76+
for i in range(n_blocks - 1):
77+
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
78+
self.kernel_size = kernel_size
79+
if self.kernel_size is not None:
80+
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
81+
82+
def forward(self, x):
83+
for i in range(self.n_blocks):
84+
x = self.conv[i](x)
85+
if self.kernel_size is not None:
86+
return x, self.pool(x)
87+
else:
88+
return x
89+
class Intermediate(nn.Module):#
90+
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
91+
super(Intermediate, self).__init__()
92+
self.n_inters = n_inters
93+
self.layers = nn.ModuleList()
94+
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
95+
for i in range(self.n_inters-1):
96+
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
97+
98+
def forward(self, x):
99+
for i in range(self.n_inters):
100+
x = self.layers[i](x)
101+
return x
102+
class ResDecoderBlock(nn.Module):
103+
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
104+
super(ResDecoderBlock, self).__init__()
105+
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
106+
self.n_blocks = n_blocks
107+
self.conv1 = nn.Sequential(
108+
nn.ConvTranspose2d(in_channels=in_channels,
109+
out_channels=out_channels,
110+
kernel_size=(3, 3),
111+
stride=stride,
112+
padding=(1, 1),
113+
output_padding=out_padding,
114+
bias=False),
115+
nn.BatchNorm2d(out_channels, momentum=momentum),
116+
nn.ReLU(),
117+
)
118+
self.conv2 = nn.ModuleList()
119+
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
120+
for i in range(n_blocks-1):
121+
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
122+
123+
def forward(self, x, concat_tensor):
124+
x = self.conv1(x)
125+
x = torch.cat((x, concat_tensor), dim=1)
126+
for i in range(self.n_blocks):
127+
x = self.conv2[i](x)
128+
return x
129+
class Decoder(nn.Module):
130+
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
131+
super(Decoder, self).__init__()
132+
self.layers = nn.ModuleList()
133+
self.n_decoders = n_decoders
134+
for i in range(self.n_decoders):
135+
out_channels = in_channels // 2
136+
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
137+
in_channels = out_channels
138+
139+
def forward(self, x, concat_tensors):
140+
for i in range(self.n_decoders):
141+
x = self.layers[i](x, concat_tensors[-1-i])
142+
return x
143+
144+
class DeepUnet(nn.Module):
145+
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
146+
super(DeepUnet, self).__init__()
147+
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
148+
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
149+
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
150+
151+
def forward(self, x):
152+
x, concat_tensors = self.encoder(x)
153+
x = self.intermediate(x)
154+
x = self.decoder(x, concat_tensors)
155+
return x
156+
157+
class E2E(nn.Module):
158+
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
159+
en_out_channels=16):
160+
super(E2E, self).__init__()
161+
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
162+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
163+
if n_gru:
164+
self.fc = nn.Sequential(
165+
BiGRU(3 * 128, 256, n_gru),
166+
nn.Linear(512, 360),
167+
nn.Dropout(0.25),
168+
nn.Sigmoid()
169+
)
170+
else:
171+
self.fc = nn.Sequential(
172+
nn.Linear(3 * N_MELS, N_CLASS),
173+
nn.Dropout(0.25),
174+
nn.Sigmoid()
175+
)
176+
177+
def forward(self, mel):
178+
mel = mel.transpose(-1, -2).unsqueeze(1)
179+
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
180+
x = self.fc(x)
181+
return x
182+
from librosa.filters import mel
183+
class MelSpectrogram(torch.nn.Module):
184+
def __init__(
185+
self,
186+
is_half,
187+
n_mel_channels,
188+
sampling_rate,
189+
win_length,
190+
hop_length,
191+
n_fft=None,
192+
mel_fmin=0,
193+
mel_fmax=None,
194+
clamp=1e-5
195+
):
196+
super().__init__()
197+
n_fft = win_length if n_fft is None else n_fft
198+
self.hann_window = {}
199+
mel_basis = mel(
200+
sr=sampling_rate,
201+
n_fft=n_fft,
202+
n_mels=n_mel_channels,
203+
fmin=mel_fmin,
204+
fmax=mel_fmax,
205+
htk=True)
206+
mel_basis = torch.from_numpy(mel_basis).float()
207+
self.register_buffer("mel_basis", mel_basis)
208+
self.n_fft = win_length if n_fft is None else n_fft
209+
self.hop_length = hop_length
210+
self.win_length = win_length
211+
self.sampling_rate = sampling_rate
212+
self.n_mel_channels = n_mel_channels
213+
self.clamp = clamp
214+
self.is_half=is_half
215+
216+
def forward(self, audio, keyshift=0, speed=1, center=True):
217+
factor = 2 ** (keyshift / 12)
218+
n_fft_new = int(np.round(self.n_fft * factor))
219+
win_length_new = int(np.round(self.win_length * factor))
220+
hop_length_new = int(np.round(self.hop_length * speed))
221+
keyshift_key = str(keyshift) + '_' + str(audio.device)
222+
if keyshift_key not in self.hann_window:
223+
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
224+
fft = torch.stft(
225+
audio,
226+
n_fft=n_fft_new,
227+
hop_length=hop_length_new,
228+
win_length=win_length_new,
229+
window=self.hann_window[keyshift_key],
230+
center=center,
231+
return_complex=True)
232+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
233+
if keyshift != 0:
234+
size = self.n_fft // 2 + 1
235+
resize = magnitude.size(1)
236+
if resize < size:
237+
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
238+
magnitude = magnitude[:, :size, :]* self.win_length / win_length_new
239+
mel_output = torch.matmul(self.mel_basis, magnitude)
240+
if(self.is_half==True):mel_output=mel_output.half()
241+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
242+
return log_mel_spec
243+
244+
245+
246+
class RMVPE:
247+
def __init__(self, model_path,is_half, device=None):
248+
self.resample_kernel = {}
249+
model = E2E(4, 1, (2, 2))
250+
ckpt = torch.load(model_path,map_location="cpu")
251+
model.load_state_dict(ckpt)
252+
model.eval()
253+
if(is_half==True):model=model.half()
254+
self.model = model
255+
self.resample_kernel = {}
256+
self.is_half=is_half
257+
if device is None:
258+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
259+
self.device=device
260+
self.mel_extractor = MelSpectrogram(is_half,128, 16000, 1024, 160, None, 30, 8000).to(device)
261+
self.model = self.model.to(device)
262+
cents_mapping = (20 * np.arange(360) + 1997.3794084376191)
263+
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
264+
265+
def mel2hidden(self, mel):
266+
with torch.no_grad():
267+
n_frames = mel.shape[-1]
268+
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect')
269+
hidden = self.model(mel)
270+
return hidden[:, :n_frames]
271+
272+
def decode(self, hidden, thred=0.03):
273+
cents_pred = self.to_local_average_cents(hidden, thred=thred)
274+
f0 = 10 * (2 ** (cents_pred / 1200))
275+
f0[f0==10]=0
276+
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
277+
return f0
278+
279+
def infer_from_audio(self, audio, thred=0.03):
280+
audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
281+
# torch.cuda.synchronize()
282+
# t0=ttime()
283+
mel = self.mel_extractor(audio, center=True)
284+
# torch.cuda.synchronize()
285+
# t1=ttime()
286+
hidden = self.mel2hidden(mel)
287+
# torch.cuda.synchronize()
288+
# t2=ttime()
289+
hidden=hidden.squeeze(0).cpu().numpy()
290+
if(self.is_half==True):hidden=hidden.astype("float32")
291+
f0 = self.decode(hidden, thred=thred)
292+
# torch.cuda.synchronize()
293+
# t3=ttime()
294+
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
295+
return f0
296+
297+
def to_local_average_cents(self,salience, thred=0.05):
298+
# t0 = ttime()
299+
center = np.argmax(salience, axis=1) # 帧长#index
300+
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
301+
# t1 = ttime()
302+
center += 4
303+
todo_salience = []
304+
todo_cents_mapping = []
305+
starts = center - 4
306+
ends = center + 5
307+
for idx in range(salience.shape[0]):
308+
todo_salience.append(salience[:, starts[idx]:ends[idx]][idx])
309+
todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
310+
# t2 = ttime()
311+
todo_salience = np.array(todo_salience) # 帧长,9
312+
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
313+
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
314+
weight_sum = np.sum(todo_salience, 1) # 帧长
315+
devided = product_sum / weight_sum # 帧长
316+
# t3 = ttime()
317+
maxx = np.max(salience, axis=1) # 帧长
318+
devided[maxx <= thred] = 0
319+
# t4 = ttime()
320+
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
321+
return devided
322+
323+
324+
325+
326+
# if __name__ == '__main__':
327+
# audio, sampling_rate = sf.read("卢本伟语录~1.wav")
328+
# if len(audio.shape) > 1:
329+
# audio = librosa.to_mono(audio.transpose(1, 0))
330+
# audio_bak = audio.copy()
331+
# if sampling_rate != 16000:
332+
# audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
333+
# model_path = "/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/test-RMVPE/weights/rmvpe_llc_half.pt"
334+
# thred = 0.03 # 0.01
335+
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
336+
# rmvpe = RMVPE(model_path,is_half=False, device=device)
337+
# t0=ttime()
338+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
339+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
340+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
341+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
342+
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
343+
# t1=ttime()
344+
# print(f0.shape,t1-t0)

vc_infer_pipeline.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
import numpy as np, parselmouth, torch, pdb
1+
import numpy as np, parselmouth, torch, pdb,sys,os
22
from time import time as ttime
33
import torch.nn.functional as F
44
import scipy.signal as signal
55
import pyworld, os, traceback, faiss, librosa, torchcrepe
66
from scipy import signal
77
from functools import lru_cache
8+
now_dir = os.getcwd()
9+
sys.path.append(now_dir)
810

911
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
1012

@@ -124,6 +126,12 @@ def get_f0(
124126
f0 = torchcrepe.filter.mean(f0, 3)
125127
f0[pd < 0.1] = 0
126128
f0 = f0[0].cpu().numpy()
129+
elif f0_method == "rmvpe":
130+
if(hasattr(self,"model_rmvpe")==False):
131+
from rmvpe import RMVPE
132+
print("loading rmvpe model")
133+
self.model_rmvpe = RMVPE("rmvpe.pt",is_half=self.is_half, device=self.device)
134+
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
127135
f0 *= pow(2, f0_up_key / 12)
128136
# with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
129137
tf0 = self.sr // self.window # 每秒f0点数

0 commit comments

Comments
 (0)