diff --git a/src/codeflare_sdk/job/jobs.py b/src/codeflare_sdk/job/jobs.py index c3814971a..71161f132 100644 --- a/src/codeflare_sdk/job/jobs.py +++ b/src/codeflare_sdk/job/jobs.py @@ -22,6 +22,8 @@ from torchx.schedulers.ray_scheduler import RayScheduler from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo +from ..utils.generate_yaml import update_pip_requirements + if TYPE_CHECKING: from ..cluster.cluster import Cluster @@ -90,6 +92,12 @@ def __init__( ) self.image = image self.workspace = workspace + if "PIP_TRUSTED_HOST" in self.env or "PIP_INDEX_URL" in self.env: + update_pip_requirements(self) + else: + self.env.setdefault("PIP_TRUSTED_HOST", "pypi.org") + self.env.setdefault("PIP_INDEX_URL", "https://pypi.org/simple") + update_pip_requirements(self) def _dry_run(self, cluster: "Cluster"): j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index a6aae3082..05d238917 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -21,6 +21,7 @@ import sys import os import argparse +from pathlib import Path import uuid from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling @@ -689,3 +690,38 @@ def generate_appwrapper( else: write_user_appwrapper(user_yaml, outfile) return outfile + + +def update_pip_requirements(self): + pip_trusted_host = self.env.get("PIP_TRUSTED_HOST") + pip_index_url = self.env.get("PIP_INDEX_URL") + requirements_path = Path("requirements.txt") + + if requirements_path.exists(): + with requirements_path.open("r") as file: + requirements = file.readlines() + + # Check and replace or add --trusted-host and --index-url + trusted_host = f"--trusted-host {pip_trusted_host}\n" + index_url = f"--index-url {pip_index_url}\n" + modified_requirements = [] + + for line in requirements: + if line.startswith("--trusted-host"): + modified_requirements.append(trusted_host) + trusted_host = None + elif line.startswith("--index-url"): + modified_requirements.append(index_url) + index_url = None + else: + modified_requirements.append(line) + + # Append the lines if they were not replaced + if trusted_host: + modified_requirements.insert(0, trusted_host) + if index_url: + modified_requirements.insert(0, index_url) + + # Write back the modified requirements + with requirements_path.open("w") as file: + file.writelines(modified_requirements) diff --git a/tests/unit_test.py b/tests/unit_test.py index 0d9403e69..094c227f9 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -2091,7 +2091,11 @@ def test_DDPJobDefinition_creation(): assert ddp.memMB == 1024 assert ddp.h == None assert ddp.j == "2x1" - assert ddp.env == {"test": "test"} + assert ddp.env == { + "PIP_TRUSTED_HOST": "pypi.org", + "PIP_INDEX_URL": "https://pypi.org/simple", + "test": "test", + } assert ddp.max_retries == 0 assert ddp.mounts == [] assert ddp.rdzv_port == 29500