-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
read torch titan #208
base: main
Are you sure you want to change the base?
read torch titan #208
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(24.12.04) torch distributed study PP 쪽 봄
def pipeline_llama( | ||
model: nn.Module, | ||
pp_mesh: DeviceMesh, | ||
parallel_dims: ParallelDims, | ||
job_config: JobConfig, | ||
device: DeviceType, | ||
model_config: ModelArgs, | ||
loss_fn: Callable[..., torch.Tensor], | ||
): | ||
stages, models = pipeline_llama_manual_split( | ||
model, pp_mesh, parallel_dims, job_config, device, model_config | ||
) | ||
|
||
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) | ||
|
||
return pp_schedule, models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PP recap -- notion
PP 메인 1) pipline_llama_manual_split
으로 모델 쪼개주는 것과 2) build_pipeline_schedule
로 micro batch 등 pipeline 스케쥴하는 것 두개로 나누어짐.
def pipeline_llama_manual_split( | ||
whole_model: nn.Module, | ||
pp_mesh: DeviceMesh, | ||
parallel_dims: ParallelDims, | ||
job_config: JobConfig, | ||
device: DeviceType, | ||
model_config: ModelArgs, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama 쪼개는 함수. DeviceMesh는 Torch native고 ParallelDims는 내부 함수
@dataclass | ||
class ParallelDims: | ||
dp_replicate: int | ||
dp_shard: int | ||
cp: int | ||
tp: int | ||
pp: int | ||
world_size: int | ||
enable_loss_parallel: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dp_replicat, dp_shard, cp, tp, pp 등 정의하는 클래스
dp = dp_replicate * dp_shard | ||
if dp < 0: | ||
dp = self.world_size // (cp * tp * pp) | ||
self.dp_shard = dp_shard = dp // dp_replicate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dp_shrad * dp_replicate는 word_size를 (cp * tp * pp)로 나눈 것과 같아야 함. (model parallel을 하고 남은 차원에서 DP)
def build_mesh(self, device_type): | ||
dims = [] | ||
names = [] | ||
for d, name in zip( | ||
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], | ||
["pp", "dp_replicate", "dp_shard", "cp", "tp"], | ||
): | ||
if d > 1: | ||
dims.append(d) | ||
if (name == "dp_replicate" and self.dp_shard == 1) or ( | ||
name == "dp_shard" and self.dp_replicate == 1 | ||
): | ||
names.append("dp") | ||
else: | ||
names.append(name) | ||
|
||
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") | ||
names = tuple(names) | ||
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device_mesh 생성하는 기능도 있음
# init distributed | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
parallel_dims = ParallelDims( | ||
dp_shard=job_config.training.data_parallel_shard_degree, | ||
dp_replicate=job_config.training.data_parallel_replicate_degree, | ||
cp=job_config.experimental.context_parallel_degree, | ||
tp=job_config.training.tensor_parallel_degree, | ||
pp=job_config.experimental.pipeline_parallel_degree, | ||
world_size=world_size, | ||
enable_loss_parallel=job_config.training.enable_loss_parallel, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParallelDims 정의
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") | ||
|
||
# build meshes | ||
world_mesh = parallel_dims.build_mesh(device_type=device_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
build mesh
if parallel_dims.pp_enabled: | ||
# apply PT-D Pipeline Parallel | ||
pp_schedule, model_parts = models_pipelining_fns[model_name]( | ||
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn | ||
) | ||
|
||
# For PP with looped schedules, each item in model_parts is one stage-model-chunk. | ||
# We need to iterate through model_parts to apply SPMD parallelisms, compilation, | ||
# optimizer, and checkpointing | ||
for m in model_parts: | ||
# apply SPMD-style PT-D techniques | ||
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) | ||
m.to_empty(device=init_device) | ||
m.init_weights(buffer_device=buffer_device) | ||
m.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
여기서 model_pipelining_fns를 불러오고 call 해줌.
with maybe_enable_profiling( | ||
job_config, global_step=train_state.step | ||
) as torch_profiler, maybe_enable_memory_snapshot( | ||
job_config, global_step=train_state.step | ||
) as memory_profiler: | ||
while train_state.step < job_config.training.steps: | ||
train_state.step += 1 | ||
gc_handler.run(train_state.step) | ||
|
||
# get batch | ||
data_load_start = time.perf_counter() | ||
batch = next(data_iterator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
여기가 실제 학습하는 부분
if parallel_dims.pp_enabled: | ||
# Pipeline Parallel forward / backward inside step() call | ||
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 | ||
|
||
with train_context(optional_context_parallel_ctx): | ||
if pp_mesh.get_local_rank() == 0: | ||
pp_schedule.step(input_ids) | ||
elif is_last_stage: | ||
losses = [] | ||
pp_schedule.step(target=labels, losses=losses) | ||
else: | ||
pp_schedule.step() | ||
|
||
# accumulate losses across pipeline microbatches | ||
loss = ( | ||
torch.mean(torch.stack(losses)) | ||
if is_last_stage | ||
else torch.Tensor([-1.0]) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pp_schedule.step()으로 micro batch forward 하는듯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(24.12.18) FSDP2 부분 읽음.
def parallelize_llama( | ||
model: nn.Module, | ||
world_mesh: DeviceMesh, | ||
parallel_dims: ParallelDims, | ||
job_config: JobConfig, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama에 대한 parallelize 하는 부분은 여기 모아져있음
and not job_config.training.compile | ||
): | ||
raise RuntimeError("Async TP requires --training.compile") | ||
apply_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TP
) | ||
|
||
if job_config.activation_checkpoint.mode != "none": | ||
apply_ac(model, job_config.activation_checkpoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
activation checkpointing
def parallelize_llama( | ||
model: nn.Module, | ||
world_mesh: DeviceMesh, | ||
parallel_dims: ParallelDims, | ||
job_config: JobConfig, | ||
): | ||
""" | ||
Apply tensor parallelism, activation checkpointing, torch.compile, and data | ||
parallelism to the model. | ||
|
||
NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
the model must fit on GPU or CPU memory. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
여기가 llama를 parallelize 하는 부분
if parallel_dims.tp_enabled: | ||
if ( | ||
job_config.experimental.enable_async_tensor_parallel | ||
and not job_config.training.compile | ||
): | ||
raise RuntimeError("Async TP requires --training.compile") | ||
apply_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TP
module (Union[nn.Module, List[nn.Module]): The module or modules to | ||
shard with FSDP and group together for communication. | ||
mesh (Optional[DeviceMesh]): This data parallel mesh defines the | ||
sharding and device. If 1D, then parameters are fully sharded | ||
across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D, | ||
then parameters are sharded across the 1st dim and replicated | ||
across the 0th dim (HSDP) with ``(Replicate(), Shard(0))`` | ||
placement. The mesh's device type gives the device type used for | ||
communication; if a CUDA or CUDA-like device type, then we use the | ||
current device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
module, mesh
mesh = mesh or _init_default_fully_shard_mesh() | ||
if mesh.ndim not in (1, 2): | ||
raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") | ||
elif mesh.ndim == 1: | ||
mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0) | ||
else: | ||
if mesh.mesh_dim_names is None: | ||
raise AssertionError( | ||
"Please init the 2D mesh for HSDP with mesh_dim_names specified" | ||
) | ||
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) | ||
device = _get_device_from_mesh(mesh) | ||
post_forward_mesh_info = _get_post_forward_mesh_info( | ||
reshard_after_forward, mesh_info | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mesh info 설정하는 부분
arg_module = module | ||
modules = ( | ||
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) | ||
) | ||
state = fully_shard.state(modules[0]) | ||
state.init(modules, device, mp_policy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fsdp state 설정해줌. hook같은거 설정해주고 등등 하는듯
class FSDPState(_State):
def __init__(self) -> None:
super().__init__()
self._fsdp_param_group: Optional[FSDPParamGroup] = None
self._is_root: Optional[bool] = None # root set during lazy init
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []
self._modules_to_run_forward: Set[nn.Module] = set()
# Define a separate init since `__init__` is called in the contract
def init(
self,
modules: Tuple[nn.Module, ...],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
) -> None:
for module in modules:
_insert_module_state(module, self)
self._modules = modules
self._device = device
self._device_handle = _get_device_handle(device.type)
self._mp_policy = mp_policy
if len(modules) == 1:
self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
self._pre_forward, prepend=True, with_kwargs=True
)
self._post_forward_hook_handle = modules[0].register_forward_hook(
self._post_forward, prepend=False
)
else:
hook_handle = _register_group_forward_hooks(
modules,
self._pre_forward,
self._post_forward,
self._modules_to_run_forward,
)
self._pre_forward_hook_handle = hook_handle
self._post_forward_hook_handle = hook_handle
if params: | ||
state._fsdp_param_group = FSDPParamGroup( | ||
params, | ||
modules, | ||
mesh_info, | ||
post_forward_mesh_info, | ||
device, | ||
shard_placement_fn, | ||
mp_policy, | ||
offload_policy, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Place FSDP leftmost for highest priority in the method resolution order | ||
for module in modules: | ||
cls = module.__class__ | ||
new_cls = cls_to_fsdp_cls.get(cls, None) | ||
if not new_cls: | ||
dct = {"__deepcopy__": _unimplemented_deepcopy} | ||
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) | ||
cls_to_fsdp_cls[cls] = new_cls | ||
module.__class__ = new_cls | ||
return arg_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cls wrapping해주는 부분 있음. 이 부분은 이전과 비슷한듯함
No description provided.