From 5157d80a461eab0d93afa3958b1431174913af72 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 3 Feb 2025 21:56:38 +0000 Subject: [PATCH] remove tenacity dependency, implement simple retry logic instead --- exo/download/new_shard_download.py | 31 +++++++++++++++++++++--------- setup.py | 1 - 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index 5807747ce..21b3f0464 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -20,7 +20,6 @@ import shutil import tempfile import hashlib -from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type def exo_home() -> Path: return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) @@ -74,13 +73,20 @@ async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> Li cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json" if await aios.path.exists(cache_file): async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read()) - file_list = await fetch_file_list(repo_id, revision) + file_list = await fetch_file_list_with_retry(repo_id, revision) await aios.makedirs(cache_file.parent, exist_ok=True) async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list)) return file_list -@retry(stop=stop_after_attempt(30), wait=wait_fixed(1)) -async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]: +async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]: + n_attempts = 30 + for attempt in range(n_attempts): + try: return await _fetch_file_list(repo_id, revision, path) + except Exception as e: + if attempt == n_attempts - 1: raise e + await asyncio.sleep(min(8, 0.1 * (2 ** attempt))) + +async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]: api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" url = f"{api_url}/{path}" if path else api_url @@ -94,7 +100,7 @@ async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") if item["type"] == "file": files.append({"path": item["path"], "size": item["size"]}) elif item["type"] == "directory": - subfiles = await fetch_file_list(repo_id, revision, item["path"]) + subfiles = await _fetch_file_list(repo_id, revision, item["path"]) files.extend(subfiles) return files else: @@ -122,8 +128,15 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]: if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1] return content_length, etag -@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError)) -async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: +async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: + n_attempts = 30 + for attempt in range(n_attempts): + try: return await _download_file(repo_id, revision, path, target_dir, on_progress) + except Exception as e: + if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e + await asyncio.sleep(min(8, 0.1 * (2 ** attempt))) + +async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: if await aios.path.exists(target_dir/path): return target_dir/path await aios.makedirs((target_dir/path).parent, exist_ok=True) length, remote_hash = await file_meta(repo_id, revision, path) @@ -163,7 +176,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--") - index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir) + index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir) async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) return index_data.get("weight_map") @@ -214,7 +227,7 @@ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): semaphore = asyncio.Semaphore(max_parallel_downloads) async def download_with_semaphore(file): async with semaphore: - await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes)) + await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes)) if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list]) final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time) on_progress.trigger_all(shard, final_repo_progress) diff --git a/setup.py b/setup.py index a158d4430..de242f544 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ "requests==2.32.3", "rich==13.7.1", "scapy==2.6.1", - "tenacity==9.0.0", "tqdm==4.66.4", "transformers==4.46.3", "uuid==1.30",