-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdatasets.py
38 lines (29 loc) · 1.02 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
'''
@Author: Gordon Lee
@Date: 2019-08-09 14:22:50
@LastEditors: Gordon Lee
@LastEditTime: 2019-08-13 16:40:09
@Description:
'''
import torch
import pandas as pd
from torch.utils.data import Dataset
class SSTreebankDataset(Dataset):
'''
创建dataloader
'''
def __init__(self, data_name, output_folder, split):
'''
:param output_folder: 数据文件所在路径
:param split: 'train', 'dev', or 'test'
'''
self.split = split
assert self.split in {'train', 'dev', 'test'}
self.dataset = pd.read_csv(output_folder + data_name + '_' + split + '.csv')
self.dataset_size = len(self.dataset)
def __getitem__(self, i):
sentence = torch.LongTensor(eval(self.dataset.iloc[i]['token_idx'])) # sentence shape [max_len]
sentence_label = self.dataset.iloc[i]['sentiment_label']
return sentence, sentence_label
def __len__(self):
return self.dataset_size