Skip to content

Commit 84602c8

Browse files
author
vmpuri
committed
Delete models from old location for huggingface download
1 parent 654dbec commit 84602c8

File tree

1 file changed

+76
-28
lines changed

1 file changed

+76
-28
lines changed

torchchat/cli/download.py

+76-28
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from pathlib import Path
1111
from typing import Optional
1212

13-
from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune
13+
from torchchat.cli.convert_hf_checkpoint import (
14+
convert_hf_checkpoint,
15+
convert_hf_checkpoint_to_tune,
16+
)
1417
from torchchat.model_config.model_config import (
1518
load_model_configs,
1619
ModelConfig,
@@ -20,20 +23,46 @@
2023

2124
# By default, download models from HuggingFace to the Hugginface hub directory.
2225
# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory.
23-
HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub"))))
26+
HUGGINGFACE_HOME_PATH = Path(
27+
os.environ.get(
28+
"HF_HOME",
29+
os.environ.get(
30+
"HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub")
31+
),
32+
)
33+
)
2434

2535
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) is None:
2636
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
2737

28-
def _download_hf_snapshot(
29-
model_config: ModelConfig, hf_token: Optional[str]
30-
):
38+
39+
# Previously, all models were stored in the torchchat models directory (by default ~/.torchchat/model-cache)
40+
# For Hugging Face models, we now store them in the HuggingFace cache directory.
41+
# This function will delete all model artifacts in the old directory for each model with the Hugging Face distribution path.
42+
def _cleanup_hf_models_from_torchchat_dir(models_dir: Path):
43+
for model_config in load_model_configs().values():
44+
if (
45+
model_config.distribution_channel
46+
== ModelDistributionChannel.HuggingFaceSnapshot
47+
):
48+
if os.path.exists(models_dir / model_config.name):
49+
print(
50+
f"Cleaning up old model artifacts in {models_dir / model_config.name}. New artifacts will be downloaded to {HUGGINGFACE_HOME_PATH}"
51+
)
52+
shutil.rmtree(models_dir / model_config.name)
53+
54+
55+
def _download_hf_snapshot(model_config: ModelConfig, hf_token: Optional[str]):
3156
from huggingface_hub import model_info, snapshot_download
3257
from requests.exceptions import HTTPError
3358

3459
# Download and store the HF model artifacts.
3560
model_dir = get_model_dir(model_config, None)
36-
print(f"Downloading {model_config.name} from Hugging Face to {model_dir}", file=sys.stderr, flush=True)
61+
print(
62+
f"Downloading {model_config.name} from Hugging Face to {model_dir}",
63+
file=sys.stderr,
64+
flush=True,
65+
)
3766
try:
3867
# Fetch the info about the model's repo
3968
model_info = model_info(model_config.distribution_path, token=hf_token)
@@ -81,14 +110,17 @@ def _download_hf_snapshot(
81110
else:
82111
raise e
83112

84-
# Update the model dir to include the snapshot we just downloaded.
113+
# Update the model dir to include the snapshot we just downloaded.
85114
model_dir = get_model_dir(model_config, None)
86115
print("Model downloaded to", model_dir)
87116

88117
# Convert the Multimodal Llama model to the torchtune format.
89-
if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}:
118+
if model_config.name in {
119+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
120+
"meta-llama/Llama-3.2-11B-Vision",
121+
}:
90122
print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr)
91-
convert_hf_checkpoint_to_tune( model_dir=model_dir, model_name=model_config.name)
123+
convert_hf_checkpoint_to_tune(model_dir=model_dir, model_name=model_config.name)
92124

93125
else:
94126
# Convert the model to the torchchat format.
@@ -108,32 +140,44 @@ def _download_direct(
108140
print(f"Downloading {url}...", file=sys.stderr)
109141
urllib.request.urlretrieve(url, str(local_path.absolute()))
110142

143+
111144
def _get_hf_artifact_dir(model_config: ModelConfig) -> Path:
112145
"""
113146
Returns the directory where the model artifacts are stored.
114-
147+
115148
This is the root folder with blobs, refs and snapshots
116149
"""
117-
assert(model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot)
118-
return HUGGINGFACE_HOME_PATH / f"models--{model_config.distribution_path.replace('/', '--')}"
150+
assert (
151+
model_config.distribution_channel
152+
== ModelDistributionChannel.HuggingFaceSnapshot
153+
)
154+
return (
155+
HUGGINGFACE_HOME_PATH
156+
/ f"models--{model_config.distribution_path.replace('/', '--')}"
157+
)
119158

120159

121160
def get_model_dir(model_config: ModelConfig, models_dir: Optional[Path]) -> Path:
122161
"""
123-
Returns the directory where the model artifacts are stored.
124-
For HuggingFace snapshots, this is the HuggingFace cache directory.
125-
For all other distribution channels, we use the models_dir.
126-
127-
For CLI usage, pass in args.model_directory.
162+
Returns the directory where the model artifacts are expected to be stored.
163+
For Hugging Face artifacts, this will be the location of the "main" snapshot if it exists, or the expected model directory otherwise.
164+
For all other distribution channels, we use the models_dir.
165+
166+
For CLI usage, pass in args.model_directory.
128167
"""
129-
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
130-
artifact_dir = _get_hf_artifact_dir(model_config)
131-
168+
if (
169+
model_config.distribution_channel
170+
== ModelDistributionChannel.HuggingFaceSnapshot
171+
):
172+
artifact_dir = _get_hf_artifact_dir(model_config)
173+
132174
# If these paths doesn't exist, it means the model hasn't been downloaded yet.
133-
if not os.path.isdir(artifact_dir) and not os.path.isdir(artifact_dir / "snapshots"):
175+
if not os.path.isdir(artifact_dir) and not os.path.isdir(
176+
artifact_dir / "snapshots"
177+
):
134178
return artifact_dir
135179
snapshot = open(artifact_dir / "refs" / "main", "r").read().strip()
136-
return artifact_dir / "snapshots" / snapshot
180+
return artifact_dir / "snapshots" / snapshot
137181
else:
138182
return models_dir / model_config.name
139183

@@ -164,9 +208,7 @@ def download_and_convert(
164208
os.makedirs(temp_dir, exist_ok=True)
165209

166210
try:
167-
if (
168-
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
169-
):
211+
if model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
170212
_download_direct(model_config, temp_dir)
171213
else:
172214
raise RuntimeError(
@@ -187,7 +229,7 @@ def download_and_convert(
187229

188230
def is_model_downloaded(model: str, models_dir: Path) -> bool:
189231
model_config = resolve_model_config(model)
190-
232+
191233
# Check if the model directory exists and is not empty.
192234
model_dir = get_model_dir(model_config, models_dir)
193235
return os.path.isdir(model_dir) and os.listdir(model_dir)
@@ -242,7 +284,10 @@ def remove_main(args) -> None:
242284
if not os.path.isdir(model_dir):
243285
print(f"Model {args.model} has no downloaded artifacts in {model_dir}.")
244286
return
245-
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
287+
if (
288+
model_config.distribution_channel
289+
== ModelDistributionChannel.HuggingFaceSnapshot
290+
):
246291
# For HuggingFace models, we need to remove the entire root directory.
247292
model_dir = _get_hf_artifact_dir(model_config)
248293

@@ -265,12 +310,15 @@ def where_main(args) -> None:
265310
model_dir = get_model_dir(model_config, args.model_directory)
266311

267312
if not os.path.isdir(model_dir):
268-
raise RuntimeError(f"Model {args.model} has no downloaded artifacts in {model_dir}.")
313+
raise RuntimeError(
314+
f"Model {args.model} has no downloaded artifacts in {model_dir}."
315+
)
269316

270317
print(str(os.path.abspath(model_dir)))
271318
exit(0)
272319

273320

274321
# Subcommand to download model artifacts.
275322
def download_main(args) -> None:
323+
_cleanup_hf_models_from_torchchat_dir(args.model_directory)
276324
download_and_convert(args.model, args.model_directory, args.hf_token)

0 commit comments

Comments
 (0)