-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
151 lines (135 loc) · 5.19 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import math
import torch
import librosa
import soundfile as sf
import torch.utils.data as tdata
from torch.nn.functional import one_hot
import torch.nn.functional as F
import numpy as np
from h5py import File
import warnings
warnings.filterwarnings('ignore')
class AudioDataset(tdata.Dataset):
"""
HDF5 dataset indexed by a labels dataframe.
Indexing is done via the dataframe since we want to preserve some storage
in cases where oversampling is needed ( pretty likely )
"""
def __init__(self,data_list_path,num_classes=10,transform=None,n_mels=64,n_fft=2048,hop_length=380,win_length=512,sr=16000,EPS = np.spacing(1),fre_len=100,mode='train'):
super().__init__()
self.MEL_ARGS = {
'n_mels': n_mels,
'n_fft': n_fft,
'hop_length': hop_length,
'win_length': win_length
}
self.dataset_path = []
self.labels = []
self.num_classes = num_classes
with open(data_list_path) as f:
data = f.readlines()
for i in data:
self.dataset_path.append(i.split()[0])
self.labels.append(one_hot(torch.tensor(eval(i.split()[1])),num_classes=self.num_classes))
self.sr = sr
self.EPS = EPS
self.fre_len = fre_len
self._transform = transform
def __len__(self):
return len(self.dataset_path)
def __getitem__(self, index):
wavpath = self.dataset_path[index]
y, sr = sf.read(wavpath, dtype='float32')
if y.ndim > 1:
y = y.mean(1)
y = librosa.resample(y, sr, self.sr)
label = self.labels[index]
data = torch.tensor(np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS)).unsqueeze(dim=0)
if data.shape[2] < self.fre_len:
num = self.fre_len //data.shape[2] + 1
data = data.repeat(1, 1, num)
data = data[:,:,:self.fre_len]
if self._transform:
data = self._transform(data)
return data, label,wavpath
class AudioDataset_test(tdata.Dataset):
"""
HDF5 dataset indexed by a labels dataframe.
Indexing is done via the dataframe since we want to preserve some storage
in cases where oversampling is needed ( pretty likely )
"""
def __init__(self,data_list_path,num_classes=10,transform=None,n_mels=64,n_fft=2048,hop_length=380,win_length=512,sr=16000,EPS = np.spacing(1),fre_len=100,audio_path="example_audio/7383-3-0-1.wav"):
super().__init__()
self.MEL_ARGS = {
'n_mels': n_mels,
'n_fft': n_fft,
'hop_length': hop_length,
'win_length': win_length
}
self.dataset_path = [audio_path]
self.labels = []
self.num_classes = num_classes
self.sr = sr
self.EPS = EPS
self.fre_len = fre_len
self._transform = transform
def __len__(self):
return len(self.dataset_path)
def __getitem__(self, index):
wavpath = self.dataset_path[index]
y, sr = sf.read(wavpath, dtype='float32')
if y.ndim > 1:
y = y.mean(1)
y = librosa.resample(y, sr, self.sr)
data = torch.tensor(np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS)).unsqueeze(dim=0)
num = 0
if data.shape[2]%self.fre_len !=0:
num = math.ceil(data.shape[2] / self.fre_len)
padding = self.fre_len*num - data.shape[2]
data = F.pad(data, (0, padding, 0, 0, 0, 0))
return data,wavpath,num
class AudioDataset_Feature(tdata.Dataset):
"""
HDF5 dataset indexed by a labels dataframe.
Indexing is done via the dataframe since we want to preserve some storage
in cases where oversampling is needed ( pretty likely )
"""
def __init__(self,data_list_path,feature_map_list,num_classes=10,transform=None,mode='train'):
super().__init__()
self.dataset_path = []
self.num_classes = num_classes
self._h5file = feature_map_list
with open(data_list_path) as f:
data = f.readlines()
for i in data:
self.dataset_path.append(i.split()[0])
self.dataset = File(self._h5file, 'r', libver='latest')
self._transform = transform
def __len__(self):
return len(self.dataset_path)
def __getitem__(self, index):
wavpath = self.dataset_path[index]
path_name = wavpath.split("/")
path_name = "_".join(path_name[-4:])
data = self.dataset[path_name+"_data"][()]
label = self.dataset[path_name+"_label"][()]
if self._transform:
data = self._transform(data)
return torch.tensor(data), torch.tensor(label),wavpath
if __name__ == "__main__":
dataset = AudioDataset(data_list_path="dataset/test_list.txt")
for date,label,path in dataset:
print(date.shape)
print(label)
print(type(date))
print(type(label))
print(path)
break
dataset = AudioDataset_Feature(data_list_path="dataset/train_list.txt",feature_map_list="features/Urbansound8K_train.h5")
for date,label,path in dataset:
print(date)
print(label)
print(type(date))
print(type(label))
print(path)
break