-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvocabulary.py
64 lines (48 loc) · 2.22 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
from dataset import SplitDataset
class Vocabulary:
"""A class for creating a vocabulary"""
padding = 0
"""padding will always be the first entry"""
def __init__(self, dataset_location= None, vocab_location= None, split= "train", sampling=0.1) -> None:
"""init with a dataset or from vocab file"""
assert dataset_location is not None or vocab_location is not None
if vocab_location is not None:
with open(vocab_location, "r") as f:
self.mapping = json.load(f)
else:
self.mapping = self.collect_vocab(split, dataset_location, sampling)
def __len__(self) -> int:
"""size of the vocabulary"""
return len(self.mapping) + 1
def collect_vocab(self, split, dataset_location, sampling):
"""collect vocabulary from dataset"""
data: SplitDataset = SplitDataset.load(dataset_location)
tokens=set()
data[split].sample(frac=sampling)
for sample in data[split]:
tokens.update(sample.input.to_tokens())
tokens.update(sample.target.to_tokens())
sorted_tokens = sorted((t for t in tokens if not t.isdigit() ), reverse=True) + [str(t) for t in sorted(int(t) for t in tokens if t.isdigit())]
return {item:idx + 1 for idx,item in enumerate(sorted_tokens)}
def to_vocab(self, tokens: list[str]) -> list[int]:
"""convert a list of tokens to a list of ints"""
def convert(token: str) -> int:
if token in self.mapping:
return self.mapping[token]
else:
raise ValueError("Unable to read " + token)
return [convert(token) for token in tokens]
def from_vocab(self, vocab: list[int]) -> list[str]:
"""convert a list of ints to a list of tokens"""
rev = {v:k for k,v in self.mapping.items()}
def convert(voc: int) -> str:
if voc in rev:
return rev[voc]
else:
raise ValueError("Unable to read " + str(voc))
return [convert(voc) for voc in vocab]
def save(self, path: str) -> None:
"""save vocab in file for reproducibility"""
with open(path, 'w') as f:
json.dump(self.mapping, f)