Skip to content

Commit 89a8403

Browse files
ringohoffmanjainapurva
authored andcommitted
Refactor tiktoken import bare except (#1024)
* Refactor tiktoken import bare except We shouldn't suppress the import error if tiktoken is not installed Instead of importing it at the top of the file, we can import it inside the only functions that use these imports; we can consider a different import structure if we end up needing to access these modules in more places This lazy importing should also decrease the load time of these modules Also do the same thing for sentencepiece * Remove __future__.annotations import
1 parent 047e656 commit 89a8403

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

torchao/_models/llama/tokenizer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
# copied from https://github.com/pytorch-labs/gpt-fast/blob/main/tokenizer.py
22

33
import os
4-
import sentencepiece as spm
5-
try:
6-
import tiktoken
7-
from tiktoken.load import load_tiktoken_bpe
8-
except:
9-
pass
104
from pathlib import Path
11-
from typing import Dict
125

136
class TokenizerInterface:
147
def __init__(self, model_path):
@@ -28,6 +21,8 @@ def eos_id(self):
2821

2922
class SentencePieceWrapper(TokenizerInterface):
3023
def __init__(self, model_path):
24+
import sentencepiece as spm
25+
3126
super().__init__(model_path)
3227
self.processor = spm.SentencePieceProcessor(str(model_path))
3328
self.bos_token_id = self.bos_id()
@@ -50,16 +45,19 @@ class TiktokenWrapper(TokenizerInterface):
5045
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
5146
"""
5247

53-
special_tokens: Dict[str, int]
48+
special_tokens: dict[str, int]
5449

5550
num_reserved_special_tokens = 256
5651

5752
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
5853

5954
def __init__(self, model_path):
55+
import tiktoken
56+
import tiktoken.load
57+
6058
super().__init__(model_path)
6159
assert os.path.isfile(model_path), str(model_path)
62-
mergeable_ranks = load_tiktoken_bpe(str(model_path))
60+
mergeable_ranks = tiktoken.load.load_tiktoken_bpe(str(model_path))
6361
num_base_tokens = len(mergeable_ranks)
6462
special_tokens = [
6563
"<|begin_of_text|>",

0 commit comments

Comments
 (0)