Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f045cfe

Browse files
author
vmpuri
committedOct 9, 2024
Download huggingface models to huggingface cache instead of ~/.torchchat
1 parent 6a2a2e8 commit f045cfe

File tree

5 files changed

+89
-84
lines changed

5 files changed

+89
-84
lines changed
 

‎torchchat/cli/builder.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
3232

33+
from torchchat.cli.download import get_model_dir
3334
from torchchat.model_config.model_config import resolve_model_config
3435
from torchchat.utils.build_utils import (
3536
device_sync,
@@ -73,7 +74,7 @@ def __post_init__(self):
7374
or (self.pte_path and Path(self.pte_path).is_file())
7475
):
7576
raise RuntimeError(
76-
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
77+
f"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path {self.checkpoint_path}"
7778
)
7879

7980
if self.dso_path and self.pte_path:
@@ -109,10 +110,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
109110
model_config = resolve_model_config(args.model)
110111

111112
checkpoint_path = (
112-
Path(args.model_directory)
113-
/ model_config.name
113+
get_model_dir(model_config, args.model_directory)
114114
/ model_config.checkpoint_file
115115
)
116+
print(f"Using checkpoint path: {checkpoint_path}")
116117
# The transformers config is keyed on the last section
117118
# of the name/path.
118119
params_table = (
@@ -264,8 +265,7 @@ def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs":
264265
elif args.model: # Using a named, well-known model
265266
model_config = resolve_model_config(args.model)
266267
tokenizer_path = (
267-
Path(args.model_directory)
268-
/ model_config.name
268+
get_model_dir(model_config, args.model_directory)
269269
/ model_config.tokenizer_file
270270
)
271271
elif args.checkpoint_path:

‎torchchat/cli/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _add_jit_downloading_args(parser) -> None:
244244
"--model-directory",
245245
type=Path,
246246
default=default_model_dir,
247-
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
247+
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}. This is overriden by the huggingface cache directory if the model is downloaded from HuggingFace.",
248248
)
249249

250250

‎torchchat/cli/convert_hf_checkpoint.py

+18-56
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import glob
76
import json
87
import os
98
import re
@@ -42,12 +41,7 @@ def convert_hf_checkpoint(
4241
print(f"Model config {config.__dict__}")
4342

4443
# Load the json file containing weight mapping
45-
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
46-
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
47-
if len(model_map_json_matches):
48-
model_map_json = model_map_json_matches[0]
49-
else:
50-
model_map_json = model_dir / "pytorch_model.bin.index.json"
44+
model_map_json = model_dir / "pytorch_model.bin.index.json"
5145

5246
# If there is no weight mapping, check for a consolidated model and
5347
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
@@ -62,9 +56,10 @@ def convert_hf_checkpoint(
6256
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
6357
)
6458
del loaded_result # No longer needed
65-
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
66-
os.rename(consolidated_pth, model_dir / "model.pth")
67-
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
59+
print(f"Symlinking checkpoint to {model_dir / 'model.pth'}.")
60+
consolidated_pth = os.path.realpath(consolidated_pth)
61+
os.symlink(consolidated_pth, model_dir / "model.pth")
62+
os.symlink(tokenizer_pth, model_dir / "tokenizer.model")
6863
print("Done.")
6964
return
7065
else:
@@ -81,17 +76,10 @@ def convert_hf_checkpoint(
8176
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
8277
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
8378
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
84-
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
85-
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
86-
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
87-
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
8879
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
8980
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
9081
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
9182
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
92-
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
93-
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
94-
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
9583
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
9684
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
9785
"model.norm.weight": "norm.weight",
@@ -100,43 +88,19 @@ def convert_hf_checkpoint(
10088
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
10189

10290
def permute(w, n_heads):
91+
dim = config.dim
10392
return (
104-
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
93+
w.view(n_heads, 2, config.head_dim // 2, dim)
10594
.transpose(1, 2)
106-
.reshape(w.shape)
95+
.reshape(config.head_dim * n_heads, dim)
10796
)
10897

10998
merged_result = {}
11099
for file in sorted(bin_files):
111-
112-
# The state_dict can be loaded from either a torch zip file or
113-
# safetensors. We take our best guess from the name and try all
114-
# possibilities
115-
load_pt_mmap = lambda: torch.load(
100+
state_dict = torch.load(
116101
str(file), map_location="cpu", mmap=True, weights_only=True
117102
)
118-
load_pt_no_mmap = lambda: torch.load(
119-
str(file), map_location="cpu", mmap=False, weights_only=True
120-
)
121-
def load_safetensors():
122-
import safetensors.torch
123-
with open(file, "rb") as handle:
124-
return safetensors.torch.load(handle.read())
125-
if "safetensors" in str(file):
126-
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
127-
else:
128-
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]
129-
130-
state_dict = None
131-
for loader in loaders:
132-
try:
133-
state_dict = loader()
134-
break
135-
except Exception:
136-
continue
137-
assert state_dict is not None, f"Unable to load tensors from {file}"
138103
merged_result.update(state_dict)
139-
140104
final_result = {}
141105
for key, value in merged_result.items():
142106
if "layers" in key:
@@ -152,18 +116,16 @@ def load_safetensors():
152116
final_result[new_key] = value
153117

154118
for key in tuple(final_result.keys()):
155-
if "wq.weight" in key or "wq.bias" in key:
156-
wk_key = key.replace("wq", "wk")
157-
wv_key = key.replace("wq", "wv")
119+
if "wq" in key:
158120
q = final_result[key]
159-
k = final_result[wk_key]
160-
v = final_result[wv_key]
121+
k = final_result[key.replace("wq", "wk")]
122+
v = final_result[key.replace("wq", "wv")]
161123
q = permute(q, config.n_heads)
162124
k = permute(k, config.n_local_heads)
163125
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
164126
del final_result[key]
165-
del final_result[wk_key]
166-
del final_result[wv_key]
127+
del final_result[key.replace("wq", "wk")]
128+
del final_result[key.replace("wq", "wv")]
167129
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
168130
torch.save(final_result, model_dir / "model.pth")
169131
print("Done.")
@@ -184,10 +146,10 @@ def convert_hf_checkpoint_to_tune(
184146
consolidated_pth = model_dir / "original" / "consolidated.pth"
185147
tokenizer_pth = model_dir / "original" / "tokenizer.model"
186148
if consolidated_pth.is_file() and tokenizer_pth.is_file():
187-
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
188-
os.rename(consolidated_pth, model_dir / "model.pth")
189-
print(f"Moving tokenizer to {model_dir / 'tokenizer.model'}.")
190-
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
149+
print(f"Creating symlink from {consolidated_pth} to {model_dir / 'model.pth'}.")
150+
os.symlink(consolidated_pth, model_dir / "model.pth")
151+
print(f"Creating symlink from {tokenizer_pth} to {model_dir / 'tokenizer.model'}.")
152+
os.symlink(tokenizer_pth, model_dir / "tokenizer.model")
191153
print("Done.")
192154
else:
193155
raise RuntimeError(f"Could not find {consolidated_pth}")

‎torchchat/cli/download.py

+62-19
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@
1818
resolve_model_config,
1919
)
2020

21+
# By default, download models from HuggingFace to the Hugginface hub directory.
22+
# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory.
23+
HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub"))))
2124

2225
def _download_hf_snapshot(
23-
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
26+
model_config: ModelConfig, hf_token: Optional[str]
2427
):
2528
from huggingface_hub import model_info, snapshot_download
2629
from requests.exceptions import HTTPError
2730

2831
# Download and store the HF model artifacts.
29-
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
32+
model_dir = get_model_dir(model_config, None)
33+
print(f"Downloading {model_config.name} from Hugging Face to {model_dir}", file=sys.stderr, flush=True)
3034
try:
3135
# Fetch the info about the model's repo
3236
model_info = model_info(model_config.distribution_path, token=hf_token)
@@ -56,8 +60,6 @@ def _download_hf_snapshot(
5660

5761
snapshot_download(
5862
model_config.distribution_path,
59-
local_dir=artifact_dir,
60-
local_dir_use_symlinks=False,
6163
token=hf_token,
6264
ignore_patterns=ignore_patterns,
6365
)
@@ -76,16 +78,20 @@ def _download_hf_snapshot(
7678
else:
7779
raise e
7880

81+
# Update the model dir to include the snapshot we just downloaded.
82+
model_dir = get_model_dir(model_config, None)
83+
print("Model downloaded to", model_dir)
84+
7985
# Convert the Multimodal Llama model to the torchtune format.
8086
if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}:
8187
print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr)
82-
convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name)
88+
convert_hf_checkpoint_to_tune( model_dir=model_dir, model_name=model_config.name)
8389

8490
else:
8591
# Convert the model to the torchchat format.
8692
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
8793
convert_hf_checkpoint(
88-
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
94+
model_dir=model_dir, model_name=model_config.name, remove_bin_files=True
8995
)
9096

9197

@@ -99,12 +105,51 @@ def _download_direct(
99105
print(f"Downloading {url}...", file=sys.stderr)
100106
urllib.request.urlretrieve(url, str(local_path.absolute()))
101107

108+
def _get_hf_artifact_dir(model_config: ModelConfig) -> Path:
109+
"""
110+
Returns the directory where the model artifacts are stored.
111+
112+
This is the root folder with blobs, refs and snapshots
113+
"""
114+
assert(model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot)
115+
return HUGGINGFACE_HOME_PATH / f"models--{model_config.distribution_path.replace("/", "--")}"
116+
117+
118+
def get_model_dir(model_config: ModelConfig, models_dir: Optional[Path]) -> Path:
119+
"""
120+
Returns the directory where the model artifacts are stored.
121+
For HuggingFace snapshots, this is the HuggingFace cache directory.
122+
For all other distribution channels, we use the models_dir.
123+
124+
For CLI usage, pass in args.model_directory.
125+
"""
126+
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
127+
artifact_dir = _get_hf_artifact_dir(model_config)
128+
129+
# If these paths doesn't exist, it means the model hasn't been downloaded yet.
130+
if not os.path.isdir(artifact_dir) and not os.path.isdir(artifact_dir / "snapshots"):
131+
return artifact_dir
132+
snapshot = open(artifact_dir / "refs" / "main", "r").read().strip()
133+
return artifact_dir / "snapshots" / snapshot
134+
else:
135+
return models_dir / model_config.name
136+
102137

103138
def download_and_convert(
104139
model: str, models_dir: Path, hf_token: Optional[str] = None
105140
) -> None:
106141
model_config = resolve_model_config(model)
107-
model_dir = models_dir / model_config.name
142+
model_dir = get_model_dir(model_config, models_dir)
143+
144+
# HuggingFace download
145+
if (
146+
model_config.distribution_channel
147+
== ModelDistributionChannel.HuggingFaceSnapshot
148+
):
149+
_download_hf_snapshot(model_config, hf_token)
150+
return
151+
152+
# Direct download
108153

109154
# Download into a temporary directory. We'll move to the final
110155
# location once the download and conversion is complete. This
@@ -117,11 +162,6 @@ def download_and_convert(
117162

118163
try:
119164
if (
120-
model_config.distribution_channel
121-
== ModelDistributionChannel.HuggingFaceSnapshot
122-
):
123-
_download_hf_snapshot(model_config, temp_dir, hf_token)
124-
elif (
125165
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
126166
):
127167
_download_direct(model_config, temp_dir)
@@ -144,9 +184,9 @@ def download_and_convert(
144184

145185
def is_model_downloaded(model: str, models_dir: Path) -> bool:
146186
model_config = resolve_model_config(model)
147-
187+
148188
# Check if the model directory exists and is not empty.
149-
model_dir = models_dir / model_config.name
189+
model_dir = get_model_dir(model_config, models_dir)
150190
return os.path.isdir(model_dir) and os.listdir(model_dir)
151191

152192

@@ -194,13 +234,16 @@ def remove_main(args) -> None:
194234
return
195235

196236
model_config = resolve_model_config(args.model)
197-
model_dir = args.model_directory / model_config.name
237+
model_dir = get_model_dir(model_config, args.model_directory)
198238

199239
if not os.path.isdir(model_dir):
200-
print(f"Model {args.model} has no downloaded artifacts.")
240+
print(f"Model {args.model} has no downloaded artifacts in {model_dir}.")
201241
return
242+
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
243+
# For HuggingFace models, we need to remove the entire root directory.
244+
model_dir = _get_hf_artifact_dir(model_config)
202245

203-
print(f"Removing downloaded model artifacts for {args.model}...")
246+
print(f"Removing downloaded model artifacts for {args.model} at {model_dir}...")
204247
shutil.rmtree(model_dir)
205248
print("Done.")
206249

@@ -216,10 +259,10 @@ def where_main(args) -> None:
216259
return
217260

218261
model_config = resolve_model_config(args.model)
219-
model_dir = args.model_directory / model_config.name
262+
model_dir = get_model_dir(model_config, args.model_directory)
220263

221264
if not os.path.isdir(model_dir):
222-
raise RuntimeError(f"Model {args.model} has no downloaded artifacts.")
265+
raise RuntimeError(f"Model {args.model} has no downloaded artifacts in {model_dir}.")
223266

224267
print(str(os.path.abspath(model_dir)))
225268
exit(0)

‎torchchat/usages/openai_api.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
2525

26-
from torchchat.cli.download import is_model_downloaded, load_model_configs
26+
from torchchat.cli.download import is_model_downloaded, load_model_configs, get_model_dir
2727
from torchchat.generate import Generator, GeneratorArgs
2828
from torchchat.model import FlamingoModel
2929

@@ -522,7 +522,7 @@ def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
522522
"""
523523
if model_config := load_model_configs().get(model_id):
524524
if is_model_downloaded(model_id, args.model_directory):
525-
path = args.model_directory / model_config.name
525+
path = get_model_dir(model_config, args.model_directory)
526526
created = int(os.path.getctime(path))
527527
owned_by = getpwuid(os.stat(path).st_uid).pw_name
528528

@@ -545,7 +545,7 @@ def get_model_info_list(args) -> ModelInfo:
545545
data = []
546546
for model_id, model_config in load_model_configs().items():
547547
if is_model_downloaded(model_id, args.model_directory):
548-
path = args.model_directory / model_config.name
548+
path = get_model_dir(model_config, args.model_directory)
549549
created = int(os.path.getctime(path))
550550
owned_by = getpwuid(os.stat(path).st_uid).pw_name
551551

0 commit comments

Comments
 (0)
Please sign in to comment.