Skip to content

Commit cfd1fdf

Browse files
committed
fix pylint
1 parent 2db49c4 commit cfd1fdf

File tree

9 files changed

+28
-28
lines changed

9 files changed

+28
-28
lines changed

internlm/checkpoint/components.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections import defaultdict
55

66
import torch
7-
from torch.distributed._shard.api import load_with_process_group
87

98
from internlm.accelerator import get_accelerator
109
from internlm.core.context import ParallelMode
@@ -15,13 +14,10 @@
1514
from internlm.utils.common import get_current_device
1615
from internlm.utils.lazy import LazyObject
1716
from internlm.utils.logger import get_logger
18-
from internlm.utils.parallel import is_using_hf, is_using_fsdp, is_using_isp
17+
from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp
1918
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
2019

21-
from .utils import (
22-
get_model_topology,
23-
get_non_moe_state_dict,
24-
)
20+
from .utils import get_model_topology, get_non_moe_state_dict
2521

2622
try:
2723
import torch.distributed.checkpoint as dcp
@@ -194,7 +190,7 @@ def load_model_checkpoint(folder, model):
194190
else:
195191
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
196192
fp = os.path.join(folder, should_load_name)
197-
193+
198194
states = llm_load(fp, map_location=get_current_device())
199195
"""
200196
# need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
@@ -366,7 +362,7 @@ def load_optimizer_checkpoint(folder, optim):
366362
max_pp = max(max_pp, int(pp[2:]))
367363
else:
368364
_, fsdp = os.path.splitext(fn)[0].split("_")
369-
max_fsdp = max(max_fsdp, int(fsdp[4:]))
365+
max_fsdp = max(max_fsdp, int(fsdp[4:]))
370366

371367
fsdp_size = gpc.get_world_size(ParallelMode.GLOBAL)
372368
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
@@ -399,7 +395,7 @@ def load_optimizer_checkpoint(folder, optim):
399395
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
400396
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
401397
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
402-
398+
403399
if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)):
404400
if is_using_isp():
405401
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"

internlm/checkpoint/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# -*- encoding: utf-8 -*-
33

44
import itertools
5+
56
import numpy as np
67
import torch
78
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8-
from torch.distributed.fsdp import StateDictType
99

1010
from internlm.core.context import global_context as gpc
1111
from internlm.core.parallel.shard import split_data_for_sequence_parallel
@@ -116,4 +116,4 @@ def init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP:
116116

117117
# run a forward pass with dummy_input to initialize FSDP
118118
_ = model(**dummy_input)
119-
return model
119+
return model

internlm/core/context/process_group_initializer.py

-1
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,6 @@ def init_dist_group(self, use_cpu: bool = False):
11071107
return groups
11081108

11091109

1110-
11111110
class Initializer_Weight(ProcessGroupInitializer):
11121111
"""A ProcessGroupInitializer for model weight parallelism.
11131112

internlm/core/trainer_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self.current_time = self._setup_time_and_logging()
9999
# load config_lines
100100
config_lines = self._read_config(kwargs["config"])
101-
101+
102102
# inject model for amp, parallel setting, parameter syncing and others
103103
model, isp_communicator = inject_model(model)
104104

internlm/model/builder.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
from typing import List, Union
22

3-
from torch import nn
43
import torch
4+
from torch import nn
55

66
from internlm.core.context import ParallelMode
77
from internlm.core.context import global_context as gpc
88
from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper
99
from internlm.model.base_model import BaseModel
10-
from internlm.model.modules.linear import ParallelLinearWithCommExt, ScaleColumnParallelLinear
10+
from internlm.model.modules.linear import (
11+
ParallelLinearWithCommExt,
12+
ScaleColumnParallelLinear,
13+
)
1114
from internlm.model.registry import model_initializer
12-
from internlm.utils.parallel import is_using_hf
1315
from internlm.utils.common import get_current_device
1416
from internlm.utils.lazy import LazyObject
1517
from internlm.utils.logger import get_logger
16-
from internlm.utils.parallel import is_using_fsdp, is_using_isp
18+
from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp
1719

1820
logger = get_logger(__file__)
1921

@@ -58,6 +60,7 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]:
5860

5961
return model
6062

63+
6164
def create_model_hf(hf: dict) -> nn.Module:
6265
cfg = LazyObject(hf.cfg, hf.cfg_cls)
6366
cfg = cfg.build()
@@ -123,4 +126,4 @@ def traverse(module):
123126
else:
124127
traverse(model)
125128

126-
return model
129+
return model

internlm/solver/activation_checkpoint.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
import weakref
55

66
import torch
7-
87
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
98
checkpoint_wrapper as ptd_checkpoint_wrapper,
109
)
11-
1210
from torch.utils.checkpoint import check_backward_validity, detach_variable
1311

1412
from internlm.accelerator import get_accelerator
@@ -287,4 +285,4 @@ def apply_ac_to_transformer_block(module: torch.nn.Module, checkpoint):
287285
if ptd_checkpoint_wrapper._count % ac_freq == 0:
288286
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
289287
else:
290-
return module
288+
return module

internlm/solver/optimizer/fsdp_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.distributed as dist
88
from torch.optim import Optimizer
99

10-
from internlm.accelerator import AcceleratorType, get_accelerator
10+
from internlm.accelerator import get_accelerator
1111
from internlm.core.context import Config, ParallelMode
1212
from internlm.core.context import global_context as gpc
1313
from internlm.solver.optimizer.base_optimizer import BaseOptimizer

internlm/train/pipeline.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
import torch
1212
from torch import nn
13-
from torch.utils.data import DataLoader
1413
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1514
from torch.distributed.fsdp.fully_sharded_data_parallel import (
1615
BackwardPrefetch,
1716
ShardingStrategy,
1817
)
1918
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
19+
from torch.utils.data import DataLoader
2020

2121
from internlm.accelerator import AcceleratorType, get_accelerator
2222
from internlm.checkpoint.utils import init_fsdp_v1
@@ -96,8 +96,8 @@
9696
is_replica_zero_parallel_parameter,
9797
is_tensor_expert_data_parallel_parameter,
9898
is_tensor_zero_parallel_parameter,
99-
is_using_hf,
10099
is_using_fsdp,
100+
is_using_hf,
101101
is_using_isp,
102102
is_weight_expert_data_parallel_parameter,
103103
is_weight_zero_parallel_parameter,
@@ -256,7 +256,8 @@ def _check_module(name, module):
256256
# special case for pure dp mode
257257
if (
258258
isinstance(gpc.config.parallel["tensor"], dict)
259-
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name
259+
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name)
260+
== TensorParallelMode.mtp.name
260261
and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL)
261262
):
262263
_check_module_func = _check_module_pure_dp
@@ -278,7 +279,9 @@ def _check_module(name, module):
278279

279280

280281
@llm_timeout(func_name="initialize_model_and_parallel_communicator")
281-
def initialize_model_and_parallel_communicator(pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None):
282+
def initialize_model_and_parallel_communicator(
283+
pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None
284+
):
282285
"""
283286
Initialize model with Automatic Mixed Precision.
284287
Returns:
@@ -362,10 +365,10 @@ def inject_model(model):
362365
# state in the same dp group are all the same.
363366
random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA
364367
set_mode(random_mode)
365-
368+
366369
# initialize isp communicator
367370
isp_communicator = initialize_parallel_communicator(model)
368-
371+
369372
model = wrap_FSDP_model(model)
370373

371374
# set is_injected flag

internlm/utils/parallel.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def is_using_fsdp():
2727
and gpc.config.parallel["fsdp"].get("enable", False)
2828
)
2929

30+
3031
def is_using_sequence_parallel():
3132
return (
3233
isinstance(gpc.config.parallel["tensor"], dict)

0 commit comments

Comments
 (0)