From 292388f6e432fe8228a79eeccd86290552f032ea Mon Sep 17 00:00:00 2001 From: HaochenYuan Date: Wed, 1 Jul 2026 23:55:30 -0700 Subject: [PATCH] partial cuda graph support for dynamic cp Signed-off-by: HaochenYuan --- megatron/core/datasets/data_schedule.py | 16 +- megatron/core/datasets/data_schedule_utils.py | 12 +- megatron/core/model_parallel_config.py | 4 + megatron/core/packed_seq_params.py | 54 +++-- megatron/core/transformer/cuda_graphs.py | 196 ++++++++++++++---- megatron/core/transformer/module.py | 2 + .../transformer/multi_latent_attention.py | 14 ++ .../core/transformer/transformer_config.py | 8 + .../core/transformer/transformer_layer.py | 58 +++++- megatron/training/arguments.py | 5 + megatron/training/training.py | 3 + .../transformer/test_thd_cuda_graph.py | 105 +++++++--- 12 files changed, 365 insertions(+), 112 deletions(-) diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py index d5bf07053cc..62cf027ca9b 100644 --- a/megatron/core/datasets/data_schedule.py +++ b/megatron/core/datasets/data_schedule.py @@ -388,6 +388,9 @@ def run( ) ) num_micro_batches = int(num_micro_batches) + graph_slots = getattr(config, '_cuda_graph_num_microbatches', None) + if graph_slots is not None and num_micro_batches > graph_slots: + raise ValueError(f"{num_micro_batches=} exceeds captured CUDA graph {graph_slots=}.") # Step 8: Broadcast to TP group and create data_iterator new_data_iterator = create_data_iterator( @@ -431,6 +434,7 @@ def get_groups_and_subsamples(self, sample_id_seqlens): self.total_hdp_gpus, max_seq_len_per_rank=mslpr, min_cp_size=min_cp, + max_num_seqs=self.max_num_seqs, ) sample_id_groups.append(sample_ids) @@ -455,8 +459,8 @@ def _get_scheduler_max_real_num_seqs(config) -> Optional[int]: """Return the scheduler cap for real THD sequences. ``thd_max_packed_sequences`` is the final static THD capacity, including the - optional dummy sequence appended for a padding tail. The dp_balanced - scheduler only packs real sequences, so reserve one slot when dummy-tail + optional dummy sequence appended for a padding tail. Packing schedulers + only place real sequences, so reserve one slot when dummy-tail padding is enabled. """ max_num_seqs = getattr(config, 'thd_max_packed_sequences', None) @@ -522,11 +526,7 @@ def wrap_data_iterator( if scheduler_type == 'default_dynamic_cp': scheduler_kwargs['min_cp_size'] = config.min_dynamic_context_parallel_size - scheduler_max_num_seqs = ( - _get_scheduler_max_real_num_seqs(config) - if scheduler_type == 'dp_balanced' - else getattr(config, 'thd_max_packed_sequences', None) - ) + scheduler_max_num_seqs = _get_scheduler_max_real_num_seqs(config) scheduler = scheduler_map[scheduler_type]( config.max_seqlen_per_dp_cp_rank, @@ -778,7 +778,7 @@ def get_batch_on_this_rank_for_sequence_packing( max_seqlen_kv=max_seqlen, local_cp_size=local_cp_size, cp_group=cp_group, - pad_between_seqs=False, + pad_between_seqs=True, ) # Pad the already-packed THD tensors at the end when requested. CUDA Graph diff --git a/megatron/core/datasets/data_schedule_utils.py b/megatron/core/datasets/data_schedule_utils.py index 190f898cfc3..fab9fc98fa1 100644 --- a/megatron/core/datasets/data_schedule_utils.py +++ b/megatron/core/datasets/data_schedule_utils.py @@ -536,6 +536,7 @@ def next_hdp_group_packing_aware( total_gpus: int, max_seq_len_per_rank: int, min_cp_size: int = 1, + max_num_seqs: Optional[int] = None, ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: """Form one DCP microbatch with packing-aware CP group selection. @@ -549,7 +550,7 @@ def next_hdp_group_packing_aware( The scheduler keeps the legacy invariant that each returned microbatch has no empty DPxCP rank after the fill step. For non-power-of-two DPxCP layouts, it falls back to the full DPxCP group if power-of-two expansion cannot fill - every rank. + every rank. ``max_num_seqs`` optionally caps the real sequences per subgroup. """ if not sample_seqlens: return ( @@ -610,6 +611,11 @@ def workload(seq_len: int, cp_size: int) -> float: for group_id, size in list(group_size.items()): if size != cp_size: continue + if ( + max_num_seqs is not None + and len(micro_batches[group_members[group_id][0]]) >= max_num_seqs + ): + continue if packing_sequence_len.get(group_id, 0) + seq_len / cp_size > max_seq_len_per_rank: continue members = group_members[group_id] @@ -745,7 +751,9 @@ def fill_with_full_dpxcp_group() -> None: for sample_id, seq_len in sample_seqlens: per_rank_len = seq_len / total_gpus - if packed_sequence_len + per_rank_len <= max_seq_len_per_rank: + if ( + max_num_seqs is None or len(selected) < max_num_seqs + ) and packed_sequence_len + per_rank_len <= max_seq_len_per_rank: selected.append((sample_id, seq_len)) packed_sequence_len += per_rank_len else: diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 5c2786285b1..03f96b00c76 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -125,6 +125,10 @@ class ModelParallelConfig: pad_packed_seq_alignment, but cu_seqlens sequence boundaries are not extended for the padding tail. CUDA Graph static-input padding may still pad the cu_seqlens tensors to thd_max_packed_sequences + 1 entries. + + Fused RoPE is unsafe only when disabling this option creates a hidden-only + tail beyond the last padded cu_seqlens boundary; inputs without such a tail + are unaffected. """ expert_model_parallel_size: int = 1 diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index 3095b1b8464..e567fbd757f 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -144,18 +144,16 @@ def _pad_cu_seqlens(cu_seqlens: Optional[Tensor], target_entries: int) -> Option return padded -def _append_dummy_seq(cu_seqlens: Optional[Tensor], dummy_end: int) -> Optional[Tensor]: - """Append a dummy sequence boundary to a cu_seqlens tensor. - - ``dummy_end`` is the padded target length. Appending it to both - ``cu_seqlens_*`` and ``cu_seqlens_*_padded`` represents the post-pack - alignment tail as an ordinary dummy sequence. That keeps every token row - covered by THD metadata without enabling TE's pad-between-sequences mode. - """ +def _append_dummy_seq( + cu_seqlens: Optional[Tensor], dummy_end: Optional[int] = None +) -> Optional[Tensor]: + """Append a boundary, repeating the final offset by default for a zero-token dummy.""" if cu_seqlens is None: return None - - dummy = torch.full((1,), int(dummy_end), dtype=cu_seqlens.dtype, device=cu_seqlens.device) + if dummy_end is None: + dummy = cu_seqlens[-1:] + else: + dummy = cu_seqlens.new_full((1,), int(dummy_end)) return torch.cat((cu_seqlens, dummy), dim=0) @@ -206,7 +204,7 @@ def _resolve_thd_padding_lengths( Returns: local_actual_T: Current rank's token-like tensor length. - global_actual_T: Global packed length represented by THD metadata. + global_actual_T: Global physical packed length represented by THD metadata. local_target_len: Current rank's padded token-like tensor length. global_target_len: Global padded endpoint represented by THD metadata. mask_device: Device used to build the returned padding mask. @@ -227,10 +225,15 @@ def _resolve_thd_padding_lengths( # Prefer THD metadata for the global packed length when it is available. has_local_tensor = local_tensor_T is not None - if packed_seq_params.cu_seqlens_q is not None: - global_actual_T = int(packed_seq_params.cu_seqlens_q[-1].item()) + physical_cu_seqlens = ( + packed_seq_params.cu_seqlens_q_padded + if packed_seq_params.cu_seqlens_q_padded is not None + else packed_seq_params.cu_seqlens_q + ) + if physical_cu_seqlens is not None: + global_actual_T = int(physical_cu_seqlens[-1].item()) if mask_device is None: - mask_device = packed_seq_params.cu_seqlens_q.device + mask_device = physical_cu_seqlens.device else: assert has_local_tensor, ( "packed_seq_params.cu_seqlens_q must be available to derive padding_mask " @@ -386,8 +389,8 @@ def pad_sequence_for_thd( stages, Megatron asks TE which packed rows this CP rank would receive and uses that row count as the local length instead of assuming equal division by CP size. - - When ``pad_by_appending_dummy_seq`` is true, the padding tail is also - represented as an ordinary dummy sequence in cu_seqlens metadata. + - When ``pad_by_appending_dummy_seq`` is true, the padding tail is represented + by a zero-valid-token dummy sequence whose padded boundary spans the tail. - ``max_num_seqs`` pads all four cu_seqlens tensors; this is required by CUDA Graph replay because those tensors are graph inputs. @@ -437,16 +440,25 @@ def pad_sequence_for_thd( cu_seqlens_q_padded = packed_seq_params.cu_seqlens_q_padded cu_seqlens_kv_padded = packed_seq_params.cu_seqlens_kv_padded + physical_q = cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q + physical_kv = cu_seqlens_kv_padded if cu_seqlens_kv_padded is not None else cu_seqlens_kv + # Represent post-pack padding as a dummy sequence when requested. target_cu_entries = None if max_num_seqs is None else max_num_seqs + 1 has_dummy_padding_seq = pad_by_appending_dummy_seq and global_target_len > global_actual_T dummy_seq_len = global_target_len - global_actual_T if has_dummy_padding_seq else 0 if has_dummy_padding_seq: - cu_seqlens_q = _append_dummy_seq(cu_seqlens_q, global_target_len) - cu_seqlens_kv = _append_dummy_seq(cu_seqlens_kv, global_target_len) - cu_seqlens_q_padded = _append_dummy_seq(cu_seqlens_q_padded, global_target_len) - cu_seqlens_kv_padded = _append_dummy_seq(cu_seqlens_kv_padded, global_target_len) + if physical_q is not None and physical_kv is not None and physical_q is not physical_kv: + assert physical_q[-1].item() == physical_kv[-1].item(), ( + "One appended THD tail dummy requires matching Q and KV physical endpoints." + ) + # Keep the logical endpoint unchanged and span the physical tail only in the + # padded metadata so fused THD helpers initialize every tensor row. + cu_seqlens_q = _append_dummy_seq(cu_seqlens_q) + cu_seqlens_kv = _append_dummy_seq(cu_seqlens_kv) + cu_seqlens_q_padded = _append_dummy_seq(physical_q, global_target_len) + cu_seqlens_kv_padded = _append_dummy_seq(physical_kv, global_target_len) # Pad cu_seqlens entry counts for static CUDA Graph inputs. if target_cu_entries is not None: @@ -475,7 +487,7 @@ def pad_sequence_for_thd( local_cp_size=packed_seq_params.local_cp_size, cp_group=packed_seq_params.cp_group, total_tokens=local_target_len if target_cu_entries is None else None, - pad_between_seqs=False if has_dummy_padding_seq else packed_seq_params.pad_between_seqs, + pad_between_seqs=has_dummy_padding_seq or packed_seq_params.pad_between_seqs, ) # True marks padded local token slots for routing/loss paths. diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index d59d1fbf5b0..d9344c3545a 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -8,7 +8,7 @@ import os import time from collections import defaultdict -from contextlib import nullcontext +from contextlib import nullcontext, suppress from copy import deepcopy from dataclasses import dataclass, is_dataclass from enum import Enum @@ -1724,6 +1724,11 @@ def _layer_is_graphable(layer, config): if not isinstance(layer, GraphableMegatronModule): return False + if getattr(config, 'dynamic_context_parallel', False) and not hasattr( + layer, '_activate_dynamic_cp_cuda_graph' + ): + return False + # If cuda_graph_modules is not set, every layer is graphed. if not config.cuda_graph_modules: return True @@ -2021,12 +2026,18 @@ def get_rotary_pos_emb(transformer_module, transformer_input): transformer_module.position_embedding_type == 'rope' and not self.config.multi_latent_attention ): - rotary_seq_len = transformer_module.rotary_pos_emb.get_rotary_seq_len( - None, transformer_module.decoder, transformer_input, self.config, None - ) + packed_seq = hasattr(layer, "_is_thd_cuda_graph") and layer._is_thd_cuda_graph() + capture_cp_group = None + if packed_seq: + capture_cp_size, capture_cp_group = layer._get_thd_cuda_graph_capture_cp() + rotary_seq_len = self.config.max_seqlen_per_dp_cp_rank * capture_cp_size + else: + rotary_seq_len = transformer_module.rotary_pos_emb.get_rotary_seq_len( + None, transformer_module.decoder, transformer_input, self.config, None + ) if rotary_seq_len not in rotary_pos_emb_cache: rotary_pos_emb_cache[rotary_seq_len] = transformer_module.rotary_pos_emb( - rotary_seq_len + rotary_seq_len, packed_seq=packed_seq, cp_group=capture_cp_group ) return rotary_pos_emb_cache[rotary_seq_len] else: @@ -2323,6 +2334,17 @@ def _get_thd_varlen_max_num_microbatches( self, runtime_num_microbatches, microbatch_group_size_per_vp_stage ): """Return the THD packing upper bound used for dynamic CUDA graph capture.""" + if self.config.sequence_packing_scheduler == 'default_dynamic_cp': + max_num_microbatches = runtime_num_microbatches * self.micro_batch_size + max_num_microbatches *= self.dp_group.size() * getattr( + self.config, + 'thd_max_subsamples_per_item', + self.config.thd_max_packed_sequences, + ) + return ( + max(1, max_num_microbatches), + "dynamic_cp_global_subsample_upper_bound", + ) if self.config.sequence_packing_scheduler != 'dp_balanced': return runtime_num_microbatches, "runtime" if self.config.max_seqlen_per_dp_cp_rank is None: @@ -2592,6 +2614,31 @@ def _get_fp8_enabled(): kwargs = get_make_graphed_callables_kwargs() return sample_args, kwargs + def _get_dynamic_cp_capture_contexts(self): + """Return ``(local_cp_size, process_group)`` contexts to capture eagerly.""" + if not self.config.dynamic_context_parallel: + return [(None, None)] + + max_size = self.dp_cp_group.size() + min_size = self.config.min_dynamic_context_parallel_size + if min_size > max_size or any( + size < 1 or size & (size - 1) for size in (min_size, max_size) + ): + raise ValueError("Dynamic CP graph sizes must be powers of two with min <= max.") + largest, smallest = (int(math.log2(size)) for size in (max_size, min_size)) + return [ + (size, parallel_state.get_dynamic_data_context_parallel_groups(group_size=size)) + for size in (2**exponent for exponent in range(largest, smallest - 1, -1)) + ] + + @staticmethod + def _clear_cuda_graph_state(layer): + layer.cuda_graphs = [] + layer.cuda_graphs_by_dynamic_cp_size = {} + layer.cuda_graph_manual_hooks = [] + for module in layer.modules(): + vars(module).pop('_cuda_graph_static_tensor_refs', None) + def _start_capturing(self): """ Start capturing CUDA Graphs. @@ -2662,52 +2709,110 @@ def _finish_capturing(self, start_time): self._capture_finished = True + def _abort_capturing(self, captured_graphs): + """Best-effort, non-collective cleanup that preserves the capture error.""" + _set_capture_end() + self._graphs_created = False + if FREEZE_GC: + with suppress(BaseException): + gc.unfreeze() + + for graph in chain.from_iterable(captured_graphs.values()): + with suppress(BaseException): + graph.reset() + + for layers in self.callables_per_chunk: + for layer in layers: + with suppress(BaseException): + self._clear_cuda_graph_state(layer) + + if self.config.fine_grained_activation_offloading: + with suppress(BaseException): + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + off_interface.reset() + + for cleanup in (self._reset_after_capture, gc.collect, torch.cuda.empty_cache): + with suppress(BaseException): + cleanup() + + def _install_captured_graphs(self, capture_contexts, captured_graphs): + """Install completed graph-bank variants on their corresponding layers.""" + num_layers_accumulated = 0 + for layers in self.callables_per_chunk: + for layer_number, layer in enumerate(layers): + base = ( + (num_layers_accumulated + layer_number) * self.num_microbatches + if self.config.overlap_moe_expert_parallel_comm + else num_layers_accumulated * self.num_microbatches + layer_number + ) + stride = 1 if self.config.overlap_moe_expert_parallel_comm else len(layers) + for cp_size, _ in capture_contexts: + graphs = captured_graphs[cp_size] + layer_graphs = [ + graphs[base + batch_number * stride] + for batch_number in range(self.num_microbatches) + ] + + layer.cuda_graphs = layer_graphs + if cp_size is not None: + layer.cuda_graphs_by_dynamic_cp_size[cp_size] = layer_graphs + num_layers_accumulated += len(layers) + + self._graphs_created = True + def create_cudagraphs(self): """ Capture CUDA Graphs per TransformerLayer per microbatch. """ - start_time = self._start_capturing() - - if not self.flattened_callables: - # Check if there are any graphable layers. If not, log a warning and skip capture, - # but still call _finish_capturing to ensure all ranks complete the capture phase. - logger.warning( - 'TECudaGraphHelper: No graphable layers found. Skipping CUDA graph capture.' - ) - else: - # Prepare CUDA Graph capturing input data and call `make_graphed_callables`. - sample_args, kwargs = self._get_cuda_graph_input_data() - if self.config.sequence_parallel: - rng_context = get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - with rng_context: - graphs = make_graphed_callables( - tuple(self.flattened_callables), sample_args, **kwargs + captured_graphs = {} + try: + start_time = self._start_capturing() + has_graphable_layers = bool(self.flattened_callables) + if not has_graphable_layers: + # Still participate in every schedule collective and barrier so + # PP stages with graphable layers cannot deadlock. + logger.warning( + 'TECudaGraphHelper: No graphable layers found. Skipping CUDA graph capture.' ) - # Push the captured graphs to the corresponding TransformerBlock. - num_layers_accumulated = 0 - for layers in self.callables_per_chunk: - for layer_number, layer in enumerate(layers): - layer.cuda_graphs = [] - for batch_number in range(self.num_microbatches): - if self.config.overlap_moe_expert_parallel_comm: - graph_idx = ( - num_layers_accumulated + layer_number - ) * self.num_microbatches + batch_number + # Keep variants helper-local so later captures cannot replay earlier ones. + capture_contexts = self._get_dynamic_cp_capture_contexts() + try: + for capture_idx, (cp_size, cp_group) in enumerate(capture_contexts): + if cp_size is not None: + self.config._cuda_graph_capture_dynamic_cp = (cp_size, cp_group) + + # This call contains PP-group collectives when dynamic graph + # slots are enabled, so all PP stages must participate. + sample_args, kwargs = self._get_cuda_graph_input_data() + if has_graphable_layers: + if self.config.sequence_parallel: + rng_context = get_cuda_rng_tracker().fork() else: - graph_idx = ( - num_layers_accumulated * self.num_microbatches - + batch_number * len(layers) - + layer_number + rng_context = nullcontext() + with rng_context: + captured_graphs[cp_size] = make_graphed_callables( + tuple(self.flattened_callables), sample_args, **kwargs ) - layer.cuda_graphs.append(graphs[graph_idx]) - num_layers_accumulated += len(layers) - self._graphs_created = True + if capture_idx + 1 < len(capture_contexts): + torch.cuda.synchronize() + torch.distributed.barrier() + self._reset_after_capture() + finally: + vars(self.config).pop('_cuda_graph_capture_dynamic_cp', None) - self._finish_capturing(start_time) + if has_graphable_layers: + self._install_captured_graphs(capture_contexts, captured_graphs) + self._finish_capturing(start_time) + if self.config.dynamic_context_parallel and self.pp_group.size() > 1: + self.config._cuda_graph_num_microbatches = self.num_microbatches + except BaseException: + self._abort_capturing(captured_graphs) + raise def cuda_graph_set_manual_hooks(self): """ @@ -2729,14 +2834,17 @@ def delete_cuda_graphs(self): graphs_reset, graphs_not_reset = 0, 0 for layers in self.callables_per_chunk: for layer in layers: - for graph in layer.cuda_graphs: + graph_bank = layer.cuda_graphs_by_dynamic_cp_size + graph_lists = graph_bank.values() if graph_bank else (layer.cuda_graphs,) + for graph in chain.from_iterable(graph_lists): if graph_resettable: graph.reset() graphs_reset += 1 else: graphs_not_reset += 1 - layer.cuda_graphs = [] - layer.cuda_graph_manual_hooks = [] + self._clear_cuda_graph_state(layer) + + vars(self.config).pop('_cuda_graph_num_microbatches', None) log_on_each_pipeline_stage( logger=logger, diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index c5211e3e6d6..4558a9f39b6 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -183,6 +183,8 @@ def __init__(self, config: TransformerConfig, vp_stage: Optional[int] = None): # script with the graphs returned by make_graphed_callables API before the first # training step. self.cuda_graphs = [] + # DCP communicators are capture-time constants, so keep one graph list per CP size. + self.cuda_graphs_by_dynamic_cp_size = {} # List to store forward pre-hooks. Forward pre-hooks are not captured into CUDA # graphs. Those hooks and args are collected in this list and should be manually # triggered before CUDA Graph running. This is required to ensure the correct param diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 0523096bea7..44b7c75a68e 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -134,6 +134,18 @@ class MultiLatentAttention(Attention): "cross attn" specializations. """ + def _retain_cuda_graph_rope_tensors(self, *rope_tensors: Optional[torch.Tensor]) -> None: + """Retain MLA-internal RoPE allocations used by TE CUDA graphs.""" + if getattr(self.config, 'cuda_graph_impl', 'none') != 'transformer_engine': + return + + from megatron.core.transformer.cuda_graphs import is_graph_capturing + + if not is_graph_capturing(): + return + refs = self.__dict__.setdefault('_cuda_graph_static_tensor_refs', {}) + refs.update((id(tensor), tensor) for tensor in rope_tensors if tensor is not None) + def __init__( self, config: MLATransformerConfig, @@ -719,6 +731,8 @@ def get_query_key_value_tensors( rotary_seq_len, packed_seq=thd_packed_seq ) + self._retain_cuda_graph_rope_tensors(rotary_pos_emb, rotary_pos_cos, rotary_pos_sin) + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': if packed_seq_params.cu_seqlens_q_padded is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 8f01db519de..53e8c6e9ffe 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -3077,6 +3077,14 @@ def _scope_to_str(s): f"({self.max_seqlen_per_dp_cp_rank}), got {self.pad_packed_seq_alignment}." ) + if self.dynamic_context_parallel and self.cuda_graph_impl != "none": + if self.cuda_graph_impl != "transformer_engine": + raise ValueError("Dynamic CP supports only layer-wise TE CUDA graphs.") + if not self.cuda_graph_dynamic_microbatches: + raise ValueError("Dynamic CP CUDA graphs require dynamic microbatch slots.") + if self.delay_wgrad_compute or self.overlap_moe_expert_parallel_comm: + raise ValueError("Dynamic CP graphs do not support delayed wgrad or EP overlap.") + if self.sequence_packing_scheduler is not None: # Check TE version. if not HAVE_PACKAGING: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index d3571152966..0c1d3369e02 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1181,6 +1181,12 @@ def get_qkv_layer_norm_weights(self) -> Tensor: """ return self.self_attention.linear_qkv.layer_norm_weight.data + def _get_thd_cuda_graph_capture_cp(self): + """Return the CP size/group whose constants are being captured.""" + if self.config.dynamic_context_parallel: + return self.config._cuda_graph_capture_dynamic_cp + return self.config.context_parallel_size, None + def get_layer_static_inputs(self, seq_length, micro_batch_size): """ Get the static inputs for the transformer layer. Besides the hidden_states that is @@ -1209,7 +1215,8 @@ def get_layer_static_inputs(self, seq_length, micro_batch_size): # cu_seqlens = [0, max_T, max_T, ..., max_T] # which represents a single packed sequence followed by zero-length # entries. cu_seqlens_q / kv / *_padded all share this layout. - max_T = self.config.max_seqlen_per_dp_cp_rank * self.config.context_parallel_size + capture_cp_size, _ = self._get_thd_cuda_graph_capture_cp() + max_T = self.config.max_seqlen_per_dp_cp_rank * capture_cp_size max_num_seqs = self.config.thd_max_packed_sequences cu_seqlens = torch.zeros(max_num_seqs + 1, dtype=torch.int32, device=device) cu_seqlens[1:] = max_T @@ -1318,7 +1325,8 @@ def _reconstruct_packed_seq_params_from_kwargs(self, kwargs): """ if 'cu_seqlens_q' not in kwargs: return - max_seqlen = self.config.max_seqlen_per_dp_cp_rank * self.config.context_parallel_size + capture_cp_size, capture_cp_group = self._get_thd_cuda_graph_capture_cp() + max_seqlen = self.config.max_seqlen_per_dp_cp_rank * capture_cp_size packed_seq_params = PackedSeqParams( qkv_format='thd', cu_seqlens_q=kwargs.pop('cu_seqlens_q'), @@ -1327,10 +1335,32 @@ def _reconstruct_packed_seq_params_from_kwargs(self, kwargs): cu_seqlens_kv_padded=kwargs.pop('cu_seqlens_kv_padded'), max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, - pad_between_seqs=False, + local_cp_size=(capture_cp_size if self.config.dynamic_context_parallel else None), + cp_group=(capture_cp_group if self.config.dynamic_context_parallel else None), + # Runtime THD batches carry distinct logical and padded boundaries. + # Capture must select the same TE context-parallel path as eager. + pad_between_seqs=True, ) kwargs['packed_seq_params'] = packed_seq_params + def _activate_dynamic_cp_cuda_graph(self, packed_seq_params): + """Select the graph-bank entry for a runtime DCP microbatch.""" + graph_bank = self.cuda_graphs_by_dynamic_cp_size + if not graph_bank: + return + + assert packed_seq_params is not None and packed_seq_params.local_cp_size is not None + dynamic_cp_size = int(packed_seq_params.local_cp_size) + assert dynamic_cp_size in graph_bank, ( + f"No layer CUDA graph bank entry for local_cp_size={dynamic_cp_size}; " + f"available sizes are {sorted(graph_bank)}." + ) + expected_group = parallel_state.get_dynamic_data_context_parallel_groups( + group_size=dynamic_cp_size + ) + assert packed_seq_params.cp_group is expected_group + self.cuda_graphs = graph_bank[dynamic_cp_size] + def _te_cuda_graph_capture(self, *args, **kwargs): """ CUDA Graph capture for this layer using TE interface. @@ -1341,6 +1371,14 @@ def _te_cuda_graph_capture(self, *args, **kwargs): For THD format, PackedSeqParams is reconstructed from tensor kwargs. """ self._reconstruct_packed_seq_params_from_kwargs(kwargs) + if self.config.dynamic_context_parallel and kwargs.get('packed_seq_params') is None: + # Scopes such as moe_router may leave attention eager and therefore + # have no cu_seqlens tensor inputs. The router still needs the DCP + # group for its captured auxiliary-loss reduction. + capture_cp_size, capture_cp_group = self._get_thd_cuda_graph_capture_cp() + kwargs['packed_seq_params'] = PackedSeqParams( + qkv_format='thd', local_cp_size=capture_cp_size, cp_group=capture_cp_group + ) # Record the backward event on cuda graph stream in backward pass. # This is to ensure the main stream waits for computing on cuda graph stream to complete, @@ -1404,6 +1442,8 @@ def _te_cuda_graph_replay(self, *args, **kwargs): Hence, `inference_context` and `packed_seq_params` are excluded from input list. For THD format, PackedSeqParams is decomposed into individual tensor kwargs. """ + self._activate_dynamic_cp_cuda_graph(kwargs.get('packed_seq_params')) + eager_packed_seq_params = kwargs.get('packed_seq_params') context = None padding_mask = kwargs.get("padding_mask", None) if ( @@ -1433,7 +1473,9 @@ def _te_cuda_graph_replay(self, *args, **kwargs): self.off_interface.enter_replay() try: - return self._te_cuda_graph_replay_impl(args, kwargs, context) + return self._te_cuda_graph_replay_impl( + args, kwargs, context, eager_packed_seq_params=eager_packed_seq_params + ) finally: if self.config.delay_offload_until_cuda_graph: self.off_interface.exit_replay() @@ -1490,7 +1532,7 @@ def resume_moe_experts_after_partial_cudagraph(self, cuda_graph_output): nvtx_range_pop(suffix="mlp") return mlp_output_with_bias - def _te_cuda_graph_replay_impl(self, args, kwargs, context): + def _te_cuda_graph_replay_impl(self, args, kwargs, context, eager_packed_seq_params=None): """Implementation of _te_cuda_graph_replay, separated for replay mode cleanup.""" cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) @@ -1606,7 +1648,7 @@ def _te_cuda_graph_replay_impl(self, args, kwargs, context): hidden_states, padding_mask=kwargs.get("padding_mask", None), input_ids=kwargs.get("input_ids", None), - packed_seq_params=kwargs.get("packed_seq_params", None), + packed_seq_params=eager_packed_seq_params, ) return output, context @@ -2223,7 +2265,7 @@ def _forward_post_mlp_with_fused_hyper_connection( ) return output - def _te_cuda_graph_replay_impl(self, args, kwargs, context): + def _te_cuda_graph_replay_impl(self, args, kwargs, context, eager_packed_seq_params=None): """Implementation of _te_cuda_graph_replay with hyper connection support. Overrides the parent's _te_cuda_graph_replay_impl so that the @@ -2317,7 +2359,7 @@ def _te_cuda_graph_replay_impl(self, args, kwargs, context): output = self._forward_mlp( *cuda_graph_output, input_ids=kwargs.get("input_ids", None), - packed_seq_params=kwargs.get("packed_seq_params", None), + packed_seq_params=eager_packed_seq_params, ) return output, context diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5125c5abc92..c0ff66093ce 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1583,6 +1583,11 @@ def validate_args(args, defaults={}): f'min_dynamic_context_parallel_size ({args.min_dynamic_context_parallel_size}) ' f'must be <= dp_size * cp_size ({dp_cp_size})' ) + if ( + args.cuda_graph_impl == 'transformer_engine' + and args.step_batch_size_schedule is not None + ): + raise ValueError('Dynamic CP CUDA graphs do not support step_batch_size_schedule.') import warnings diff --git a/megatron/training/training.py b/megatron/training/training.py index ae6216c260b..15878856d23 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -3790,6 +3790,9 @@ def trace_handler(p): # Initialize CUDA Graphs helper. if args.cuda_graph_impl == "transformer_engine": + config.thd_max_subsamples_per_item = ( + 1 if args.use_varlen_dataset else args.thd_max_packed_sequences + ) cuda_graph_helper = TECudaGraphHelper( model=model, config=config, diff --git a/tests/unit_tests/transformer/test_thd_cuda_graph.py b/tests/unit_tests/transformer/test_thd_cuda_graph.py index 92a81b2fcdc..e6dde90a049 100644 --- a/tests/unit_tests/transformer/test_thd_cuda_graph.py +++ b/tests/unit_tests/transformer/test_thd_cuda_graph.py @@ -25,7 +25,9 @@ import re import socket import subprocess +import weakref from pathlib import Path +from types import SimpleNamespace import pytest import torch @@ -268,18 +270,16 @@ def teardown_method(self): @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_generic_alignment_appends_dummy_padding_sequence(self): - """Generic THD padding covers tail slots with an independent dummy sequence.""" + """Generic THD padding covers tail slots with a zero-valid-token dummy sequence.""" seqlens, total_T = [50, 30], 80 psp = _make_psp(seqlens) - orig = psp.cu_seqlens_q.clone() p_tok, _, _, _, p, mask = pad_sequence_for_thd( torch.ones(1, total_T, device="cuda"), None, None, None, psp, alignment=64 ) assert p_tok.shape == (1, 128) - expected = torch.cat((orig, torch.tensor([128], dtype=orig.dtype, device=orig.device))) - assert torch.equal(p.cu_seqlens_q, expected) - assert torch.equal(p.cu_seqlens_q_padded, expected) - assert p.pad_between_seqs is False + assert p.cu_seqlens_q.tolist() == [0, 50, 80, 80] + assert p.cu_seqlens_q_padded.tolist() == [0, 50, 80, 128] + assert p.pad_between_seqs is True assert mask.shape == (1, 128) assert not mask[0, :total_T].any() and mask[0, total_T:].all() @@ -291,18 +291,19 @@ def test_cp_alignment_uses_global_cu_seqlens_length(self): Utils.initialize_model_parallel(tensor_model_parallel_size=1, context_parallel_size=2) psp = _make_psp([140]) + psp.cu_seqlens_q_padded[-1] = psp.cu_seqlens_kv_padded[-1] = 160 + psp.max_seqlen_q = psp.max_seqlen_kv = 160 local_T = 80 p_tok, _, _, _, p, mask = pad_sequence_for_thd( torch.ones(1, local_T, device="cuda"), None, None, None, psp, alignment=128 ) assert p_tok.shape[-1] >= local_T - assert p.cu_seqlens_q[-1].item() == 256 + assert p.cu_seqlens_q[-1].item() == 140 assert p.cu_seqlens_q_padded[-1].item() == 256 - assert p.max_seqlen_q == 140 - assert p.max_seqlen_kv == 140 + assert p.max_seqlen_q == 160 + assert p.max_seqlen_kv == 160 assert mask.shape[-1] == p_tok.shape[-1] - assert not mask[0, :local_T].any() @pytest.mark.internal @_REQUIRES_TWO_RANKS @@ -318,7 +319,7 @@ def test_cp_alignment_covers_local_padding_tail(self): ) assert p_tok.shape[-1] == 1664 - assert p.cu_seqlens_q[-1].item() == 3328 + assert p.cu_seqlens_q[-1].item() == 3200 assert p.cu_seqlens_q_padded[-1].item() == 3328 assert mask.shape[-1] == p_tok.shape[-1] assert not mask[0, :local_T].any() @@ -375,16 +376,19 @@ def test_shapes_and_data_preservation(self): p_params.cu_seqlens_kv_padded, ): assert cu.shape[0] == max_num_seqs + 1 - expected_cu = torch.tensor( + expected_actual_cu = torch.tensor( + [0, 100, 150, 180, 180, 180, 180, 180, 180], dtype=torch.int32, device="cuda" + ) + expected_padded_cu = torch.tensor( [0, 100, 150, 180, 256, 256, 256, 256, 256], dtype=torch.int32, device="cuda" ) - assert torch.equal(p_params.cu_seqlens_q, expected_cu) - assert torch.equal(p_params.cu_seqlens_kv, expected_cu) - assert torch.equal(p_params.cu_seqlens_q_padded, expected_cu) - assert torch.equal(p_params.cu_seqlens_kv_padded, expected_cu) + assert torch.equal(p_params.cu_seqlens_q, expected_actual_cu) + assert torch.equal(p_params.cu_seqlens_kv, expected_actual_cu) + assert torch.equal(p_params.cu_seqlens_q_padded, expected_padded_cu) + assert torch.equal(p_params.cu_seqlens_kv_padded, expected_padded_cu) assert p_params.max_seqlen_q == max_seqlen assert p_params.max_seqlen_kv == max_seqlen - assert p_params.pad_between_seqs is False + assert p_params.pad_between_seqs is True assert p_mask.shape == (1, max_seqlen) and p_mask.dtype == torch.bool assert torch.equal(p_tok[0, :total_T], tokens[0]) assert (p_tok[0, total_T:] == 0).all() @@ -392,10 +396,11 @@ def test_shapes_and_data_preservation(self): @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_eager_pad_to_max_adds_dummy_padding_sequence(self): - """Eager pad-to-max represents the tail as an independent dummy sequence.""" + """Eager pad-to-max represents the tail as a zero-valid-token dummy sequence.""" seqlens, total_T, target_len = [50, 30], 80, 8192 psp = _make_psp(seqlens) - orig_cu = psp.cu_seqlens_q.clone() + psp.cu_seqlens_q_padded = None + psp.cu_seqlens_kv_padded = None alignment, pad_target_len, max_num_seqs = get_thd_padding_kwargs( pad_packed_seq_alignment="max", max_seqlen_per_dp_cp_rank=target_len, @@ -415,16 +420,12 @@ def test_eager_pad_to_max_adds_dummy_padding_sequence(self): ) assert p_tok.shape == (1, target_len) - expected = torch.cat( - (orig_cu, torch.tensor([target_len], dtype=orig_cu.dtype, device=orig_cu.device)) - ) - assert torch.equal(p_params.cu_seqlens_q, expected) - assert torch.equal(p_params.cu_seqlens_q_padded, expected) - assert p_params.cu_seqlens_q.shape[0] == orig_cu.shape[0] + 1 + assert p_params.cu_seqlens_q.tolist() == [0, 50, 80, 80] + assert p_params.cu_seqlens_q_padded.tolist() == [0, 50, 80, target_len] assert p_params.max_seqlen_q == target_len - total_T assert p_params.max_seqlen_kv == target_len - total_T assert p_params.total_tokens == target_len - assert p_params.pad_between_seqs is False + assert p_params.pad_between_seqs is True assert p_mask.shape == (1, target_len) assert not p_mask[0, :total_T].any() assert p_mask[0, total_T:].all() @@ -489,10 +490,10 @@ def test_cu_seqlens_fill_value(self): max_num_seqs=32, ) assert p.cu_seqlens_q[0] == 0 and p.cu_seqlens_q[2] == 80 - assert (p.cu_seqlens_q[3:] == 128).all() + assert (p.cu_seqlens_q[3:] == 80).all() assert p.cu_seqlens_q_padded[0] == 0 and p.cu_seqlens_q_padded[2] == 80 assert (p.cu_seqlens_q_padded[3:] == 128).all() - assert p.pad_between_seqs is False + assert p.pad_between_seqs is True @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -540,7 +541,7 @@ def test_round_trip(self): layer._reconstruct_packed_seq_params_from_kwargs(kw) r = kw['packed_seq_params'] assert r.qkv_format == 'thd' and r.max_seqlen_q == 128 - assert r.pad_between_seqs is False + assert r.pad_between_seqs is True for k, v in orig.items(): assert torch.equal(getattr(r, k), v) @@ -581,6 +582,52 @@ def test_thd_static_padding_mask_is_unmasked_for_capture(self): class TestDynamicMicrobatchSlots: + @pytest.mark.internal + def test_dynamic_cp_graph_bank_and_capture_contexts(self, monkeypatch): + from megatron.core import parallel_state + from megatron.core.transformer.cuda_graphs import TECudaGraphHelper + + groups = {size: object() for size in (1, 2, 4, 8)} + monkeypatch.setattr( + parallel_state, + 'get_dynamic_data_context_parallel_groups', + lambda group_size: groups[group_size], + ) + helper = TECudaGraphHelper.__new__(TECudaGraphHelper) + helper.config = SimpleNamespace( + dynamic_context_parallel=True, min_dynamic_context_parallel_size=1 + ) + helper.dp_cp_group = SimpleNamespace(size=lambda: 8) + assert helper._get_dynamic_cp_capture_contexts() == [ + (size, groups[size]) for size in (8, 4, 2, 1) + ] + + bank = {size: [f'cp{size}'] for size in groups} + layer = SimpleNamespace(cuda_graphs=[], cuda_graphs_by_dynamic_cp_size=bank) + params = SimpleNamespace(local_cp_size=4, cp_group=groups[4]) + TransformerLayer._activate_dynamic_cp_cuda_graph(layer, params) + assert layer.cuda_graphs is bank[4] + + @pytest.mark.internal + def test_mla_rope_tensor_follows_te_graph_lifetime(self, monkeypatch): + from megatron.core.transformer import cuda_graphs + from megatron.core.transformer.cuda_graphs import TECudaGraphHelper + from megatron.core.transformer.multi_latent_attention import MLASelfAttention + + mla = MLASelfAttention.__new__(MLASelfAttention) + torch.nn.Module.__init__(mla) + mla.config = SimpleNamespace(cuda_graph_impl='transformer_engine') + monkeypatch.setattr(cuda_graphs, 'is_graph_capturing', lambda: True) + + tensor = torch.ones(1) + tensor_ref = weakref.ref(tensor) + mla._retain_cuda_graph_rope_tensors(tensor) + del tensor + assert tensor_ref() is not None + + TECudaGraphHelper._clear_cuda_graph_state(mla) + assert tensor_ref() is None + @pytest.mark.internal def test_pp2_slots_track_max_outstanding_microbatches(self): from megatron.core.transformer.cuda_graphs import TECudaGraphHelper