diff --git a/.github/workflows/test_lemonade.yml b/.github/workflows/test_lemonade.yml index 89b53cdc..de79d40c 100644 --- a/.github/workflows/test_lemonade.yml +++ b/.github/workflows/test_lemonade.yml @@ -32,6 +32,11 @@ jobs: conda install pylint python -m pip check pip install -e .[llm] + - name: Lint with Black + uses: psf/black@stable + with: + options: "--check --verbose" + src: "./src" - name: Lint with PyLint shell: bash -el {0} run: | diff --git a/src/turnkeyml/common/build.py b/src/turnkeyml/common/build.py index 57b8ddd4..b90feb8e 100644 --- a/src/turnkeyml/common/build.py +++ b/src/turnkeyml/common/build.py @@ -282,14 +282,16 @@ def get_wmic_info(command): try: output = subprocess.check_output(command, shell=True).decode() return output.split("\n")[1].strip() - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except return str(e) if os_type == "Windows": if shutil.which("wmic") is not None: info_dict["Processor"] = get_wmic_info("wmic cpu get name") info_dict["OEM System"] = get_wmic_info("wmic computersystem get model") - mem_info_bytes = get_wmic_info("wmic computersystem get TotalPhysicalMemory") + mem_info_bytes = get_wmic_info( + "wmic computersystem get TotalPhysicalMemory" + ) try: mem_info_gb = round(int(mem_info_bytes) / (1024**3), 2) info_dict["Physical Memory"] = f"{mem_info_gb} GB" diff --git a/src/turnkeyml/llm/cli.py b/src/turnkeyml/llm/cli.py index e7ce69f3..3ab89c12 100644 --- a/src/turnkeyml/llm/cli.py +++ b/src/turnkeyml/llm/cli.py @@ -54,10 +54,6 @@ def main(): except ModuleNotFoundError: pass - - - - # Define the argument parser parser = cli.CustomArgumentParser( description="Turnkey analysis and benchmarking of GenAI models. " diff --git a/src/turnkeyml/llm/tools/chat.py b/src/turnkeyml/llm/tools/chat.py index 21491038..8c8ee94f 100644 --- a/src/turnkeyml/llm/tools/chat.py +++ b/src/turnkeyml/llm/tools/chat.py @@ -1,15 +1,27 @@ import argparse -from threading import Thread +import time +import statistics +from threading import Thread, Event import asyncio from fastapi import FastAPI, WebSocket from fastapi.responses import HTMLResponse +from starlette.websockets import WebSocketDisconnect from pydantic import BaseModel -from transformers import TextIteratorStreamer +from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList import uvicorn from turnkeyml.state import State from turnkeyml.tools import Tool from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter +DEFAULT_GENERATE_PARAMS = { + "do_sample": True, + "top_k": 50, + "top_p": 0.95, + "temperature": 0.7, +} + +DEFAULT_SERVER_PORT = 8000 + class LLMPrompt(Tool): """ @@ -61,7 +73,9 @@ def run( tokenizer: TokenizerAdapter = state.tokenizer input_ids = tokenizer(prompt, return_tensors="pt").input_ids - response = model.generate(input_ids, max_new_tokens=max_new_tokens) + response = model.generate( + input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS + ) response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip() state.response = response_text @@ -70,16 +84,32 @@ def run( return state +# Custom huggingface-style stopping criteria to allow +# us to halt streaming in-progress generations +class StopOnEvent(StoppingCriteria): + def __init__(self, stop_event: Event): + super().__init__() + self.stop_event = stop_event + + def __call__(self, input_ids, scores, **kwargs): + return self.stop_event.is_set() + + class Serve(Tool): """ Open a web server that apps can use to communicate with the LLM. - There are two ways interact with the server: + There are two ways to perform generations with the server: - Send an http request to "http://localhost:8000/generate" and receive back a response with the complete prompt. - Open a WebSocket with "ws://localhost:8000" and receive a streaming response to the prompt. + The server also exposes these helpful endpoints: + - /health: check whether a model is loaded and ready to serve. + - /stats: performance statistics for the generation. + - /halt: stop an in-progress generation from make more tokens. + The WebSocket functionality is demonstrated by the webpage served at http://localhost:8000, which you can visit with a web browser after opening the server. @@ -89,6 +119,7 @@ class Serve(Tool): huggingface TextIteratorStreamer. - state.tokenizer: tokenizer instance used to generate inputs for the model. Must be compatible with the huggingface TextIteratorStreamer. + - state.checkpoint: name of the checkpoint used to load state.model. Output state produced: None """ @@ -102,6 +133,17 @@ def __init__(self): enable_logger=False, ) + # Performance stats that are set during /ws and can be + # fetched in /stats + self.time_to_first_token = None + self.tokens_per_second = None + self.input_tokens = None + self.output_tokens = None + self.decode_token_times = None + + # Flag that tells the LLM to stop generating text and end the response + self.stop_event = Event() + @staticmethod def parser(add_help: bool = True) -> argparse.ArgumentParser: parser = __class__.helpful_parser( @@ -151,10 +193,15 @@ class Message(BaseModel): + + +

+

@@ -188,11 +265,8 @@ async def generate_response(message: Message): response = model.generate( input_ids, max_new_tokens=max_new_tokens, - do_sample=True, - top_k=50, - top_p=0.95, - temperature=0.7, pad_token_id=tokenizer.eos_token_id, + **DEFAULT_GENERATE_PARAMS, ) generated_text = tokenizer.decode(response[0], skip_special_tokens=True) @@ -203,13 +277,23 @@ async def generate_response(message: Message): @app.websocket("/ws") async def stream_response(websocket: WebSocket): + """ + Receive a prompt string, and then stream the response back + over a websocket. + """ + await websocket.accept() while True: - message = await websocket.receive_text() - - if message == "done": + try: + message = await websocket.receive_text() + except WebSocketDisconnect: + print("Client closed connection") break + + # Reset the early-exit flag before we start each generation + self.stop_event.clear() + input_ids = tokenizer(message, return_tensors="pt").input_ids # Set up the generation parameters @@ -219,39 +303,109 @@ async def stream_response(websocket: WebSocket): streamer = oga.OrtGenaiStreamer(tokenizer) + self.input_tokens = len(input_ids) + else: # Huggingface-like models streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, ) + + self.input_tokens = len(input_ids[0]) + + # Enable sending a signal into the generator thread to stop + # the generation early + stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)]) + generation_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, - "do_sample": True, - "top_k": 50, - "top_p": 0.95, - "temperature": 0.7, "pad_token_id": tokenizer.eos_token_id, + "stopping_criteria": stopping_criteria, + **DEFAULT_GENERATE_PARAMS, } + # Initialize performance variables + generation_start_time = time.perf_counter() + first_token = True + self.decode_token_times = [] + self.output_tokens = 0 + + # Begin generation thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Generate the response using streaming for new_text in streamer: + + # Capture performance stats about this token + self.output_tokens = self.output_tokens + 1 + if first_token: + self.time_to_first_token = ( + time.perf_counter() - generation_start_time + ) + first_token = False + else: + self.decode_token_times.append( + time.perf_counter() - next_token_start_time + ) + next_token_start_time = time.perf_counter() + + # Print the decoded value to the terminal for debugging purposes print(new_text, end="", flush=True) # Send the generated text to the client - await asyncio.sleep(0.1) # Add a small delay (adjust as needed) + await asyncio.sleep(0.001) # Add a small delay (adjust as needed) await websocket.send_text(new_text) + # Allow the user to finish the response early + if self.stop_event.is_set(): + print("Stopping generation early.") + break + + self.tokens_per_second = 1 / statistics.mean(self.decode_token_times) print("\n") thread.join() - await websocket.close() - - uvicorn.run(app, host="localhost", port=8000) + @app.get("/stats") + async def send_stats(): + """ + Send performance statistics to the client. + """ + return { + "time_to_first_token": self.time_to_first_token, + "tokens_per_second": self.tokens_per_second, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "decode_token_times": self.decode_token_times, + } + + @app.get("/halt") + async def halt_generation(): + """ + Allow the client to halt an in-progress generation. + """ + + self.stop_event.set() + + return { + "terminated": True, + } + + @app.get("/health") + async def health(): + """ + Report server health information to the client. + """ + + self.stop_event.set() + + return { + "model_loaded": state.checkpoint, + } + + uvicorn.run(app, host="localhost", port=DEFAULT_SERVER_PORT) return state diff --git a/src/turnkeyml/llm/tools/huggingface_load.py b/src/turnkeyml/llm/tools/huggingface_load.py index 789702d8..ba46f214 100644 --- a/src/turnkeyml/llm/tools/huggingface_load.py +++ b/src/turnkeyml/llm/tools/huggingface_load.py @@ -201,8 +201,15 @@ def __init__(self, model, dtype=torch.float32, device="cpu"): self.dtype = dtype self.device = device - def generate(self, input_ids, max_new_tokens=512, repetition_penalty=1.2, - do_sample=True, temperature=0.1, **kwargs): + def generate( + self, + input_ids, + max_new_tokens=512, + repetition_penalty=1.2, + do_sample=True, + temperature=0.1, + **kwargs, + ): amp_enabled = ( True if (self.dtype == torch.float16 or self.dtype == torch.bfloat16) @@ -221,7 +228,7 @@ def generate(self, input_ids, max_new_tokens=512, repetition_penalty=1.2, repetition_penalty=repetition_penalty, do_sample=do_sample, temperature=temperature, - **kwargs + **kwargs, ) diff --git a/src/turnkeyml/llm/tools/llamacpp.py b/src/turnkeyml/llm/tools/llamacpp.py index 21a42c1d..2e9ca761 100644 --- a/src/turnkeyml/llm/tools/llamacpp.py +++ b/src/turnkeyml/llm/tools/llamacpp.py @@ -13,6 +13,7 @@ def llamacpp_dir(state: State): return os.path.join(build.output_dir(state.cache_dir, state.build_name), "llamacpp") + class LlamaCppAdapter(ModelAdapter): unique_name = "llama-cpp-adapter" @@ -45,7 +46,7 @@ def generate(self, input_ids: str, max_new_tokens: Optional[int] = None): "threads": self.threads, "model": self.model, "prompt": input_ids, - "temp": self.temp + "temp": self.temp, } for flag, value in optional_params.items(): @@ -61,11 +62,12 @@ def generate(self, input_ids: str, max_new_tokens: Optional[int] = None): universal_newlines=True, ) - raw_output, raw_err= process.communicate() + raw_output, raw_err = process.communicate() if process.returncode != 0: raise subprocess.CalledProcessError( - process.returncode, process.args, raw_output, raw_err) + process.returncode, process.args, raw_output, raw_err + ) prompt_found = False output_text = "" @@ -82,6 +84,7 @@ def generate(self, input_ids: str, max_new_tokens: Optional[int] = None): return [output_text] + class LoadLlamaCpp(FirstTool): unique_name = "load-llama-cpp" @@ -156,7 +159,7 @@ def run( if executable is None: raise Exception(f"{self.__class__.unique_name} requires an executable") - if (input is not None and input != ""): + if input is not None and input != "": model_binary = input # Save execution parameters @@ -171,7 +174,7 @@ def run( ) state.model = LlamaCppAdapter( - executable = executable, + executable=executable, model=model_binary, tool_dir=llamacpp_dir(state), context_size=context_size, diff --git a/src/turnkeyml/llm/tools/mmlu.py b/src/turnkeyml/llm/tools/mmlu.py index 828e2527..26946d11 100644 --- a/src/turnkeyml/llm/tools/mmlu.py +++ b/src/turnkeyml/llm/tools/mmlu.py @@ -132,9 +132,11 @@ def run( "Subject": subject, "Accuracy": acc, "Total Questions": len(test_df), - "Evaluated Questions": (max_evals - if max_evals is not None and max_evals < len(test_df) - else len(test_df)), + "Evaluated Questions": ( + max_evals + if max_evals is not None and max_evals < len(test_df) + else len(test_df) + ), "Correct Answers": correct_answers_count, } ) @@ -281,7 +283,7 @@ def _eval_model(ntrain, max_evals, subject, model, tokenizer, dev_df, test_df): "Correct": pred_label == label, } ) - if (max_evals is not None and i >= max_evals -1): + if max_evals is not None and i >= max_evals - 1: break acc = np.mean([res["Correct"] for res in detailed_results]) diff --git a/src/turnkeyml/llm/tools/ort_genai/oga.py b/src/turnkeyml/llm/tools/ort_genai/oga.py index b7a00695..510dfb85 100644 --- a/src/turnkeyml/llm/tools/ort_genai/oga.py +++ b/src/turnkeyml/llm/tools/ort_genai/oga.py @@ -20,6 +20,7 @@ ) from turnkeyml.llm.cache import Keys + class OrtGenaiTokenizer(TokenizerAdapter): def __init__(self, model: og.Model): # Initialize the tokenizer and produce the initial tokens. @@ -91,6 +92,7 @@ def generate( temperature=0.7, streamer: OrtGenaiStreamer = None, pad_token_id=None, + stopping_criteria=None, ): params = og.GeneratorParams(self.model) @@ -164,7 +166,10 @@ def generate( return [generator.get_sequence(0)] else: tokenizer_stream = streamer.tokenizer.tokenizer.create_stream() - while not generator.is_done(): + + stop_early = False + + while not generator.is_done() and not stop_early: generator.compute_logits() generator.generate_next_token() @@ -173,6 +178,10 @@ def generate( streamer.add_text(new_text) + if stopping_criteria is not None: + if stopping_criteria[0].stop_event.is_set(): + stop_early = True + streamer.add_text("") streamer.done() diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py index c7a18d13..4b56dfc5 100644 --- a/src/turnkeyml/version.py +++ b/src/turnkeyml/version.py @@ -1 +1 @@ -__version__ = "4.0.3" +__version__ = "4.0.4"