|
23 | 23 | from torchchat.distributed.logging_utils import SingletonLogger
|
24 | 24 |
|
25 | 25 | # TODO - these are not distributed specific, consider moving to new package
|
26 |
| -from torchchat.distributed.safetensor_utils import ( |
| 26 | +from torchchat.distributed.checkpoint_utils import ( |
27 | 27 | get_hf_config_file,
|
28 |
| - get_hf_weight_map_and_path, |
29 |
| - load_safetensor_weights, |
| 28 | + load_weights_from_hf_format, |
| 29 | + load_weights_from_torchchat_format, |
30 | 30 | )
|
31 | 31 | from torchchat.distributed.utils import (
|
32 | 32 | bytes_to_readable,
|
@@ -129,26 +129,33 @@ def _build_chat_tokenizer(
|
129 | 129 | return tokenizer
|
130 | 130 |
|
131 | 131 |
|
132 |
| -def _load_model_weights(stage_module, distribution, device, model_config): |
| 132 | +def _load_model_weights( |
| 133 | + stage_module: torch.nn.Module, |
| 134 | + distribution: str, |
| 135 | + device: torch.device, |
| 136 | + model_config: ModelArgs, |
| 137 | + chpt_from: str, |
| 138 | +): |
133 | 139 | """Load the weights from the safetensor file(s) into the model stage.
|
134 | 140 | Model config is needed b/c we permute wq and wk weights based on attn heads.
|
135 |
| - """ |
136 | 141 |
|
137 |
| - weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution) |
138 |
| - |
139 |
| - num_loaded_weights, num_missing_weights = load_safetensor_weights( |
140 |
| - stage_module, |
141 |
| - weight_map, |
142 |
| - weight_path, |
143 |
| - key_map, |
144 |
| - device, |
145 |
| - model_config=model_config, |
146 |
| - ) |
147 |
| - logger.info( |
148 |
| - f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights" |
149 |
| - ) |
150 |
| - if num_missing_weights > 0: |
151 |
| - raise ValueError(f"Missing {num_missing_weights} weights") |
| 142 | + Args: |
| 143 | + stage_module (torch.nn.Module): The model stage to load the weights into. |
| 144 | + distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". |
| 145 | + device (torch.device): The device to load the weights onto. |
| 146 | + model_config (ModelArgs): The model config. |
| 147 | + chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". |
| 148 | + """ |
| 149 | + if chpt_from == "hf": |
| 150 | + # This format stands for: index file + multiple binary files |
| 151 | + load_weights_from_hf_format(stage_module, distribution, device, model_config) |
| 152 | + elif chpt_from == "torchchat": |
| 153 | + # This format stands for: |
| 154 | + # single binary file, OR |
| 155 | + # multiple binary files without index files. |
| 156 | + load_weights_from_torchchat_format(stage_module, distribution, device, model_config) |
| 157 | + else: |
| 158 | + raise ValueError(f"Unknown checkpoint format: {chpt_from}") |
152 | 159 |
|
153 | 160 |
|
154 | 161 | def _encode_strings(
|
@@ -306,7 +313,7 @@ def main(args):
|
306 | 313 | logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
|
307 | 314 |
|
308 | 315 | distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
|
309 |
| - logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}") |
| 316 | + logger.info(f"Using model weights from {distribution} and dtype {model_dtype}") |
310 | 317 |
|
311 | 318 | # Model-level config
|
312 | 319 | model_config = ModelArgs.from_name(distribution)
|
@@ -368,7 +375,7 @@ def main(args):
|
368 | 375 | # Load weights
|
369 | 376 | logger.info(f"Loading weights for {pp_rank=} on {device=}")
|
370 | 377 | with CUDATrackTime() as timer:
|
371 |
| - _load_model_weights(model, distribution, device=device, model_config=config) |
| 378 | + _load_model_weights(model, distribution, device, config, args.chpt_from) |
372 | 379 |
|
373 | 380 | logger.info(
|
374 | 381 | f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
|
@@ -602,6 +609,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
602 | 609 | default=False,
|
603 | 610 | help="Whether to decode token into string in flight",
|
604 | 611 | )
|
| 612 | + parser.add_argument( |
| 613 | + "--chpt-from", |
| 614 | + type=str, |
| 615 | + default="hf", # TODO: change to torchchat once we support it well |
| 616 | + help="Checkpoint format to load from", |
| 617 | + choices=["hf", "torchchat"], |
| 618 | + ) |
605 | 619 | args = parser.parse_args()
|
606 | 620 |
|
607 | 621 | main(args)
|
0 commit comments