-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsentencepiece.py
214 lines (179 loc) · 6.91 KB
/
sentencepiece.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import regex as re
import collections
import math
import random
from typing import List, Dict
class SentencePieceTokenizer:
def __init__(
self,
vocab_size: int = 1000,
special_tokens: List[str] = ["<s>", "</s>", "<unk>"],
alpha: float = 0.1,
): # subword regularization parameter
"""
Initialize the tokenizer
Args:
vocab_size: Target vocabulary size
special_tokens: List of special tokens to add to vocabulary
alpha: Smoothing parameter for subword regularization
"""
self.vocab_size = vocab_size
self.special_tokens = special_tokens
self.alpha = alpha
# Initialize vocabulary with special tokens
self.vocab = {token: idx for idx, token in enumerate(special_tokens)}
self.inv_vocab = {idx: token for token, idx in self.vocab.items()}
# Initialize merges dictionary for BPE
self.merges = {}
# Regex for basic preprocessing
# Handles whitespace, punctuation, and keeps numbers together
self.pre_tokenize_pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def pre_tokenize(self, text: str) -> List[str]:
"""
Initial tokenization of raw text into word-level tokens
"""
return [t for t in re.findall(self.pre_tokenize_pat, text) if t.strip()]
def compute_token_scores(self, token_freqs: Dict[str, int]) -> Dict[str, float]:
"""
Compute language model scores for tokens using unigram probability
"""
total_count = sum(token_freqs.values())
scores = {}
for token, freq in token_freqs.items():
# Basic unigram probability with smoothing
prob = (freq + self.alpha) / (total_count + self.alpha * len(token_freqs))
scores[token] = math.log(prob)
return scores
def train(self, text: str, min_freq: int = 2):
"""
Train the tokenizer on input text
Args:
text: Input text for training
min_freq: Minimum frequency for considering a merge
"""
# Pre-tokenize text into words using the regex
words = self.pre_tokenize(text)
# Initialize character-level vocabulary
char_freqs = collections.Counter("".join(words))
base_vocab = {
c: i + len(self.special_tokens)
for i, (c, _) in enumerate(char_freqs.items())
}
self.vocab.update(base_vocab)
self.inv_vocab.update({i: c for c, i in base_vocab.items()})
# Convert words to character sequences
sequences = [[c for c in word] for word in words]
# Track current vocabulary size
curr_vocab_size = len(self.vocab)
while curr_vocab_size < self.vocab_size:
# Find most frequent pairs
pair_freqs = collections.defaultdict(int)
for seq in sequences:
if len(seq) < 2:
continue
for i in range(len(seq) - 1):
pair = (seq[i], seq[i + 1])
pair_freqs[pair] += 1
# Find best pair to merge -- the char pair that occurs the most
if not pair_freqs:
break
best_pair = max(pair_freqs.items(), key=lambda x: x[1])
if best_pair[1] < min_freq:
break
# Create new token and add to vocabulary
new_token = "".join(best_pair[0])
self.vocab[new_token] = curr_vocab_size
self.inv_vocab[curr_vocab_size] = new_token
self.merges[best_pair[0]] = curr_vocab_size
# Update sequences with merged pairs
new_sequences = []
for seq in sequences:
new_seq = []
i = 0
while i < len(seq):
if i < len(seq) - 1 and (seq[i], seq[i + 1]) == best_pair[0]:
new_seq.append(new_token)
i += 2
else:
new_seq.append(seq[i])
i += 1
new_sequences.append(new_seq)
sequences = new_sequences
curr_vocab_size += 1
def encode(self, text: str, sample: bool = False) -> List[int]:
"""
Encode text to token ids
Args:
text: Text to encode
sample: Whether to use subword regularization
Returns:
List of token ids
"""
if not text:
return []
# Pre-tokenize
words = self.pre_tokenize(text)
# Initialize with character-level tokens
sequences = [[c for c in word] for word in words]
# Apply merges
for seq in sequences:
i = 0
while i < len(seq) - 1:
current_pair = (seq[i], seq[i + 1])
if current_pair in self.merges:
# If sampling enabled, probabilistically skip some merges
if sample and random.random() < self.alpha:
i += 1
continue
new_token = "".join(current_pair)
seq[i : i + 2] = [new_token]
else:
i += 1
# Flatten and convert to ids
tokens = []
for seq in sequences:
for token in seq:
if token in self.vocab:
tokens.append(self.vocab[token])
else:
tokens.append(self.vocab["<unk>"])
return tokens
def decode(self, ids: List[int]) -> str:
"""
Decode token ids back to text
Args:
ids: List of token ids
Returns:
Decoded text
"""
tokens = []
for idx in ids:
if idx in self.inv_vocab:
tokens.append(self.inv_vocab[idx])
else:
tokens.append(self.inv_vocab[self.vocab["<unk>"]])
return "".join(tokens)
# Usage Example
if __name__ == "__main__":
# Sample text for testing
# text = """SentencePiece is an unsupervised text tokenizer and detokenizer.
# It implements subword units like BPE and unigram language model with the
# extension of direct training from raw sentences."""
# read the tokenizer dataset
with open("input_text.txt", "r") as ip_file:
text = ip_file.readlines()
text = text[:-1]
# Initialize and train tokenizer
tokenizer = SentencePieceTokenizer(vocab_size=100)
tokenizer.train(text)
# Test encoding and decoding
encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)
print(f"Original text: {text}")
print(f"Encoded: {encoded[:10]}...")
print(f"Decoded text: {decoded}")
# Test with subword regularization
encoded_sampled = tokenizer.encode(text, sample=True)
print(f"Encoded (with sampling): {encoded_sampled[:10]}...")