Skip to content

feat: support Mixed Preference Optimization (MPO)#1609

Open
LarryLeeee wants to merge 1 commit intoInternLM:mainfrom
LarryLeeee:feat/add-mpo-support
Open

feat: support Mixed Preference Optimization (MPO)#1609
LarryLeeee wants to merge 1 commit intoInternLM:mainfrom
LarryLeeee:feat/add-mpo-support

Conversation

@LarryLeeee
Copy link

No description provided.

@windreamer
Copy link
Collaborator

Thanks for contributing!
Hi @LarryLeeee ! We noticed that besides the new MPO feature, some recent XTuner updates appear to have been reverted. Was this intentional? If not, would you mind rebasing your changes on the latest commit or resolving the conflicts on your forked branch? Let us know if you need any help!

@hhaAndroid
Copy link
Collaborator

@claude review

import os
import gc
from copy import deepcopy
from datetime import datetime
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Missing module xtuner.v1.rl.dpo

This import will fail at runtime. The module xtuner/v1/rl/dpo/ does not exist in the repository and is not created by this PR. The classes DPOLossConfig, DPOLossContext, and DPOLossContextInputItem are imported but never defined anywhere in the diff.

This means the entire DPO trainer feature is non-functional -- it will crash on import.

- hinge, ipo, robust: Other DPO variants

For MPO (Mixed Preference Optimization), use loss_types=["sigmoid", "bco_pair", "sft"]
with appropriate loss_weights.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Missing imports -- qwen3_vl_dpo_collator and DPOColateItem do not exist

These symbols are not defined in xtuner/v1/datasets/__init__.py nor in collator.py (checked the diff). The Qwen3VLDPOTokenizeFnConfig used in the example configs is also not defined anywhere in this PR. This will cause an ImportError at runtime.

)
self.ref_engine.from_hf(str(ref_load_from))

# Freeze reference model
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Reference model creates a full optimizer unnecessarily, wasting GPU memory

_init_reference_model creates a TrainEngine (or VisionComposeTrainEngine) for the reference model. These engines call build_optimizer() internally, allocating optimizer state (Adam moments, etc.) for all parameters -- which is then immediately wasted since the reference model is frozen (requires_grad = False) and never updated.

For a typical 8B model, this wastes ~32GB of GPU memory (2x fp32 copies of all parameters for Adam states).

The reference model should be loaded without an optimizer. Consider either:

  1. Adding a build_inference_only() path that skips optimizer creation, or
  2. Creating the FSDP-wrapped model directly without the engine abstraction for inference-only use.

# Average loss over batch
total_loss = total_loss / len(batch)

# Backward pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: _train_step calls self.train_engine.model(...) directly but does not manage gradient accumulation correctly

The method calls total_loss.backward() on every invocation, but optimizer.step() and zero_grad() are only called every gradient_accumulation_steps steps. The problem is that total_loss is computed fresh each call and .backward() accumulates gradients -- this is correct for gradient accumulation. However, the loss value logged as metrics["loss"] only reflects the current micro-batch, not the accumulated loss over all accumulation steps. This is misleading for monitoring.

More importantly: if gradient_accumulation_steps > 1, the first call to _train_step will call backward() but NOT zero_grad() before the next call. The gradients from step 0 will still be present. Since zero_grad() is only called AFTER optimizer.step(), the very first gradient_accumulation_steps - 1 micro-batches will train with stale gradients from initialization (which are zero, so this is actually OK for the first accumulation window). But the issue is that after the optimizer step + zero_grad, the pattern is correct. This is technically fine but fragile -- consider calling optimizer.zero_grad() once at the start of training in fit() to make the intent explicit.

Comment on lines +605 to +610

# Forward chosen and rejected separately (policy)
chosen_output = self.train_engine.model(seq_ctx=chosen_seq_ctx, loss_ctx=None)
rejected_output = self.train_engine.model(seq_ctx=rejected_seq_ctx, loss_ctx=None)

chosen_logits = _get_field(chosen_output, "logits")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: _train_step does two full forward passes through the policy model (chosen_output and rejected_output) with loss_ctx=None, but the model is in training mode with gradients enabled.

The forward passes compute chosen_output and rejected_output but their intermediate activations (including the full logits tensors) are held in memory simultaneously. For a typical 8B model with 8K sequence length, this means ~2x the activation memory. This is very likely to OOM.

Consider:

  1. Computing chosen and rejected forward passes sequentially, extracting logprobs immediately, and deleting the logits before the next forward pass.
  2. Using torch.no_grad() for the logits computation and only enabling gradients for the loss backward (which requires custom autograd).


# Warmup function - linear warmup from 0 to base_lr (same as SFT Trainer)
def warmup_fn(x):
return x / warmup_steps if x < warmup_steps else 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: LR warmup lambda returns 0 at step 0, causing division by zero in warmup_fn

When x = 0, warmup_fn returns 0 / warmup_steps = 0, which means the learning rate at step 0 is exactly 0. This is a common pattern but worth noting. However, if warmup_steps = 0 (e.g. warmup_ratio = 0), then warmup_fn will attempt 0 / 0 which produces nan.

Add a guard: if warmup_steps == 0: warmup_fn = lambda x: 1.

ref_rejected_logits = _get_field(ref_rejected_output, "logits")

# Compute log probs
# NOTE: _gather_logprobs returns token-level logprobs [B, L]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Return type annotation says tuple[Tensor, Tensor] but returns None, None

When self.ref_engine is None, the method returns (None, None), which violates the declared return type tuple[torch.Tensor, torch.Tensor]. The type hint should be tuple[torch.Tensor | None, torch.Tensor | None].


def _gather_logprobs(
self, logits: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numerical stability: _gather_logprobs computes log_softmax over the full vocabulary without any numerical guard

While F.log_softmax is numerically stable itself, the gathered log probabilities for padding tokens (label = -100, clipped to 0) will return the log probability of token 0, which is meaningless noise. This is handled downstream by masking, but it means the returned tensor contains uninitialized/garbage values in masked positions.

More importantly, this method duplicates the existing gather_logprobs from xtuner/v1/rl/utils.py (which is also imported later in _train_step at line 622). This dual implementation is confusing. Consider using the utility function consistently.


For MPO (Mixed Preference Optimization), use:
loss_types=["sigmoid", "bco_pair", "sft"]
loss_weights=[0.8, 0.2, 1.0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Hardcoded internal storage paths

This config file contains hardcoded paths to internal shared storage (/mnt/shared-storage-user/lisongze/...). These should use environment variables like the pattern in the original rl_qwen3_vl_8B_grpo.py (which this PR also breaks -- see separate comment).

Suggested change
loss_weights=[0.8, 0.2, 1.0]
ceph_config = os.environ.get("CEPH_CONFIG", "")
meta_data_path = os.environ["META_DATA_PATH"]
model_path = os.environ["MODEL_PATH"]
work_dir = os.environ["WORK_DIR"]
tokenizer_cache_dir = os.environ.get("TOKENIZER_CACHE_DIR", os.path.join(work_dir, "tokenizer_cache"))

data_path = os.environ["DATA_PATH"]
eval_data_path = os.environ["EVAL_DATA_PATH"]
work_dir = '/mnt/shared-storage-user/yanziang/test_xtuner/105xtuner/xtuner/examples/v1/config/rl_qwen3_vl_8B_grpo.py'
model_path = "/mnt/shared-storage-user/yanziang/xtuner/Qwen3-VL-8B-Instruct"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Reverted change -- environment variables replaced with hardcoded internal paths

The existing config used os.environ["WORK_DIR"], os.environ["MODEL_PATH"], etc. This PR replaces them with hardcoded paths to a developer's internal shared storage. This is clearly a debugging artifact that was accidentally committed. This change should be reverted entirely.

Comment on lines +43 to +46
#打印lmdeploy版本
import lmdeploy
print(f"lmdeploy version: {lmdeploy.__version__}")
# breakpoint()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Debug artifacts left in production code

This adds a lmdeploy version print and a commented-out breakpoint(). These are clearly debugging artifacts that should not be committed. Please remove.

Comment on lines +11 to +12

from pathlib import Path
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Globally disabling torch._dynamo is a heavy-handed side effect

torch._dynamo.config.disable = True is set at module import time, meaning any code that imports this CLI module will have dynamo disabled globally. This conflicts with the torch_compile=True setting in the FSDP configs of the example configs. If dynamo is truly needed to be disabled for DPO, it should be documented why, and scoped more narrowly (e.g., inside main()).

f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Untrainable parameters names: {untrainable_names}")
def build(self, params):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Breaking change -- MuonConfig removed from optim.py

This PR removes the entire MuonConfig class and the Muon optimizer support. This is a breaking change to the existing codebase that is unrelated to the MPO feature. Anyone using MuonConfig in their training configs will break. The commit message says "feat: support MPO" but this deletion is a separate breaking change.

Per the PR standards: "One logical change per PR. Do not mix bug fixes with features or refactors."

The MuonConfig removal and the AdamWConfig.build() signature change should be in a separate PR, or this PR should not touch these files.

from .chunk_loss import ChunkLoss
from .moe_loss import BalancingLoss, ZLoss
from .rl_loss import LogProbConfig, LogProbContext

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Breaking change -- LogProbConfig and LogProbContext removed from public API

These classes are removed from the __init__.py exports. Any external code importing from xtuner.v1.loss import LogProbConfig will break. This removal is tied to the forward_only signature change in TrainEngine which also removes the loss_ctx parameter. This is a significant API-breaking refactor that goes beyond "feat: support MPO".

loss_kwargs_list = DPOLossContext.build_batches_loss_kwargs(
[dpo_input], self.config.loss_cfg
)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: _eval_step accesses chosen_output.hidden_states and rejected_output.hidden_states as attributes, but elsewhere in the trainer _get_field is used because outputs may be dicts

In _train_step, the code uses a helper _get_field(out, key) that handles both dict and attribute access patterns. But in _eval_step, the code directly accesses .hidden_states as an attribute. If the model returns a dict (which is the case when loss_ctx=None is passed to compose models), this will raise AttributeError.

Additionally, _eval_step uses self.train_engine.forward_only(...) which wraps the call in @torch.no_grad(), but then tries to compute loss with loss_ctx.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs). This pattern assumes DPOLossContext.loss_fn exists and has this signature, but since xtuner.v1.rl.dpo is missing, we cannot verify this.

@property
def cur_epoch(self) -> int:
return self._cur_epoch
# [XTuner][2026-01-12 07:36:21][WARNING] Failed to process inputs: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'images', using text-only No newline at end of file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Debug log message left at end of file

This line contains a leftover debug/warning message as a comment at the end of the file. Also, the file is missing a trailing newline.

Comment on lines +660 to +665
loss_ctx = DPOLossContext(self.config.loss_cfg, loss_kwargs)
loss = torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
extra_info: dict[str, Any] = {}

for loss_type, weight in zip(self.config.loss_cfg.loss_types, self.config.loss_cfg.loss_weights):
if loss_type == "sigmoid":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Calling private methods _dpo_loss_sigmoid, _bco_pair_loss, etc. with type: ignore annotations

The loss dispatch in _train_step calls private methods on DPOLossContext (e.g., loss_ctx._dpo_loss_sigmoid(...)) with # type: ignore[attr-defined] annotations. Since the DPOLossContext class doesn't exist in the diff (missing xtuner.v1.rl.dpo module), we cannot verify these methods exist or have the expected signatures.

This pattern of calling private methods and silencing type errors is fragile. Consider adding a proper public dispatch method to DPOLossContext (e.g., compute_loss(loss_type, ...)) that encapsulates the loss selection logic.

from typing import List, cast

import torch
import torch.distributed as dist
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: VisionComposeTrainEngine lacks a docstring for the class and its public methods

Per CLAUDE.md: "Only provide docstrings for public methods. Private methods do not require docstrings." The class itself and public methods like build_model, from_hf, save_hf, train_step, and maybe_precompute_float8_dynamic_scale_for_fsdp all lack Google-style docstrings with type-hinted parameters.

export NCCL_TIMEOUT=10800
export TORCH_DISTRIBUTED_TIMEOUT=10800
export XTUNER_USE_FA3=1
export PYTHONPATH="$(pwd)"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Hardcoded absolute paths in shell scripts

Both config_file (line 12) and the torchrun target (line 24) use hardcoded absolute paths to a developer's workspace. These should use relative paths or environment variables, following the pattern of other scripts in the repo.

loss = loss + _l
elif loss_type == "sft":
# SFT loss only on chosen part
_l = loss_ctx._sft_loss( # type: ignore[attr-defined]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: SFT loss slicing logic is fragile and likely incorrect with sequence parallelism

The SFT loss branch slices logits[:, : chosen_mask.shape[1]] to extract the "chosen" portion. But at this point in the code, logits is the concatenation of chosen_logits and rejected_logits along dim=1 (line 614). The chosen_mask comes from loss_kwargs.chosen_mask which was built by DPOLossContext.build_batches_loss_kwargs.

With sequence parallelism enabled, chosen_logits and rejected_logits are already SP-split (each rank has only a portion of the sequence). The concatenation torch.cat([chosen_logits, rejected_logits], dim=1) concatenates these partial sequences. But chosen_mask.shape[1] would be the SP-split chosen length, so logits[:, :chosen_mask.shape[1]] would correctly slice the chosen portion from the concatenated tensor. This is OK only if loss_kwargs.chosen_mask was also SP-split by build_batches_loss_kwargs -- but since that module is missing, this cannot be verified.

Also, the all_reduce of _l (SFT loss) on line 716 sums across SP ranks. If the SFT loss was already normalized per-token within _sft_loss, then summing gives the total loss, not the average. The other loss types (sigmoid, bco_pair, etc.) do NOT have this all_reduce, meaning they operate on SP-local logprobs that were already globally aggregated (lines 638-651). This inconsistency between how different loss types handle SP is a potential correctness bug.

@@ -8,8 +8,6 @@
DatasetConfigList,
DatasetConfigListAdatper,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Breaking change -- CustomPackDataset, CustomSampler, LongTextPretrainTokenizeFunction, and LongTextPretrainTokenizeFunctionConfig removed from public API

These classes are removed from the datasets __init__.py. Any downstream code importing these will break. This removal is unrelated to the MPO feature and should be in a separate refactoring PR.

@@ -51,22 +48,11 @@ class BaseLossKwargs(BaseModel):
model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Breaking change -- BaseLossKwargs.sp_split() and .to() methods removed

The sp_split and to methods are removed from BaseLossKwargs. These were used by BaseRLLossConfig.build() (also removed) and potentially by downstream code. The GRPOLossKwargs and OrealLossKwargs subclasses now define their own fields inline rather than inheriting from the base.

While the refactor toward immutable construction in sp_split (returning new objects via type(self)(...) in RLLossContextInputItem) is an improvement, removing the base class methods is a breaking change for any code that relied on the base BaseLossKwargs.to() pattern.

@@ -144,7 +150,6 @@ def loss_fn(
self.loss_cfg.policy_loss_cfg,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Safety assertion removed

The assertion assert old_logprobs is not None was removed here. While the code will still crash with a clear error if old_logprobs is None (since .detach() would fail on None), the explicit assertion provided a more informative error message. Consider keeping defensive checks on critical invariants.

Copy link

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

This PR adds Mixed Preference Optimization (MPO) / DPO support via a new DPOTrainer, CLI entry point, and VisionComposeTrainEngine. However, it also includes a large number of unrelated refactors to the loss system, optimizer configs, RL worker, and dataset APIs that introduce breaking changes.

Issues

Critical

  1. [xtuner/v1/train/dpo_trainer.py:18] Missing module xtuner.v1.rl.dpo — The entire DPO trainer imports DPOLossConfig, DPOLossContext, and DPOLossContextInputItem from xtuner.v1.rl.dpo, but this module does not exist and is not created by the PR. The feature is completely non-functional.

  2. [xtuner/v1/train/dpo_trainer.py:12] Missing qwen3_vl_dpo_collator, DPOColateItem, Qwen3VLDPOTokenizeFnConfig, VLMPreferenceJsonlDataset — These symbols are imported/referenced but not defined anywhere in the diff or existing codebase.

  3. [examples/v1/config/rl_qwen3_vl_8B_grpo.py:25] Existing config broken — Environment variable lookups (os.environ["WORK_DIR"], etc.) replaced with hardcoded developer-specific paths. This breaks the existing RL example config for all users.

  4. [xtuner/v1/config/optim.py:29] MuonConfig removed — The entire Muon optimizer support is deleted, breaking any existing configs that use it. This is unrelated to MPO.

  5. [xtuner/v1/loss/init.py:5] LogProbConfig/LogProbContext removed from public API — Breaking change to the loss module's public interface, unrelated to MPO.

  6. [xtuner/v1/train/dpo_trainer.py:463] Reference model wastes ~32GB GPU memory — The reference engine creates a full optimizer with Adam states for a model that is immediately frozen. Should use inference-only loading.

Warning

  1. [xtuner/v1/train/dpo_trainer.py:605-610] High OOM risk — Two full forward passes (chosen + rejected) through the policy model are kept in memory simultaneously before backward. For 8B models this roughly doubles activation memory.

  2. [xtuner/v1/train/dpo_trainer.py:413] Division by zero in LR warmup — When warmup_ratio=0, warmup_steps=0, and warmup_fn(0) computes 0/0 = nan.

  3. [xtuner/v1/train/dpo_trainer.py:709] Inconsistent SP handling across loss types — SFT loss does an explicit all_reduce across SP ranks, but sigmoid/bco_pair/etc. do not. This inconsistency may produce incorrect loss values under sequence parallelism.

  4. [xtuner/v1/train/cli/dpo.py:11-12] torch._dynamo globally disabled — Conflicts with torch_compile=True in example FSDP configs.

  5. [xtuner/v1/train/cli/rl.py:43-46] Debug artifactslmdeploy version print and commented-out breakpoint() committed to production code.

  6. [examples/] Hardcoded paths — All new config files and shell scripts contain hardcoded paths to internal storage (/mnt/shared-storage-user/lisongze/...).

  7. [xtuner/v1/datasets/init.py, xtuner/v1/loss/base_loss_ctx.py, xtuner/v1/rl/base/loss.py] — Multiple breaking API removals (CustomPackDataset, CustomSampler, LongTextPretrainTokenizeFunction, BaseLossKwargs.sp_split(), BaseLossKwargs.to(), BaseRLLossContext, BaseRLLossKwargs, compute_kl_loss_weight) bundled into a feature PR.

Nit

  1. [xtuner/v1/train/dpo_trainer.py:1059] — Leftover debug comment at end of file, missing trailing newline.
  2. [xtuner/v1/train/dpo_trainer.py:497] — Return type annotation tuple[Tensor, Tensor] is wrong when method returns (None, None).
  3. [xtuner/v1/train/dpo_trainer.py:530] — Duplicate _gather_logprobs implementation; xtuner.v1.rl.utils.gather_logprobs already exists and is imported later in the same file.
  4. [xtuner/v1/engine/vision_compose_train_engine.py] — Missing docstrings on class and public methods (CLAUDE.md compliance).
  5. Code style: single quotes used in several places (e.g., _data['annotation']), Chinese comments without English equivalents in public example configs.

Verdict

REQUEST_CHANGES

The PR is not mergeable in its current state:

  1. The core feature does not work — the xtuner.v1.rl.dpo module (containing DPOLossConfig, DPOLossContext, DPOLossContextInputItem) is missing from the PR. Without it, every import fails.

  2. Too many unrelated breaking changes — the PR removes MuonConfig, LogProbConfig/Context, BaseRLLossContext, CustomPackDataset, and others. Per project standards: "One logical change per PR."

  3. Hardcoded paths throughout — example configs and scripts use developer-specific absolute paths instead of environment variables.

Recommended action: split this into (a) the loss system refactor as a separate PR, (b) the MPO feature with all required modules included, and (c) ensure no existing configs/APIs are broken.

Comment on lines +9 to +10
export AWS_ACCESS_KEY_ID=0giydv1f6acxmwwsvm54
export AWS_SECRET_ACCESS_KEY=gi3l1nedpaurw6606p2g4pykhfom1zhkeumhldzc
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: Leaked credentials in a public repository.

AWS access key ID and secret access key are hardcoded here. These credentials must be revoked immediately and removed from this PR. Use environment variables or a secrets manager instead — credentials should never appear in source code.

Even after removal from the PR, they will remain in git history. Revocation is mandatory.

self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Bug: Trailing commas on dataclass/config field defaults create tuples, not scalars.

In Python, kv_lora_rank: int = 896, assigns (896,) (a tuple), not 896. Same for dropout: float = 0.0,, qkv_bias: bool = True,, etc. This will cause type validation errors or unexpected behavior at runtime.

Remove the trailing commas from all these field defaults.

Comment on lines +5 to +10
in xtuner v1 framework, following the same pattern as RL configs.

Supported loss types:
- sigmoid: Standard DPO loss for preference learning
- bco_pair: Binary Classifier Optimization for absolute quality
- sft: Supervised Fine-Tuning loss to maintain generation quality
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: The DPO feature is non-functional — missing module xtuner.v1.rl.dpo.

This config imports DPOLossConfig from xtuner.v1.rl.dpo, and the trainer imports DPOLossContext, DPOColateItem, etc., but the module defining these classes is not included in this PR. Additionally:

  • Qwen3VLDPOTokenizeFnConfig is not defined anywhere
  • VLMPreferenceJsonlDataset is not defined anywhere
  • qwen3_vl_dpo_collator is not defined anywhere, and also not in the Literal type for DataloaderConfig.collator

The MPO/DPO feature cannot work as submitted.

Comment on lines +37 to +41
# ============================================================================
# 路径配置 (Path Configuration)
# ============================================================================
ceph_config = "/mnt/shared-storage-user/lisongze/iv3/xtuner/config/petreloss.conf"
meta_data_path = "/mnt/shared-storage-user/lisongze/iv3/xtuner/dpo_config/MMPR.json"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: Hardcoded internal paths in example config.

Paths like /mnt/shared-storage-user/lisongze/iv3/xtuner/... are developer-specific. Example configs should use environment variables (e.g., os.environ["MODEL_PATH"]) or clearly documented placeholders. Same issue in mpo_qwen3vl_8B_video.py and the shell scripts.

Comment on lines +85 to +89
if sequence_parallel_mesh is not None and sequence_parallel_mesh.size() > 1:
input_ids_list = [torch.empty_like(input_ids) for _ in range(sequence_parallel_mesh.size())]
dist.all_gather(input_ids_list, input_ids, group=sequence_parallel_mesh.get_group())
input_ids = torch.cat(input_ids_list, dim=1) # type: ignore

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: Silently swallowing RuntimeError in forward pass masks real bugs.

Catching RuntimeError and continuing with visual_embeds.sum() * 0.0 is dangerous — it hides real bugs (dimension mismatches, OOM, etc.) and produces silently incorrect training. The visual_embeds.sum() * 0.0 line is also duplicated. Additionally, the fallback path below that creates dummy pixel_values with torch.zeros and runs the full vision pipeline just to add 0.0 is wasteful.

@claude
Copy link

claude bot commented Mar 23, 2026

Claude: ## Review Summary — PR #1609 "feat: support Mixed Preference Optimization (MPO)"

This PR adds MPO/DPO support (new trainer, loss, configs, qwen3_vl_mla module) but also performs extensive refactoring across 110 files (~5400 additions, ~4100 deletions) that removes existing features and introduces several bugs. As the maintainer already noted, many recent XTuner updates appear to have been reverted.

Critical Issues

1. Leaked AWS Credentials (examples/v1/scripts/run_rl.sh)

Hardcoded AWS access key ID and secret access key in the script. These must be revoked immediately. Even after removal from the PR, they persist in git history.

2. DPO Feature is Non-Functional — Missing Core Module

The config imports DPOLossConfig from xtuner.v1.rl.dpo, but this module is not included in the PR. Also missing: Qwen3VLDPOTokenizeFnConfig, VLMPreferenceJsonlDataset, qwen3_vl_dpo_collator, DPOColateItem. The collator string qwen3_vl_dpo_collator also fails validation against the Literal type that this same PR tightens in DataloaderConfig.

3. MLA Module Corruption — Breaks All Existing MLA Models

Changes to the shared MultiLatentAttention module break all existing MLA models:

  • kv_a_layernorm removed globally (breaks pretrained checkpoint loading)
  • kv_b_proj output dimension formula changed (breaks weight shapes)
  • self.compressed_kv referenced but never defined (AttributeError at runtime)
  • q_proj bias hardcoded to True, MLAConfig.qkv_bias default changed

These should be gated behind config flags, not changed globally.

4. Hardcoded Internal Paths Replace Env-Var-Based Paths

  • examples/v1/config/rl_qwen3_vl_8B_grpo.py: env vars (os.environ) replaced with /mnt/shared-storage-user/... paths
  • examples/v1/scripts/run_rl.sh: CLI args replaced with hardcoded paths
  • All new MPO configs and scripts contain developer-specific paths

5. Trailing Commas Create Tuples in Qwen3Dense8BMLAConfig

kv_lora_rank: int = 896,   # This is (896,), a tuple!
dropout: float = 0.0,       # This is (0.0,), a tuple!

These will cause type errors at runtime.

6. Reverted Changes (Non-Exhaustive)

  • chunk_loss.py: Reverted from torch.autograd.grad back to torch.func.grad_and_value (will fail with frozen LM head — the exact DPO use case)
  • flash_attn/__init__.py: Reverted FlashAttention 3 compat layer (only flash_attn_3_cuda, not flash_attn_3_gpu)
  • flash_attn/gpu.py: Removed configurable XTUNER_SM_MARGIN env var
  • model/base.py: Removed _mark_dynamic (causes torch.compile recompilation) and missing_keys guard in _load_fused_hf_param
  • collator.py: Removed pad chunking (split_to_chunks) — changes flash attention masking behavior
  • data_proto/rl_data.py: f-string dropped, log message now prints literal {is_skipped} text
  • utils/convert_geometry3k.py: Introduced typo .repalce() (will crash)

7. Removed Features (Breaking API Changes)

  • MuonConfig optimizer removed entirely
  • Tool-calling support removed (tool role, tool_calls, tool_formatter, etc.)
  • GatedDeltaNet, zero_centered_rms_norm, Qwen3_5_VLMoE35BA3Config removed
  • proxy_attn_flops, CustomPackDataset, LogProbConfig/Context removed
  • Replay buffer async sampling architecture replaced

Warnings

  • modeling_qwen3_vl.py (qwen3_vl_mla): Silently swallows RuntimeError in forward pass, masking real bugs
  • datasets/config.py: Removed if dset.num_tokens is not None guard — AttributeError when num_tokens is None
  • qwen3_vl_sft_collator: Missing drop_from == 0 handling — crash on oversized samples
  • qwen3_vl_tokenize_fn.py: Default rand_video_max_frames changed from 24 to 512 (21x increase, potential OOM)
  • fsdp.py: Default torch_compile changed from True to False
  • Debug print() statements left in production code (model/base.py, datasets/jsonl.py)

Verdict: REQUEST_CHANGES

Recommendation: Split this PR into smaller, focused PRs:

  1. Infrastructure cleanup (if intentional): proxy_attn_flops removal, tool-calling removal, etc. — each with justification and migration notes
  2. MLA variant: New qwen3_vl_mla module that does NOT modify the shared MultiLatentAttention class
  3. MPO/DPO feature: With all required class definitions included (DPOLossConfig, Qwen3VLDPOTokenizeFnConfig, VLMPreferenceJsonlDataset, qwen3_vl_dpo_collator)

Please also rebase onto the latest main to avoid reverting recent changes, and remove all hardcoded paths and credentials.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants