10
10
from pathlib import Path
11
11
from typing import Optional
12
12
13
- from torchchat .cli .convert_hf_checkpoint import convert_hf_checkpoint , convert_hf_checkpoint_to_tune
13
+ from torchchat .cli .convert_hf_checkpoint import (
14
+ convert_hf_checkpoint ,
15
+ convert_hf_checkpoint_to_tune ,
16
+ )
14
17
from torchchat .model_config .model_config import (
15
18
load_model_configs ,
16
19
ModelConfig ,
20
23
21
24
# By default, download models from HuggingFace to the Hugginface hub directory.
22
25
# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory.
23
- HUGGINGFACE_HOME_PATH = Path (os .environ .get ("HF_HOME" , os .environ .get ("HUGGINGFACE_HUB_CACHE" , os .path .expanduser ("~/.cache/huggingface/hub" ))))
26
+ HUGGINGFACE_HOME_PATH = Path (
27
+ os .environ .get (
28
+ "HF_HOME" ,
29
+ os .environ .get (
30
+ "HUGGINGFACE_HUB_CACHE" , os .path .expanduser ("~/.cache/huggingface/hub" )
31
+ ),
32
+ )
33
+ )
24
34
25
35
if os .environ .get ("HF_HUB_ENABLE_HF_TRANSFER" , None ) is None :
26
36
os .environ ["HF_HUB_ENABLE_HF_TRANSFER" ] = "1"
27
37
28
- def _download_hf_snapshot (
29
- model_config : ModelConfig , hf_token : Optional [str ]
30
- ):
38
+
39
+ # Previously, all models were stored in the torchchat models directory (by default ~/.torchchat/model-cache)
40
+ # For Hugging Face models, we now store them in the HuggingFace cache directory.
41
+ # This function will delete all model artifacts in the old directory for each model with the Hugging Face distribution path.
42
+ def _cleanup_hf_models_from_torchchat_dir (models_dir : Path ):
43
+ for model_config in load_model_configs ().values ():
44
+ if (
45
+ model_config .distribution_channel
46
+ == ModelDistributionChannel .HuggingFaceSnapshot
47
+ ):
48
+ if os .path .exists (models_dir / model_config .name ):
49
+ print (
50
+ f"Cleaning up old model artifacts in { models_dir / model_config .name } . New artifacts will be downloaded to { HUGGINGFACE_HOME_PATH } "
51
+ )
52
+ shutil .rmtree (models_dir / model_config .name )
53
+
54
+
55
+ def _download_hf_snapshot (model_config : ModelConfig , hf_token : Optional [str ]):
31
56
from huggingface_hub import model_info , snapshot_download
32
57
from requests .exceptions import HTTPError
33
58
34
59
# Download and store the HF model artifacts.
35
60
model_dir = get_model_dir (model_config , None )
36
- print (f"Downloading { model_config .name } from Hugging Face to { model_dir } " , file = sys .stderr , flush = True )
61
+ print (
62
+ f"Downloading { model_config .name } from Hugging Face to { model_dir } " ,
63
+ file = sys .stderr ,
64
+ flush = True ,
65
+ )
37
66
try :
38
67
# Fetch the info about the model's repo
39
68
model_info = model_info (model_config .distribution_path , token = hf_token )
@@ -81,14 +110,17 @@ def _download_hf_snapshot(
81
110
else :
82
111
raise e
83
112
84
- # Update the model dir to include the snapshot we just downloaded.
113
+ # Update the model dir to include the snapshot we just downloaded.
85
114
model_dir = get_model_dir (model_config , None )
86
115
print ("Model downloaded to" , model_dir )
87
116
88
117
# Convert the Multimodal Llama model to the torchtune format.
89
- if model_config .name in {"meta-llama/Llama-3.2-11B-Vision-Instruct" , "meta-llama/Llama-3.2-11B-Vision" }:
118
+ if model_config .name in {
119
+ "meta-llama/Llama-3.2-11B-Vision-Instruct" ,
120
+ "meta-llama/Llama-3.2-11B-Vision" ,
121
+ }:
90
122
print (f"Converting { model_config .name } to torchtune format..." , file = sys .stderr )
91
- convert_hf_checkpoint_to_tune ( model_dir = model_dir , model_name = model_config .name )
123
+ convert_hf_checkpoint_to_tune (model_dir = model_dir , model_name = model_config .name )
92
124
93
125
else :
94
126
# Convert the model to the torchchat format.
@@ -108,32 +140,44 @@ def _download_direct(
108
140
print (f"Downloading { url } ..." , file = sys .stderr )
109
141
urllib .request .urlretrieve (url , str (local_path .absolute ()))
110
142
143
+
111
144
def _get_hf_artifact_dir (model_config : ModelConfig ) -> Path :
112
145
"""
113
146
Returns the directory where the model artifacts are stored.
114
-
147
+
115
148
This is the root folder with blobs, refs and snapshots
116
149
"""
117
- assert (model_config .distribution_channel == ModelDistributionChannel .HuggingFaceSnapshot )
118
- return HUGGINGFACE_HOME_PATH / f"models--{ model_config .distribution_path .replace ('/' , '--' )} "
150
+ assert (
151
+ model_config .distribution_channel
152
+ == ModelDistributionChannel .HuggingFaceSnapshot
153
+ )
154
+ return (
155
+ HUGGINGFACE_HOME_PATH
156
+ / f"models--{ model_config .distribution_path .replace ('/' , '--' )} "
157
+ )
119
158
120
159
121
160
def get_model_dir (model_config : ModelConfig , models_dir : Optional [Path ]) -> Path :
122
161
"""
123
- Returns the directory where the model artifacts are stored.
124
- For HuggingFace snapshots , this is the HuggingFace cache directory.
125
- For all other distribution channels, we use the models_dir.
126
-
127
- For CLI usage, pass in args.model_directory.
162
+ Returns the directory where the model artifacts are expected to be stored.
163
+ For Hugging Face artifacts , this will be the location of the "main" snapshot if it exists, or the expected model directory otherwise .
164
+ For all other distribution channels, we use the models_dir.
165
+
166
+ For CLI usage, pass in args.model_directory.
128
167
"""
129
- if model_config .distribution_channel == ModelDistributionChannel .HuggingFaceSnapshot :
130
- artifact_dir = _get_hf_artifact_dir (model_config )
131
-
168
+ if (
169
+ model_config .distribution_channel
170
+ == ModelDistributionChannel .HuggingFaceSnapshot
171
+ ):
172
+ artifact_dir = _get_hf_artifact_dir (model_config )
173
+
132
174
# If these paths doesn't exist, it means the model hasn't been downloaded yet.
133
- if not os .path .isdir (artifact_dir ) and not os .path .isdir (artifact_dir / "snapshots" ):
175
+ if not os .path .isdir (artifact_dir ) and not os .path .isdir (
176
+ artifact_dir / "snapshots"
177
+ ):
134
178
return artifact_dir
135
179
snapshot = open (artifact_dir / "refs" / "main" , "r" ).read ().strip ()
136
- return artifact_dir / "snapshots" / snapshot
180
+ return artifact_dir / "snapshots" / snapshot
137
181
else :
138
182
return models_dir / model_config .name
139
183
@@ -164,9 +208,7 @@ def download_and_convert(
164
208
os .makedirs (temp_dir , exist_ok = True )
165
209
166
210
try :
167
- if (
168
- model_config .distribution_channel == ModelDistributionChannel .DirectDownload
169
- ):
211
+ if model_config .distribution_channel == ModelDistributionChannel .DirectDownload :
170
212
_download_direct (model_config , temp_dir )
171
213
else :
172
214
raise RuntimeError (
@@ -187,7 +229,7 @@ def download_and_convert(
187
229
188
230
def is_model_downloaded (model : str , models_dir : Path ) -> bool :
189
231
model_config = resolve_model_config (model )
190
-
232
+
191
233
# Check if the model directory exists and is not empty.
192
234
model_dir = get_model_dir (model_config , models_dir )
193
235
return os .path .isdir (model_dir ) and os .listdir (model_dir )
@@ -242,7 +284,10 @@ def remove_main(args) -> None:
242
284
if not os .path .isdir (model_dir ):
243
285
print (f"Model { args .model } has no downloaded artifacts in { model_dir } ." )
244
286
return
245
- if model_config .distribution_channel == ModelDistributionChannel .HuggingFaceSnapshot :
287
+ if (
288
+ model_config .distribution_channel
289
+ == ModelDistributionChannel .HuggingFaceSnapshot
290
+ ):
246
291
# For HuggingFace models, we need to remove the entire root directory.
247
292
model_dir = _get_hf_artifact_dir (model_config )
248
293
@@ -265,12 +310,15 @@ def where_main(args) -> None:
265
310
model_dir = get_model_dir (model_config , args .model_directory )
266
311
267
312
if not os .path .isdir (model_dir ):
268
- raise RuntimeError (f"Model { args .model } has no downloaded artifacts in { model_dir } ." )
313
+ raise RuntimeError (
314
+ f"Model { args .model } has no downloaded artifacts in { model_dir } ."
315
+ )
269
316
270
317
print (str (os .path .abspath (model_dir )))
271
318
exit (0 )
272
319
273
320
274
321
# Subcommand to download model artifacts.
275
322
def download_main (args ) -> None :
323
+ _cleanup_hf_models_from_torchchat_dir (args .model_directory )
276
324
download_and_convert (args .model , args .model_directory , args .hf_token )
0 commit comments