diff --git a/.github/workflows/test_lemonade.yml b/.github/workflows/test_lemonade.yml index 89b53cdc..de79d40c 100644 --- a/.github/workflows/test_lemonade.yml +++ b/.github/workflows/test_lemonade.yml @@ -32,6 +32,11 @@ jobs: conda install pylint python -m pip check pip install -e .[llm] + - name: Lint with Black + uses: psf/black@stable + with: + options: "--check --verbose" + src: "./src" - name: Lint with PyLint shell: bash -el {0} run: | diff --git a/src/turnkeyml/common/build.py b/src/turnkeyml/common/build.py index 57b8ddd4..b90feb8e 100644 --- a/src/turnkeyml/common/build.py +++ b/src/turnkeyml/common/build.py @@ -282,14 +282,16 @@ def get_wmic_info(command): try: output = subprocess.check_output(command, shell=True).decode() return output.split("\n")[1].strip() - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except return str(e) if os_type == "Windows": if shutil.which("wmic") is not None: info_dict["Processor"] = get_wmic_info("wmic cpu get name") info_dict["OEM System"] = get_wmic_info("wmic computersystem get model") - mem_info_bytes = get_wmic_info("wmic computersystem get TotalPhysicalMemory") + mem_info_bytes = get_wmic_info( + "wmic computersystem get TotalPhysicalMemory" + ) try: mem_info_gb = round(int(mem_info_bytes) / (1024**3), 2) info_dict["Physical Memory"] = f"{mem_info_gb} GB" diff --git a/src/turnkeyml/llm/cli.py b/src/turnkeyml/llm/cli.py index e7ce69f3..3ab89c12 100644 --- a/src/turnkeyml/llm/cli.py +++ b/src/turnkeyml/llm/cli.py @@ -54,10 +54,6 @@ def main(): except ModuleNotFoundError: pass - - - - # Define the argument parser parser = cli.CustomArgumentParser( description="Turnkey analysis and benchmarking of GenAI models. " diff --git a/src/turnkeyml/llm/tools/chat.py b/src/turnkeyml/llm/tools/chat.py index 21491038..8c8ee94f 100644 --- a/src/turnkeyml/llm/tools/chat.py +++ b/src/turnkeyml/llm/tools/chat.py @@ -1,15 +1,27 @@ import argparse -from threading import Thread +import time +import statistics +from threading import Thread, Event import asyncio from fastapi import FastAPI, WebSocket from fastapi.responses import HTMLResponse +from starlette.websockets import WebSocketDisconnect from pydantic import BaseModel -from transformers import TextIteratorStreamer +from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList import uvicorn from turnkeyml.state import State from turnkeyml.tools import Tool from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter +DEFAULT_GENERATE_PARAMS = { + "do_sample": True, + "top_k": 50, + "top_p": 0.95, + "temperature": 0.7, +} + +DEFAULT_SERVER_PORT = 8000 + class LLMPrompt(Tool): """ @@ -61,7 +73,9 @@ def run( tokenizer: TokenizerAdapter = state.tokenizer input_ids = tokenizer(prompt, return_tensors="pt").input_ids - response = model.generate(input_ids, max_new_tokens=max_new_tokens) + response = model.generate( + input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS + ) response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip() state.response = response_text @@ -70,16 +84,32 @@ def run( return state +# Custom huggingface-style stopping criteria to allow +# us to halt streaming in-progress generations +class StopOnEvent(StoppingCriteria): + def __init__(self, stop_event: Event): + super().__init__() + self.stop_event = stop_event + + def __call__(self, input_ids, scores, **kwargs): + return self.stop_event.is_set() + + class Serve(Tool): """ Open a web server that apps can use to communicate with the LLM. - There are two ways interact with the server: + There are two ways to perform generations with the server: - Send an http request to "http://localhost:8000/generate" and receive back a response with the complete prompt. - Open a WebSocket with "ws://localhost:8000" and receive a streaming response to the prompt. + The server also exposes these helpful endpoints: + - /health: check whether a model is loaded and ready to serve. + - /stats: performance statistics for the generation. + - /halt: stop an in-progress generation from make more tokens. + The WebSocket functionality is demonstrated by the webpage served at http://localhost:8000, which you can visit with a web browser after opening the server. @@ -89,6 +119,7 @@ class Serve(Tool): huggingface TextIteratorStreamer. - state.tokenizer: tokenizer instance used to generate inputs for the model. Must be compatible with the huggingface TextIteratorStreamer. + - state.checkpoint: name of the checkpoint used to load state.model. Output state produced: None """ @@ -102,6 +133,17 @@ def __init__(self): enable_logger=False, ) + # Performance stats that are set during /ws and can be + # fetched in /stats + self.time_to_first_token = None + self.tokens_per_second = None + self.input_tokens = None + self.output_tokens = None + self.decode_token_times = None + + # Flag that tells the LLM to stop generating text and end the response + self.stop_event = Event() + @staticmethod def parser(add_help: bool = True) -> argparse.ArgumentParser: parser = __class__.helpful_parser( @@ -151,10 +193,15 @@ class Message(BaseModel): + + +
+