Skip to content

Commit

Permalink
revert back changes
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Feb 14, 2025
1 parent 8ad2505 commit bdb44d6
Show file tree
Hide file tree
Showing 20 changed files with 114 additions and 113 deletions.
2 changes: 1 addition & 1 deletion ci_scripts/train/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os

from ci_scripts.common import com_func
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config


def generate_new_config(config_py_file, test_config_json, case_name):
Expand Down
95 changes: 0 additions & 95 deletions internlm/core/context/config.py

This file was deleted.

97 changes: 96 additions & 1 deletion internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@

# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context

from importlib.machinery import SourceFileLoader
import inspect
from pathlib import Path
import random
import socket
import sys
from typing import Union

import numpy as np
import torch
import torch.distributed as dist

from internlm.accelerator import get_accelerator
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.utils.logger import get_logger
from internlm.utils.common import SingletonMeta
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
Expand Down Expand Up @@ -43,6 +47,97 @@
internlm_accelerator = get_accelerator()


class Config(dict):
"""This is a wrapper class for dict objects so that values of which can be
accessed as attributes.
Args:
config (dict): The dict object to be wrapped.
"""

def __init__(self, config: dict = None): # pylint: disable=W0231
if config is not None:
for k, v in config.items():
self._add_item(k, v)

def __missing__(self, key):
raise KeyError(key)

def __getattr__(self, key):
try:
value = super().__getitem__(key)
return value
except KeyError:
raise AttributeError(key)

def __setattr__(self, key, value):
super().__setitem__(key, value)

def _add_item(self, key, value):
if isinstance(value, dict):
self.__setattr__(key, Config(value))
else:
self.__setattr__(key, value)

def update(self, config):
assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
for k, v in config.items():
self._add_item(k, v)
return self

@staticmethod
def from_file(filename: str):
"""Reads a python file and constructs a corresponding :class:`Config` object.
Args:
filename (str): Name of the file to construct the return object.
Returns:
:class:`Config`: A :class:`Config` object constructed with information in the file.
Raises:
AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
"""

# check config path
if isinstance(filename, str):
filepath = Path(filename).absolute()
elif isinstance(filename, Path):
filepath = filename.absolute()

assert filepath.exists(), f"{filename} is not found, please check your configuration path"

# check extension
extension = filepath.suffix
assert extension == ".py", "only .py files are supported"

# import the config as module
remove_path = False
if filepath.parent not in sys.path:
sys.path.insert(0, (filepath))
remove_path = True

module_name = filepath.stem
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
module = source_file.load_module() # pylint: disable=W4902,E1120,W1505

# load into config
config = Config()

for k, v in module.__dict__.items():
if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)

# remove module
del sys.modules[module_name]
if remove_path:
sys.path.pop(0)

return config


class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
Expand Down
2 changes: 1 addition & 1 deletion internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.utils.common import get_master_node
Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/fsdp_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.optim import Optimizer

from internlm.accelerator import get_accelerator
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.solver.optimizer.base_optimizer import BaseOptimizer
Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.optim import Optimizer

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import (
IS_REPLICA_EXPERT_DATA_PARALLEL,
IS_REPLICA_ZERO_PARALLEL,
Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/hybrid_zero_optim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.distributed as dist
from torch.optim import Optimizer

from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import (
IS_REPLICA_ZERO_PARALLEL,
IS_TENSOR_EXPERT_DATA_PARALLEL,
Expand Down
1 change: 1 addition & 0 deletions internlm/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import ChainMap
from contextlib import contextmanager
from datetime import datetime
import threading
from typing import Union

import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion tests/common_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from internlm.initialize.launch import launch_from_torch
from internlm.accelerator import get_accelerator
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.data.utils import unpack_type_ids
from internlm.initialize.launch import args_sanity_check

Expand Down
2 changes: 1 addition & 1 deletion tests/test_core/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw
from internlm.utils.common import get_current_device
from tests.test_core.utils import (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from internlm.core.context.parallel_context import global_context as gpc

# from internlm.core.context import ParallelMode
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
from internlm.data import (
build_train_loader_with_data_type,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model/test_model_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from internlm.initialize.launch import launch_from_torch
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.parallel.comm.tensor import (
HeadTensorParallelCommunicator,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model/test_npu_ops/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch import nn

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.ops.attention import SelfAttention
from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn
Expand Down
2 changes: 1 addition & 1 deletion tests/test_solver/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from internlm.initialize.launch import launch_from_torch
from internlm.accelerator import get_accelerator
from internlm.core.context.config import Config, ParallelMode
from internlm.core.context.parallel_context import Config, ParallelMode
from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training/test_forward_output_no_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import Trainer
from internlm.data import build_train_loader_with_data_type
from internlm.initialize.launch import args_sanity_check
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from internlm.initialize.initialize_trainer import initialize_trainer
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.checkpoint import CheckpointManager
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.trainer import Trainer, TrainState
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training/test_swap_nb_loss_and_gradnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import Trainer
from internlm.data import (
build_train_loader_with_data_type,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/common_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.builder import create_model
from internlm.model.registry import register_model_initializer
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.distributed as dist

from internlm.checkpoint import CheckpointManager
from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/test_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from internlm.core.context.config import Config
from internlm.core.context.parallel_context import Config
from internlm.initialize.launch import get_config_value
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ALI_SAVE_PATH,
Expand Down

0 comments on commit bdb44d6

Please sign in to comment.