diff --git a/src/lemonade/tools/huggingface_load.py b/src/lemonade/tools/huggingface_load.py index 8cbec6e..617a4b3 100644 --- a/src/lemonade/tools/huggingface_load.py +++ b/src/lemonade/tools/huggingface_load.py @@ -6,7 +6,7 @@ from turnkeyml.state import State import turnkeyml.common.status as status from turnkeyml.tools import Tool, FirstTool -from lemonade.tools.adapter import ModelAdapter +from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter from lemonade.cache import Keys # Command line interfaces for tools will use string inputs for data @@ -32,6 +32,26 @@ def make_example_inputs(state: State) -> Dict: return {"input_ids": inputs_ids} +class HuggingfaceTokenizerAdapter(TokenizerAdapter): + def __init__(self, tokenizer: transformers.AutoTokenizer, device: str): + super().__init__() + self.tokenizer = tokenizer + self.device = device + + def __call__(self, prompt, **kwargs): + return self.tokenizer(prompt, **kwargs).to(self.device) + + def decode(self, response, **kwargs): + return self.tokenizer.decode(response, **kwargs) + + def batch_decode(self, tokens, **kwargs): + return self.tokenizer.batch_decode(tokens, **kwargs) + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id + + class HuggingfaceLoad(FirstTool): """ Load an LLM as a torch.nn.Module using the Hugging Face transformers @@ -167,7 +187,7 @@ def run( # Pass the model and inputs into state state.model = model - state.tokenizer = tokenizer + state.tokenizer = HuggingfaceTokenizerAdapter(tokenizer, device) state.dtype = dtype state.checkpoint = checkpoint state.device = device