Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

read torch titan #208

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

read torch titan #208

wants to merge 3 commits into from

Conversation

long8v
Copy link
Owner

@long8v long8v commented Dec 4, 2024

No description provided.

Copy link
Owner Author

@long8v long8v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(24.12.04) torch distributed study PP 쪽 봄

Comment on lines +31 to +46
def pipeline_llama(
model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
stages, models = pipeline_llama_manual_split(
model, pp_mesh, parallel_dims, job_config, device, model_config
)

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PP recap -- notion
PP 메인 1) pipline_llama_manual_split으로 모델 쪼개주는 것과 2) build_pipeline_schedule로 micro batch 등 pipeline 스케쥴하는 것 두개로 나누어짐.

Comment on lines +49 to +56
def pipeline_llama_manual_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama 쪼개는 함수. DeviceMesh는 Torch native고 ParallelDims는 내부 함수

Comment on lines +14 to +22
@dataclass
class ParallelDims:
dp_replicate: int
dp_shard: int
cp: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dp_replicat, dp_shard, cp, tp, pp 등 정의하는 클래스

Comment on lines +39 to +42
dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (cp * tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dp_shrad * dp_replicate는 word_size를 (cp * tp * pp)로 나눈 것과 같아야 함. (model parallel을 하고 남은 차원에서 DP)

Comment on lines +54 to +72
def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
dims.append(d)
if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
else:
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_mesh 생성하는 기능도 있음

Comment on lines +54 to +64
# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParallelDims 정의

logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")

# build meshes
world_mesh = parallel_dims.build_mesh(device_type=device_type)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build mesh

Comment on lines +154 to +168
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device=init_device)
m.init_weights(buffer_device=buffer_device)
m.train()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

여기서 model_pipelining_fns를 불러오고 call 해줌.

Comment on lines +256 to +267
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=train_state.step
) as memory_profiler:
while train_state.step < job_config.training.steps:
train_state.step += 1
gc_handler.run(train_state.step)

# get batch
data_load_start = time.perf_counter()
batch = next(data_iterator)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

여기가 실제 학습하는 부분

Comment on lines +288 to +306
if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context(optional_context_parallel_ctx):
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
pp_schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pp_schedule.step()으로 micro batch forward 하는듯

Copy link
Owner Author

@long8v long8v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(24.12.18) FSDP2 부분 읽음.

Comment on lines +40 to +45
def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama에 대한 parallelize 하는 부분은 여기 모아져있음

and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
apply_tp(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TP

)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation checkpointing

Comment on lines +40 to +52
def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

여기가 llama를 parallelize 하는 부분

Comment on lines +54 to +60
if parallel_dims.tp_enabled:
if (
job_config.experimental.enable_async_tensor_parallel
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
apply_tp(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TP

Comment on lines +97 to +106
module (Union[nn.Module, List[nn.Module]): The module or modules to
shard with FSDP and group together for communication.
mesh (Optional[DeviceMesh]): This data parallel mesh defines the
sharding and device. If 1D, then parameters are fully sharded
across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D,
then parameters are sharded across the 1st dim and replicated
across the 0th dim (HSDP) with ``(Replicate(), Shard(0))``
placement. The mesh's device type gives the device type used for
communication; if a CUDA or CUDA-like device type, then we use the
current device.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module, mesh

Comment on lines +149 to +163
mesh = mesh or _init_default_fully_shard_mesh()
if mesh.ndim not in (1, 2):
raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
elif mesh.ndim == 1:
mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
else:
if mesh.mesh_dim_names is None:
raise AssertionError(
"Please init the 2D mesh for HSDP with mesh_dim_names specified"
)
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
device = _get_device_from_mesh(mesh)
post_forward_mesh_info = _get_post_forward_mesh_info(
reshard_after_forward, mesh_info
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mesh info 설정하는 부분

Comment on lines +165 to +170
arg_module = module
modules = (
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
)
state = fully_shard.state(modules[0])
state.init(modules, device, mp_policy)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp state 설정해줌. hook같은거 설정해주고 등등 하는듯

class FSDPState(_State):
    def __init__(self) -> None:
        super().__init__()
        self._fsdp_param_group: Optional[FSDPParamGroup] = None
        self._is_root: Optional[bool] = None  # root set during lazy init
        self._state_ctx = FSDPStateContext()
        self._comm_ctx = FSDPCommContext()
        self._training_state: TrainingState = TrainingState.IDLE
        self._states_to_forward_prefetch: List[FSDPState] = []
        self._states_to_backward_prefetch: List[FSDPState] = []
        self._modules_to_run_forward: Set[nn.Module] = set()

    # Define a separate init since `__init__` is called in the contract
    def init(
        self,
        modules: Tuple[nn.Module, ...],
        device: torch.device,
        mp_policy: MixedPrecisionPolicy,
    ) -> None:
        for module in modules:
            _insert_module_state(module, self)
        self._modules = modules
        self._device = device
        self._device_handle = _get_device_handle(device.type)
        self._mp_policy = mp_policy
        if len(modules) == 1:
            self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
                self._pre_forward, prepend=True, with_kwargs=True
            )
            self._post_forward_hook_handle = modules[0].register_forward_hook(
                self._post_forward, prepend=False
            )
        else:
            hook_handle = _register_group_forward_hooks(
                modules,
                self._pre_forward,
                self._post_forward,
                self._modules_to_run_forward,
            )
            self._pre_forward_hook_handle = hook_handle
            self._post_forward_hook_handle = hook_handle

Comment on lines +175 to +185
if params:
state._fsdp_param_group = FSDPParamGroup(
params,
modules,
mesh_info,
post_forward_mesh_info,
device,
shard_placement_fn,
mp_policy,
offload_policy,
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDPParamGroup을 설정해줌. 내부적으론 몰라도 되지 않을까..?! ㅎㅎ
image

Comment on lines +192 to +201
# Place FSDP leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
new_cls = cls_to_fsdp_cls.get(cls, None)
if not new_cls:
dct = {"__deepcopy__": _unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
cls_to_fsdp_cls[cls] = new_cls
module.__class__ = new_cls
return arg_module
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls wrapping해주는 부분 있음. 이 부분은 이전과 비슷한듯함

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant