diff --git a/README.md b/README.md index aa4d7db..ed8938e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Py-TGI -Py-TGI is a Python wrapper around [TGI](https://github.com/huggingface/text-generation-inference) to enable creating and running TGI servers and clients in a similar style to vLLM +Py-TGI is a Python wrapper around [TGI](https://github.com/huggingface/text-generation-inference) to enable creating and running TGI servers in a similar style to vLLM. ## Installation @@ -10,24 +10,22 @@ pip install py-tgi ## Usage -Running a TGI server with a batched inference client: +Py-TGI is designed to be used in a similar way to vLLM. Here's an example of how to use it: ```python -from logging import basicConfig, INFO -basicConfig(level=INFO) # to stream tgi container logs to stdout - from py_tgi import TGI - -llm = TGI(model="TheBloke/Mistral-7B-Instruct-v0.1-AWQ", quantize="awq") - -try: - output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"]) - print(output) -except Exception as e: - print(e) -finally: - # make sure to close the server - llm.close() +from py_tgi.utils import is_nvidia_system, is_rocm_system + +llm = TGI( + model="TheBloke/Llama-2-7B-AWQ", # awq model checkpoint + devices=["/dev/kfd", "/dev/dri"] if is_rocm_system() else None, # custom devices (ROCm) + gpus="all" if is_nvidia_system() else None, # all gpus (NVIDIA) + quantize="gptq", # use exllama kernels (rocm compatible) +) +output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"]) +print(output) ``` Output: ```[" and I'm here to help you with any questions you have. What can I help you with", "\nUser 0: I'm doing well, thanks for asking. I'm just a"]``` + +That's it! Now you can write your Python scripts using the power of TGI. diff --git a/example.py b/example.py new file mode 100644 index 0000000..a282d53 --- /dev/null +++ b/example.py @@ -0,0 +1,15 @@ +from py_tgi import TGI +from py_tgi.utils import is_nvidia_system, is_rocm_system + +if is_rocm_system() or is_nvidia_system(): + llm = TGI( + model="TheBloke/Llama-2-7B-AWQ", # awq model checkpoint + devices=["/dev/kfd", "/dev/dri"] if is_rocm_system() else None, # custom devices (ROCm) + gpus="all" if is_nvidia_system() else None, # all gpus (NVIDIA) + quantize="gptq", # use exllama kernels (rocm compatible) + ) +else: + llm = TGI(model="gpt2", sharded=False) + +output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"]) +print(output) diff --git a/py_tgi/__init__.py b/py_tgi/__init__.py index 35d5109..0a0ff46 100644 --- a/py_tgi/__init__.py +++ b/py_tgi/__init__.py @@ -1,7 +1,9 @@ import os +import re import time from concurrent.futures import ThreadPoolExecutor -from logging import getLogger +from logging import INFO, basicConfig, getLogger +from pathlib import Path from typing import List, Literal, Optional, Union import docker @@ -10,13 +12,17 @@ from huggingface_hub import InferenceClient from huggingface_hub.inference._text_generation import TextGenerationResponse -from .utils import get_nvidia_gpu_devices, timeout +from .utils import is_rocm_system +basicConfig(level=INFO) + +CONNECTION_TIMEOUT = 60 LOGGER = getLogger("tgi") HF_CACHE_DIR = f"{os.path.expanduser('~')}/.cache/huggingface/hub" -Quantization_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"] -Torch_Dtype_Literal = Literal["float32", "float16", "bfloat16"] + +Dtype_Literal = Literal["float32", "float16", "bfloat16"] +Quantize_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"] class TGI: @@ -26,18 +32,19 @@ def __init__( model: str, revision: str = "main", # image options - image: str = "ghcr.io/huggingface/text-generation-inference", - version: str = "latest", + image: str = "ghcr.io/huggingface/text-generation-inference:latest", # docker options - volume: str = HF_CACHE_DIR, + port: int = 1111, shm_size: str = "1g", address: str = "127.0.0.1", - port: int = 1111, + volume: str = HF_CACHE_DIR, # automatically connects local hf cache to the container's /data + devices: Optional[List[str]] = None, # ["/dev/kfd", "/dev/dri"] or None for custom devices (ROCm) + gpus: Optional[Union[str, int]] = None, # "all" or "0,1,2,3" or 3 or None for NVIDIA # tgi launcher options sharded: Optional[bool] = None, num_shard: Optional[int] = None, - torch_dtype: Optional[Torch_Dtype_Literal] = None, - quantize: Optional[Quantization_Literal] = None, + dtype: Optional[Dtype_Literal] = None, + quantize: Optional[Quantize_Literal] = None, trust_remote_code: Optional[bool] = False, disable_custom_kernels: Optional[bool] = False, ) -> None: @@ -46,29 +53,33 @@ def __init__( self.revision = revision # image options self.image = image - self.version = version # docker options self.port = port self.volume = volume self.address = address self.shm_size = shm_size # tgi launcher options + self.dtype = dtype self.sharded = sharded self.num_shard = num_shard - self.torch_dtype = torch_dtype self.quantize = quantize self.trust_remote_code = trust_remote_code self.disable_custom_kernels = disable_custom_kernels + if is_rocm_system() and "-rocm" not in self.image: + LOGGER.warning("ROCm system detected, but the image does not contain '-rocm'. Adding it.") + self.image = self.image + "-rocm" + LOGGER.info("\t+ Starting Docker client") self.docker_client = docker.from_env() try: - LOGGER.info("\t+ Checking if TGI image exists") - self.docker_client.images.get(f"{self.image}:{self.version}") + LOGGER.info("\t+ Checking if TGI image is available locally") + self.docker_client.images.get(self.image) + LOGGER.info("\t+ TGI image found locally") except docker.errors.ImageNotFound: - LOGGER.info("\t+ TGI image not found, downloading it (this may take a while)") - self.docker_client.images.pull(f"{self.image}:{self.version}") + LOGGER.info("\t+ TGI image not found locally, pulling from Docker Hub") + self.docker_client.images.pull(self.image) env = {} if os.environ.get("HUGGING_FACE_HUB_TOKEN", None) is not None: @@ -83,79 +94,75 @@ def __init__( self.command.extend(["--num-shard", str(self.num_shard)]) if self.quantize is not None: self.command.extend(["--quantize", self.quantize]) - if self.torch_dtype is not None: - self.command.extend(["--dtype", self.torch_dtype]) + if self.dtype is not None: + self.command.extend(["--dtype", self.dtype]) if self.trust_remote_code: self.command.append("--trust-remote-code") if self.disable_custom_kernels: self.command.append("--disable-custom-kernels") - try: - LOGGER.info("\t+ Checking if GPU is available") - if os.environ.get("CUDA_VISIBLE_DEVICES") is not None: - LOGGER.info("\t+ Using specified `CUDA_VISIBLE_DEVICES` to set GPU(s)") - device_ids = os.environ.get("CUDA_VISIBLE_DEVICES") - else: - LOGGER.info("\t+ Using nvidia-smi to get available GPU(s) (if any)") - device_ids = get_nvidia_gpu_devices() - - LOGGER.info(f"\t+ Using GPU(s): {device_ids}") - self.device_requests = [docker.types.DeviceRequest(device_ids=[device_ids], capabilities=[["gpu"]])] - except Exception: - LOGGER.info("\t+ No GPU detected") + if gpus is not None and isinstance(gpus, str) and gpus == "all": + LOGGER.info("\t+ Using all GPU(s)") + self.device_requests = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] + elif gpus is not None and isinstance(gpus, int): + LOGGER.info(f"\t+ Using {gpus} GPU(s)") + self.device_requests = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] + elif gpus is not None and isinstance(gpus, str) and re.match(r"^\d+(,\d+)*$", gpus): + LOGGER.info(f"\t+ Using GPU(s) {gpus}") + self.device_requests = [docker.types.DeviceRequest(device_ids=[gpus], capabilities=[["gpu"]])] + else: + LOGGER.info("\t+ Not using any GPU(s)") self.device_requests = None - self.tgi_container = self.docker_client.containers.run( + if devices is not None and isinstance(devices, list) and all(Path(d).exists() for d in devices): + LOGGER.info(f"\t+ Using custom device(s) {devices}") + self.devices = devices + else: + LOGGER.info("\t+ Not using any custom device(s)") + self.devices = None + + self.closed = False + self.docker_container = self.docker_client.containers.run( + image=self.image, command=self.command, - image=f"{self.image}:{self.version}", volumes={self.volume: {"bind": "/data", "mode": "rw"}}, ports={"80/tcp": (self.address, self.port)}, device_requests=self.device_requests, shm_size=self.shm_size, + devices=self.devices, environment=env, + auto_remove=True, # this is so cool detach=True, ) LOGGER.info("\t+ Waiting for TGI server to be ready") - with timeout(60): - for line in self.tgi_container.logs(stream=True): - tgi_log = line.decode("utf-8").strip() - if "Connected" in tgi_log: - break - elif "Error" in tgi_log: - raise Exception(f"\t {tgi_log}") - + for line in self.docker_container.logs(stream=True): + tgi_log = line.decode("utf-8").strip() + if "Connected" in tgi_log: + LOGGER.info(f"\t {tgi_log}") + break + elif "Error" in tgi_log: + LOGGER.info(f"\t {tgi_log}") + raise Exception("TGI server failed to start") + else: LOGGER.info(f"\t {tgi_log}") LOGGER.info("\t+ Conecting to TGI server") self.url = f"http://{self.address}:{self.port}" - with timeout(60): - while True: - try: - self.tgi_client = InferenceClient(model=self.url) - self.tgi_client.text_generation("Hello world!") - LOGGER.info(f"\t+ Connected to TGI server at {self.url}") - break - except Exception: - LOGGER.info("\t+ TGI server not ready, retrying in 1 second") - time.sleep(1) - def close(self) -> None: - if hasattr(self, "tgi_container"): - LOGGER.info("\t+ Stoping TGI container") - self.tgi_container.stop() - LOGGER.info("\t+ Waiting for TGI container to stop") - self.tgi_container.wait() + start_time = time.time() + while time.time() - start_time < CONNECTION_TIMEOUT: + try: + self.tgi_client = InferenceClient(model=self.url) + self.tgi_client.text_generation("Hello world!") + LOGGER.info("\t+ Connected to TGI server successfully") + return + except Exception: + LOGGER.info("\t+ TGI server is not ready yet, waiting 1 second") + time.sleep(1) - if hasattr(self, "docker_client"): - LOGGER.info("\t+ Closing docker client") - self.docker_client.close() - - def __call__( - self, prompt: Union[str, List[str]], **kwargs - ) -> Union[TextGenerationResponse, List[TextGenerationResponse]]: - return self.generate(prompt, **kwargs) + raise Exception("TGI server took too long to start (60 seconds)") def generate( self, prompt: Union[str, List[str]], **kwargs @@ -174,3 +181,25 @@ def generate( for i in range(len(prompt)): output.append(futures[i].result()) return output + + def close(self) -> None: + if not self.closed: + if hasattr(self, "docker_container"): + LOGGER.info("\t+ Stoping docker container") + self.docker_container.stop() + self.docker_container.wait() + + if hasattr(self, "docker_client"): + LOGGER.info("\t+ Closing docker client") + self.docker_client.close() + + self.closed = True + + def __call__( + self, prompt: Union[str, List[str]], **kwargs + ) -> Union[TextGenerationResponse, List[TextGenerationResponse]]: + return self.generate(prompt, **kwargs) + + def __del__(self) -> None: + if not self.closed: + self.close() diff --git a/py_tgi/utils.py b/py_tgi/utils.py index 3143cba..d8e566b 100644 --- a/py_tgi/utils.py +++ b/py_tgi/utils.py @@ -1,48 +1,17 @@ -import signal import subprocess -from contextlib import contextmanager -def get_nvidia_gpu_devices() -> str: - nvidia_smi = ( - subprocess.check_output( - [ - "nvidia-smi", - "--query-gpu=index,gpu_name,compute_cap", - "--format=csv", - ], - ) - .decode("utf-8") - .strip() - .split("\n")[1:] - ) - device = [ - { - "id": int(gpu.split(", ")[0]), - "name": gpu.split(", ")[1], - "compute_cap": gpu.split(", ")[2], - } - for gpu in nvidia_smi - ] - device_ids = [gpu["id"] for gpu in device if "Display" not in gpu["name"]] - device_ids = ",".join([str(device_id) for device_id in device_ids]) - - return device_ids - - -@contextmanager -def timeout(time: int): - """ - Timeout context manager. Raises TimeoutError if the code inside the context manager takes longer than `time` seconds to execute. - """ - - def signal_handler(signum, frame): - raise TimeoutError("Timed out") +def is_rocm_system() -> bool: + try: + subprocess.check_output(["rocm-smi"]) + return True + except FileNotFoundError: + return False - signal.signal(signal.SIGALRM, signal_handler) - signal.alarm(time) +def is_nvidia_system() -> bool: try: - yield - finally: - signal.alarm(0) + subprocess.check_output(["nvidia-smi"]) + return True + except FileNotFoundError: + return False diff --git a/tests/test.py b/tests/test.py index 164f7f6..89d0d15 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,21 +1,21 @@ -from logging import INFO, basicConfig - from py_tgi import TGI +from py_tgi.utils import is_nvidia_system, is_rocm_system -basicConfig(level=INFO) - -llm = TGI("gpt2", sharded=False) - -try: +if is_nvidia_system() or is_rocm_system(): + llm = TGI( + model="TheBloke/Llama-2-7B-AWQ", # awq model checkpoint + devices=["/dev/kfd", "/dev/dri"] if is_rocm_system() else None, # custom devices (ROCm) + gpus="all" if is_nvidia_system() else None, # all gpus (NVIDIA) + quantize="gptq", # use exllama kernels (rocm compatible) + ) output = llm.generate("Hi, I'm a sanity test") assert isinstance(output, str) - output = llm.generate(["Hi, I'm a sanity test", "I'm a second sentence"]) - assert isinstance(output, list) + assert isinstance(output, list) and all(isinstance(x, str) for x in output) - llm.close() - -# catch Exception and InterruptedError -except (Exception, InterruptedError, KeyboardInterrupt) as e: - llm.close() - raise e +else: + llm = TGI(model="gpt2", sharded=False) + output = llm.generate("Hi, I'm a sanity test") + assert isinstance(output, str) + output = llm.generate(["Hi, I'm a sanity test", "I'm a second sentence"]) + assert isinstance(output, list) and all(isinstance(x, str) for x in output)