From d1980c7e842b0c6d4ddc2e730b0e6654a9ce59ac Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 16 Dec 2024 23:56:13 +0100 Subject: [PATCH] [sharktank] Remove 'torch' from deps and warn instead (#706) Instead of enforcing the installation for 'torch' as a dependency, error if 'torch' cannot be imported and point the user to how to install. Co-authored-by: Scott Todd --- docs/user_guide.md | 13 ++++++++++--- sharktank/requirements.txt | 4 ---- sharktank/sharktank/__init__.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/docs/user_guide.md b/docs/user_guide.md index a0415eb63..d3ef192e0 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -34,13 +34,20 @@ Setup your Python environment with the following commands: # Set up a virtual environment to isolate packages from other envs. python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate +``` + +## Install SHARK and its dependencies + +First install a torch version that fulfills your needs: -# Optional: faster installation of torch with just CPU support. -# See other options at https://pytorch.org/get-started/locally/ +```bash +# Fast installation of torch with just CPU support. pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ``` -## Install SHARK and its dependencies +For other options, see https://pytorch.org/get-started/locally/. + +Next install shark-ai: ```bash pip install shark-ai[apps] diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 8a2a8ea3b..70780c346 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -9,10 +9,6 @@ huggingface-hub==0.22.2 transformers==4.40.0 datasets -# It is expected that you have installed a PyTorch version/variant specific -# to your needs, so we only include a minimum version spec. -torch>=2.3.0 - # Serving deps. fastapi>=0.112.2 uvicorn>=0.30.6 diff --git a/sharktank/sharktank/__init__.py b/sharktank/sharktank/__init__.py index a85ba359d..c0eb89810 100644 --- a/sharktank/sharktank/__init__.py +++ b/sharktank/sharktank/__init__.py @@ -3,3 +3,13 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import importlib.util + +msg = """No module named 'torch'. Follow https://pytorch.org/get-started/locally/#start-locally to install 'torch'. +For example, on Linux to install with CPU support run: + pip3 install torch --index-url https://download.pytorch.org/whl/cpu +""" + +if spec := importlib.util.find_spec("torch") is None: + raise ModuleNotFoundError(msg)