Skip to content

Commit 120696d

Browse files
committed
init
0 parents  commit 120696d

20 files changed

+2185
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.idea

Constants.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
PAD = 0
2+
UNK = 1
3+
BOS = 2
4+
EOS = 3
5+
6+
PAD_WORD = '<blank>'
7+
UNK_WORD = '<unk>'
8+
BOS_WORD = '<s>'
9+
EOS_WORD = '</s>'

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2017 Riddhiman Dasgupta
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Tree-Structured Long Short-Term Memory Networks
2+
A [PyTorch](http://pytorch.org/) based implementation of Tree-LSTM from Kai Sheng Tai's paper
3+
[Improved Semantic Representations From Tree-Structured Long Short-Term Memory
4+
Networks](http://arxiv.org/abs/1503.00075).
5+
6+
### Requirements
7+
- [PyTorch](http://pytorch.org/) Deep learning library
8+
- [tqdm](https://github.com/tqdm/tqdm): display progress bar
9+
- [meowlogtool](https://pypi.python.org/pypi/meowlogtool): a logger that write everything on console to file
10+
- Java >= 8 (for Stanford CoreNLP utilities)
11+
- Python >= 2.7
12+
13+
## Usage
14+
First run the script `./fetch_and_preprocess.sh`
15+
16+
This downloads the following data:
17+
- [Stanford Sentiment Treebank](http://nlp.stanford.edu/sentiment/index.html) (sentiment classification task)
18+
- [Glove word vectors](http://nlp.stanford.edu/projects/glove/) (Common Crawl 840B) -- **Warning:** this is a 2GB download!
19+
20+
and the following libraries:
21+
22+
- [Stanford Parser](http://nlp.stanford.edu/software/lex-parser.shtml)
23+
- [Stanford POS Tagger](http://nlp.stanford.edu/software/tagger.shtml)
24+
25+
### Sentiment classification
26+
27+
```
28+
python sentiment.py --name <name_of_log_file> --model_name <constituency|dependency> --epochs 10
29+
```
30+
We have not fully test on fine grain classification yet. Binary classification accuracy on both model are the same in original paper.
31+
32+
### Acknowledgements
33+
[Kai Sheng Tai](https://github.com/kaishengtai/) for the [original LuaTorch implementation](https://github.com/stanfordnlp/treelstm) <br>
34+
[Pytorch team](https://github.com/pytorch/pytorch#the-team) for Python library<br>
35+
[Riddhiman Dasgupta](https://researchweb.iiit.ac.in/~riddhiman.dasgupta/) for his implement on sentiment relatedness [https://github.com/dasguptar/treelstm.pytorch](https://github.com/dasguptar/treelstm.pytorch) which I based on as starter code.
36+
37+
38+
39+
40+
41+
42+
### License
43+
MIT

config.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import argparse
2+
3+
def parse_args(type=0):
4+
if type == 0:
5+
parser = argparse.ArgumentParser(description='PyTorch TreeLSTM for Sentence Similarity on Dependency Trees')
6+
parser.add_argument('--data', default='data/sick/',
7+
help='path to dataset')
8+
parser.add_argument('--glove', default='data/glove/',
9+
help='directory with GLOVE embeddings')
10+
parser.add_argument('--batchsize', default=25, type=int,
11+
help='batchsize for optimizer updates')
12+
parser.add_argument('--epochs', default=15, type=int,
13+
help='number of total epochs to run')
14+
parser.add_argument('--lr', default=0.01, type=float,
15+
metavar='LR', help='initial learning rate')
16+
parser.add_argument('--wd', default=1e-4, type=float,
17+
help='weight decay (default: 1e-4)')
18+
parser.add_argument('--optim', default='adam',
19+
help='optimizer (default: adam)')
20+
parser.add_argument('--seed', default=123, type=int,
21+
help='random seed (default: 123)')
22+
cuda_parser = parser.add_mutually_exclusive_group(required=False)
23+
cuda_parser.add_argument('--cuda', dest='cuda', action='store_true')
24+
cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false')
25+
parser.set_defaults(cuda=True)
26+
27+
args = parser.parse_args()
28+
return args
29+
else: # for sentiment classification on SST
30+
parser = argparse.ArgumentParser(description='PyTorch TreeLSTM for Sentiment Analysis Trees')
31+
parser.add_argument('--name', default='default_name',
32+
help='name for log and saved models')
33+
parser.add_argument('--saved', default='saved_model',
34+
help='name for log and saved models')
35+
36+
parser.add_argument('--model_name', default='constituency',
37+
help='model name constituency or dependency')
38+
parser.add_argument('--data', default='data/sst/',
39+
help='path to dataset')
40+
parser.add_argument('--glove', default='data/glove/',
41+
help='directory with GLOVE embeddings')
42+
parser.add_argument('--batchsize', default=25, type=int,
43+
help='batchsize for optimizer updates')
44+
parser.add_argument('--epochs', default=10, type=int,
45+
help='number of total epochs to run')
46+
parser.add_argument('--lr', default=0.05, type=float,
47+
metavar='LR', help='initial learning rate')
48+
parser.add_argument('--emblr', default=0.1, type=float,
49+
metavar='EMLR', help='initial embedding learning rate')
50+
parser.add_argument('--wd', default=1e-4, type=float,
51+
help='weight decay (default: 1e-4)')
52+
parser.add_argument('--reg', default=1e-4, type=float,
53+
help='l2 regularization (default: 1e-4)')
54+
parser.add_argument('--optim', default='adagrad',
55+
help='optimizer (default: adagrad)')
56+
parser.add_argument('--seed', default=123, type=int,
57+
help='random seed (default: 123)')
58+
parser.add_argument('--fine_grain', default=0, type=int,
59+
help='fine grained (default 0 - binary mode)')
60+
# untest on fine_grain yet.
61+
cuda_parser = parser.add_mutually_exclusive_group(required=False)
62+
cuda_parser.add_argument('--cuda', dest='cuda', action='store_true')
63+
cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false')
64+
cuda_parser.add_argument('--lower', dest='cuda', action='store_true')
65+
parser.set_defaults(cuda=True)
66+
parser.set_defaults(lower=True)
67+
68+
args = parser.parse_args()
69+
return args

dataset.py

+214
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import os
2+
from copy import deepcopy
3+
from tqdm import tqdm
4+
import torch
5+
import torch.utils.data as data
6+
from tree import Tree
7+
from vocab import Vocab
8+
import Constants
9+
import utils
10+
11+
# Dataset class for SICK dataset
12+
class SICKDataset(data.Dataset):
13+
def __init__(self, path, vocab, num_classes):
14+
super(SICKDataset, self).__init__()
15+
self.vocab = vocab
16+
self.num_classes = num_classes
17+
18+
self.lsentences = self.read_sentences(os.path.join(path,'a.toks'))
19+
self.rsentences = self.read_sentences(os.path.join(path,'b.toks'))
20+
21+
self.ltrees = self.read_trees(os.path.join(path,'a.parents'))
22+
self.rtrees = self.read_trees(os.path.join(path,'b.parents'))
23+
24+
self.labels = self.read_labels(os.path.join(path,'sim.txt'))
25+
26+
self.size = self.labels.size(0)
27+
28+
def __len__(self):
29+
return self.size
30+
31+
def __getitem__(self, index):
32+
ltree = deepcopy(self.ltrees[index])
33+
rtree = deepcopy(self.rtrees[index])
34+
lsent = deepcopy(self.lsentences[index])
35+
rsent = deepcopy(self.rsentences[index])
36+
label = deepcopy(self.labels[index])
37+
return (ltree,lsent,rtree,rsent,label)
38+
39+
def read_sentences(self, filename):
40+
with open(filename,'r') as f:
41+
sentences = [self.read_sentence(line) for line in tqdm(f.readlines())]
42+
return sentences
43+
44+
def read_sentence(self, line):
45+
indices = self.vocab.convertToIdx(line.split(), Constants.UNK_WORD)
46+
return torch.LongTensor(indices)
47+
48+
def read_trees(self, filename):
49+
with open(filename,'r') as f:
50+
trees = [self.read_tree(line) for line in tqdm(f.readlines())]
51+
return trees
52+
53+
def read_tree(self, line):
54+
parents = map(int,line.split())
55+
trees = dict()
56+
root = None
57+
for i in xrange(1,len(parents)+1):
58+
#if not trees[i-1] and parents[i-1]!=-1:
59+
if i-1 not in trees.keys() and parents[i-1]!=-1:
60+
idx = i
61+
prev = None
62+
while True:
63+
parent = parents[idx-1]
64+
if parent == -1:
65+
break
66+
tree = Tree()
67+
if prev is not None:
68+
tree.add_child(prev)
69+
trees[idx-1] = tree
70+
tree.idx = idx-1
71+
#if trees[parent-1] is not None:
72+
if parent-1 in trees.keys():
73+
trees[parent-1].add_child(tree)
74+
break
75+
elif parent==0:
76+
root = tree
77+
break
78+
else:
79+
prev = tree
80+
idx = parent
81+
return root
82+
83+
def read_labels(self, filename):
84+
with open(filename,'r') as f:
85+
labels = map(lambda x: float(x), f.readlines())
86+
labels = torch.Tensor(labels)
87+
return labels
88+
89+
# Dataset class for SICK dataset
90+
class SSTDataset(data.Dataset):
91+
def __init__(self, path, vocab, num_classes, fine_grain, model_name):
92+
super(SSTDataset, self).__init__()
93+
self.vocab = vocab
94+
self.num_classes = num_classes
95+
self.fine_grain = fine_grain
96+
self.model_name = model_name
97+
98+
temp_sentences = self.read_sentences(os.path.join(path,'sents.toks'))
99+
if model_name == "dependency":
100+
temp_trees = self.read_trees(os.path.join(path,'dparents.txt'), os.path.join(path,'dlabels.txt'))
101+
else:
102+
temp_trees = self.read_trees(os.path.join(path, 'parents.txt'), os.path.join(path, 'labels.txt'))
103+
104+
# self.labels = self.read_labels(os.path.join(path,'dlabels.txt'))
105+
self.labels = []
106+
107+
if not self.fine_grain:
108+
# only get pos or neg
109+
new_trees = []
110+
new_sentences = []
111+
for i in range(len(temp_trees)):
112+
if temp_trees[i].gold_label != 1: # 0 neg, 1 neutral, 2 pos
113+
new_trees.append(temp_trees[i])
114+
new_sentences.append(temp_sentences[i])
115+
self.trees = new_trees
116+
self.sentences = new_sentences
117+
else:
118+
self.trees = temp_trees
119+
self.sentences = temp_sentences
120+
121+
for i in xrange(0, len(self.trees)):
122+
self.labels.append(self.trees[i].gold_label)
123+
self.labels = torch.Tensor(self.labels) # let labels be tensor
124+
self.size = len(self.trees)
125+
126+
def __len__(self):
127+
return self.size
128+
129+
def __getitem__(self, index):
130+
# ltree = deepcopy(self.ltrees[index])
131+
# rtree = deepcopy(self.rtrees[index])
132+
# lsent = deepcopy(self.lsentences[index])
133+
# rsent = deepcopy(self.rsentences[index])
134+
# label = deepcopy(self.labels[index])
135+
tree = deepcopy(self.trees[index])
136+
sent = deepcopy(self.sentences[index])
137+
label = deepcopy(self.labels[index])
138+
return (tree, sent, label)
139+
140+
def read_sentences(self, filename):
141+
with open(filename,'r') as f:
142+
sentences = [self.read_sentence(line) for line in tqdm(f.readlines())]
143+
return sentences
144+
145+
def read_sentence(self, line):
146+
indices = self.vocab.convertToIdx(line.split(), Constants.UNK_WORD)
147+
return torch.LongTensor(indices)
148+
149+
def read_trees(self, filename_parents, filename_labels):
150+
pfile = open(filename_parents, 'r') # parent node
151+
lfile = open(filename_labels, 'r') # label node
152+
p = pfile.readlines()
153+
l = lfile.readlines()
154+
pl = zip(p, l) # (parent, label) tuple
155+
trees = [self.read_tree(p_line, l_line) for p_line, l_line in tqdm(pl)]
156+
157+
return trees
158+
159+
def parse_dlabel_token(self, x):
160+
if x == '#':
161+
return None
162+
else:
163+
if self.fine_grain: # -2 -1 0 1 2 => 0 1 2 3 4
164+
return int(x)+2
165+
else: # # -2 -1 0 1 2 => 0 1 2
166+
tmp = int(x)
167+
if tmp < 0:
168+
return 0
169+
elif tmp == 0:
170+
return 1
171+
elif tmp >0 :
172+
return 2
173+
174+
def read_tree(self, line, label_line):
175+
# FIXED: tree.idx, also tree dict() use base 1 as it was in dataset
176+
# parents is list base 0, keep idx-1
177+
# labels is list base 0, keep idx-1
178+
parents = map(int,line.split()) # split each number and turn to int
179+
trees = dict() # this is dict
180+
root = None
181+
labels = map(self.parse_dlabel_token, label_line.split())
182+
for i in xrange(1,len(parents)+1):
183+
#if not trees[i-1] and parents[i-1]!=-1:
184+
if i not in trees.keys() and parents[i-1]!=-1:
185+
idx = i
186+
prev = None
187+
while True:
188+
parent = parents[idx-1]
189+
if parent == -1:
190+
break
191+
tree = Tree()
192+
if prev is not None:
193+
tree.add_child(prev)
194+
trees[idx] = tree
195+
tree.idx = idx # -1 remove -1 here to prevent embs[tree.idx -1] = -1 while tree.idx = 0
196+
tree.gold_label = labels[idx-1] # add node label
197+
#if trees[parent-1] is not None:
198+
if parent in trees.keys():
199+
trees[parent].add_child(tree)
200+
break
201+
elif parent==0:
202+
root = tree
203+
break
204+
else:
205+
prev = tree
206+
idx = parent
207+
return root
208+
209+
def read_labels(self, filename):
210+
# Not in used
211+
with open(filename,'r') as f:
212+
labels = map(lambda x: float(x), f.readlines())
213+
labels = torch.Tensor(labels)
214+
return labels

fetch_and_preprocess.sh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
set -e
3+
python2.7 scripts/download.py
4+
5+
CLASSPATH="lib:lib/stanford-parser/stanford-parser.jar:lib/stanford-parser/stanford-parser-3.5.1-models.jar"
6+
javac -cp $CLASSPATH lib/*.java
7+
python2.7 scripts/preprocess-sst.py
8+

0 commit comments

Comments
 (0)