Skip to content

Commit e3f38c1

Browse files
committed
transition parser example
1 parent bd36370 commit e3f38c1

8 files changed

+30937
-0
lines changed

data/small-dev.txt

+400
Large diffs are not rendered by default.

data/small-dev.unk.txt

+400
Large diffs are not rendered by default.

data/small-test.txt

+400
Large diffs are not rendered by default.

data/small-test.unk.txt

+400
Large diffs are not rendered by default.

data/small-train.txt

+10,000
Large diffs are not rendered by default.

data/small-train.unk.txt

+10,000
Large diffs are not rendered by default.

data/vocab.txt

+9,151
Large diffs are not rendered by default.

tutorial_transition_parser.py

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from __future__ import print_function
2+
from operator import itemgetter
3+
from itertools import count
4+
from collections import Counter, defaultdict
5+
import random
6+
import dynet as dy
7+
import numpy as np
8+
import re
9+
10+
# actions the parser can take
11+
SHIFT = 0
12+
REDUCE_L = 1
13+
REDUCE_R = 2
14+
NUM_ACTIONS = 3
15+
16+
class Vocab:
17+
def __init__(self, w2i):
18+
self.w2i = dict(w2i)
19+
self.i2w = {i:w for w,i in w2i.iteritems()}
20+
@classmethod
21+
def from_list(cls, words):
22+
w2i = {}
23+
idx = 0
24+
for word in words:
25+
w2i[word] = idx
26+
idx += 1
27+
return Vocab(w2i)
28+
@classmethod
29+
def from_file(cls, vocab_fname):
30+
words = []
31+
with file(vocab_fname) as fh:
32+
for line in fh:
33+
line.strip()
34+
word, count = line.split()
35+
words.append(word)
36+
return Vocab.from_list(words)
37+
38+
def size(self): return len(self.w2i.keys())
39+
40+
def read_oracle(fname, vw, va):
41+
with file(fname) as fh:
42+
for line in fh:
43+
line = line.strip()
44+
ssent, sacts = re.split(r' \|\|\| ', line)
45+
sent = [vw.w2i[x] for x in ssent.split()]
46+
acts = [va.w2i[x] for x in sacts.split()]
47+
yield (sent, acts)
48+
49+
WORD_DIM = 64
50+
LSTM_DIM = 64
51+
ACTION_DIM = 32
52+
53+
class TransitionParser:
54+
def __init__(self, model, vocab):
55+
self.vocab = vocab
56+
self.pW_comp = model.add_parameters((LSTM_DIM, LSTM_DIM * 2))
57+
self.pb_comp = model.add_parameters((LSTM_DIM, ))
58+
self.pW_s2h = model.add_parameters((LSTM_DIM, LSTM_DIM * 2))
59+
self.pb_s2h = model.add_parameters((LSTM_DIM, ))
60+
self.pW_act = model.add_parameters((NUM_ACTIONS, LSTM_DIM))
61+
self.pb_act = model.add_parameters((NUM_ACTIONS, ))
62+
63+
# layers, in-dim, out-dim, model
64+
self.buffRNN = dy.LSTMBuilder(1, WORD_DIM, LSTM_DIM, model)
65+
self.stackRNN = dy.LSTMBuilder(1, WORD_DIM, LSTM_DIM, model)
66+
self.pempty_buffer_emb = model.add_parameters((LSTM_DIM,))
67+
nwords=vocab.size()
68+
self.WORDS_LOOKUP = model.add_lookup_parameters((nwords, WORD_DIM))
69+
70+
# returns an expression of the loss for the sequence of actions
71+
# (that is, the oracle_actions if present or the predicted sequence otherwise)
72+
def parse(self, t, oracle_actions=None):
73+
dy.renew_cg()
74+
if oracle_actions:
75+
oracle_actions = list(oracle_actions)
76+
oracle_actions.reverse()
77+
stack_top = self.stackRNN.initial_state()
78+
toks = list(t)
79+
toks.reverse()
80+
stack = []
81+
cur = self.buffRNN.initial_state()
82+
buffer = []
83+
empty_buffer_emb = dy.parameter(self.pempty_buffer_emb)
84+
W_comp = dy.parameter(self.pW_comp)
85+
b_comp = dy.parameter(self.pb_comp)
86+
W_s2h = dy.parameter(self.pW_s2h)
87+
b_s2h = dy.parameter(self.pb_s2h)
88+
W_act = dy.parameter(self.pW_act)
89+
b_act = dy.parameter(self.pb_act)
90+
losses = []
91+
for tok in toks:
92+
tok_embedding = self.WORDS_LOOKUP[tok]
93+
cur = cur.add_input(tok_embedding)
94+
buffer.append((cur.output(), tok_embedding, self.vocab.i2w[tok]))
95+
96+
while not (len(stack) == 1 and len(buffer) == 0):
97+
# based on parser state, get valid actions
98+
valid_actions = []
99+
if len(buffer) > 0: # can only reduce if elements in buffer
100+
valid_actions += [SHIFT]
101+
if len(stack) >= 2: # can only shift if 2 elements on stack
102+
valid_actions += [REDUCE_L, REDUCE_R]
103+
104+
# compute probability of each of the actions and choose an action
105+
# either from the oracle or if there is no oracle, based on the model
106+
action = valid_actions[0]
107+
log_probs = None
108+
if len(valid_actions) > 1:
109+
buffer_embedding = buffer[-1][0] if buffer else empty_buffer_emb
110+
stack_embedding = stack[-1][0].output() # the stack has something here
111+
parser_state = dy.concatenate([buffer_embedding, stack_embedding])
112+
h = dy.tanh(W_s2h * parser_state + b_s2h)
113+
logits = W_act * h + b_act
114+
log_probs = dy.log_softmax(logits, valid_actions)
115+
if oracle_actions is None:
116+
action = max(enumerate(log_probs.vec_value()), key=itemgetter(1))[0]
117+
if oracle_actions is not None:
118+
action = oracle_actions.pop()
119+
if log_probs is not None:
120+
# append the action-specific loss
121+
losses.append(dy.pick(log_probs, action))
122+
123+
# execute the action to update the parser state
124+
if action == SHIFT:
125+
_, tok_embedding, token = buffer.pop()
126+
stack_state, _ = stack[-1] if stack else (stack_top, '<TOP>')
127+
stack_state = stack_state.add_input(tok_embedding)
128+
stack.append((stack_state, token))
129+
else: # one of the reduce actions
130+
right = stack.pop()
131+
left = stack.pop()
132+
head, modifier = (left, right) if action == REDUCE_R else (right, left)
133+
top_stack_state, _ = stack[-1] if stack else (stack_top, '<TOP>')
134+
head_rep, head_tok = head[0].output(), head[1]
135+
mod_rep, mod_tok = modifier[0].output(), modifier[1]
136+
composed_rep = dy.rectify(W_comp * dy.concatenate([head_rep, mod_rep]) + b_comp)
137+
top_stack_state = top_stack_state.add_input(composed_rep)
138+
stack.append((top_stack_state, head_tok))
139+
if oracle_actions is None:
140+
print('{0} --> {1}'.format(head_tok, mod_tok))
141+
142+
# the head of the tree that remains at the top of the stack is now the root
143+
if oracle_actions is None:
144+
head = stack.pop()[1]
145+
print('ROOT --> {0}'.format(head))
146+
return -dy.esum(losses) if losses else None
147+
148+
acts = ['SHIFT', 'REDUCE_L', 'REDUCE_R']
149+
vocab_acts = Vocab.from_list(acts)
150+
151+
vocab_words = Vocab.from_file('data/vocab.txt')
152+
train = list(read_oracle('data/small-train.unk.txt', vocab_words, vocab_acts))
153+
dev = list(read_oracle('data/small-dev.unk.txt', vocab_words, vocab_acts))
154+
155+
model = dy.Model()
156+
trainer = dy.AdamTrainer(model)
157+
158+
tp = TransitionParser(model, vocab_words)
159+
160+
i = 0
161+
for epoch in range(5):
162+
words = 0
163+
total_loss = 0.0
164+
for (s,a) in train:
165+
loss = tp.parse(s, a)
166+
words += len(s)
167+
if loss is not None:
168+
total_loss += loss.scalar_value()
169+
loss.backward()
170+
trainer.update()
171+
e = float(i) / len(train)
172+
if i % 50 == 0:
173+
print('epoch {}: per-word loss: {}'.format(e, total_loss / words))
174+
words = 0
175+
total_loss = 0.0
176+
if i % 500 == 0:
177+
tp.parse(dev[209][0])
178+
dev_words = 0
179+
dev_loss = 0.0
180+
for (ds, da) in dev:
181+
loss = tp.parse(ds, da)
182+
dev_words += len(ds)
183+
if loss is not None:
184+
dev_loss += loss.scalar_value()
185+
print('[validation] epoch {}: per-word loss: {}'.format(e, dev_loss / dev_words))
186+
i += 1

0 commit comments

Comments
 (0)