6
6
from turnkeyml .state import State
7
7
import turnkeyml .common .status as status
8
8
from turnkeyml .tools import Tool , FirstTool
9
- from lemonade .tools .adapter import ModelAdapter
9
+ from lemonade .tools .adapter import ModelAdapter , TokenizerAdapter
10
10
from lemonade .cache import Keys
11
11
12
12
# Command line interfaces for tools will use string inputs for data
@@ -32,6 +32,19 @@ def make_example_inputs(state: State) -> Dict:
32
32
return {"input_ids" : inputs_ids }
33
33
34
34
35
+ class HuggingfaceTokenizerAdapter (TokenizerAdapter ):
36
+ def __init__ (self , tokenizer : transformers .AutoTokenizer , device : str ):
37
+ super ().__init__ ()
38
+ self .tokenizer = tokenizer
39
+ self .device = device
40
+
41
+ def __call__ (self , prompt , ** kwargs ):
42
+ return self .tokenizer (prompt , ** kwargs ).to (self .device )
43
+
44
+ def decode (self , response , ** kwargs ):
45
+ return self .tokenizer .decode (response , ** kwargs )
46
+
47
+
35
48
class HuggingfaceLoad (FirstTool ):
36
49
"""
37
50
Load an LLM as a torch.nn.Module using the Hugging Face transformers
@@ -167,7 +180,7 @@ def run(
167
180
168
181
# Pass the model and inputs into state
169
182
state .model = model
170
- state .tokenizer = tokenizer
183
+ state .tokenizer = HuggingfaceTokenizerAdapter ( tokenizer , device )
171
184
state .dtype = dtype
172
185
state .checkpoint = checkpoint
173
186
state .device = device
0 commit comments