Skip to content

Commit

Permalink
added connection timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 20, 2024
1 parent 8a36542 commit f4f4986
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions py_tgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .utils import is_rocm_system

basicConfig(level=INFO)

CONNECTION_TIMEOUT = 60
LOGGER = getLogger("tgi")
HF_CACHE_DIR = f"{os.path.expanduser('~')}/.cache/huggingface/hub"

Expand Down Expand Up @@ -65,7 +67,7 @@ def __init__(
self.disable_custom_kernels = disable_custom_kernels

if is_rocm_system() and "-rocm" not in self.image:
LOGGER.warning("ROCm system detected, but the image does not contain 'rocm' in its name. Adding it.")
LOGGER.warning("ROCm system detected, but the image does not contain '-rocm'. Adding it.")
self.image = self.image + "-rocm"

LOGGER.info("\t+ Starting Docker client")
Expand Down Expand Up @@ -130,32 +132,38 @@ def __init__(
shm_size=self.shm_size,
devices=self.devices,
environment=env,
auto_remove=True, # this is so cool
detach=True,
)

LOGGER.info("\t+ Waiting for TGI server to be ready")
for line in self.docker_container.logs(stream=True):
tgi_log = line.decode("utf-8").strip()
if "Connected" in tgi_log:
LOGGER.info(f"\t {tgi_log}")
break
elif "Error" in tgi_log:
raise Exception(f"TGI server failed to start: {tgi_log}")

LOGGER.info(f"\t {tgi_log}")
LOGGER.info(f"\t {tgi_log}")
raise Exception("TGI server failed to start")
else:
LOGGER.info(f"\t {tgi_log}")

LOGGER.info("\t+ Conecting to TGI server")
self.url = f"http://{self.address}:{self.port}"

while True:
start_time = time.time()
while time.time() - start_time < CONNECTION_TIMEOUT:
try:
self.tgi_client = InferenceClient(model=self.url)
self.tgi_client.text_generation("Hello world!")
LOGGER.info("\t+ Connected to TGI server successfully")
break
return
except Exception:
LOGGER.info("\t+ TGI server not ready, retrying in 1 second")
LOGGER.info("\t+ TGI server is not ready yet, waiting 1 second")
time.sleep(1)

raise Exception("TGI server took too long to start (60 seconds)")

def generate(
self, prompt: Union[str, List[str]], **kwargs
) -> Union[TextGenerationResponse, List[TextGenerationResponse]]:
Expand Down

0 comments on commit f4f4986

Please sign in to comment.