|
| 1 | +############################## Warning! ############################## |
| 2 | +# # |
| 3 | +# Onnx Export Not Support All Of Non-Torch Types # |
| 4 | +# Include Python Built-in Types!!!!!!!!!!!!!!!!! # |
| 5 | +# If You Want TO Change This File # |
| 6 | +# Do Not Use All Of Non-Torch Types! # |
| 7 | +# # |
| 8 | +############################## Warning! ############################## |
| 9 | + |
1 | 10 | import math
|
2 | 11 | import logging
|
3 | 12 |
|
@@ -316,58 +325,41 @@ def _f02uv(self, f0):
|
316 | 325 | # generate uv signal
|
317 | 326 | uv = torch.ones_like(f0)
|
318 | 327 | uv = uv * (f0 > self.voiced_threshold)
|
| 328 | + if uv.device.type == "privateuseone": # for DirectML |
| 329 | + uv = uv.float() |
319 | 330 | return uv
|
320 |
| - |
321 |
| - def forward(self, f0, upp): |
| 331 | + |
| 332 | + def _f02sine(self, f0, upp): |
| 333 | + """ f0: (batchsize, length, dim) |
| 334 | + where dim indicates fundamental tone and overtones |
| 335 | + """ |
| 336 | + a = torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device) |
| 337 | + rad = f0 / self.sampling_rate * a |
| 338 | + rad2 = torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5 |
| 339 | + rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0) |
| 340 | + rad += F.pad(rad_acc, (0, 0, 1, 0), mode='constant') |
| 341 | + rad = rad.reshape(f0.shape[0], -1, 1) |
| 342 | + b = torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1) |
| 343 | + rad *= b |
| 344 | + rand_ini = torch.rand(1, 1, self.dim, device=f0.device) |
| 345 | + rand_ini[..., 0] = 0 |
| 346 | + rad += rand_ini |
| 347 | + sines = torch.sin(2 * np.pi * rad) |
| 348 | + return sines |
| 349 | + |
| 350 | + def forward(self, f0: torch.Tensor, upp: int): |
322 | 351 | """sine_tensor, uv = forward(f0)
|
323 | 352 | input F0: tensor(batchsize=1, length, dim=1)
|
324 | 353 | f0 for unvoiced steps should be 0
|
325 | 354 | output sine_tensor: tensor(batchsize=1, length, dim)
|
326 | 355 | output uv: tensor(batchsize=1, length, 1)
|
327 | 356 | """
|
328 | 357 | with torch.no_grad():
|
329 |
| - f0 = f0[:, None].transpose(1, 2) |
330 |
| - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) |
331 |
| - # fundamental component |
332 |
| - f0_buf[:, :, 0] = f0[:, :, 0] |
333 |
| - for idx in np.arange(self.harmonic_num): |
334 |
| - f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( |
335 |
| - idx + 2 |
336 |
| - ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic |
337 |
| - rad_values = ( |
338 |
| - f0_buf / self.sampling_rate |
339 |
| - ) % 1 ###%1意味着n_har的乘积无法后处理优化 |
340 |
| - rand_ini = torch.rand( |
341 |
| - f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device |
342 |
| - ) |
343 |
| - rand_ini[:, 0] = 0 |
344 |
| - rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini |
345 |
| - tmp_over_one = torch.cumsum( |
346 |
| - rad_values, 1 |
347 |
| - ) # % 1 #####%1意味着后面的cumsum无法再优化 |
348 |
| - tmp_over_one *= upp |
349 |
| - tmp_over_one = F.interpolate( |
350 |
| - tmp_over_one.transpose(2, 1), |
351 |
| - scale_factor=upp, |
352 |
| - mode="linear", |
353 |
| - align_corners=True, |
354 |
| - ).transpose(2, 1) |
355 |
| - rad_values = F.interpolate( |
356 |
| - rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" |
357 |
| - ).transpose( |
358 |
| - 2, 1 |
359 |
| - ) ####### |
360 |
| - tmp_over_one %= 1 |
361 |
| - tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 |
362 |
| - cumsum_shift = torch.zeros_like(rad_values) |
363 |
| - cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 |
364 |
| - sine_waves = torch.sin( |
365 |
| - torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi |
366 |
| - ) |
367 |
| - sine_waves = sine_waves * self.sine_amp |
| 358 | + f0 = f0.unsqueeze(-1) |
| 359 | + sine_waves = self._f02sine(f0, upp) * self.sine_amp |
368 | 360 | uv = self._f02uv(f0)
|
369 | 361 | uv = F.interpolate(
|
370 |
| - uv.transpose(2, 1), scale_factor=upp, mode="nearest" |
| 362 | + uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" |
371 | 363 | ).transpose(2, 1)
|
372 | 364 | noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
373 | 365 | noise = noise_amp * torch.randn_like(sine_waves)
|
|
0 commit comments