diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 1dcc81a5b..4295dd54e 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -112,6 +112,7 @@ def create_pg(self, device): @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @prepare def test_save_hf_interval(self): """Test save_hf is called at correct intervals during training.""" @@ -184,6 +185,7 @@ def test_save_hf_interval(self): @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @prepare def test_save_checkpoint_interval(self): self.create_pg(DEVICE) @@ -258,6 +260,7 @@ def test_save_checkpoint_interval(self): @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @prepare def test_resume(self): self.create_pg(DEVICE) @@ -738,6 +741,7 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch): assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1 +@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) def test_resume_and_load_checkpoint_cfg(tmp_path: Path): # 0. prepare environment diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index b0744864f..32cc7fa69 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -23,7 +23,7 @@ from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig from .moe.deepseek_v3 import DeepSeekV3Config from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig -from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig +from .moe.moe import BalancingLossConfig, MoE, MoEConfig, MoEModelOutputs, ZLossConfig from .moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig, Qwen3MoEFoPEConfig @@ -87,6 +87,7 @@ def get_model_config_from_hf(model_path: Path): "get_model_config", "get_model_config_from_hf", "MoE", + "MoEConfig", "MoEModelOutputs", "BalancingLossConfig", "ZLossConfig", diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 92d39caee..cb3eb3ae4 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -1374,6 +1374,9 @@ def _load_fused_hf_param( continue _loaded_tensor.append(weight.to(local_tensor.device)) + if not _loaded_tensor: + return missing_keys + if not hf_keys: # fp8 pad assert self.config.float8_cfg is not None diff --git a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py index 884cd48b1..5bded9d90 100644 --- a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py +++ b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py @@ -1,4 +1,4 @@ -from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.model.moe.moe import MoEConfig from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig from xtuner.v1.utils import get_logger @@ -19,7 +19,7 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig): class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): vision_config: Qwen3_5_VisionConfig projector_config: Qwen3_5_ProjectorConfig - text_config: TransformerConfig + text_config: MoEConfig image_token_id: int = 248056 video_token_id: int = 248057 @@ -30,4 +30,4 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig): vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig() projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig() - text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig() + text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig() diff --git a/xtuner/v1/model/dense/qwen3vl_text.py b/xtuner/v1/model/dense/qwen3vl_text.py index 15e40c40b..308a9817d 100644 --- a/xtuner/v1/model/dense/qwen3vl_text.py +++ b/xtuner/v1/model/dense/qwen3vl_text.py @@ -1,9 +1,10 @@ import re import torch +import torch.nn.functional as F from xtuner.v1.data_proto import SequenceContext -from xtuner.v1.loss import CELossContext +from xtuner.v1.loss import BaseLossContext from xtuner.v1.model.base import ModelOutputs from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig @@ -34,10 +35,10 @@ def _deepstack_process( hidden_states[visual_pos_masks, :] = local_this return hidden_states - def forward( + def forward( # type: ignore[override] self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext, + loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None, ) -> ModelOutputs: input_ids = seq_ctx.input_ids position_ids = seq_ctx.position_ids @@ -78,11 +79,18 @@ def forward( hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) - output["loss"] = loss - output["logits"] = logits - output["extra_info"] = extra_info - return ModelOutputs(**output) # type: ignore[typeddict-item] + if loss_ctx is None: + # Inference mode + logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias) + output["logits"] = logits + else: + # Training mode + loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload] + output["loss"] = loss + output["logits"] = logits + output["extra_info"] = extra_info + + return ModelOutputs(**output) class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig): diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index b62d5aae6..1657f4dbe 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -2,7 +2,7 @@ import os import types from pathlib import Path -from typing import Annotated, Literal, Self, Sequence, cast, TypedDict +from typing import TYPE_CHECKING, Annotated, Literal, Self, Sequence, TypedDict, cast import torch import torch.distributed as dist @@ -28,12 +28,10 @@ from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler from xtuner.v1.loss import ( - BalancingLoss, BalancingLossConfig, BalancingLossContext, - BalancingLossKwargs, - CELossContext, - ZLoss, + BaseLossContext, + LMHeadLossContext, ZLossConfig, ZLossContext, ZLossKwargs, @@ -62,6 +60,7 @@ ) from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig, MoEBlock, MoEDecoderLayer +from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer, roll_packed_tensor from xtuner.v1.utils import ( get_device, get_logger, @@ -69,6 +68,10 @@ from xtuner.v1.utils.activation_offload import async_save_on_cpu +if TYPE_CHECKING: + from xtuner.v1.datasets.collator import ColateItem + + DEVICE = get_device() logger = get_logger() @@ -99,6 +102,7 @@ class MoEModelOutputs(ModelOutputs): balancing_loss: torch.Tensor | None = None z_loss: torch.Tensor | None = None tokens_per_expert_global: torch.Tensor + mtp_loss: torch.Tensor | None = None def free_nongrad_feature(self): """Release large intermediate tensors not needed for backward or @@ -124,6 +128,7 @@ class MoELossContextDict(TypedDict): lm: BaseLossContext balancing: BalancingLossContext | None z_loss: ZLossContext | None + mtp: list[BaseLossContext] | None class MoEConfig(TransformerConfig): @@ -144,6 +149,7 @@ class MoEConfig(TransformerConfig): gate_bias: bool = False moe_bias: bool = False moe_act_fn_cfg: MoEActFnConfig = MoEActFnConfig() + mtp_config: MTPConfig | None = None freeze_routers: bool = False def build(self) -> "MoE": @@ -187,6 +193,7 @@ def __init__(self, config: MoEConfig): self.layers = self.build_layers(config) self.rotary_emb = self.build_rotary_embedding(config) self.embed_tokens = self.build_embeddings(config) + self.mtp_block = self.build_mtp_block(config) if config.mtp_config is not None else None self.fp32_layers = [self.rotary_emb] @@ -301,6 +308,7 @@ def build_loss_ctx_batch( - "lm": LM loss context - "balancing": Balancing loss context (if configured) - "z_loss": Z-loss context (if configured) + - "mtp": MTP loss contexts (if configured) Note: Auxiliary loss contexts are built without parameters. @@ -309,13 +317,40 @@ def build_loss_ctx_batch( - z_loss_ctx(router_logits) """ # Build LM loss context - res = super().build_loss_ctx_batch(data_batch, sp_mesh) + _data_batch: list[dict] = data_batch # type: ignore[assignment] + res: list[dict] = super().build_loss_ctx_batch(_data_batch, sp_mesh) + cu_seq_lens_list = [data["seq_ctx"].cu_seq_lens_k for data in data_batch] # Add auxiliary losses - self._add_auxiliary_loss("balancing", self.config.balancing_loss_cfg, data_batch, res) - self._add_auxiliary_loss("z_loss", self.config.z_loss_cfg, data_batch, res) + self._add_auxiliary_loss("balancing", self.config.balancing_loss_cfg, _data_batch, res) + self._add_auxiliary_loss("z_loss", self.config.z_loss_cfg, _data_batch, res) + + # Add MTP loss contexts if MTP is enabled + if self.config.mtp_config is not None: + # Build MTP loss contexts using the same approach as LM loss + # Each MTP depth needs its own loss context + for mtp_idx in range(self.config.mtp_config.num_layers): + # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch + mtp_loss_ctx_list = self._build_loss_ctx(self.config.lm_loss_cfg, _data_batch, sp_mesh) + if mtp_loss_ctx_list is not None: + loss_ctx_cls = mtp_loss_ctx_list[0].__class__ + mtp_loss_ctx_list = loss_ctx_cls.build_batches( + mtp_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh + ) + for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list): + if "mtp" not in res[i]: + res[i]["mtp"] = [] + res[i]["mtp"].append(mtp_loss_ctx) # type: ignore[union-attr] + + # Ensure all microbatches have mtp key + for loss_ctx_dict in res: + if "mtp" not in loss_ctx_dict: + loss_ctx_dict["mtp"] = None + else: + for loss_ctx_dict in res: + loss_ctx_dict["mtp"] = None - return res + return res # type: ignore[return-value] def forward( self, @@ -340,6 +375,11 @@ def forward( ) if loss_ctx is None: raise NotImplementedError("loss_ctx must be provided for intra-layer bsz > 1") + if self.mtp_block is not None: + raise NotImplementedError( + "MTP is not supported in micro-batch forward mode (intra_layer_micro_batch > 1). " + "Please set intra_layer_micro_batch=1 when using MTP." + ) return self._micro_batch_forward( seq_ctx_list=seq_ctx, @@ -351,12 +391,8 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[MoEModelOutputs]) -> base_info = super().post_micro_batch_forward(batch_outputs) logs_info = base_info["logs_info"] - tokens_per_expert_global = torch.zeros( - self.config.num_hidden_layers - self.config.first_k_dense_replace, - self.config.n_routed_experts, - dtype=torch.int64, - device=DEVICE, - ) + first_tokens_per_expert = batch_outputs[0]["tokens_per_expert_global"] + tokens_per_expert_global = torch.zeros_like(first_tokens_per_expert) for output in batch_outputs: tokens_per_expert_global += output["tokens_per_expert_global"] @@ -479,7 +515,7 @@ def _micro_batch_forward( # Extract LM loss context from dict lm_loss_ctx_list = [loss_ctx_dict["lm"] for loss_ctx_dict in loss_ctx_list] cat_loss_ctx = type(lm_loss_ctx_list[0]).cat(lm_loss_ctx_list) - loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) + loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cast(LMHeadLossContext, cat_loss_ctx)) # Aggregate losses (mean across micro-batches) output["loss"] = loss.sum() @@ -628,6 +664,7 @@ def _forward( if self.config.return_hidden_states: output["hidden_states"].append(hidden_states) + layer_hidden_states = hidden_states hidden_states = self.norm(hidden_states) # Get LM loss context from dict @@ -637,6 +674,49 @@ def _forward( output["logits"] = logits output["extra_info"] = extra_info + # MTP forward pass and loss computation + if ( + self.mtp_block is not None + and loss_ctx is not None + and (mtp_loss_ctx_list := loss_ctx.get("mtp")) is not None + ): + mtp_seq_ctx = seq_ctx.copy( + input_ids=input_ids.clone() if input_ids is not None else None, + position_ids=position_ids.clone(), + inputs_embeds=seq_ctx.inputs_embeds.clone() if seq_ctx.inputs_embeds is not None else None, + ) + + # Forward through MTP block + mtp_outputs = self.mtp_block( + hidden_states=layer_hidden_states, + embed_tokens_fn=self.embed_tokens, + position_embeddings=position_embeddings, + seq_ctx=mtp_seq_ctx, + ) + + # Compute MTP losses for each depth + mtp_losses = torch.tensor(0.0, device=DEVICE) + for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): + shifted_tensor = mtp_ctx.loss_kwargs.shifted_labels + mtp_ctx.loss_kwargs.shifted_labels = roll_packed_tensor( + shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1 + ) + + mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(LMHeadLossContext, mtp_ctx)) + mtp_losses += mtp_loss + + output["router_logits"][f"mtp_layer{idx}"] = mtp_router_results + output["router_weights"][f"mtp_layer{idx}"] = mtp_router_weights + + # Average MTP losses across depths and scale + mtp_losses = mtp_losses / len(mtp_loss_ctx_list) + scaled_mtp_loss = mtp_losses * self.config.mtp_config.loss_scaling_factor # type: ignore + + # Add to total loss + output["loss"] = output["loss"] + scaled_mtp_loss + output["mtp_loss"] = scaled_mtp_loss + router_logits_list = list(output["router_logits"].values()) # type: ignore router_weights_list = list(output["router_weights"].values()) # type: ignore router_logits = self._select_non_pad_router_logits(router_logits_list, seq_ctx.mask) @@ -752,6 +832,74 @@ def build_rotary_embedding(self, config: MoEConfig) -> RotaryEmbeddingProtocol: with torch.device(DEVICE): return get_rope_embedding(config=config) + def build_mtp_block(self, config: MoEConfig) -> MTPBlock: + """Build MTP block with MoE decoder layers. + + Args: + config (MoEConfig): Model configuration. + + Returns: + MTPBlock: Constructed MTP block. + """ + mtp_config = config.mtp_config + assert mtp_config is not None, "mtp_config must be provided" + + mtp_layers = [] + # Get attention config for MTP layers (use last layer's config) + last_layer_idx = config.num_hidden_layers - 1 + layers_type_list = config.layers_type + attention_config: MLAConfig | MHAConfig | GatedDeltaNetConfig + if layers_type_list[last_layer_idx] in ["full_attention", "sliding_attention"]: + attention_config = config.attention + elif layers_type_list[last_layer_idx] == "linear_attention": + assert config.linear_attention is not None, ( + "linear_attention config must be provided for linear_attention layer" + ) + attention_config = config.linear_attention + else: + raise ValueError(f"Unsupported layer type {layers_type_list[last_layer_idx]}") + + for i in range(mtp_config.num_layers): + # Build MoE decoder layer for MTP + decoder_layer = MoEDecoderLayer( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + moe_intermediate_size=config.moe_intermediate_size, + mlp_bias=config.mlp_bias, + gate_bias=config.gate_bias, + moe_bias=config.moe_bias, + hidden_act=config.hidden_act, + rms_norm_eps=config.rms_norm_eps, + rms_norm_type=config.rms_norm_type, + num_experts_per_tok=config.num_experts_per_tok, + n_routed_experts=config.n_routed_experts, + n_shared_experts=config.n_shared_experts, + with_shared_expert_gate=config.with_shared_expert_gate, + hidden_factor=config.hidden_factor, + layer_type=layers_type_list[last_layer_idx], + attention_config=attention_config, + rope_scaling_cfg=config.rope_scaling_cfg, + generate_config=config.generate_config, + router_config=config.router, + moe_act_fn_cfg=config.moe_act_fn_cfg, + float8_cfg=config.float8_cfg, + layer_idx=config.num_hidden_layers + i, + dispatcher=config.dispatcher, + ep_mesh=self.ep_mesh, + ) + + # Wrap decoder layer in MTPLayer + mtp_layer = MTPLayer( + hidden_size=config.hidden_size, + rms_norm_eps=config.rms_norm_eps, + rms_norm_type=config.rms_norm_type, + decoder_layer=decoder_layer, + float8_cfg=config.float8_cfg, + ) + mtp_layers.append(mtp_layer) + + return MTPBlock(mtp_layers=mtp_layers) + @override def from_hf(self, hf_path: str | Path, strict: bool = True) -> tuple: # If model is built on meta device, we need to rebuild rotary embedding since from_hf will not @@ -815,18 +963,21 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) - num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"): layer_idx = int(layer_idx) - if layer_idx < num_recompute_layers - 1: + if self._should_recompute( + layer_idx=layer_idx, + mtp_idx=None, + ): layer = checkpoint_wrapper(layer, checkpoint_impl=CheckpointImpl.REENTRANT) self.layers[str(layer_idx)] = layer - if layer_idx >= len(self.layers) - 1: + if layer_idx >= len(self.layers) - 1 and self.mtp_block is None: reshard_after_forward = False else: reshard_after_forward = self.fsdp_config.reshard_after_forward + fully_shard( layer, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, @@ -861,10 +1012,35 @@ def fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=lm_head_mp_policy, - reshard_after_forward=self.fsdp_config.reshard_after_forward, + reshard_after_forward=self.fsdp_config.reshard_after_forward if self.mtp_block is None else False, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) + # Shard MTP block if it exists + if self.mtp_block is not None: + for mtp_idx, mtp_layer in enumerate(self.mtp_block.layers): + if self._should_recompute(None, mtp_idx=mtp_idx): + mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) + self.mtp_block.layers[mtp_idx] = mtp_layer + + reshard_after_forward = mtp_idx != len(self.mtp_block.layers) - 1 + fully_shard( + mtp_layer, + mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + ) + if mtp_idx == 0: + layer_next.set_modules_to_forward_prefetch([mtp_layer]) # type: ignore + + if self.config.mtp_config is not None and self.config.mtp_config.num_layers > 0: + for prev_mtp_layer, next_mtp_layer in zip( + list(self.mtp_block.layers)[:-1], + list(self.mtp_block.layers)[1:], + ): + prev_mtp_layer.set_modules_to_forward_prefetch([next_mtp_layer]) # type: ignore + fully_shard( self, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, @@ -1015,6 +1191,62 @@ def patched_emb_forward(self, input): self.sparse, ) + def _should_recompute( + self, + layer_idx: int | None, + mtp_idx: int | None, + ) -> bool: + """Determine if a layer should use gradient checkpointing + (recomputation). + + The recomputation strategy treats decoder layers and MTP layers as a single + sequence. The recompute_ratio is applied to the total layer count. The last + layer in the entire model is never recomputed to avoid unnecessary overhead. + + Args: + layer_idx (int | None): Index of the decoder layer (0-based). None if this + is an MTP layer. + mtp_idx (int | None): Index of the MTP layer (0-based). None if this is a + decoder layer. + + Returns: + bool: True if the layer should use gradient checkpointing, False otherwise. + + Example: + Configuration: 7 decoder layers, 3 MTP layers, recompute_ratio=0.8 + - Total layers: 10 + - Recompute layers: int(10 * 0.8) = 8 + - Layer mapping: + * Decoder 0-6 → global index 0-6 (7 layers) + * MTP 0-2 → global index 7-9 (3 layers) + - Recomputation decision: + * Global 0-7 (decoder 0-6, MTP 0): recompute ✓ + * Global 8 (MTP 1): no recompute + * Global 9 (MTP 2, last layer): no recompute (forced) + """ + num_layers = self.config.num_hidden_layers + mtp_layers = self.config.mtp_config.num_layers if self.config.mtp_config is not None else 0 + recompute_ratio = self.fsdp_config.recompute_ratio if self.fsdp_config is not None else 0.0 + + total_layers = num_layers + mtp_layers + num_recompute_layers = int(total_layers * recompute_ratio) + + # Determine the global layer index (0-based) + if layer_idx is not None: + # This is a decoder layer + global_idx = layer_idx + else: + # This is an MTP layer (comes after all decoder layers) + assert mtp_idx is not None, "Either layer_idx or mtp_idx must be provided" + global_idx = num_layers + mtp_idx + + # Last layer is never recomputed + if global_idx == total_layers - 1: + return False + + # Recompute if within the recompute range + return global_idx < num_recompute_layers + # NOTE: Add this overload for inferring the return type for easier type checking and using @overload # type: ignore def __call__( # type: ignore diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index 3afc4e5dd..74ac1c277 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -42,6 +42,57 @@ class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): def to_hf_key_list(self, key: str) -> list[str]: + # Handle MTP parameters + if key.startswith("mtp_block."): + # Remove "mtp_block." prefix + key = key.replace("mtp_block.", "", 1) + + # Handle MTP layer-specific parameters + # xtuner: mtp_block.layers.{idx}.decoder_layer.{param} + # HF: mtp.layers.{idx}.{param} + key = re.sub(r"layers\.(\d+)\.decoder_layer\.", r"layers.\1.", key) + + # Handle MTP normalization layers + # xtuner: mtp_block.layers.{idx}.enorm -> HF: mtp.pre_fc_norm_embedding + # xtuner: mtp_block.layers.{idx}.hnorm -> HF: mtp.pre_fc_norm_hidden + # xtuner: mtp_block.layers.{idx}.final_layernorm -> HF: mtp.norm + # Note: Currently assuming single MTP layer (idx=0), may need adjustment for multiple layers + if ".enorm." in key: + key = re.sub(r"layers\.\d+\.enorm\.", "pre_fc_norm_embedding.", key) + elif ".hnorm." in key: + key = re.sub(r"layers\.\d+\.hnorm\.", "pre_fc_norm_hidden.", key) + elif ".final_layernorm." in key: + key = re.sub(r"layers\.\d+\.final_layernorm\.", "norm.", key) + + # Handle MTP projection layer + # xtuner: mtp_block.layers.{idx}.eh_proj -> HF: mtp.fc + if ".eh_proj." in key: + key = re.sub(r"layers\.\d+\.eh_proj\.", "fc.", key) + + # Handle MoE-specific transformations within MTP layers + key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts|shared_expert_gate)", r"layers.\1.mlp.\2", key) + key = key.replace("shared_experts", "shared_expert") + + # Handle fused weights + n_routed_experts = self.config.n_routed_experts + if "fused_w1w3.weight" in key: + w1w3_keys: list[str] = [] + + for i in range(n_routed_experts): + w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.gate_proj.weight")) + w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.up_proj.weight")) + + return [f"mtp.{key}" for key in w1w3_keys] + + elif "fused_w2.weight" in key: + w2_keys: list[str] = [] + for i in range(n_routed_experts): + w2_keys.append(key.replace("fused_w2.weight", f"{i}.down_proj.weight")) + return [f"mtp.{key}" for key in w2_keys] + else: + return ["mtp." + key] + + # Handle main model parameters if "layers" in key or "embed_tokens" in key: key = "model.language_model." + key @@ -85,13 +136,13 @@ def safetensors_to_params( else: loaded_tensor = safetensors[0] - if "fused_w1w3.weight" in param_name: + if "fused_w1w3.weight" in param_name and "mtp" not in param_name: # hf: num_experts, 2 * expert_dim, hidden_size # xtuner: num_experts * 2 * expert_dim, hidden_size # num_experts * 2 * expert_dim, hidden_size loaded_tensor = loaded_tensor.flatten(0, 1) - elif "fused_w2.weight" in param_name: + elif "fused_w2.weight" in param_name and "mtp" not in param_name: # hf: num_experts, hidden_size, expert_dim # xtuner: num_experts * hidden_size, expert_dim loaded_tensor = loaded_tensor.flatten(0, 1) @@ -116,6 +167,9 @@ def param_to_safetensor( safetensor: torch.Tensor, hf_param_name: str, ): + if "mtp" in hf_param_name: + return super().param_to_safetensor(safetensor, hf_param_name) + assert isinstance(hf_param_name, str) if "gate_up_proj" in hf_param_name: # xtuner: num_experts * 2 * expert_dim, hidden_size diff --git a/xtuner/v1/model/moe/qwen3vl_text.py b/xtuner/v1/model/moe/qwen3vl_text.py index 1996edcf7..e0fac316e 100644 --- a/xtuner/v1/model/moe/qwen3vl_text.py +++ b/xtuner/v1/model/moe/qwen3vl_text.py @@ -4,10 +4,9 @@ import torch from xtuner.v1.data_proto import SequenceContext -from xtuner.v1.loss import CELossContext from xtuner.v1.utils.activation_offload import async_save_on_cpu -from .moe import MoEModelOutputs +from .moe import MoELossContextDict, MoEModelOutputs from .qwen3 import Qwen3MoE, Qwen3MoE30BA3Config, Qwen3MoE235BA22Config @@ -112,9 +111,12 @@ def _deepstack_process( def _forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext | None, + loss_ctx: MoELossContextDict | None, return_router_logits: bool = False, ) -> MoEModelOutputs: + if seq_ctx.deepstack_visual_embeds is None: + return super()._forward(seq_ctx, loss_ctx, return_router_logits) + input_ids = seq_ctx.input_ids position_ids = seq_ctx.position_ids @@ -183,7 +185,9 @@ def _forward( hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) # type: ignore + # Get LM loss context from dict + lm_loss_ctx = loss_ctx["lm"] if loss_ctx is not None else None + loss, (logits, extra_info) = self.lm_head(hidden_states, lm_loss_ctx) # type: ignore output["loss"] = loss output["logits"] = logits output["extra_info"] = extra_info @@ -193,17 +197,23 @@ def _forward( router_logits = self._select_non_pad_router_logits(router_logits_list, seq_ctx.mask) router_weights = self._select_non_pad_router_logits(router_weights_list, seq_ctx.mask) - if self.balancing_loss: - balancing_loss = self.balancing_loss( - router_weights=router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, - ) - output["balancing_loss"] = balancing_loss - - if self.z_loss: - z_loss = self.z_loss(router_logits=router_logits) - output["z_loss"] = z_loss + # Calculate balancing loss using loss context + if loss_ctx is not None: + balancing_ctx = loss_ctx.get("balancing") + if balancing_ctx is not None: + balancing_loss = balancing_ctx( + router_weights, + self.config.n_routed_experts, + self.config.num_experts_per_tok, + ) + output["balancing_loss"] = balancing_loss + + # Calculate z-loss using loss context + if loss_ctx is not None: + z_loss_ctx = loss_ctx.get("z_loss") + if z_loss_ctx is not None: + z_loss = z_loss_ctx(router_logits) + output["z_loss"] = z_loss tokens_per_expert_global = self._cal_tokens_per_expert(router_logits) output["tokens_per_expert_global"] = tokens_per_expert_global diff --git a/xtuner/v1/module/mtp/__init__.py b/xtuner/v1/module/mtp/__init__.py new file mode 100644 index 000000000..8ced4cbaa --- /dev/null +++ b/xtuner/v1/module/mtp/__init__.py @@ -0,0 +1,7 @@ +from .config import MTPConfig +from .mtp_block import MTPBlock +from .mtp_layer import MTPLayer +from .utils import roll_packed_tensor, roll_sequence_context + + +__all__ = ["MTPConfig", "MTPBlock", "MTPLayer", "roll_packed_tensor", "roll_sequence_context"] diff --git a/xtuner/v1/module/mtp/config.py b/xtuner/v1/module/mtp/config.py new file mode 100644 index 000000000..1d362afcb --- /dev/null +++ b/xtuner/v1/module/mtp/config.py @@ -0,0 +1,41 @@ +"""Configuration for Multi-Token Prediction (MTP).""" + +from typing import Annotated + +from cyclopts import Parameter +from pydantic import BaseModel, ConfigDict + + +class MTPConfig(BaseModel): + """Configuration for Multi-Token Prediction (MTP). + + MTP extends the prediction scope to multiple future tokens at each position, + creating denser training signals and potentially improving data efficiency. + + This config only contains training-related hyperparameters. The actual + construction of MTP layers (including choosing Dense vs MoE decoder layers) + is handled by the model (Dense/MoE) which knows how to create the appropriate + decoder layers. + + Args: + num_layers (int): Number of MTP layers (prediction depths). Each layer + predicts tokens at increasing future positions (i+1, i+2, ..., i+D). + loss_scaling_factor (float): Scaling factor for MTP loss. The total MTP loss + is computed as the average of losses across all depths, multiplied by + this factor. Default: 0.1. + + Example: + >>> # In model config + >>> config = TransformerConfig( + ... ..., + ... mtp_config=MTPConfig( + ... num_layers=2, + ... loss_scaling_factor=0.1, + ... ), + ... ) + """ + + model_config = ConfigDict(extra="forbid") + + num_layers: Annotated[int, Parameter(group="model")] + loss_scaling_factor: Annotated[float, Parameter(group="model")] = 0.1 diff --git a/xtuner/v1/module/mtp/mtp_block.py b/xtuner/v1/module/mtp/mtp_block.py new file mode 100644 index 000000000..1b21ccec8 --- /dev/null +++ b/xtuner/v1/module/mtp/mtp_block.py @@ -0,0 +1,121 @@ +"""Multi-Token Prediction (MTP) Block implementation.""" + +from typing import Callable + +import torch +import torch.nn as nn + +from xtuner.v1.data_proto import SequenceContext + +from .mtp_layer import MTPLayer +from .utils import roll_sequence_context + + +class MTPBlock(nn.Module): + """Multi-Token Prediction (MTP) block containing multiple MTP layers. + + This block manages D sequential MTP layers, where each layer predicts + a future token at increasing depths (i+1, i+2, ..., i+D). + + The k-th layer receives: + - Hidden states from the (k-1)-th layer + - Embeddings of tokens at position (i+k) + + This forms a sequential prediction chain where deeper layers build upon + the predictions of shallower layers. + + Args: + mtp_layers (list[MTPLayer]): List of MTP layers. Each layer should be a + fully constructed MTPLayer instance. The number of layers determines + the prediction depth (D). + + Example: + >>> # Build MTP layers (typically done by Dense/MoE model) + >>> mtp_layers = [] + >>> for i in range(2): + ... decoder_layer = build_decoder_layer(...) + ... mtp_layer = MTPLayer( + ... hidden_size=512, + ... rms_norm_eps=1e-6, + ... rms_norm_type="default", + ... decoder_layer=decoder_layer, + ... ) + ... mtp_layers.append(mtp_layer) + >>> + >>> # Create MTP block + >>> mtp_block = MTPBlock(mtp_layers=mtp_layers) + >>> + >>> # Forward pass + >>> outputs = mtp_block( + ... hidden_states=h, + ... input_ids=ids, + ... position_ids=pos, + ... embed_tokens_fn=embed_fn, + ... position_embeddings=pos_emb, + ... seq_ctx=ctx, + ... ) + >>> # outputs[0]: predictions for i+1 + >>> # outputs[1]: predictions for i+2 + """ + + def __init__(self, *, mtp_layers: list[MTPLayer]): + super().__init__() + if not mtp_layers: + raise ValueError("mtp_layers cannot be empty") + + self.layers = nn.ModuleList(mtp_layers) + self.num_layers = len(mtp_layers) + + def forward( + self, + hidden_states: torch.Tensor, + embed_tokens_fn: Callable[[torch.Tensor], torch.Tensor], + position_embeddings: tuple[torch.Tensor, torch.Tensor], + seq_ctx: SequenceContext, + ) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Forward pass through all MTP layers. + + Args: + hidden_states (torch.Tensor): Hidden states from the main model, + shape [batch, seq_len, hidden_size]. + embed_tokens_fn (Callable): Function to embed tokens. Takes token IDs + and returns embeddings. Should have signature: + embed_tokens_fn(token_ids: Tensor) -> Tensor + position_embeddings (tuple[torch.Tensor, torch.Tensor]): Rotary position + embeddings (cos, sin). + seq_ctx (SequenceContext): Sequence context containing input_ids, position_ids, + attention mask, etc. + + Returns: + list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: List of 3-tuples + (hidden_states, router_weights, router_results) for each MTP depth. + Length equals num_layers. + - outputs[0]: Outputs for predicting token at position (i+1) + - outputs[k]: Outputs for predicting token at position (i+k+1) + """ + mtp_outputs = [] + current_hidden_states = hidden_states + current_seq_ctx = seq_ctx + + for layer in self.layers: + # Roll sequence context to get future tokens + # This shifts each packed sequence independently, respecting boundaries + current_seq_ctx = roll_sequence_context(current_seq_ctx, shifts=-1) + + # Get embeddings for future tokens + if current_seq_ctx.inputs_embeds is None: + future_embeddings = embed_tokens_fn(current_seq_ctx.input_ids) # type: ignore[arg-type] + else: + future_embeddings = current_seq_ctx.inputs_embeds + + # Forward through MTP layer + current_hidden_states = layer( + hidden_states=current_hidden_states, + future_embeddings=future_embeddings, + position_embeddings=position_embeddings, + seq_ctx=current_seq_ctx, + ) + # Save output for this depth + mtp_outputs.append(current_hidden_states) + + return mtp_outputs diff --git a/xtuner/v1/module/mtp/mtp_layer.py b/xtuner/v1/module/mtp/mtp_layer.py new file mode 100644 index 000000000..8099b76b4 --- /dev/null +++ b/xtuner/v1/module/mtp/mtp_layer.py @@ -0,0 +1,126 @@ +"""Multi-Token Prediction (MTP) Layer implementation.""" + +from typing import Literal + +import torch +import torch.nn as nn + +from xtuner.v1.data_proto import SequenceContext +from xtuner.v1.module import RMSNorm +from xtuner.v1.module.linear import build_linear + + +class MTPLayer(nn.Module): + """Single Multi-Token Prediction (MTP) layer. + + MTP Layer wraps a standard decoder layer with MTP-specific preprocessing + and postprocessing. The structure is: + + [enorm + hnorm + projection] → [DecoderLayer] → [final_layernorm] + + The k-th MTP layer predicts the (i+k)-th token by combining: + 1. Hidden states from the previous MTP layer (or main model) + 2. Embedding of the future token at position (i+k) + + Note: The decoder layer's internal normalization (input_layernorm) is preserved + for simplicity and modularity. While this adds a small computational overhead, + it allows MTP to work with any decoder layer implementation (Dense, MoE, etc.) + without modification. + + Args: + hidden_size (int): Hidden dimension size. + rms_norm_eps (float): Epsilon for RMSNorm. + rms_norm_type (str): Type of RMSNorm ("default" or "zero_centered"). + decoder_layer (nn.Module): A fully constructed decoder layer instance. + This can be DenseDecoderLayer, MoEDecoderLayer, or any custom decoder layer + that implements the standard forward signature. + float8_cfg: Float8 configuration for the projection layer. + + Example: + >>> from xtuner.v1.module.decoder_layer import DenseDecoderLayer + >>> decoder_layer = DenseDecoderLayer( + ... hidden_size=512, + ... intermediate_size=2048, + ... ... + ... ) + >>> mtp_layer = MTPLayer( + ... hidden_size=512, + ... rms_norm_eps=1e-6, + ... rms_norm_type="default", + ... decoder_layer=decoder_layer, + ... ) + """ + + def __init__( + self, + *, + hidden_size: int, + rms_norm_eps: float, + rms_norm_type: Literal["default", "zero_centered"], + decoder_layer: nn.Module, + float8_cfg=None, + ): + super().__init__() + self.hidden_size = hidden_size + + # MTP-specific preprocessing components + self.enorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) + self.hnorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) + self.eh_proj = build_linear( + hidden_size * 2, + hidden_size, + bias=False, + float8_cfg=float8_cfg, + ) + + # Core decoder layer (Dense, MoE, or any custom implementation) + self.decoder_layer = decoder_layer + + # MTP-specific postprocessing component + self.final_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) + + def forward( + self, + hidden_states: torch.Tensor, + future_embeddings: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + seq_ctx: SequenceContext, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass through the MTP layer. + + Args: + hidden_states (torch.Tensor): Hidden states from previous layer, + shape [batch, seq_len, hidden_size]. + future_embeddings (torch.Tensor): Embeddings of future tokens, + shape [batch, seq_len, hidden_size]. + position_embeddings (tuple[torch.Tensor, torch.Tensor]): Rotary position + embeddings (cos, sin). + seq_ctx (SequenceContext): Sequence context containing attention mask, etc. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A 3-tuple of + (hidden_states, router_weights, router_results) where each tensor + has shape [batch, seq_len, ...]. + """ + # Step 1: Normalize embeddings and hidden states separately + # This ensures both inputs are in the same numerical range + normalized_embedding = self.enorm(future_embeddings) + normalized_hidden = self.hnorm(hidden_states) + + # Step 2: Concatenate and project to combine information + # [B, S, H] + [B, S, H] → [B, S, 2H] → [B, S, H] + combined = torch.cat([normalized_embedding, normalized_hidden], dim=-1) + projected = self.eh_proj(combined) + + # Step 3: Pass through the standard decoder layer + # This includes attention, MLP, and their respective normalizations + # TODO: TMP hardcode here. + hidden_states, router_results, router_weights = self.decoder_layer( + projected, + position_embeddings=position_embeddings, + seq_ctx=seq_ctx, + ) + + # Step 4: Final normalization before output + hidden_states = self.final_layernorm(hidden_states) + return hidden_states, router_results, router_weights diff --git a/xtuner/v1/module/mtp/utils.py b/xtuner/v1/module/mtp/utils.py new file mode 100644 index 000000000..4e174a16a --- /dev/null +++ b/xtuner/v1/module/mtp/utils.py @@ -0,0 +1,119 @@ +"""Utility functions for Multi-Token Prediction (MTP).""" + +import torch + +from xtuner.v1.data_proto import SequenceContext + + +def roll_packed_tensor( + tensor: torch.Tensor, + cu_seq_lens: torch.IntTensor, + shifts: int = -1, + dim: int = -1, +) -> torch.Tensor: + """Roll a packed tensor along the specified dimension. + + This function respects sequence boundaries in packed sequences, shifting each + sequence independently without crossing boundaries. + + Args: + tensor (torch.Tensor): Input packed tensor to roll. + cu_seq_lens (torch.IntTensor): Cumulative sequence lengths defining packed + sequence boundaries. Shape [num_sequences + 1]. + shifts (int): Number of positions to shift. Use -1 for left shift (default). + Only negative shifts are supported. + dim (int): Dimension along which to roll. The ``cu_seq_lens`` boundaries + are applied on this dimension. Default is -1 (last dimension). + + Returns: + torch.Tensor: Rolled tensor with boundary positions zeroed. + + Example: + For packed sequences [1,2,3] and [4,5,6] with shifts=-1, dim=-1: + >>> tensor = torch.tensor([[1, 2, 3, 4, 5, 6]]) + >>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32) + >>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-1) + >>> rolled # [[2, 3, 0, 5, 6, 0]] + + For a 3D tensor with dim=-2 (e.g., inputs_embeds of shape [1, seq_len, hidden]): + >>> tensor = torch.arange(12).reshape(1, 6, 2) + >>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32) + >>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-2) + >>> rolled[0, 2] # tensor([0, 0]) (boundary zeroed) + """ + assert shifts <= 0, "Only negative shift is supported" + + # Normalize dim to a positive index + dim = dim % tensor.dim() + + rolled_tensor = tensor.clone() + + # Roll each packed sequence independently within its boundaries + for i in range(len(cu_seq_lens) - 1): + start_idx = cu_seq_lens[i].item() + end_idx = cu_seq_lens[i + 1].item() + + # Extract sequence slice along the specified dimension + seq_slice = tensor.narrow(dim, start_idx, end_idx - start_idx) # type: ignore[arg-type] + rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dim) + + # Zero out the last |shifts| positions along dim to avoid information + # leakage across sequences. For shifts=-1 the last 1 position is + # zeroed; for shifts=-2 the last 2 positions are zeroed, etc. + zero_len = -shifts + zero_start = (end_idx - start_idx) - zero_len + zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type] + zero_slice.zero_() + + # Write back to the rolled tensor + rolled_tensor.narrow(dim, start_idx, end_idx - start_idx).copy_(rolled_seq) # type: ignore[arg-type] + + return rolled_tensor + + +def roll_sequence_context( + seq_ctx: SequenceContext, + shifts: int = -1, +) -> SequenceContext: + """Roll the sequence context to get future tokens for MTP prediction. + + This function respects sequence boundaries in packed sequences, shifting each + sequence independently without crossing boundaries. Returns a new + ``SequenceContext`` — the original is never modified. + + Args: + seq_ctx (SequenceContext): Input sequence context with packed sequences. + shifts (int): Number of positions to shift. Use -1 for left shift (default). + Only -1 is currently supported. + + Returns: + SequenceContext: A new sequence context with shifted input_ids (and/or + inputs_embeds). Positions at sequence boundaries are zeroed to prevent + information leakage. + + Example: + For packed sequences [1,2,3] and [4,5,6] with shifts=-1: + Original input_ids: [1, 2, 3, 4, 5, 6] + Rolled input_ids: [2, 3, 0, 5, 6, 0] + """ + assert seq_ctx.sequence_parallel_mesh is None, "Sequence parallel is not yet supported" + + overrides: dict = {} + + if seq_ctx.input_ids is not None: + overrides["input_ids"] = roll_packed_tensor( + tensor=seq_ctx.input_ids, + cu_seq_lens=seq_ctx.cu_seq_lens_q, + shifts=shifts, + dim=-1, + ) + + if seq_ctx.inputs_embeds is not None: + overrides["inputs_embeds"] = roll_packed_tensor( + tensor=seq_ctx.inputs_embeds, + cu_seq_lens=seq_ctx.cu_seq_lens_q, + shifts=shifts, + dim=-2, # Embedding dimension is typically the second to last + ) + + return seq_ctx.copy(**overrides) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 7d73499bc..84f78d714 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -586,6 +586,12 @@ def __init__( self._resolve_config_conflicts(self.tokenizer, model_cfg, dataloader_cfg, fsdp_cfg) + if intra_layer_micro_batch > 1 and isinstance(model_cfg, MoEConfig) and model_cfg.mtp_config is not None: + raise ValueError( + "MTP (Multi-Token Prediction) is not supported with intra_layer_micro_batch > 1. " + f"Got intra_layer_micro_batch={intra_layer_micro_batch} and mtp_config={model_cfg.mtp_config}." + ) + if dataset_cfg is not None: # TODO: Removed in version 1.1.0 logger.warning("`dataset_cfg` is deprecated, please use `dataloader_cfg.dataset_config_list` instead") # For backward compatibility, reserve the dataset_cfg interface, remove it later