diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 9b9da3a73..51e8becbe 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -275,6 +275,16 @@ def get_modules(args, model_config, flagfile, td_spec): return vmfbs, params +def is_port_valid(port): + max_port = 65535 + if port < 1 or port > max_port: + print( + f"Error: Invalid port specified ({port}), expected a value between 1 and {max_port}" + ) + return False + return True + + def main(argv, log_config=UVICORN_LOG_CONFIG): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -403,6 +413,9 @@ def main(argv, log_config=UVICORN_LOG_CONFIG): help="Use tunings for attention and matmul ops. 0 to disable.", ) args = parser.parse_args(argv) + if not is_port_valid(args.port): + exit(3) + if not args.artifacts_dir: home = Path.home() artdir = home / ".cache" / "shark" diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py index 0d88a59c7..46b204067 100644 --- a/shortfin/python/shortfin_apps/sd/simple_client.py +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -177,7 +177,15 @@ async def async_range(count): await asyncio.sleep(0.0) -def check_health(url): +def is_connected(host, port): + max_port = 65535 + if port < 1 or port > max_port: + print( + f"Error: Invalid port specified ({port}), expected a value between 1 and {max_port}" + ) + return False + + url = f"{host}:{port}" ready = False print("Waiting for server.", end=None) while not ready: @@ -192,6 +200,8 @@ def check_health(url): time.sleep(2) print(".", end=None) + return True + def main(): p = argparse.ArgumentParser() @@ -222,7 +232,7 @@ def main(): p.add_argument( "--host", type=str, default="http://0.0.0.0", help="Server host address." ) - p.add_argument("--port", type=str, default="8000", help="Server port") + p.add_argument("--port", type=int, default=8000, help="Server port") p.add_argument( "--steps", type=int, @@ -235,7 +245,10 @@ def main(): help="Start as an example CLI client instead of sending static requests.", ) args = p.parse_args() - check_health(f"{args.host}:{args.port}") + + if not is_connected(args.host, args.port): + exit(3) + if args.interactive: asyncio.run(interactive(args)) else: