Skip to content

Commit

Permalink
Merge pull request #664 from exo-explore/resumedownload
Browse files Browse the repository at this point in the history
resumable downloads with integrity checks
  • Loading branch information
AlexCheema authored Feb 1, 2025
2 parents 7a75fb0 + 2c0d17c commit 9a1f0a8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 46 deletions.
5 changes: 2 additions & 3 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,8 @@ async def handle_model_support(self, request):
try:
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
await response.prepare(request)
downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
for (path, d) in downloads:
model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
return response
Expand Down
117 changes: 82 additions & 35 deletions exo/download/new_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
import aiohttp
import aiofiles
from urllib.parse import urljoin
from typing import Callable, Union, Tuple, Dict, List
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
import time
from datetime import timedelta
import asyncio
import json
import traceback
import shutil
import tempfile
from tenacity import retry, stop_after_attempt, wait_exponential
import hashlib
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type

def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
Expand Down Expand Up @@ -69,8 +70,17 @@ async def seed_models(seed_dir: Union[str, Path]):
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
async def fetch_file_list(repo_id, revision, path=""):
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
file_list = await fetch_file_list(repo_id, revision)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
return file_list

@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url

Expand All @@ -90,29 +100,55 @@ async def fetch_file_list(repo_id, revision, path=""):
else:
raise Exception(f"Failed to fetch file list: {response.status}")

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
if type == "sha1":
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
hash.update(header)
async with aiofiles.open(path, 'rb') as f:
while chunk := await f.read(1024 * 1024):
hash.update(chunk)
return hash.hexdigest()

async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.head(url, headers=headers) as r:
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
return content_length, etag

@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
temp_file_name = None
try:
if (target_dir/path).exists(): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, path)
if await aios.path.exists(target_dir/path): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
length, remote_hash = await file_meta(repo_id, revision, path)
partial_path = target_dir/f"{path}.partial"
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
n_read = resume_byte_pos or 0
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
length = int(r.headers.get('content-length', 0))
n_read = 0
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
temp_file_name = temp_file.name
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
await aios.rename(temp_file.name, target_dir/path)
return target_dir/path
finally:
if temp_file_name: # attempt to delete tmp file if it still exists
try: await aios.unlink(temp_file_name)
except: pass
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)

final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
integrity = final_hash == remote_hash
if not integrity:
try: await aios.remove(partial_path)
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
await aios.rename(partial_path, target_dir/path)
return target_dir/path


def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
Expand All @@ -126,7 +162,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)

async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
return index_data.get("weight_map")
Expand All @@ -140,6 +176,12 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
if DEBUG >= 1: traceback.print_exc()
return ["*"]

async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path): return (await aios.stat(path)).st_size
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
return 0

async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
repo_id = get_repo(shard.model_id, inference_engine_classname)
Expand All @@ -154,7 +196,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")

all_start_time = time.time()
file_list = await fetch_file_list(repo_id, revision)
file_list = await fetch_file_list_with_cache(repo_id, revision)
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
file_progress: Dict[str, RepoFileProgressEvent] = {}
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
Expand All @@ -166,7 +208,7 @@ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
for file in filtered_file_list:
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())

semaphore = asyncio.Semaphore(max_parallel_downloads)
Expand Down Expand Up @@ -199,8 +241,9 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
finally:
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]

async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status

class CachedShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
Expand All @@ -220,8 +263,9 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
self.cache[(inference_engine_name, shard)] = target_dir
return target_dir

async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name)
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status

class NewShardDownloader(ShardDownloader):
def __init__(self):
Expand All @@ -235,9 +279,12 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
return target_dir

async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
if DEBUG >= 6: print("Downloaded shards:", downloads)
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
return [d for d in downloads if not isinstance(d, Exception)]
tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
for task in asyncio.as_completed(tasks):
try:
path, progress = await task
yield (path, progress)
except Exception as e:
print("Error downloading shard:", e)
8 changes: 4 additions & 4 deletions exo/download/shard_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple, Dict, AsyncIterator
from pathlib import Path
from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent
Expand Down Expand Up @@ -27,7 +27,7 @@ def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent
pass

@abstractmethod
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
"""Get the download status of shards.
Returns:
Expand All @@ -45,5 +45,5 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return AsyncCallbackSystem()

async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
return None
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
if False: yield
7 changes: 3 additions & 4 deletions exo/download/test_new_shard_download.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from exo.download.new_shard_download import download_shard, NewShardDownloader
from exo.download.new_shard_download import NewShardDownloader
from exo.inference.shard import Shard
from pathlib import Path
import asyncio

async def test_new_shard_download():
shard_downloader = NewShardDownloader()
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
print("Shard download status:", path, shard_status)

if __name__ == "__main__":
asyncio.run(test_new_shard_download())
Expand Down

0 comments on commit 9a1f0a8

Please sign in to comment.