Skip to content

Commit 067ac62

Browse files
committed
Always move HF tozenizer encodings to the target device
1 parent c1463b0 commit 067ac62

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/lemonade/tools/huggingface_load.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from turnkeyml.state import State
77
import turnkeyml.common.status as status
88
from turnkeyml.tools import Tool, FirstTool
9-
from lemonade.tools.adapter import ModelAdapter
9+
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
1010
from lemonade.cache import Keys
1111

1212
# Command line interfaces for tools will use string inputs for data
@@ -32,6 +32,19 @@ def make_example_inputs(state: State) -> Dict:
3232
return {"input_ids": inputs_ids}
3333

3434

35+
class HuggingfaceTokenizerAdapter(TokenizerAdapter):
36+
def __init__(self, tokenizer: transformers.AutoTokenizer, device: str):
37+
super().__init__()
38+
self.tokenizer = tokenizer
39+
self.device = device
40+
41+
def __call__(self, prompt, **kwargs):
42+
return self.tokenizer(prompt, **kwargs).to(self.device)
43+
44+
def decode(self, response, **kwargs):
45+
return self.tokenizer.decode(response, **kwargs)
46+
47+
3548
class HuggingfaceLoad(FirstTool):
3649
"""
3750
Load an LLM as a torch.nn.Module using the Hugging Face transformers
@@ -167,7 +180,7 @@ def run(
167180

168181
# Pass the model and inputs into state
169182
state.model = model
170-
state.tokenizer = tokenizer
183+
state.tokenizer = HuggingfaceTokenizerAdapter(tokenizer, device)
171184
state.dtype = dtype
172185
state.checkpoint = checkpoint
173186
state.device = device

0 commit comments

Comments
 (0)