Skip to content

Commit

Permalink
Merge pull request #1 from IlyasMoutawwakil/rocm-support
Browse files Browse the repository at this point in the history
Rocm and custom devices support
  • Loading branch information
IlyasMoutawwakil authored Feb 20, 2024
2 parents 8c3e23d + f4f4986 commit 840eefe
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 138 deletions.
30 changes: 14 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Py-TGI

Py-TGI is a Python wrapper around [TGI](https://github.com/huggingface/text-generation-inference) to enable creating and running TGI servers and clients in a similar style to vLLM
Py-TGI is a Python wrapper around [TGI](https://github.com/huggingface/text-generation-inference) to enable creating and running TGI servers in a similar style to vLLM.

## Installation

Expand All @@ -10,24 +10,22 @@ pip install py-tgi

## Usage

Running a TGI server with a batched inference client:
Py-TGI is designed to be used in a similar way to vLLM. Here's an example of how to use it:

```python
from logging import basicConfig, INFO
basicConfig(level=INFO) # to stream tgi container logs to stdout

from py_tgi import TGI

llm = TGI(model="TheBloke/Mistral-7B-Instruct-v0.1-AWQ", quantize="awq")

try:
output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"])
print(output)
except Exception as e:
print(e)
finally:
# make sure to close the server
llm.close()
from py_tgi.utils import is_nvidia_system, is_rocm_system

llm = TGI(
model="TheBloke/Llama-2-7B-AWQ", # awq model checkpoint
devices=["/dev/kfd", "/dev/dri"] if is_rocm_system() else None, # custom devices (ROCm)
gpus="all" if is_nvidia_system() else None, # all gpus (NVIDIA)
quantize="gptq", # use exllama kernels (rocm compatible)
)
output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"])
print(output)
```

Output: ```[" and I'm here to help you with any questions you have. What can I help you with", "\nUser 0: I'm doing well, thanks for asking. I'm just a"]```

That's it! Now you can write your Python scripts using the power of TGI.
15 changes: 15 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from py_tgi import TGI
from py_tgi.utils import is_nvidia_system, is_rocm_system

if is_rocm_system() or is_nvidia_system():
llm = TGI(
model="TheBloke/Llama-2-7B-AWQ", # awq model checkpoint
devices=["/dev/kfd", "/dev/dri"] if is_rocm_system() else None, # custom devices (ROCm)
gpus="all" if is_nvidia_system() else None, # all gpus (NVIDIA)
quantize="gptq", # use exllama kernels (rocm compatible)
)
else:
llm = TGI(model="gpt2", sharded=False)

output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"])
print(output)
159 changes: 94 additions & 65 deletions py_tgi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from logging import INFO, basicConfig, getLogger
from pathlib import Path
from typing import List, Literal, Optional, Union

import docker
Expand All @@ -10,13 +12,17 @@
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationResponse

from .utils import get_nvidia_gpu_devices, timeout
from .utils import is_rocm_system

basicConfig(level=INFO)

CONNECTION_TIMEOUT = 60
LOGGER = getLogger("tgi")
HF_CACHE_DIR = f"{os.path.expanduser('~')}/.cache/huggingface/hub"

Quantization_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"]
Torch_Dtype_Literal = Literal["float32", "float16", "bfloat16"]

Dtype_Literal = Literal["float32", "float16", "bfloat16"]
Quantize_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"]


class TGI:
Expand All @@ -26,18 +32,19 @@ def __init__(
model: str,
revision: str = "main",
# image options
image: str = "ghcr.io/huggingface/text-generation-inference",
version: str = "latest",
image: str = "ghcr.io/huggingface/text-generation-inference:latest",
# docker options
volume: str = HF_CACHE_DIR,
port: int = 1111,
shm_size: str = "1g",
address: str = "127.0.0.1",
port: int = 1111,
volume: str = HF_CACHE_DIR, # automatically connects local hf cache to the container's /data
devices: Optional[List[str]] = None, # ["/dev/kfd", "/dev/dri"] or None for custom devices (ROCm)
gpus: Optional[Union[str, int]] = None, # "all" or "0,1,2,3" or 3 or None for NVIDIA
# tgi launcher options
sharded: Optional[bool] = None,
num_shard: Optional[int] = None,
torch_dtype: Optional[Torch_Dtype_Literal] = None,
quantize: Optional[Quantization_Literal] = None,
dtype: Optional[Dtype_Literal] = None,
quantize: Optional[Quantize_Literal] = None,
trust_remote_code: Optional[bool] = False,
disable_custom_kernels: Optional[bool] = False,
) -> None:
Expand All @@ -46,29 +53,33 @@ def __init__(
self.revision = revision
# image options
self.image = image
self.version = version
# docker options
self.port = port
self.volume = volume
self.address = address
self.shm_size = shm_size
# tgi launcher options
self.dtype = dtype
self.sharded = sharded
self.num_shard = num_shard
self.torch_dtype = torch_dtype
self.quantize = quantize
self.trust_remote_code = trust_remote_code
self.disable_custom_kernels = disable_custom_kernels

if is_rocm_system() and "-rocm" not in self.image:
LOGGER.warning("ROCm system detected, but the image does not contain '-rocm'. Adding it.")
self.image = self.image + "-rocm"

LOGGER.info("\t+ Starting Docker client")
self.docker_client = docker.from_env()

try:
LOGGER.info("\t+ Checking if TGI image exists")
self.docker_client.images.get(f"{self.image}:{self.version}")
LOGGER.info("\t+ Checking if TGI image is available locally")
self.docker_client.images.get(self.image)
LOGGER.info("\t+ TGI image found locally")
except docker.errors.ImageNotFound:
LOGGER.info("\t+ TGI image not found, downloading it (this may take a while)")
self.docker_client.images.pull(f"{self.image}:{self.version}")
LOGGER.info("\t+ TGI image not found locally, pulling from Docker Hub")
self.docker_client.images.pull(self.image)

env = {}
if os.environ.get("HUGGING_FACE_HUB_TOKEN", None) is not None:
Expand All @@ -83,79 +94,75 @@ def __init__(
self.command.extend(["--num-shard", str(self.num_shard)])
if self.quantize is not None:
self.command.extend(["--quantize", self.quantize])
if self.torch_dtype is not None:
self.command.extend(["--dtype", self.torch_dtype])
if self.dtype is not None:
self.command.extend(["--dtype", self.dtype])

if self.trust_remote_code:
self.command.append("--trust-remote-code")
if self.disable_custom_kernels:
self.command.append("--disable-custom-kernels")

try:
LOGGER.info("\t+ Checking if GPU is available")
if os.environ.get("CUDA_VISIBLE_DEVICES") is not None:
LOGGER.info("\t+ Using specified `CUDA_VISIBLE_DEVICES` to set GPU(s)")
device_ids = os.environ.get("CUDA_VISIBLE_DEVICES")
else:
LOGGER.info("\t+ Using nvidia-smi to get available GPU(s) (if any)")
device_ids = get_nvidia_gpu_devices()

LOGGER.info(f"\t+ Using GPU(s): {device_ids}")
self.device_requests = [docker.types.DeviceRequest(device_ids=[device_ids], capabilities=[["gpu"]])]
except Exception:
LOGGER.info("\t+ No GPU detected")
if gpus is not None and isinstance(gpus, str) and gpus == "all":
LOGGER.info("\t+ Using all GPU(s)")
self.device_requests = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])]
elif gpus is not None and isinstance(gpus, int):
LOGGER.info(f"\t+ Using {gpus} GPU(s)")
self.device_requests = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])]
elif gpus is not None and isinstance(gpus, str) and re.match(r"^\d+(,\d+)*$", gpus):
LOGGER.info(f"\t+ Using GPU(s) {gpus}")
self.device_requests = [docker.types.DeviceRequest(device_ids=[gpus], capabilities=[["gpu"]])]
else:
LOGGER.info("\t+ Not using any GPU(s)")
self.device_requests = None

self.tgi_container = self.docker_client.containers.run(
if devices is not None and isinstance(devices, list) and all(Path(d).exists() for d in devices):
LOGGER.info(f"\t+ Using custom device(s) {devices}")
self.devices = devices
else:
LOGGER.info("\t+ Not using any custom device(s)")
self.devices = None

self.closed = False
self.docker_container = self.docker_client.containers.run(
image=self.image,
command=self.command,
image=f"{self.image}:{self.version}",
volumes={self.volume: {"bind": "/data", "mode": "rw"}},
ports={"80/tcp": (self.address, self.port)},
device_requests=self.device_requests,
shm_size=self.shm_size,
devices=self.devices,
environment=env,
auto_remove=True, # this is so cool
detach=True,
)

LOGGER.info("\t+ Waiting for TGI server to be ready")
with timeout(60):
for line in self.tgi_container.logs(stream=True):
tgi_log = line.decode("utf-8").strip()
if "Connected" in tgi_log:
break
elif "Error" in tgi_log:
raise Exception(f"\t {tgi_log}")

for line in self.docker_container.logs(stream=True):
tgi_log = line.decode("utf-8").strip()
if "Connected" in tgi_log:
LOGGER.info(f"\t {tgi_log}")
break
elif "Error" in tgi_log:
LOGGER.info(f"\t {tgi_log}")
raise Exception("TGI server failed to start")
else:
LOGGER.info(f"\t {tgi_log}")

LOGGER.info("\t+ Conecting to TGI server")
self.url = f"http://{self.address}:{self.port}"
with timeout(60):
while True:
try:
self.tgi_client = InferenceClient(model=self.url)
self.tgi_client.text_generation("Hello world!")
LOGGER.info(f"\t+ Connected to TGI server at {self.url}")
break
except Exception:
LOGGER.info("\t+ TGI server not ready, retrying in 1 second")
time.sleep(1)

def close(self) -> None:
if hasattr(self, "tgi_container"):
LOGGER.info("\t+ Stoping TGI container")
self.tgi_container.stop()
LOGGER.info("\t+ Waiting for TGI container to stop")
self.tgi_container.wait()
start_time = time.time()
while time.time() - start_time < CONNECTION_TIMEOUT:
try:
self.tgi_client = InferenceClient(model=self.url)
self.tgi_client.text_generation("Hello world!")
LOGGER.info("\t+ Connected to TGI server successfully")
return
except Exception:
LOGGER.info("\t+ TGI server is not ready yet, waiting 1 second")
time.sleep(1)

if hasattr(self, "docker_client"):
LOGGER.info("\t+ Closing docker client")
self.docker_client.close()

def __call__(
self, prompt: Union[str, List[str]], **kwargs
) -> Union[TextGenerationResponse, List[TextGenerationResponse]]:
return self.generate(prompt, **kwargs)
raise Exception("TGI server took too long to start (60 seconds)")

def generate(
self, prompt: Union[str, List[str]], **kwargs
Expand All @@ -174,3 +181,25 @@ def generate(
for i in range(len(prompt)):
output.append(futures[i].result())
return output

def close(self) -> None:
if not self.closed:
if hasattr(self, "docker_container"):
LOGGER.info("\t+ Stoping docker container")
self.docker_container.stop()
self.docker_container.wait()

if hasattr(self, "docker_client"):
LOGGER.info("\t+ Closing docker client")
self.docker_client.close()

self.closed = True

def __call__(
self, prompt: Union[str, List[str]], **kwargs
) -> Union[TextGenerationResponse, List[TextGenerationResponse]]:
return self.generate(prompt, **kwargs)

def __del__(self) -> None:
if not self.closed:
self.close()
53 changes: 11 additions & 42 deletions py_tgi/utils.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,17 @@
import signal
import subprocess
from contextlib import contextmanager


def get_nvidia_gpu_devices() -> str:
nvidia_smi = (
subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=index,gpu_name,compute_cap",
"--format=csv",
],
)
.decode("utf-8")
.strip()
.split("\n")[1:]
)
device = [
{
"id": int(gpu.split(", ")[0]),
"name": gpu.split(", ")[1],
"compute_cap": gpu.split(", ")[2],
}
for gpu in nvidia_smi
]
device_ids = [gpu["id"] for gpu in device if "Display" not in gpu["name"]]
device_ids = ",".join([str(device_id) for device_id in device_ids])

return device_ids


@contextmanager
def timeout(time: int):
"""
Timeout context manager. Raises TimeoutError if the code inside the context manager takes longer than `time` seconds to execute.
"""

def signal_handler(signum, frame):
raise TimeoutError("Timed out")
def is_rocm_system() -> bool:
try:
subprocess.check_output(["rocm-smi"])
return True
except FileNotFoundError:
return False

signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(time)

def is_nvidia_system() -> bool:
try:
yield
finally:
signal.alarm(0)
subprocess.check_output(["nvidia-smi"])
return True
except FileNotFoundError:
return False
Loading

0 comments on commit 840eefe

Please sign in to comment.