Skip to content

Commit dfec431

Browse files
kwen2501lessw2020
andauthored
[Distributed] Add support for torchchat checkpoint format (#1268)
* Create load path from HF format * Add purge_fqn_prefix * Remove weight map and file from update_state_dict * Add load support for torchchat checkpoint * Rename safetensor_utils to checkpoint_utils --------- Co-authored-by: Less Wright <[email protected]>
1 parent 7a67429 commit dfec431

File tree

3 files changed

+195
-91
lines changed

3 files changed

+195
-91
lines changed

dist_run.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from torchchat.distributed.logging_utils import SingletonLogger
2424

2525
# TODO - these are not distributed specific, consider moving to new package
26-
from torchchat.distributed.safetensor_utils import (
26+
from torchchat.distributed.checkpoint_utils import (
2727
get_hf_config_file,
28-
get_hf_weight_map_and_path,
29-
load_safetensor_weights,
28+
load_weights_from_hf_format,
29+
load_weights_from_torchchat_format,
3030
)
3131
from torchchat.distributed.utils import (
3232
bytes_to_readable,
@@ -129,26 +129,33 @@ def _build_chat_tokenizer(
129129
return tokenizer
130130

131131

132-
def _load_model_weights(stage_module, distribution, device, model_config):
132+
def _load_model_weights(
133+
stage_module: torch.nn.Module,
134+
distribution: str,
135+
device: torch.device,
136+
model_config: ModelArgs,
137+
chpt_from: str,
138+
):
133139
"""Load the weights from the safetensor file(s) into the model stage.
134140
Model config is needed b/c we permute wq and wk weights based on attn heads.
135-
"""
136141
137-
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
138-
139-
num_loaded_weights, num_missing_weights = load_safetensor_weights(
140-
stage_module,
141-
weight_map,
142-
weight_path,
143-
key_map,
144-
device,
145-
model_config=model_config,
146-
)
147-
logger.info(
148-
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
149-
)
150-
if num_missing_weights > 0:
151-
raise ValueError(f"Missing {num_missing_weights} weights")
142+
Args:
143+
stage_module (torch.nn.Module): The model stage to load the weights into.
144+
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
145+
device (torch.device): The device to load the weights onto.
146+
model_config (ModelArgs): The model config.
147+
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
148+
"""
149+
if chpt_from == "hf":
150+
# This format stands for: index file + multiple binary files
151+
load_weights_from_hf_format(stage_module, distribution, device, model_config)
152+
elif chpt_from == "torchchat":
153+
# This format stands for:
154+
# single binary file, OR
155+
# multiple binary files without index files.
156+
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
157+
else:
158+
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
152159

153160

154161
def _encode_strings(
@@ -306,7 +313,7 @@ def main(args):
306313
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
307314

308315
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
309-
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
316+
logger.info(f"Using model weights from {distribution} and dtype {model_dtype}")
310317

311318
# Model-level config
312319
model_config = ModelArgs.from_name(distribution)
@@ -368,7 +375,7 @@ def main(args):
368375
# Load weights
369376
logger.info(f"Loading weights for {pp_rank=} on {device=}")
370377
with CUDATrackTime() as timer:
371-
_load_model_weights(model, distribution, device=device, model_config=config)
378+
_load_model_weights(model, distribution, device, config, args.chpt_from)
372379

373380
logger.info(
374381
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -602,6 +609,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
602609
default=False,
603610
help="Whether to decode token into string in flight",
604611
)
612+
parser.add_argument(
613+
"--chpt-from",
614+
type=str,
615+
default="hf", # TODO: change to torchchat once we support it well
616+
help="Checkpoint format to load from",
617+
choices=["hf", "torchchat"],
618+
)
605619
args = parser.parse_args()
606620

607621
main(args)

torchchat/cli/builder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
335335
return model
336336

337337

338-
def _load_model_default(builder_args: BuilderArgs) -> Model:
339-
assert not builder_args.gguf_path
340-
341-
model: Model = _init_model_on_meta_device(builder_args)
342-
338+
def _load_checkpoint(builder_args: BuilderArgs):
343339
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
344340
print("Loading Tune checkpoint")
345341
meta_checkpoint = torch.load(
@@ -377,6 +373,16 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
377373
mmap=True,
378374
weights_only=True,
379375
)
376+
return checkpoint
377+
378+
379+
def _load_model_default(builder_args: BuilderArgs) -> Model:
380+
assert not builder_args.gguf_path
381+
382+
model: Model = _init_model_on_meta_device(builder_args)
383+
384+
# Load checkpoint from filesystem
385+
checkpoint = _load_checkpoint(builder_args)
380386

381387
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
382388
checkpoint = checkpoint["model"]

0 commit comments

Comments
 (0)