Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def create_pg(self, device):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_save_hf_interval(self):
"""Test save_hf is called at correct intervals during training."""
Expand Down Expand Up @@ -184,6 +185,7 @@ def test_save_hf_interval(self):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_save_checkpoint_interval(self):
self.create_pg(DEVICE)
Expand Down Expand Up @@ -258,6 +260,7 @@ def test_save_checkpoint_interval(self):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_resume(self):
self.create_pg(DEVICE)
Expand Down Expand Up @@ -738,6 +741,7 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1


@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
def test_resume_and_load_checkpoint_cfg(tmp_path: Path):
# 0. prepare environment
Expand Down
3 changes: 2 additions & 1 deletion xtuner/v1/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig
from .moe.deepseek_v3 import DeepSeekV3Config
from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig
from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig
from .moe.moe import BalancingLossConfig, MoE, MoEConfig, MoEModelOutputs, ZLossConfig
from .moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig, Qwen3MoEFoPEConfig


Expand Down Expand Up @@ -87,6 +87,7 @@ def get_model_config_from_hf(model_path: Path):
"get_model_config",
"get_model_config_from_hf",
"MoE",
"MoEConfig",
"MoEModelOutputs",
"BalancingLossConfig",
"ZLossConfig",
Expand Down
3 changes: 3 additions & 0 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,9 @@ def _load_fused_hf_param(
continue
_loaded_tensor.append(weight.to(local_tensor.device))

if not _loaded_tensor:
return missing_keys

if not hf_keys:
# fp8 pad
assert self.config.float8_cfg is not None
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xtuner.v1.model.base import TransformerConfig
from xtuner.v1.model.moe.moe import MoEConfig
from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig
from xtuner.v1.utils import get_logger

Expand All @@ -19,7 +19,7 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig):
class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
vision_config: Qwen3_5_VisionConfig
projector_config: Qwen3_5_ProjectorConfig
text_config: TransformerConfig
text_config: MoEConfig

image_token_id: int = 248056
video_token_id: int = 248057
Expand All @@ -30,4 +30,4 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig):
vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig()
projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig()
text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig()
text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig()
24 changes: 16 additions & 8 deletions xtuner/v1/model/dense/qwen3vl_text.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re

import torch
import torch.nn.functional as F

from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.loss import CELossContext
from xtuner.v1.loss import BaseLossContext
from xtuner.v1.model.base import ModelOutputs

from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig
Expand Down Expand Up @@ -34,10 +35,10 @@ def _deepstack_process(
hidden_states[visual_pos_masks, :] = local_this
return hidden_states

def forward(
def forward( # type: ignore[override]
self,
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
loss_ctx: CELossContext,
loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None,
) -> ModelOutputs:
input_ids = seq_ctx.input_ids
position_ids = seq_ctx.position_ids
Expand Down Expand Up @@ -78,11 +79,18 @@ def forward(

hidden_states = self.norm(hidden_states)

loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx)
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info
return ModelOutputs(**output) # type: ignore[typeddict-item]
if loss_ctx is None:
# Inference mode
logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias)
output["logits"] = logits
else:
# Training mode
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload]
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info

return ModelOutputs(**output)


class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig):
Expand Down
Loading
Loading