|
4 | 4 | from collections import defaultdict
|
5 | 5 |
|
6 | 6 | import torch
|
7 |
| -from torch.distributed._shard.api import load_with_process_group |
8 | 7 |
|
9 | 8 | from internlm.accelerator import get_accelerator
|
10 | 9 | from internlm.core.context import ParallelMode
|
|
15 | 14 | from internlm.utils.common import get_current_device
|
16 | 15 | from internlm.utils.lazy import LazyObject
|
17 | 16 | from internlm.utils.logger import get_logger
|
18 |
| -from internlm.utils.parallel import is_using_hf, is_using_fsdp, is_using_isp |
| 17 | +from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp |
19 | 18 | from internlm.utils.storage_manager import get_fns, llm_load, llm_save
|
20 | 19 |
|
21 |
| -from .utils import ( |
22 |
| - get_model_topology, |
23 |
| - get_non_moe_state_dict, |
24 |
| -) |
| 20 | +from .utils import get_model_topology, get_non_moe_state_dict |
25 | 21 |
|
26 | 22 | try:
|
27 | 23 | import torch.distributed.checkpoint as dcp
|
@@ -194,7 +190,7 @@ def load_model_checkpoint(folder, model):
|
194 | 190 | else:
|
195 | 191 | should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
196 | 192 | fp = os.path.join(folder, should_load_name)
|
197 |
| - |
| 193 | + |
198 | 194 | states = llm_load(fp, map_location=get_current_device())
|
199 | 195 | """
|
200 | 196 | # need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
|
@@ -366,7 +362,7 @@ def load_optimizer_checkpoint(folder, optim):
|
366 | 362 | max_pp = max(max_pp, int(pp[2:]))
|
367 | 363 | else:
|
368 | 364 | _, fsdp = os.path.splitext(fn)[0].split("_")
|
369 |
| - max_fsdp = max(max_fsdp, int(fsdp[4:])) |
| 365 | + max_fsdp = max(max_fsdp, int(fsdp[4:])) |
370 | 366 |
|
371 | 367 | fsdp_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
372 | 368 | zero_size = gpc.get_world_size(ParallelMode.ZERO1)
|
@@ -399,7 +395,7 @@ def load_optimizer_checkpoint(folder, optim):
|
399 | 395 | tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
400 | 396 | wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
|
401 | 397 | pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
402 |
| - |
| 398 | + |
403 | 399 | if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)):
|
404 | 400 | if is_using_isp():
|
405 | 401 | fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
|
0 commit comments