-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathHyperPartDataLoader.py
53 lines (46 loc) · 1.7 KB
/
HyperPartDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import torch
from torch.utils.data import Dataset
MAX_SENTENCE_LENGTH = 19650
class HyperPartGroupDataset(Dataset):
"""
Class that represents a train/validation/test dataset that's readable for PyTorch
Note that this class inherits torch.utils.data.Dataset
"""
def __init__(self, data_list, target_list):
"""
@param data_list: list of newsgroup tokens
@param target_list: list of newsgroup targets
"""
self.data_list = data_list
self.target_list = target_list
assert (len(self.data_list) == len(self.target_list))
def __len__(self):
return len(self.data_list)
def __getitem__(self, key):
"""
Triggered when you call dataset[i]
"""
token_idx = self.data_list[key][:MAX_SENTENCE_LENGTH]
label = self.target_list[key]
return [token_idx, len(token_idx), label]
def hype_collate_func(batch):
"""
Customized function for DataLoader that dynamically pads the batch so that all
data have the same length
"""
data_list = []
label_list = []
length_list = []
#print("collate batch: ", batch[0][0])
#batch[0][0] = batch[0][0][:MAX_SENTENCE_LENGTH]
for datum in batch:
label_list.append(datum[2])
length_list.append(datum[1])
# padding
for datum in batch:
padded_vec = np.pad(np.array(datum[0]),
pad_width=((0,MAX_SENTENCE_LENGTH-datum[1])),
mode="constant", constant_values=0)
data_list.append(padded_vec)
return [torch.from_numpy(np.array(data_list)), torch.LongTensor(length_list), torch.LongTensor(label_list)]