|
| 1 | +from __future__ import print_function |
| 2 | +from operator import itemgetter |
| 3 | +from itertools import count |
| 4 | +from collections import Counter, defaultdict |
| 5 | +import random |
| 6 | +import dynet as dy |
| 7 | +import numpy as np |
| 8 | +import re |
| 9 | + |
| 10 | +# actions the parser can take |
| 11 | +SHIFT = 0 |
| 12 | +REDUCE_L = 1 |
| 13 | +REDUCE_R = 2 |
| 14 | +NUM_ACTIONS = 3 |
| 15 | + |
| 16 | +class Vocab: |
| 17 | + def __init__(self, w2i): |
| 18 | + self.w2i = dict(w2i) |
| 19 | + self.i2w = {i:w for w,i in w2i.iteritems()} |
| 20 | + @classmethod |
| 21 | + def from_list(cls, words): |
| 22 | + w2i = {} |
| 23 | + idx = 0 |
| 24 | + for word in words: |
| 25 | + w2i[word] = idx |
| 26 | + idx += 1 |
| 27 | + return Vocab(w2i) |
| 28 | + @classmethod |
| 29 | + def from_file(cls, vocab_fname): |
| 30 | + words = [] |
| 31 | + with file(vocab_fname) as fh: |
| 32 | + for line in fh: |
| 33 | + line.strip() |
| 34 | + word, count = line.split() |
| 35 | + words.append(word) |
| 36 | + return Vocab.from_list(words) |
| 37 | + |
| 38 | + def size(self): return len(self.w2i.keys()) |
| 39 | + |
| 40 | +def read_oracle(fname, vw, va): |
| 41 | + with file(fname) as fh: |
| 42 | + for line in fh: |
| 43 | + line = line.strip() |
| 44 | + ssent, sacts = re.split(r' \|\|\| ', line) |
| 45 | + sent = [vw.w2i[x] for x in ssent.split()] |
| 46 | + acts = [va.w2i[x] for x in sacts.split()] |
| 47 | + yield (sent, acts) |
| 48 | + |
| 49 | +WORD_DIM = 64 |
| 50 | +LSTM_DIM = 64 |
| 51 | +ACTION_DIM = 32 |
| 52 | + |
| 53 | +class TransitionParser: |
| 54 | + def __init__(self, model, vocab): |
| 55 | + self.vocab = vocab |
| 56 | + self.pW_comp = model.add_parameters((LSTM_DIM, LSTM_DIM * 2)) |
| 57 | + self.pb_comp = model.add_parameters((LSTM_DIM, )) |
| 58 | + self.pW_s2h = model.add_parameters((LSTM_DIM, LSTM_DIM * 2)) |
| 59 | + self.pb_s2h = model.add_parameters((LSTM_DIM, )) |
| 60 | + self.pW_act = model.add_parameters((NUM_ACTIONS, LSTM_DIM)) |
| 61 | + self.pb_act = model.add_parameters((NUM_ACTIONS, )) |
| 62 | + |
| 63 | + # layers, in-dim, out-dim, model |
| 64 | + self.buffRNN = dy.LSTMBuilder(1, WORD_DIM, LSTM_DIM, model) |
| 65 | + self.stackRNN = dy.LSTMBuilder(1, WORD_DIM, LSTM_DIM, model) |
| 66 | + self.pempty_buffer_emb = model.add_parameters((LSTM_DIM,)) |
| 67 | + nwords=vocab.size() |
| 68 | + self.WORDS_LOOKUP = model.add_lookup_parameters((nwords, WORD_DIM)) |
| 69 | + |
| 70 | + # returns an expression of the loss for the sequence of actions |
| 71 | + # (that is, the oracle_actions if present or the predicted sequence otherwise) |
| 72 | + def parse(self, t, oracle_actions=None): |
| 73 | + dy.renew_cg() |
| 74 | + if oracle_actions: |
| 75 | + oracle_actions = list(oracle_actions) |
| 76 | + oracle_actions.reverse() |
| 77 | + stack_top = self.stackRNN.initial_state() |
| 78 | + toks = list(t) |
| 79 | + toks.reverse() |
| 80 | + stack = [] |
| 81 | + cur = self.buffRNN.initial_state() |
| 82 | + buffer = [] |
| 83 | + empty_buffer_emb = dy.parameter(self.pempty_buffer_emb) |
| 84 | + W_comp = dy.parameter(self.pW_comp) |
| 85 | + b_comp = dy.parameter(self.pb_comp) |
| 86 | + W_s2h = dy.parameter(self.pW_s2h) |
| 87 | + b_s2h = dy.parameter(self.pb_s2h) |
| 88 | + W_act = dy.parameter(self.pW_act) |
| 89 | + b_act = dy.parameter(self.pb_act) |
| 90 | + losses = [] |
| 91 | + for tok in toks: |
| 92 | + tok_embedding = self.WORDS_LOOKUP[tok] |
| 93 | + cur = cur.add_input(tok_embedding) |
| 94 | + buffer.append((cur.output(), tok_embedding, self.vocab.i2w[tok])) |
| 95 | + |
| 96 | + while not (len(stack) == 1 and len(buffer) == 0): |
| 97 | + # based on parser state, get valid actions |
| 98 | + valid_actions = [] |
| 99 | + if len(buffer) > 0: # can only reduce if elements in buffer |
| 100 | + valid_actions += [SHIFT] |
| 101 | + if len(stack) >= 2: # can only shift if 2 elements on stack |
| 102 | + valid_actions += [REDUCE_L, REDUCE_R] |
| 103 | + |
| 104 | + # compute probability of each of the actions and choose an action |
| 105 | + # either from the oracle or if there is no oracle, based on the model |
| 106 | + action = valid_actions[0] |
| 107 | + log_probs = None |
| 108 | + if len(valid_actions) > 1: |
| 109 | + buffer_embedding = buffer[-1][0] if buffer else empty_buffer_emb |
| 110 | + stack_embedding = stack[-1][0].output() # the stack has something here |
| 111 | + parser_state = dy.concatenate([buffer_embedding, stack_embedding]) |
| 112 | + h = dy.tanh(W_s2h * parser_state + b_s2h) |
| 113 | + logits = W_act * h + b_act |
| 114 | + log_probs = dy.log_softmax(logits, valid_actions) |
| 115 | + if oracle_actions is None: |
| 116 | + action = max(enumerate(log_probs.vec_value()), key=itemgetter(1))[0] |
| 117 | + if oracle_actions is not None: |
| 118 | + action = oracle_actions.pop() |
| 119 | + if log_probs is not None: |
| 120 | + # append the action-specific loss |
| 121 | + losses.append(dy.pick(log_probs, action)) |
| 122 | + |
| 123 | + # execute the action to update the parser state |
| 124 | + if action == SHIFT: |
| 125 | + _, tok_embedding, token = buffer.pop() |
| 126 | + stack_state, _ = stack[-1] if stack else (stack_top, '<TOP>') |
| 127 | + stack_state = stack_state.add_input(tok_embedding) |
| 128 | + stack.append((stack_state, token)) |
| 129 | + else: # one of the reduce actions |
| 130 | + right = stack.pop() |
| 131 | + left = stack.pop() |
| 132 | + head, modifier = (left, right) if action == REDUCE_R else (right, left) |
| 133 | + top_stack_state, _ = stack[-1] if stack else (stack_top, '<TOP>') |
| 134 | + head_rep, head_tok = head[0].output(), head[1] |
| 135 | + mod_rep, mod_tok = modifier[0].output(), modifier[1] |
| 136 | + composed_rep = dy.rectify(W_comp * dy.concatenate([head_rep, mod_rep]) + b_comp) |
| 137 | + top_stack_state = top_stack_state.add_input(composed_rep) |
| 138 | + stack.append((top_stack_state, head_tok)) |
| 139 | + if oracle_actions is None: |
| 140 | + print('{0} --> {1}'.format(head_tok, mod_tok)) |
| 141 | + |
| 142 | + # the head of the tree that remains at the top of the stack is now the root |
| 143 | + if oracle_actions is None: |
| 144 | + head = stack.pop()[1] |
| 145 | + print('ROOT --> {0}'.format(head)) |
| 146 | + return -dy.esum(losses) if losses else None |
| 147 | + |
| 148 | +acts = ['SHIFT', 'REDUCE_L', 'REDUCE_R'] |
| 149 | +vocab_acts = Vocab.from_list(acts) |
| 150 | + |
| 151 | +vocab_words = Vocab.from_file('data/vocab.txt') |
| 152 | +train = list(read_oracle('data/small-train.unk.txt', vocab_words, vocab_acts)) |
| 153 | +dev = list(read_oracle('data/small-dev.unk.txt', vocab_words, vocab_acts)) |
| 154 | + |
| 155 | +model = dy.Model() |
| 156 | +trainer = dy.AdamTrainer(model) |
| 157 | + |
| 158 | +tp = TransitionParser(model, vocab_words) |
| 159 | + |
| 160 | +i = 0 |
| 161 | +for epoch in range(5): |
| 162 | + words = 0 |
| 163 | + total_loss = 0.0 |
| 164 | + for (s,a) in train: |
| 165 | + loss = tp.parse(s, a) |
| 166 | + words += len(s) |
| 167 | + if loss is not None: |
| 168 | + total_loss += loss.scalar_value() |
| 169 | + loss.backward() |
| 170 | + trainer.update() |
| 171 | + e = float(i) / len(train) |
| 172 | + if i % 50 == 0: |
| 173 | + print('epoch {}: per-word loss: {}'.format(e, total_loss / words)) |
| 174 | + words = 0 |
| 175 | + total_loss = 0.0 |
| 176 | + if i % 500 == 0: |
| 177 | + tp.parse(dev[209][0]) |
| 178 | + dev_words = 0 |
| 179 | + dev_loss = 0.0 |
| 180 | + for (ds, da) in dev: |
| 181 | + loss = tp.parse(ds, da) |
| 182 | + dev_words += len(ds) |
| 183 | + if loss is not None: |
| 184 | + dev_loss += loss.scalar_value() |
| 185 | + print('[validation] epoch {}: per-word loss: {}'.format(e, dev_loss / dev_words)) |
| 186 | + i += 1 |
0 commit comments