1
+ import os
2
+ from copy import deepcopy
3
+ from tqdm import tqdm
4
+ import torch
5
+ import torch .utils .data as data
6
+ from tree import Tree
7
+ from vocab import Vocab
8
+ import Constants
9
+ import utils
10
+
11
+ # Dataset class for SICK dataset
12
+ class SICKDataset (data .Dataset ):
13
+ def __init__ (self , path , vocab , num_classes ):
14
+ super (SICKDataset , self ).__init__ ()
15
+ self .vocab = vocab
16
+ self .num_classes = num_classes
17
+
18
+ self .lsentences = self .read_sentences (os .path .join (path ,'a.toks' ))
19
+ self .rsentences = self .read_sentences (os .path .join (path ,'b.toks' ))
20
+
21
+ self .ltrees = self .read_trees (os .path .join (path ,'a.parents' ))
22
+ self .rtrees = self .read_trees (os .path .join (path ,'b.parents' ))
23
+
24
+ self .labels = self .read_labels (os .path .join (path ,'sim.txt' ))
25
+
26
+ self .size = self .labels .size (0 )
27
+
28
+ def __len__ (self ):
29
+ return self .size
30
+
31
+ def __getitem__ (self , index ):
32
+ ltree = deepcopy (self .ltrees [index ])
33
+ rtree = deepcopy (self .rtrees [index ])
34
+ lsent = deepcopy (self .lsentences [index ])
35
+ rsent = deepcopy (self .rsentences [index ])
36
+ label = deepcopy (self .labels [index ])
37
+ return (ltree ,lsent ,rtree ,rsent ,label )
38
+
39
+ def read_sentences (self , filename ):
40
+ with open (filename ,'r' ) as f :
41
+ sentences = [self .read_sentence (line ) for line in tqdm (f .readlines ())]
42
+ return sentences
43
+
44
+ def read_sentence (self , line ):
45
+ indices = self .vocab .convertToIdx (line .split (), Constants .UNK_WORD )
46
+ return torch .LongTensor (indices )
47
+
48
+ def read_trees (self , filename ):
49
+ with open (filename ,'r' ) as f :
50
+ trees = [self .read_tree (line ) for line in tqdm (f .readlines ())]
51
+ return trees
52
+
53
+ def read_tree (self , line ):
54
+ parents = map (int ,line .split ())
55
+ trees = dict ()
56
+ root = None
57
+ for i in xrange (1 ,len (parents )+ 1 ):
58
+ #if not trees[i-1] and parents[i-1]!=-1:
59
+ if i - 1 not in trees .keys () and parents [i - 1 ]!= - 1 :
60
+ idx = i
61
+ prev = None
62
+ while True :
63
+ parent = parents [idx - 1 ]
64
+ if parent == - 1 :
65
+ break
66
+ tree = Tree ()
67
+ if prev is not None :
68
+ tree .add_child (prev )
69
+ trees [idx - 1 ] = tree
70
+ tree .idx = idx - 1
71
+ #if trees[parent-1] is not None:
72
+ if parent - 1 in trees .keys ():
73
+ trees [parent - 1 ].add_child (tree )
74
+ break
75
+ elif parent == 0 :
76
+ root = tree
77
+ break
78
+ else :
79
+ prev = tree
80
+ idx = parent
81
+ return root
82
+
83
+ def read_labels (self , filename ):
84
+ with open (filename ,'r' ) as f :
85
+ labels = map (lambda x : float (x ), f .readlines ())
86
+ labels = torch .Tensor (labels )
87
+ return labels
88
+
89
+ # Dataset class for SICK dataset
90
+ class SSTDataset (data .Dataset ):
91
+ def __init__ (self , path , vocab , num_classes , fine_grain , model_name ):
92
+ super (SSTDataset , self ).__init__ ()
93
+ self .vocab = vocab
94
+ self .num_classes = num_classes
95
+ self .fine_grain = fine_grain
96
+ self .model_name = model_name
97
+
98
+ temp_sentences = self .read_sentences (os .path .join (path ,'sents.toks' ))
99
+ if model_name == "dependency" :
100
+ temp_trees = self .read_trees (os .path .join (path ,'dparents.txt' ), os .path .join (path ,'dlabels.txt' ))
101
+ else :
102
+ temp_trees = self .read_trees (os .path .join (path , 'parents.txt' ), os .path .join (path , 'labels.txt' ))
103
+
104
+ # self.labels = self.read_labels(os.path.join(path,'dlabels.txt'))
105
+ self .labels = []
106
+
107
+ if not self .fine_grain :
108
+ # only get pos or neg
109
+ new_trees = []
110
+ new_sentences = []
111
+ for i in range (len (temp_trees )):
112
+ if temp_trees [i ].gold_label != 1 : # 0 neg, 1 neutral, 2 pos
113
+ new_trees .append (temp_trees [i ])
114
+ new_sentences .append (temp_sentences [i ])
115
+ self .trees = new_trees
116
+ self .sentences = new_sentences
117
+ else :
118
+ self .trees = temp_trees
119
+ self .sentences = temp_sentences
120
+
121
+ for i in xrange (0 , len (self .trees )):
122
+ self .labels .append (self .trees [i ].gold_label )
123
+ self .labels = torch .Tensor (self .labels ) # let labels be tensor
124
+ self .size = len (self .trees )
125
+
126
+ def __len__ (self ):
127
+ return self .size
128
+
129
+ def __getitem__ (self , index ):
130
+ # ltree = deepcopy(self.ltrees[index])
131
+ # rtree = deepcopy(self.rtrees[index])
132
+ # lsent = deepcopy(self.lsentences[index])
133
+ # rsent = deepcopy(self.rsentences[index])
134
+ # label = deepcopy(self.labels[index])
135
+ tree = deepcopy (self .trees [index ])
136
+ sent = deepcopy (self .sentences [index ])
137
+ label = deepcopy (self .labels [index ])
138
+ return (tree , sent , label )
139
+
140
+ def read_sentences (self , filename ):
141
+ with open (filename ,'r' ) as f :
142
+ sentences = [self .read_sentence (line ) for line in tqdm (f .readlines ())]
143
+ return sentences
144
+
145
+ def read_sentence (self , line ):
146
+ indices = self .vocab .convertToIdx (line .split (), Constants .UNK_WORD )
147
+ return torch .LongTensor (indices )
148
+
149
+ def read_trees (self , filename_parents , filename_labels ):
150
+ pfile = open (filename_parents , 'r' ) # parent node
151
+ lfile = open (filename_labels , 'r' ) # label node
152
+ p = pfile .readlines ()
153
+ l = lfile .readlines ()
154
+ pl = zip (p , l ) # (parent, label) tuple
155
+ trees = [self .read_tree (p_line , l_line ) for p_line , l_line in tqdm (pl )]
156
+
157
+ return trees
158
+
159
+ def parse_dlabel_token (self , x ):
160
+ if x == '#' :
161
+ return None
162
+ else :
163
+ if self .fine_grain : # -2 -1 0 1 2 => 0 1 2 3 4
164
+ return int (x )+ 2
165
+ else : # # -2 -1 0 1 2 => 0 1 2
166
+ tmp = int (x )
167
+ if tmp < 0 :
168
+ return 0
169
+ elif tmp == 0 :
170
+ return 1
171
+ elif tmp > 0 :
172
+ return 2
173
+
174
+ def read_tree (self , line , label_line ):
175
+ # FIXED: tree.idx, also tree dict() use base 1 as it was in dataset
176
+ # parents is list base 0, keep idx-1
177
+ # labels is list base 0, keep idx-1
178
+ parents = map (int ,line .split ()) # split each number and turn to int
179
+ trees = dict () # this is dict
180
+ root = None
181
+ labels = map (self .parse_dlabel_token , label_line .split ())
182
+ for i in xrange (1 ,len (parents )+ 1 ):
183
+ #if not trees[i-1] and parents[i-1]!=-1:
184
+ if i not in trees .keys () and parents [i - 1 ]!= - 1 :
185
+ idx = i
186
+ prev = None
187
+ while True :
188
+ parent = parents [idx - 1 ]
189
+ if parent == - 1 :
190
+ break
191
+ tree = Tree ()
192
+ if prev is not None :
193
+ tree .add_child (prev )
194
+ trees [idx ] = tree
195
+ tree .idx = idx # -1 remove -1 here to prevent embs[tree.idx -1] = -1 while tree.idx = 0
196
+ tree .gold_label = labels [idx - 1 ] # add node label
197
+ #if trees[parent-1] is not None:
198
+ if parent in trees .keys ():
199
+ trees [parent ].add_child (tree )
200
+ break
201
+ elif parent == 0 :
202
+ root = tree
203
+ break
204
+ else :
205
+ prev = tree
206
+ idx = parent
207
+ return root
208
+
209
+ def read_labels (self , filename ):
210
+ # Not in used
211
+ with open (filename ,'r' ) as f :
212
+ labels = map (lambda x : float (x ), f .readlines ())
213
+ labels = torch .Tensor (labels )
214
+ return labels
0 commit comments