Skip to content

Commit 709bbba

Browse files
authored
optimize nsf inference (#2387)
1 parent 1376ce7 commit 709bbba

File tree

1 file changed

+21
-40
lines changed

1 file changed

+21
-40
lines changed

infer/lib/infer_pack/models.py

+21-40
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,25 @@ def _f02uv(self, f0):
349349
if uv.device.type == "privateuseone": # for DirectML
350350
uv = uv.float()
351351
return uv
352-
352+
353+
def _f02sine(self, f0, upp):
354+
""" f0: (batchsize, length, dim)
355+
where dim indicates fundamental tone and overtones
356+
"""
357+
a = torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
358+
rad = f0 / self.sampling_rate * a
359+
rad2 = torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5
360+
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0)
361+
rad += F.pad(rad_acc, (0, 0, 1, 0), mode='constant')
362+
rad = rad.reshape(f0.shape[0], -1, 1)
363+
b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
364+
rad *= b
365+
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
366+
rand_ini[..., 0] = 0
367+
rad += rand_ini
368+
sines = torch.sin(2 * np.pi * rad)
369+
return sines
370+
353371
def forward(self, f0: torch.Tensor, upp: int):
354372
"""sine_tensor, uv = forward(f0)
355373
input F0: tensor(batchsize=1, length, dim=1)
@@ -358,45 +376,8 @@ def forward(self, f0: torch.Tensor, upp: int):
358376
output uv: tensor(batchsize=1, length, 1)
359377
"""
360378
with torch.no_grad():
361-
f0 = f0[:, None].transpose(1, 2)
362-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
363-
# fundamental component
364-
f0_buf[:, :, 0] = f0[:, :, 0]
365-
for idx in range(self.harmonic_num):
366-
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
367-
idx + 2
368-
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
369-
rad_values = (
370-
f0_buf / self.sampling_rate
371-
) % 1 ###%1意味着n_har的乘积无法后处理优化
372-
rand_ini = torch.rand(
373-
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
374-
)
375-
rand_ini[:, 0] = 0
376-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
377-
tmp_over_one = torch.cumsum(
378-
rad_values, 1
379-
) # % 1 #####%1意味着后面的cumsum无法再优化
380-
tmp_over_one *= upp
381-
tmp_over_one = F.interpolate(
382-
tmp_over_one.transpose(2, 1),
383-
scale_factor=float(upp),
384-
mode="linear",
385-
align_corners=True,
386-
).transpose(2, 1)
387-
rad_values = F.interpolate(
388-
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
389-
).transpose(
390-
2, 1
391-
) #######
392-
tmp_over_one %= 1
393-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
394-
cumsum_shift = torch.zeros_like(rad_values)
395-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
396-
sine_waves = torch.sin(
397-
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
398-
)
399-
sine_waves = sine_waves * self.sine_amp
379+
f0 = f0.unsqueeze(-1)
380+
sine_waves = self._f02sine(f0, upp) * self.sine_amp
400381
uv = self._f02uv(f0)
401382
uv = F.interpolate(
402383
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"

0 commit comments

Comments
 (0)