diff --git a/ci_scripts/train/generate_config.py b/ci_scripts/train/generate_config.py index bfb68b74d..d1a34940a 100644 --- a/ci_scripts/train/generate_config.py +++ b/ci_scripts/train/generate_config.py @@ -5,7 +5,7 @@ import os from ci_scripts.common import com_func -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config def generate_new_config(config_py_file, test_config_json, case_name): diff --git a/internlm/core/context/config.py b/internlm/core/context/config.py deleted file mode 100644 index 0bedbb420..000000000 --- a/internlm/core/context/config.py +++ /dev/null @@ -1,95 +0,0 @@ -import inspect -import sys -from importlib.machinery import SourceFileLoader -from pathlib import Path - - -class Config(dict): - """This is a wrapper class for dict objects so that values of which can be - accessed as attributes. - - Args: - config (dict): The dict object to be wrapped. - """ - - def __init__(self, config: dict = None): # pylint: disable=W0231 - if config is not None: - for k, v in config.items(): - self._add_item(k, v) - - def __missing__(self, key): - raise KeyError(key) - - def __getattr__(self, key): - try: - value = super().__getitem__(key) - return value - except KeyError: - raise AttributeError(key) - - def __setattr__(self, key, value): - super().__setitem__(key, value) - - def _add_item(self, key, value): - if isinstance(value, dict): - self.__setattr__(key, Config(value)) - else: - self.__setattr__(key, value) - - def update(self, config): - assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." - for k, v in config.items(): - self._add_item(k, v) - return self - - @staticmethod - def from_file(filename: str): - """Reads a python file and constructs a corresponding :class:`Config` object. - - Args: - filename (str): Name of the file to construct the return object. - - Returns: - :class:`Config`: A :class:`Config` object constructed with information in the file. - - Raises: - AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file - """ - - # check config path - if isinstance(filename, str): - filepath = Path(filename).absolute() - elif isinstance(filename, Path): - filepath = filename.absolute() - - assert filepath.exists(), f"{filename} is not found, please check your configuration path" - - # check extension - extension = filepath.suffix - assert extension == ".py", "only .py files are supported" - - # import the config as module - remove_path = False - if filepath.parent not in sys.path: - sys.path.insert(0, (filepath)) - remove_path = True - - module_name = filepath.stem - source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) - module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 - - # load into config - config = Config() - - for k, v in module.__dict__.items(): - if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): - continue - else: - config._add_item(k, v) - - # remove module - del sys.modules[module_name] - if remove_path: - sys.path.pop(0) - - return config diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index f4868e424..1bd7ede38 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -3,8 +3,12 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context +from importlib.machinery import SourceFileLoader +import inspect +from pathlib import Path import random import socket +import sys from typing import Union import numpy as np @@ -12,7 +16,7 @@ import torch.distributed as dist from internlm.accelerator import get_accelerator -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.utils.logger import get_logger from internlm.utils.common import SingletonMeta from internlm.utils.timeout import LLM_NCCL_TIMEOUT @@ -43,6 +47,97 @@ internlm_accelerator = get_accelerator() +class Config(dict): + """This is a wrapper class for dict objects so that values of which can be + accessed as attributes. + + Args: + config (dict): The dict object to be wrapped. + """ + + def __init__(self, config: dict = None): # pylint: disable=W0231 + if config is not None: + for k, v in config.items(): + self._add_item(k, v) + + def __missing__(self, key): + raise KeyError(key) + + def __getattr__(self, key): + try: + value = super().__getitem__(key) + return value + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + super().__setitem__(key, value) + + def _add_item(self, key, value): + if isinstance(value, dict): + self.__setattr__(key, Config(value)) + else: + self.__setattr__(key, value) + + def update(self, config): + assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." + for k, v in config.items(): + self._add_item(k, v) + return self + + @staticmethod + def from_file(filename: str): + """Reads a python file and constructs a corresponding :class:`Config` object. + + Args: + filename (str): Name of the file to construct the return object. + + Returns: + :class:`Config`: A :class:`Config` object constructed with information in the file. + + Raises: + AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file + """ + + # check config path + if isinstance(filename, str): + filepath = Path(filename).absolute() + elif isinstance(filename, Path): + filepath = filename.absolute() + + assert filepath.exists(), f"{filename} is not found, please check your configuration path" + + # check extension + extension = filepath.suffix + assert extension == ".py", "only .py files are supported" + + # import the config as module + remove_path = False + if filepath.parent not in sys.path: + sys.path.insert(0, (filepath)) + remove_path = True + + module_name = filepath.stem + source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) + module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 + + # load into config + config = Config() + + for k, v in module.__dict__.items(): + if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): + continue + else: + config._add_item(k, v) + + # remove module + del sys.modules[module_name] + if remove_path: + sys.path.pop(0) + + return config + + class ParallelContext(metaclass=SingletonMeta): """This class provides interface functions for users to get the parallel context, such as the global rank, the local rank, the world size, etc. of each device. diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index a73676b5c..4247d99f0 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -9,7 +9,7 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode from internlm.utils.common import get_master_node diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index d434f6dbc..4bdb7aef0 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -8,7 +8,7 @@ from torch.optim import Optimizer from internlm.accelerator import get_accelerator -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.solver.optimizer.base_optimizer import BaseOptimizer diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 6c88fd49f..ba8d947a3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,7 +11,7 @@ from torch.optim import Optimizer from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import ( IS_REPLICA_EXPERT_DATA_PARALLEL, IS_REPLICA_ZERO_PARALLEL, diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index 614bea647..42c3ab177 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -7,7 +7,7 @@ import torch.distributed as dist from torch.optim import Optimizer -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import ( IS_REPLICA_ZERO_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 1381ca77b..15c984509 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -10,6 +10,7 @@ from collections import ChainMap from contextlib import contextmanager from datetime import datetime +import threading from typing import Union import numpy as np diff --git a/tests/common_fixture.py b/tests/common_fixture.py index fa9e0acbe..3362099f5 100644 --- a/tests/common_fixture.py +++ b/tests/common_fixture.py @@ -8,7 +8,7 @@ from internlm.initialize.launch import launch_from_torch from internlm.accelerator import get_accelerator from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.data.utils import unpack_type_ids from internlm.initialize.launch import args_sanity_check diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 416591ce3..28cb3959e 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -5,7 +5,7 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.utils.common import get_current_device from tests.test_core.utils import ( diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 531188791..e1dea8fa0 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -8,7 +8,7 @@ from internlm.core.context.parallel_context import global_context as gpc # from internlm.core.context import ParallelMode -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState from internlm.data import ( build_train_loader_with_data_type, diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index 3011e7121..310d2f1d3 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -9,7 +9,7 @@ from internlm.initialize.launch import launch_from_torch from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.comm.tensor import ( HeadTensorParallelCommunicator, diff --git a/tests/test_model/test_npu_ops/test_flash_attention.py b/tests/test_model/test_npu_ops/test_flash_attention.py index bf40da712..81166bf9f 100644 --- a/tests/test_model/test_npu_ops/test_flash_attention.py +++ b/tests/test_model/test_npu_ops/test_flash_attention.py @@ -12,7 +12,7 @@ from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import global_context as gpc from internlm.model.ops.attention import SelfAttention from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn diff --git a/tests/test_solver/test_optimizer.py b/tests/test_solver/test_optimizer.py index 43a2ba367..11d0ebc7e 100644 --- a/tests/test_solver/test_optimizer.py +++ b/tests/test_solver/test_optimizer.py @@ -11,7 +11,7 @@ from internlm.initialize.launch import launch_from_torch from internlm.accelerator import get_accelerator -from internlm.core.context.config import Config, ParallelMode +from internlm.core.context.parallel_context import Config, ParallelMode from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index bb6c4642b..715a362a9 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -12,7 +12,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.initialize.launch import args_sanity_check diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index cc2f543eb..1f61544fd 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -8,7 +8,7 @@ from internlm.initialize.initialize_trainer import initialize_trainer from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint import CheckpointManager -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer import Trainer, TrainState diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 0d0783e16..5b6eeca30 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -14,7 +14,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.trainer import Trainer from internlm.data import ( build_train_loader_with_data_type, diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index daf8162ab..8c51a161a 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -6,7 +6,7 @@ import torch from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.naive_amp import NaiveAMPModel from internlm.model.builder import create_model from internlm.model.registry import register_model_initializer diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 010da60b4..5cfdf10c7 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -10,7 +10,7 @@ import torch.distributed as dist from internlm.checkpoint import CheckpointManager -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.utils.common import SingletonMeta diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index 99eb63388..9454a8369 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -3,7 +3,7 @@ import pytest import torch -from internlm.core.context.config import Config +from internlm.core.context.parallel_context import Config from internlm.initialize.launch import get_config_value from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import ALI_SAVE_PATH,