Skip to content

Commit

Permalink
remove tenacity dependency, implement simple retry logic instead
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Feb 3, 2025
1 parent 72329ba commit 5157d80
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
31 changes: 22 additions & 9 deletions exo/download/new_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import shutil
import tempfile
import hashlib
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type

def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
Expand Down Expand Up @@ -74,13 +73,20 @@ async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> Li
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
file_list = await fetch_file_list(repo_id, revision)
file_list = await fetch_file_list_with_retry(repo_id, revision)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
return file_list

@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
n_attempts = 30
for attempt in range(n_attempts):
try: return await _fetch_file_list(repo_id, revision, path)
except Exception as e:
if attempt == n_attempts - 1: raise e
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))

async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url

Expand All @@ -94,7 +100,7 @@ async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "")
if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory":
subfiles = await fetch_file_list(repo_id, revision, item["path"])
subfiles = await _fetch_file_list(repo_id, revision, item["path"])
files.extend(subfiles)
return files
else:
Expand Down Expand Up @@ -122,8 +128,15 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
return content_length, etag

@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
n_attempts = 30
for attempt in range(n_attempts):
try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))

async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
if await aios.path.exists(target_dir/path): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
length, remote_hash = await file_meta(repo_id, revision, path)
Expand Down Expand Up @@ -163,7 +176,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog

async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
return index_data.get("weight_map")

Expand Down Expand Up @@ -214,7 +227,7 @@ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file):
async with semaphore:
await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
on_progress.trigger_all(shard, final_repo_progress)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"requests==2.32.3",
"rich==13.7.1",
"scapy==2.6.1",
"tenacity==9.0.0",
"tqdm==4.66.4",
"transformers==4.46.3",
"uuid==1.30",
Expand Down

0 comments on commit 5157d80

Please sign in to comment.