-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathutils.py
78 lines (61 loc) · 2.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import numpy as np
import torch
from torch.utils import data
from data_preprocess import serialized_test_folder, serialized_train_folder
def emphasis(signal_batch, emph_coeff=0.95, pre=True):
"""
Pre-emphasis or De-emphasis of higher frequencies given a batch of signal.
Args:
signal_batch: batch of signals, represented as numpy arrays
emph_coeff: emphasis coefficient
pre: pre-emphasis or de-emphasis signals
Returns:
result: pre-emphasized or de-emphasized signal batch
"""
result = np.zeros(signal_batch.shape)
for sample_idx, sample in enumerate(signal_batch):
for ch, channel_data in enumerate(sample):
if pre:
result[sample_idx][ch] = np.append(channel_data[0], channel_data[1:] - emph_coeff * channel_data[:-1])
else:
result[sample_idx][ch] = np.append(channel_data[0], channel_data[1:] + emph_coeff * channel_data[:-1])
return result
class AudioDataset(data.Dataset):
"""
Audio sample reader.
"""
def __init__(self, data_type):
if data_type == 'train':
data_path = serialized_train_folder
else:
data_path = serialized_test_folder
if not os.path.exists(data_path):
raise FileNotFoundError('The {} data folder does not exist!'.format(data_type))
self.data_type = data_type
self.file_names = [os.path.join(data_path, filename) for filename in os.listdir(data_path)]
def reference_batch(self, batch_size):
"""
Randomly selects a reference batch from dataset.
Reference batch is used for calculating statistics for virtual batch normalization operation.
Args:
batch_size(int): batch size
Returns:
ref_batch: reference batch
"""
ref_file_names = np.random.choice(self.file_names, batch_size)
ref_batch = np.stack([np.load(f) for f in ref_file_names])
ref_batch = emphasis(ref_batch, emph_coeff=0.95)
return torch.from_numpy(ref_batch).type(torch.FloatTensor)
def __getitem__(self, idx):
pair = np.load(self.file_names[idx])
pair = emphasis(pair[np.newaxis, :, :], emph_coeff=0.95).reshape(2, -1)
noisy = pair[1].reshape(1, -1)
if self.data_type == 'train':
clean = pair[0].reshape(1, -1)
return torch.from_numpy(pair).type(torch.FloatTensor), torch.from_numpy(clean).type(
torch.FloatTensor), torch.from_numpy(noisy).type(torch.FloatTensor)
else:
return os.path.basename(self.file_names[idx]), torch.from_numpy(noisy).type(torch.FloatTensor)
def __len__(self):
return len(self.file_names)