-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvocab.py
187 lines (171 loc) · 6 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# -*- coding:utf8 -*-
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module implements the Vocab class for converting string to id and back
"""
import numpy as np
import codecs
class Vocab(object):
"""
Implements a vocabulary to store the tokens in the data, with their corresponding embeddings.
"""
def __init__(self, filename=None, initial_tokens=None, lower=False):
self.id2token = {}
self.token2id = {}
self.token_cnt = {}
self.lower = lower
self.embed_dim = None
self.embeddings = None
self.pad_token = '<blank>'
self.unk_token = '<unk>'
self.initial_tokens = initial_tokens if initial_tokens is not None else []
self.initial_tokens.extend([self.pad_token, self.unk_token])
for token in self.initial_tokens:
self.add(token)
if filename is not None:
self.load_from_file(filename)
def size(self):
"""
get the size of vocabulary
Returns:
an integer indicating the size
"""
return len(self.id2token)
def load_from_file(self, file_path):
"""
loads the vocab from file_path
Args:
file_path: a file with a word in each line
"""
for line in open(file_path, 'r'):
token = line.rstrip('\n')
self.add(token)
def get_id(self, token):
"""
gets the id of a token, returns the id of unk token if token is not in vocab
Args:
key: a string indicating the word
Returns:
an integer
"""
token = token.lower() if self.lower else token
try:
return self.token2id[token]
except KeyError:
return self.token2id[self.unk_token]
def get_token(self, idx):
"""
gets the token corresponding to idx, returns unk token if idx is not in vocab
Args:
idx: an integer
returns:
a token string
"""
try:
return self.id2token[idx]
except KeyError:
return self.unk_token
def add(self, token, cnt=1):
"""
adds the token to vocab
Args:
token: a string
cnt: a num indicating the count of the token to add, default is 1
"""
token = token.lower() if self.lower else token
if token in self.token2id:
idx = self.token2id[token]
else:
idx = len(self.id2token)
self.id2token[idx] = token
self.token2id[token] = idx
if cnt > 0:
if token in self.token_cnt:
self.token_cnt[token] += cnt
else:
self.token_cnt[token] = cnt
return idx
def filter_tokens_by_cnt(self, min_cnt):
"""
filter the tokens in vocab by their count
Args:
min_cnt: tokens with frequency less than min_cnt is filtered
"""
filtered_tokens = [token for token in self.token2id if self.token_cnt[token] >= min_cnt]
# rebuild the token x id map
self.token2id = {}
self.id2token = {}
for token in self.initial_tokens:
self.add(token, cnt=0)
for token in filtered_tokens:
self.add(token, cnt=0)
def randomly_init_embeddings(self, embed_dim):
"""
randomly initializes the embeddings for each token
Args:
embed_dim: the size of the embedding for each token
"""
self.embed_dim = embed_dim
self.embeddings = np.random.rand(self.size(), embed_dim)
for token in [self.pad_token, self.unk_token]:
self.embeddings[self.get_id(token)] = np.zeros([self.embed_dim])
def load_pretrained_embeddings(self, embedding_path):
"""
loads the pretrained embeddings from embedding_path,
tokens not in pretrained embeddings will be filtered
Args:
embedding_path: the path of the pretrained embedding file
"""
over_map_num = 0
for line in codecs.open(embedding_path, 'r', 'utf-8'):
contents = line.strip().split(" ")
token = contents[0]
if token not in self.token2id:
continue
# print(token)
try:
self.embeddings[self.get_id(token)] = list(map(float, contents[1:]))
except ValueError:
continue
# print("aaaaaaaa")
over_map_num += 1
print("over_map word num is ", over_map_num)
def convert_to_ids(self, tokens):
"""
Convert a list of tokens to ids, use unk_token if the token is not in vocab.
Args:
tokens: a list of token
Returns:
a list of ids
"""
vec = [self.get_id(label) for label in tokens]
return vec
def recover_from_ids(self, ids, stop_id=None):
"""
Convert a list of ids to tokens, stop converting if the stop_id is encountered
Args:
ids: a list of ids to convert
stop_id: the stop id, default is None
Returns:
a list of tokens
"""
tokens = []
for i in ids:
tokens += [self.get_token(i)]
if stop_id is not None and i == stop_id:
break
return tokens