From bfa1d7d91787a312fd8fab0c5dc149abcd78d289 Mon Sep 17 00:00:00 2001 From: LarryLeeee Date: Mon, 23 Mar 2026 15:30:04 +0800 Subject: [PATCH 1/2] feat: support Mixed Preference Optimization (MPO) --- examples/v1/config/mpo_qwen3_vl_8B.py | 193 +++ examples/v1/scripts/run_mpo_qwen3_vl.sh | 24 + xtuner/v1/datasets/__init__.py | 26 + xtuner/v1/datasets/config.py | 17 +- xtuner/v1/datasets/dpo_collator.py | 248 ++++ xtuner/v1/datasets/jsonl.py | 12 +- xtuner/v1/datasets/preference_dataset.py | 648 ++++++++++ xtuner/v1/datasets/qwen3vl_vision_process.py | 573 +++++++++ xtuner/v1/engine/train_engine.py | 19 +- .../v1/engine/vision_compose_train_engine.py | 228 ++++ xtuner/v1/loss/base_loss_ctx.py | 5 +- xtuner/v1/rl/dpo/__init__.py | 9 + xtuner/v1/rl/dpo/loss.py | 629 ++++++++++ xtuner/v1/rl/utils.py | 12 + xtuner/v1/train/cli/dpo.py | 83 ++ xtuner/v1/train/dpo_trainer.py | 1059 +++++++++++++++++ 16 files changed, 3780 insertions(+), 5 deletions(-) create mode 100755 examples/v1/config/mpo_qwen3_vl_8B.py create mode 100755 examples/v1/scripts/run_mpo_qwen3_vl.sh create mode 100644 xtuner/v1/datasets/dpo_collator.py create mode 100644 xtuner/v1/datasets/preference_dataset.py create mode 100644 xtuner/v1/datasets/qwen3vl_vision_process.py create mode 100644 xtuner/v1/engine/vision_compose_train_engine.py create mode 100644 xtuner/v1/rl/dpo/__init__.py create mode 100644 xtuner/v1/rl/dpo/loss.py create mode 100644 xtuner/v1/train/cli/dpo.py create mode 100644 xtuner/v1/train/dpo_trainer.py diff --git a/examples/v1/config/mpo_qwen3_vl_8B.py b/examples/v1/config/mpo_qwen3_vl_8B.py new file mode 100755 index 000000000..66a44f917 --- /dev/null +++ b/examples/v1/config/mpo_qwen3_vl_8B.py @@ -0,0 +1,193 @@ +""" +DPO (Direct Preference Optimization) Configuration for Qwen3-VL-8B + +This configuration demonstrates how to use DPO/MPO for offline preference learning +in xtuner v1 framework, following the same pattern as RL configs. + +Supported loss types: +- sigmoid: Standard DPO loss for preference learning +- bco_pair: Binary Classifier Optimization for absolute quality +- sft: Supervised Fine-Tuning loss to maintain generation quality + +For MPO (Mixed Preference Optimization), use: + loss_types=["sigmoid", "bco_pair", "sft"] + loss_weights=[0.8, 0.2, 1.0] + +Usage: + # Set environment variables + export WORK_DIR=/path/to/work_dir + export MODEL_PATH=/path/to/model + export META_DATA_PATH=/path/to/meta.json + + # Run with torchrun + torchrun --nproc_per_node=8 xtuner/v1/train/cli/dpo.py --config dpo_qwen3_vl_8B.py +""" + +import json + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.datasets import Qwen3VLDPOTokenizeFnConfig +from xtuner.v1.datasets.config import DatasetConfig, DataloaderConfig +from xtuner.v1.datasets.mllm_tokenize_fn import OSSLoaderConfig +from xtuner.v1.model import Qwen3VLDense8BConfig +from xtuner.v1.rl.dpo import DPOLossConfig +from xtuner.v1.train.dpo_trainer import DPOTrainerConfig + + +# ============================================================================ +# 路径配置 (Path Configuration) +# ============================================================================ +ceph_config = "/mnt/shared-storage-user/lisongze/iv3/xtuner/dpo_config/petreloss.conf" +meta_data_path = "/mnt/shared-storage-user/lisongze/iv3/xtuner/dpo_config/MMPR.json" +model_path = "/mnt/shared-storage-user/lisongze/iv3/xtuner_no/Qwen3-VL-8B-Instruct" #"/mnt/shared-storage-user/lisongze/cache/sp-mla-hf-800" +work_dir = "/mnt/shared-storage-user/iv3/mpo/xtuner_saved_model/qwen3vl-8B-mpo-sp-mmpr-fix-bug" +tokenizer_cache_dir = "/mnt/shared-storage-user/iv3/mpo/xtuner_tokenizer_cache/qwen3vl-8B-mpo-sp-mmpr-fix-bug" + +# ============================================================================ +# Training Settings (训练超参数) +# ============================================================================ +# global_batch_size = num_gpus × per_device_batch_size × gradient_accumulation_steps +# 8 GPUs × 4 gradient_accumulation_steps × 1 per_device_batch_size x 2 sp_size = 64 +total_epochs = 1 +global_batch_size = 64 # suppose 256 +per_device_batch_size = 1 +gradient_accumulation_steps = 4 +max_length = 4096 * 2 +pack_max_length = 4096 * 2 +num_workers = 8 +save_interval = 5000 +log_interval = 1000 + +# Learning rate settings +lr = 5e-6 # Lower LR for DPO +# Paper: cosine decay with minimum learning rate 0 +lr_min = 0#1e-8 +# Paper: linear warmup for first 5% of total training steps +warmup_ratio = 0.05 +weight_decay = 0.05 + +# ============================================================================ +# 1. Model Configuration +# ============================================================================ +# Freeze vision encoder to prevent catastrophic forgetting +model_cfg = Qwen3VLDense8BConfig()#freeze_vision=True, freeze_projector=True) + +# ============================================================================ +# 2. DPO Loss Configuration +# ============================================================================ +# Option 1: Standard DPO (sigmoid only) +# loss_cfg = DPOLossConfig( +# loss_types=["sigmoid"], +# loss_weights=[1.0], +# beta=0.1, +# ) + +# Option 2: MPO (Mixed Preference Optimization) +# Combines DPO, BCO, and SFT losses +loss_cfg = DPOLossConfig( + loss_types=["sigmoid", "bco_pair", "sft"], + loss_weights=[0.8, 0.2, 1.0], + beta=0.1, + label_smoothing=0.0, + reference_free=False, + use_average_log_prob=False, + mode="chunk", + chunk_size=512, + ignore_idx=-100, +) + +# ============================================================================ +# 3. Dataset Configuration - 逐个 JSON 加载 (参考 sft_internvl3.5_8B_config_tiny.py) +# ============================================================================ +oss_loader_cfg = OSSLoaderConfig(backend_kwargs={"conf_path": ceph_config}) +ds_collections = json.loads(open(meta_data_path).read()) +dataset_config = [] +for name, _data in ds_collections.items(): + _data_cfg = { + "dataset": DatasetConfig( + name=name, + anno_path=_data['annotation'], + media_root=_data.get('media_root', ''), + sample_ratio=_data.get('sample_ratio', 1.0), + class_name='VLMPreferenceJsonlDataset', # 使用偏好数据集类 + enable_sequential_sampler=True, + cache_tag='cache_tags_dpo_v1', + cache_dir=tokenizer_cache_dir, + ), + "tokenize_fn": Qwen3VLDPOTokenizeFnConfig( + processor_path=model_path, + max_length=max_length, + min_pixels=_data.get('min_pixels', None), + max_pixels=_data.get('max_pixels', None), + oss_loader_cfg=oss_loader_cfg, + prompt_key="prompt", + chosen_key="chosen", + rejected_key="rejected", + images_key="images", + add_eos_token=True, # 使用父类定义的字段名 + system_message=_data.get('system_message', None), + hash=_data.get('hash', None), + ), + } + dataset_config.append(_data_cfg) + +dataloader_config = DataloaderConfig( + dataset_config_list=dataset_config, + pack_max_length=pack_max_length, + pack_to_max_length=True,# must set to True if using sp_size>1 + pack_level="none", # DPO 不需要 packing,每个样本独立处理 + collator="qwen3_vl_dpo_collator", # 使用 DPO collator + num_workers=num_workers, + group_by_length=False, # pack_level=none 时必须为 False +) + +# ============================================================================ +# 4. Optimizer and Learning Rate +# ============================================================================ +optim_cfg = AdamWConfig(lr=lr, weight_decay=weight_decay, foreach=False) +lr_cfg = LRConfig(lr_type="cosine", warmup_ratio=warmup_ratio, lr_min=lr_min) + +# ============================================================================ +# 5. FSDP Configuration (内存优化版) +# ============================================================================ +fsdp_cfg = FSDPConfig( + # Gradient checkpointing: 1.0 = 完全启用,用计算换内存(最重要的内存优化) + recompute_ratio=1.0, + # Vision 模块也启用 gradient checkpointing + vision_recompute_ratio=1.0, + # 前向传播后重新分片参数,节省内存 + reshard_after_forward=True, + # 关闭 RNG 状态保存,节省少量内存(不影响训练质量,只影响精确复现) + checkpoint_preserve_rng_state=False, + # CPU offload:将优化器状态卸载到 CPU(会降速但省显存) + # cpu_offload=True, + # 关闭 torch.compile:VLM 动态 shape 不适合 compile,且编译时额外耗内存 + torch_compile=True, +) +# ============================================================================ +# 6. DPO Trainer Configuration (export as 'trainer' for CLI compatibility) +# ============================================================================ +trainer = DPOTrainerConfig( + model_cfg=model_cfg, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + load_from=model_path, + ref_load_from=None, # Use same model as reference + tokenizer_path=model_path, + work_dir=work_dir, + sp_size=1, + total_epochs=total_epochs, + global_batch_size=global_batch_size, + per_device_batch_size=per_device_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + max_length=max_length, + save_interval=save_interval, + log_interval=log_interval, + seed=42, + freeze_ref_model=True, + use_vlm_collator=True, + num_workers=num_workers, + dataloader_cfg=dataloader_config, +) \ No newline at end of file diff --git a/examples/v1/scripts/run_mpo_qwen3_vl.sh b/examples/v1/scripts/run_mpo_qwen3_vl.sh new file mode 100755 index 000000000..c0cbd0c33 --- /dev/null +++ b/examples/v1/scripts/run_mpo_qwen3_vl.sh @@ -0,0 +1,24 @@ +set -ex +export XTUNER_PACK_WORKERS=8 +export XTUNER_TOKENIZE_WORKERS=16 +export NCCL_TIMEOUT=10800 +export TORCH_DISTRIBUTED_TIMEOUT=10800 +export XTUNER_USE_FA3=1 +export PYTHONPATH="$(pwd)" +export HF_HOME="$(pwd)/" +export TORCHDYNAMO_VERBOSE=1 + +MASTER_PORT=20500 +config_file="/mnt/shared-storage-user/lisongze/xtuner/examples/v1/config/mpo_qwen3_vl_8B.py" +# NODE_COUNT=1 +# NODE_RANK=0 +# MASTER_ADDR=127.0.0.1 +# PROC_PER_NODE=8 + +torchrun \ + --nnodes=$NODE_COUNT \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --master_port=$MASTER_PORT \ + --nproc_per_node=$PROC_PER_NODE \ + /mnt/shared-storage-user/lisongze/xtuner/xtuner/v1/train/cli/dpo.py --config ${config_file} diff --git a/xtuner/v1/datasets/__init__.py b/xtuner/v1/datasets/__init__.py index 569620e62..aaa3485d8 100644 --- a/xtuner/v1/datasets/__init__.py +++ b/xtuner/v1/datasets/__init__.py @@ -8,6 +8,10 @@ DatasetConfigList, DatasetConfigListAdatper, ) +from .dpo_collator import ( + DPOColateItem, + qwen3_vl_dpo_collator +) from .custom_pack import CustomPackDataset from .custom_sampler import CustomSampler from .ftdp import FTDPTokenizeFnConfig, FtdpTokenizeFunction @@ -19,6 +23,15 @@ Qwen3VLTokenizeFunction, ) from .packing import ExpandSoftPackDataset, HardPackDataset, MLLMPretrainHybridPackDataset, _LegacySoftPackDataset +from .preference_dataset import ( + InMemoryPreferenceDataset, + PreferenceDataItem, + PreferenceJsonlDataset, + PreferenceTokenizeFunction, + VLMPreferenceJsonlDataset, + Qwen3VLDPOTokenizeFnConfig, + Qwen3VLDPOTokenizeFunction, +) from .pt_tokenize_fn import ( LongTextPretrainTokenizeFunction, LongTextPretrainTokenizeFunctionConfig, @@ -31,6 +44,7 @@ from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig from .utils import CachableTokenizeFunction, calculate_file_sha256, calculate_xxhash, tokenizer_hash from .vlm_jsonl import VLMJsonlDataset +from .qwen3vl_vision_process import process_vision_info from . import _hardcode_patch # isort: skip @@ -78,4 +92,16 @@ "DatasetConfig", "OpenaiTokenizeFunctionConfig", "OpenaiTokenizeFunction", + # DPO collators (xtuner v1 style) + "DPOColateItem", + "qwen3_vl_dpo_collator", + # Preference datasets + "PreferenceDataItem", + "PreferenceTokenizeFunction", + "PreferenceJsonlDataset", + "VLMPreferenceJsonlDataset", + "InMemoryPreferenceDataset", + "Qwen3VLDPOTokenizeFnConfig", + "Qwen3VLDPOTokenizeFunction", + "process_vision_info" ] diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index 20047e8a8..bd75a74f9 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -25,6 +25,7 @@ qwen3_vl_sft_collator, sft_llm_collator, ) +from .dpo_collator import qwen3_vl_dpo_collator from .custom_pack import CustomPackDataset from .custom_sampler import CustomSampler from .dataloader import BaseDataloader, Dataloader @@ -33,6 +34,7 @@ from .sampler import LengthGroupedSampler, ParallelSampler from .utils import CachableTokenizeFunction, tokenizer_xxhash from .vlm_jsonl import VLMJsonlDataset +from .preference_dataset import VLMPreferenceJsonlDataset logger = get_logger() @@ -77,6 +79,17 @@ def build( cache_dir=self.cache_dir, cache_tag=self.cache_tag, ) + elif self.class_name == "VLMPreferenceJsonlDataset": + return VLMPreferenceJsonlDataset( + tokenize_fn=tokenize_fn, + anno_path=self.anno_path, + sample_ratio=self.sample_ratio, + enable_sequential_sampler=self.enable_sequential_sampler, + name=self.name, + media_root=self.media_root, + cache_dir=self.cache_dir, + cache_tag=self.cache_tag, + ) else: raise ValueError(f"Unsupported class_name: {self.class_name}") @@ -276,7 +289,7 @@ class DataloaderConfig(BaseDataloaderConfig): dataset_config_list: DatasetConfigList | None = None collator: Annotated[ - Literal["sft_llm_collator", "intern_s1_vl_sft_collator", "qwen3_vl_sft_collator", "fake_collator"] | str, + Literal["sft_llm_collator", "intern_s1_vl_sft_collator", "qwen3_vl_sft_collator", "fake_collator", "qwen3_vl_dpo_collator"] | str, Parameter(help="collator func name"), ] = "sft_llm_collator" pack_to_max_length: Annotated[bool, Parameter(help="whether to pack to max length")] = True @@ -312,6 +325,8 @@ def build_collator(self): return qwen3_vl_sft_collator elif self.collator == "fake_collator": return fake_collator # for RL + elif self.collator == "qwen3_vl_dpo_collator": + return qwen3_vl_dpo_collator else: collator = pydoc.locate(self.collator) if collator is None: diff --git a/xtuner/v1/datasets/dpo_collator.py b/xtuner/v1/datasets/dpo_collator.py new file mode 100644 index 000000000..6e13cb7a2 --- /dev/null +++ b/xtuner/v1/datasets/dpo_collator.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +Data collator for DPO (Direct Preference Optimization) training. + +This module provides collators for handling preference data (chosen/rejected pairs) +in the xtuner v1 framework. It follows the existing xtuner v1 collator design +pattern, outputting SequenceContext and shifted_labels. +""" + +import torch +from typing import Any +from typing_extensions import TypedDict + +from xtuner.v1.data_proto import SequenceContext +from xtuner.v1.utils import IGNORE_INDEX, get_logger +from xtuner.v1.utils.pad import pad_to_max_length + + +logger = get_logger() + + +class DPOColateItem(TypedDict): + """Output of DPO collator, containing both chosen and rejected sequences.""" + # Chosen sequence + chosen_seq_ctx: SequenceContext + chosen_shifted_labels: torch.Tensor + # Rejected sequence + rejected_seq_ctx: SequenceContext + rejected_shifted_labels: torch.Tensor + # Optional: precomputed reference log probabilities + ref_chosen_logps: torch.Tensor | None + ref_rejected_logps: torch.Tensor | None + + +def qwen3_vl_dpo_collator( + instances: list[list[dict]], + pack_max_length: int, + padding_token_idx: int, + pack_to_max_length: bool = True, +) -> list[DPOColateItem]: + """ + Collate DPO preference data for Qwen3-VL models. + + This function handles the additional visual features (pixel_values, + image_grid_thw, position_ids) required by Qwen3-VL models. + + Args: + instances: List of batched data items with visual features. + pack_max_length: Maximum sequence length for packing. + padding_token_idx: Token ID to use for padding. + pack_to_max_length: Whether to pad sequences to max_length. + + Returns: + List of DPOColateItem with visual features included in SequenceContext. + + Example: + >>> from functools import partial + >>> collator = partial(qwen3_vl_dpo_collator, pack_max_length=4096, padding_token_idx=0) + >>> # Use with dataloader + """ + ret: list[DPOColateItem] = [] + + for instance in instances: + if isinstance(instance, dict): + instance = [instance] + + # Process chosen with visual features + chosen_result = _process_qwen3_vl_sequence( + instance, "chosen", pack_max_length, padding_token_idx, pack_to_max_length + ) + + # Process rejected with visual features + rejected_result = _process_qwen3_vl_sequence( + instance, "rejected", pack_max_length, padding_token_idx, pack_to_max_length + ) + + # Handle precomputed ref logps + ref_chosen_logps = None + ref_rejected_logps = None + if "ref_chosen_logps" in instance[0]: + ref_chosen_logps = torch.tensor([instance[0]["ref_chosen_logps"]]) + if "ref_rejected_logps" in instance[0]: + ref_rejected_logps = torch.tensor([instance[0]["ref_rejected_logps"]]) + + ret.append({ + "chosen_seq_ctx": chosen_result["seq_ctx"], + "chosen_shifted_labels": chosen_result["shifted_labels"], + "rejected_seq_ctx": rejected_result["seq_ctx"], + "rejected_shifted_labels": rejected_result["shifted_labels"], + "ref_chosen_logps": ref_chosen_logps, + "ref_rejected_logps": ref_rejected_logps, + }) + + return ret + + +def _process_qwen3_vl_sequence( + instance: list[dict], + seq_type: str, # "chosen" or "rejected" + pack_max_length: int, + padding_token_idx: int, + pack_to_max_length: bool = True, +) -> dict[str, Any]: + """ + Process a single Qwen3-VL sequence for DPO training. + + This handles the visual features specific to Qwen3-VL models. + + Args: + instance: List of data items. + seq_type: "chosen" or "rejected". + pack_max_length: Maximum sequence length. + padding_token_idx: Padding token ID. + pack_to_max_length: Whether to pad to max length. + + Returns: + Dictionary with seq_ctx (including visual features) and shifted_labels. + """ + # Extract data for the specified sequence type + data = [] + for item in instance: + input_ids_key = f"{seq_type}_input_ids" + labels_key = f"{seq_type}_labels" + + input_ids = item.get(input_ids_key, item.get("input_ids", [])) + labels = item.get(labels_key, item.get("labels", [])) + + data_item = { + "input_ids": input_ids, + "labels": labels, + "num_tokens": len(input_ids), + } + + # Copy visual features (shared between chosen and rejected) + if "pixel_values" in item: + data_item["pixel_values"] = item["pixel_values"] + if "image_grid_thw" in item: + data_item["image_grid_thw"] = item["image_grid_thw"] + if "num_img_tokens" in item: + data_item["num_img_tokens"] = item["num_img_tokens"] + if "position_ids" in item: + data_item["position_ids"] = item.get(f"{seq_type}_position_ids", item["position_ids"]) + + data.append(data_item) + + # Calculate total tokens and truncate if necessary + total_num_tokens = sum(d.get("num_tokens", len(d.get("input_ids", []))) for d in data) + + if total_num_tokens > pack_max_length: + logger.warning( + f"Found sample with {total_num_tokens} tokens > pack_max_length {pack_max_length}. Truncating." + ) + data[0]["input_ids"] = data[0]["input_ids"][:pack_max_length] + data[0]["labels"] = data[0]["labels"][:pack_max_length] + data[0]["num_tokens"] = pack_max_length + + # Concatenate input_ids and labels + input_ids = torch.cat([torch.tensor(d["input_ids"]).view(1, -1) for d in data], dim=-1) + labels = torch.cat([torch.tensor(d["labels"]).view(1, -1) for d in data], dim=-1) + + # Handle position_ids for Qwen3-VL (3D: temporal, height, width) + all_position_ids_none = all( + "position_ids" not in d or d["position_ids"] is None for d in data + ) + position_ids_list = [] + if not all_position_ids_none: + for d in data: + if "position_ids" in d and d["position_ids"] is not None: + pos_ids = d["position_ids"] + if not isinstance(pos_ids, torch.Tensor): + pos_ids = torch.tensor(pos_ids) + position_ids_list.append(pos_ids) + else: + # Create default position_ids (3, seq_len, seq_len) for Qwen3-VL + seq_len = len(d["input_ids"]) + pos_ids = torch.arange(seq_len).view(1, 1, -1).expand(3, seq_len, -1) + position_ids_list.append(pos_ids) + + # Shift for causal LM + input_ids = input_ids[:, :-1] + shifted_labels = labels[:, 1:] + + position_ids: torch.Tensor | None = None + if len(position_ids_list) > 0: + position_ids = torch.cat(position_ids_list, dim=-1) + position_ids = position_ids[:, :, :-1] + + # Calculate num_tokens + num_tokens = [d.get("num_tokens", len(d.get("input_ids", []))) for d in data] + if num_tokens[-1] == 1: + num_tokens = num_tokens[:-1] + else: + num_tokens[-1] -= 1 + + # Padding + if pack_to_max_length: + pad_len = pack_max_length - input_ids.shape[-1] + else: + pad_len = 0 + + if pad_len > 0: + input_ids = pad_to_max_length(input_ids, padding_token_idx, max_length=pack_max_length, dim=-1) + shifted_labels = pad_to_max_length(shifted_labels, IGNORE_INDEX, max_length=pack_max_length, dim=-1) + if position_ids is not None: + position_ids = pad_to_max_length(position_ids, 0, max_length=pack_max_length, dim=-1) + num_tokens = [0] + num_tokens + [pad_len] + elif pad_len < 0: + raise ValueError(f"Sample length {input_ids.shape[-1]} > pack_max_length {pack_max_length}") + else: + num_tokens = [0] + num_tokens + + cu_seq_lens = torch.cumsum(torch.IntTensor(num_tokens), dim=0).int() + + # Collect visual features + num_img_tokens: list[int] = [] + for d in data: + num_img_tokens.extend(d.get("num_img_tokens", [0])) + + pixel_values: torch.Tensor | None = None + pv_list = [d["pixel_values"] for d in data if "pixel_values" in d] + if pv_list: + if all(isinstance(pv, torch.Tensor) for pv in pv_list): + pixel_values = torch.cat(pv_list, dim=0) + + image_grid_thw: torch.Tensor | None = None + thw_list = [d["image_grid_thw"] for d in data if "image_grid_thw" in d] + if thw_list: + if all(isinstance(thw, torch.Tensor) for thw in thw_list): + image_grid_thw = torch.cat(thw_list, dim=0) + + # Create SequenceContext with visual features + seq_ctx = SequenceContext( + input_ids=input_ids, + cu_seq_lens_q=cu_seq_lens, + cu_seq_lens_k=cu_seq_lens, + max_length_q=max(num_tokens), + max_length_k=max(num_tokens), + num_padding=pad_len, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + num_img_tokens=num_img_tokens if num_img_tokens else None, + ) + + return { + "seq_ctx": seq_ctx, + "shifted_labels": shifted_labels, + } diff --git a/xtuner/v1/datasets/jsonl.py b/xtuner/v1/datasets/jsonl.py index a7ed5a014..dcffa4270 100644 --- a/xtuner/v1/datasets/jsonl.py +++ b/xtuner/v1/datasets/jsonl.py @@ -624,7 +624,17 @@ def _tokenize_by_offset( ) -> dict: line = data.decode() tokenized: dict = tokenize_fn(json.loads(line)) # type: ignore[assignment] - res = {"num_tokens": tokenized["num_tokens"], "proxy_attn_flops": tokenized["proxy_attn_flops"]} + num_tokens = tokenized["num_tokens"] + # Some legacy tokenize fns in cache mode only return `num_tokens`. + # Fallback to token length as proxy_attn_flops to keep cache building compatible. + if "proxy_attn_flops" in tokenized: + proxy_attn_flops = tokenized["proxy_attn_flops"] + elif isinstance(num_tokens, list): + proxy_attn_flops = [float(nt) for nt in num_tokens] + else: + proxy_attn_flops = float(num_tokens) + + res = {"num_tokens": num_tokens, "proxy_attn_flops": proxy_attn_flops} if "chunks" in tokenized: res["chunks"] = tokenized["chunks"] return res diff --git a/xtuner/v1/datasets/preference_dataset.py b/xtuner/v1/datasets/preference_dataset.py new file mode 100644 index 000000000..4a1897609 --- /dev/null +++ b/xtuner/v1/datasets/preference_dataset.py @@ -0,0 +1,648 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +Preference Dataset for DPO (Direct Preference Optimization) training. + +This module provides dataset classes for loading and processing preference data +(chosen/rejected pairs) in the xtuner v1 framework. +""" + +import json +import os +from pathlib import Path +from typing import Any, Callable, Literal, TypeVar + +import torch +from pydantic import BaseModel, ConfigDict +from torch.utils.data import Dataset + +from xtuner.v1.utils import get_logger + +from .data_item import DataItem, QwenVL3DataItem +from .jsonl import JsonlDataset +from .utils import CachableTokenizeFunction + + +logger = get_logger() +T = TypeVar("T") + + +class PreferenceDataItem(DataItem): + """Data item for preference data containing chosen and rejected responses.""" + chosen_input_ids: list[int] + chosen_labels: list[int] + rejected_input_ids: list[int] + rejected_labels: list[int] + # For VLM + pixel_values: torch.Tensor | None + image_grid_thw: torch.Tensor | None + position_ids: torch.Tensor | None + num_img_tokens: list[int] | None + + +class PreferenceTokenizeFunction(CachableTokenizeFunction[PreferenceDataItem]): + """ + Tokenize function for preference data. + + This function tokenizes the prompt, chosen, and rejected responses + separately, then combines them for DPO training. + """ + + def __init__( + self, + tokenizer, + max_length: int = 4096, + max_prompt_length: int | None = None, + prompt_key: str = "prompt", + chosen_key: str = "chosen", + rejected_key: str = "rejected", + add_eos: bool = True, + ): + """ + Initialize the preference tokenize function. + + Args: + tokenizer: The tokenizer to use. + max_length: Maximum total length (prompt + response). + max_prompt_length: Maximum prompt length. If None, no limit. + prompt_key: Key for prompt in the data dict. + chosen_key: Key for chosen response in the data dict. + rejected_key: Key for rejected response in the data dict. + add_eos: Whether to add EOS token to responses. + """ + self.tokenizer = tokenizer + self.max_length = max_length + self.max_prompt_length = max_prompt_length + self.prompt_key = prompt_key + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.add_eos = add_eos + self._state = "runtime" + + def set_state(self, state: str): + """Set the state of the tokenize function (cache or runtime).""" + self._state = state + + def hash(self) -> str: + """Return a hash for caching purposes.""" + import hashlib + config_str = f"{self.tokenizer.name_or_path}_{self.max_length}_{self.max_prompt_length}_{self.add_eos}" + return hashlib.md5(config_str.encode()).hexdigest()[:16] + + def __call__(self, data: dict[str, Any], **kwargs) -> PreferenceDataItem | dict: + """ + Tokenize the preference data. + + Args: + data: Dictionary containing prompt, chosen, and rejected. + + Returns: + PreferenceDataItem with tokenized sequences. + """ + prompt = data.get(self.prompt_key, "") + chosen = data.get(self.chosen_key, "") + rejected = data.get(self.rejected_key, "") + + # Tokenize prompt + prompt_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + + # Truncate prompt if needed + if self.max_prompt_length is not None and len(prompt_ids) > self.max_prompt_length: + prompt_ids = prompt_ids[-self.max_prompt_length:] + + # Tokenize chosen and rejected + chosen_ids = self.tokenizer(chosen, add_special_tokens=False)["input_ids"] + rejected_ids = self.tokenizer(rejected, add_special_tokens=False)["input_ids"] + + # Add EOS token + if self.add_eos and self.tokenizer.eos_token_id is not None: + chosen_ids = chosen_ids + [self.tokenizer.eos_token_id] + rejected_ids = rejected_ids + [self.tokenizer.eos_token_id] + + # Calculate max completion length + max_completion_length = self.max_length - len(prompt_ids) + + # Truncate completions if needed + if len(chosen_ids) > max_completion_length: + chosen_ids = chosen_ids[:max_completion_length] + if len(rejected_ids) > max_completion_length: + rejected_ids = rejected_ids[:max_completion_length] + + # Build full sequences + # For DPO: prompt tokens are masked in labels (-100) + chosen_input_ids = prompt_ids + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + + # Labels: mask prompt tokens + chosen_labels = [-100] * len(prompt_ids) + chosen_ids + rejected_labels = [-100] * len(prompt_ids) + rejected_ids + + # Calculate num_tokens for packing + num_tokens = max(len(chosen_input_ids), len(rejected_input_ids)) + + if self._state == "cache": + return {"num_tokens": num_tokens} + + return { + "input_ids": chosen_input_ids, # For compatibility with SFT collator + "labels": chosen_labels, + "num_tokens": num_tokens, + "chosen_input_ids": chosen_input_ids, + "chosen_labels": chosen_labels, + "rejected_input_ids": rejected_input_ids, + "rejected_labels": rejected_labels, + } + + +class PreferenceJsonlDataset(JsonlDataset[PreferenceDataItem]): + """ + JSONL dataset for preference data. + + Expected JSONL format: + {"prompt": "...", "chosen": "...", "rejected": "..."} + + For VLM: + {"prompt": "...", "chosen": "...", "rejected": "...", "images": ["path/to/img.jpg"]} + """ + + def __init__( + self, + anno_path: str | Path, + tokenize_fn: PreferenceTokenizeFunction | CachableTokenizeFunction | None = None, + sample_ratio: float = 1.0, + name: str = "preference", + cache_dir: str | Path | None = None, + max_length: int | None = None, + cache_tag: str | None = None, + **kwargs, + ): + """ + Initialize the preference JSONL dataset. + + Args: + anno_path: Path to the JSONL file. + tokenize_fn: Tokenize function for processing data. + sample_ratio: Ratio of samples to use. + name: Dataset name for logging. + cache_dir: Directory for caching. + max_length: Maximum sequence length. + cache_tag: Tag for caching. + """ + super().__init__( + anno_path=anno_path, + tokenize_fn=tokenize_fn, + sample_ratio=sample_ratio, + name=name, + cache_dir=cache_dir, + max_length=max_length, + cache_tag=cache_tag, + **kwargs, + ) + + +class VLMPreferenceJsonlDataset(PreferenceJsonlDataset): + """ + JSONL dataset for VLM preference data with image support. + + Expected JSONL format: + { + "prompt": "...", + "chosen": "...", + "rejected": "...", + "images": ["path/to/img1.jpg", "path/to/img2.jpg"] + } + """ + + def __init__( + self, + *args, + media_root: str | None = "", + **kwargs, + ): + """ + Initialize the VLM preference dataset. + + Args: + media_root: Root directory for media files. + """ + from functools import wraps + + if media_root is None: + media_root = "" + self.media_root = media_root + + # IMPORTANT: Wrap tokenize_fn BEFORE calling super().__init__() + # because count_tokens() is called during super().__init__() and needs media_root + tokenize_fn = kwargs.get('tokenize_fn') + if tokenize_fn is not None: + original_tokenize_fn = tokenize_fn + _media_root = media_root # Capture in closure + + @wraps(original_tokenize_fn) + def tokenize_fn_with_media_root(data, **fn_kwargs): + # Always inject media_root if not provided + if 'media_root' not in fn_kwargs: + fn_kwargs['media_root'] = _media_root + return original_tokenize_fn(data, **fn_kwargs) + + # Preserve the original methods for CachableTokenizeFunction + tokenize_fn_with_media_root.set_state = original_tokenize_fn.set_state + tokenize_fn_with_media_root.hash = original_tokenize_fn.hash + kwargs['tokenize_fn'] = tokenize_fn_with_media_root + + super().__init__(*args, **kwargs) + + # Fake data for error handling + self.fake_data = { + "id": -1, + "prompt": "你好", + "chosen": "你好呀!很高兴为你服务~", + "rejected": "抱歉,我不太明白。", + } + + def __getitem__(self, item) -> PreferenceDataItem | dict: + """Get an item from the dataset with error handling.""" + try: + with open(self.path) as f: + f.seek(self.offsets[item]) + line = f.readline() + + raw_data = json.loads(line) + + if self.tokenize_fn: + tokenized_data = self.tokenize_fn(raw_data, media_root=self.media_root) + return tokenized_data + else: + return raw_data + except Exception as e: + logger.warning(f"[{os.path.basename(self.path)}]: {e}. Dumping a fake data.") + data = self.tokenize_fn(self.fake_data) + assert isinstance(data, dict), f"Expected dict, got {type(data)}" + # Mask all labels for fake data + if "chosen_labels" in data: + data["chosen_labels"] = [-100] * len(data["chosen_input_ids"]) + if "rejected_labels" in data: + data["rejected_labels"] = [-100] * len(data["rejected_input_ids"]) + return data + + +class InMemoryPreferenceDataset(Dataset): + """ + In-memory preference dataset for smaller datasets. + + This dataset loads all data into memory for faster access during training. + """ + + def __init__( + self, + data: list[dict[str, Any]], + tokenize_fn: Callable | None = None, + ): + """ + Initialize the in-memory dataset. + + Args: + data: List of preference data dictionaries. + tokenize_fn: Optional tokenize function. + """ + self.data = data + self.tokenize_fn = tokenize_fn + + # Pre-tokenize if function provided + if tokenize_fn is not None: + self.tokenized_data = [tokenize_fn(d) for d in data] + self.num_tokens = [d.get("num_tokens", 0) for d in self.tokenized_data] + else: + self.tokenized_data = None + self.num_tokens = None + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> dict: + if self.tokenized_data is not None: + return self.tokenized_data[idx] + return self.data[idx] + + @classmethod + def from_jsonl( + cls, + path: str | Path, + tokenize_fn: Callable | None = None, + ) -> "InMemoryPreferenceDataset": + """ + Create dataset from a JSONL file. + + Args: + path: Path to the JSONL file. + tokenize_fn: Optional tokenize function. + + Returns: + InMemoryPreferenceDataset instance. + """ + data = [] + with open(path) as f: + for line in f: + data.append(json.loads(line)) + return cls(data, tokenize_fn) + + @classmethod + def from_hf_dataset( + cls, + dataset, + tokenize_fn: Callable | None = None, + prompt_key: str = "prompt", + chosen_key: str = "chosen", + rejected_key: str = "rejected", + ) -> "InMemoryPreferenceDataset": + """ + Create dataset from a HuggingFace dataset. + + Args: + dataset: HuggingFace dataset object. + tokenize_fn: Optional tokenize function. + prompt_key: Key for prompt in the dataset. + chosen_key: Key for chosen response. + rejected_key: Key for rejected response. + + Returns: + InMemoryPreferenceDataset instance. + """ + data = [] + for item in dataset: + data.append({ + "prompt": item.get(prompt_key, ""), + "chosen": item.get(chosen_key, ""), + "rejected": item.get(rejected_key, ""), + }) + return cls(data, tokenize_fn) + +# ============================================================================ +# DPO Tokenize Function - Inherits from Qwen3VLTokenizeFunction +# ============================================================================ + +from .mllm_tokenize_fn import OSSLoaderConfig +from .mllm_tokenize_fn.qwen3_vl_tokenize_fn import ( + Qwen3VLTokenizeFunction, + Qwen3VLTokenizeFnConfig, +) + + +class Qwen3VLDPOTokenizeFnConfig(Qwen3VLTokenizeFnConfig): + """ + Configuration for Qwen3-VL DPO tokenize function. + + Inherits from Qwen3VLTokenizeFnConfig to reuse all VLM processing config. + """ + + # DPO specific keys + prompt_key: str = "question" + chosen_key: str = "chosen" + rejected_key: str = "rejected" + images_key: str = "image" + + def build( + self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs + ) -> "Qwen3VLDPOTokenizeFunction": + return Qwen3VLDPOTokenizeFunction( + tokenizer=tokenizer, + processor_path=self.processor_path, + anno_name=anno_name, + prompt_key=self.prompt_key, + chosen_key=self.chosen_key, + rejected_key=self.rejected_key, + images_key=self.images_key, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + oss_loader_cfg=self.oss_loader_cfg, + video_min_total_pixels=self.video_min_total_pixels, + video_max_total_pixels=self.video_max_total_pixels, + video_min_frames=self.video_min_frames, + video_max_frames=self.video_max_frames, + rand_video_max_frames=self.rand_video_max_frames, + fps=self.fps, + enable_3d_rope=self.enable_3d_rope, + add_vision_id=self.add_vision_id, + max_length=self.max_length, + system_message=self.system_message, + tokenizer_hash=tokenizer_hash, + hash=self.hash, + debug=self.debug, + oss_time_log_thr=self.oss_time_log_thr, + add_eos_token=self.add_eos_token, + add_bos_token=self.add_bos_token, + ) + + +class Qwen3VLDPOTokenizeFunction(Qwen3VLTokenizeFunction): + """ + DPO tokenize function for Qwen3-VL. + + Inherits from Qwen3VLTokenizeFunction to reuse all image/video processing logic. + Handles DPO format: {question, chosen, rejected, image} -> two sequences. + + Expected data format (MMPR style): + { + "question": [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "..."}]}], + "chosen": [{"role": "assistant", "content": [{"type": "text", "text": "..."}]}], + "rejected": [{"role": "assistant", "content": [{"type": "text", "text": "..."}]}], + "image": "path/to/image.jpg" + } + """ + + def __init__( + self, + tokenizer, + processor_path: str, + anno_name: str, + prompt_key: str = "question", + chosen_key: str = "chosen", + rejected_key: str = "rejected", + images_key: str = "image", + **kwargs, + ): + self.prompt_key = prompt_key + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.images_key = images_key + super().__init__(tokenizer, processor_path, anno_name, **kwargs) + + def _extract_text_from_messages(self, messages) -> str: + """Extract text content from messages format.""" + if isinstance(messages, str): + return messages + if isinstance(messages, list) and len(messages) > 0: + for msg in messages: + if isinstance(msg, dict): + content = msg.get('content', '') + if isinstance(content, str): + return content + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get('type') == 'text': + return item.get('text', '') + return str(messages) + + def _convert_mmpr_to_xtuner_format(self, messages: list, images: list, image_wh_list: list | None = None) -> list: + """ + Convert MMPR format messages to xtuner ChatMessages format. + + MMPR format: {"type": "image"}, {"type": "text", "text": "..."} + xtuner format: {"type": "image_url", "image_url": {...}}, {"type": "text", "text": "..."} + + The key difference is that xtuner expects placeholder in text content, + while MMPR uses separate {"type": "image"} items. + + Args: + messages: List of message dicts + images: List of image paths + image_wh_list: List of [width, height] for each image + """ + import copy + converted = [] + image_idx = 0 + + for msg in messages: + new_msg = copy.deepcopy(msg) + content = new_msg.get('content', []) + + if isinstance(content, list): + new_content = [] + + for item in content: + if isinstance(item, dict): + if item.get('type') == 'image': + # Convert MMPR image format to xtuner format + # MMPR may have only one {"type": "image"} but multiple images in the list + # Add ALL remaining images when we see the first image placeholder + while image_idx < len(images): + img_path = images[image_idx] + image_url_dict = {"url": img_path} + # Add image_wh if available + if image_wh_list is not None and image_idx < len(image_wh_list): + image_url_dict["image_wh"] = image_wh_list[image_idx] + new_content.append({ + "type": "image_url", + "image_url": image_url_dict + }) + image_idx += 1 + elif item.get('type') == 'text': + # Data already preprocessed with , just pass through + new_content.append({ + "type": "text", + "text": item.get('text', '') + }) + else: + new_content.append(item) + else: + new_content.append(item) + + new_msg['content'] = new_content + converted.append(new_msg) + + return converted + + def _build_dpo_messages(self, data_item: dict, response_messages: list) -> dict: + """Build messages format for DPO by combining prompt and response.""" + prompt_messages = data_item.get(self.prompt_key, []) + images = data_item.get(self.images_key, []) + image_wh_list = data_item.get("image_wh", None) + + # Normalize images to list + if isinstance(images, str): + images = [images] if images else [] + elif images is None: + images = [] + + # Ensure prompt_messages is a list + if not isinstance(prompt_messages, list): + prompt_messages = [{"role": "user", "content": [{"type": "text", "text": str(prompt_messages)}]}] + + # Convert MMPR format to xtuner format (with image_wh) + prompt_messages = self._convert_mmpr_to_xtuner_format(prompt_messages, images, image_wh_list) + response_messages = self._convert_mmpr_to_xtuner_format(response_messages, [], None) + + # Combine prompt and response + combined_messages = prompt_messages + response_messages + + return {"messages": combined_messages} + + def __call__(self, item: dict, media_root: str = "", **kwargs): + """ + Process DPO data item. + + Returns dict with: + - chosen_input_ids, chosen_labels + - rejected_input_ids, rejected_labels + - pixel_values, image_grid_thw (shared between chosen and rejected) + - num_tokens + """ + from .data_item import CacheItem + + # Get chosen and rejected responses + chosen_raw = item.get(self.chosen_key, []) + rejected_raw = item.get(self.rejected_key, []) + + # Build combined messages for chosen and rejected + chosen_item = self._build_dpo_messages(item, chosen_raw) + rejected_item = self._build_dpo_messages(item, rejected_raw) + + # Use parent's __call__ to process (handles image extraction and processing) + try: + # Process chosen sequence (this also processes images) + chosen_result = super().__call__(chosen_item, media_root=media_root, **kwargs) + if self.state == "cache": + # For cache mode, also check rejected sequence length + rejected_result = super().__call__(rejected_item, media_root=media_root, **kwargs) + num_tokens = max( + chosen_result.get("num_tokens", 0) if isinstance(chosen_result, dict) else chosen_result.num_tokens, + rejected_result.get("num_tokens", 0) if isinstance(rejected_result, dict) else rejected_result.num_tokens, + ) + return {"num_tokens": num_tokens} + + # Process rejected sequence + rejected_result = super().__call__(rejected_item, media_root=media_root, **kwargs) + + # Normalize outputs: parent may return dict or DataItem-like object + def _as_dict(x: Any) -> dict: + if isinstance(x, dict): + return x + # DataItem / pydantic model / dataclass-like + if hasattr(x, "model_dump"): + return x.model_dump() + if hasattr(x, "__dict__"): + return dict(x.__dict__) + return {"value": x} + + chosen_d = _as_dict(chosen_result) + rejected_d = _as_dict(rejected_result) + + # Build DPO result + result = { + "input_ids": chosen_d["input_ids"], + "labels": chosen_d["labels"], + "num_tokens": max(len(chosen_d["input_ids"]), len(rejected_d["input_ids"])), + "chosen_input_ids": chosen_d["input_ids"], + "chosen_labels": chosen_d["labels"], + "rejected_input_ids": rejected_d["input_ids"], + "rejected_labels": rejected_d["labels"], + } + + # Add visual features (shared between chosen and rejected) + if chosen_d.get("pixel_values", None) is not None: + result["pixel_values"] = chosen_d["pixel_values"] + if chosen_d.get("image_grid_thw", None) is not None: + result["image_grid_thw"] = chosen_d["image_grid_thw"] + if chosen_d.get("position_ids", None) is not None: + result["chosen_position_ids"] = chosen_d["position_ids"] + if rejected_d.get("position_ids", None) is not None: + result["rejected_position_ids"] = rejected_d["position_ids"] + if chosen_d.get("num_img_tokens", None) is not None: + result["num_img_tokens"] = chosen_d["num_img_tokens"] + + return result + + except Exception as e: + logger.warning(f"Failed to process DPO item: {e}") + if self.state == "cache": + return {"num_tokens": 0} + raise diff --git a/xtuner/v1/datasets/qwen3vl_vision_process.py b/xtuner/v1/datasets/qwen3vl_vision_process.py new file mode 100644 index 000000000..08551fed0 --- /dev/null +++ b/xtuner/v1/datasets/qwen3vl_vision_process.py @@ -0,0 +1,573 @@ +import base64 +import copy +import logging +import math +import os +import sys +import time +import warnings +from functools import lru_cache +from io import BytesIO +from typing import Optional, Union, Tuple, List, Any, Dict +from concurrent.futures import ThreadPoolExecutor + +import requests +import torch +import torchvision +from packaging import version +from PIL import Image +import numpy as np +from torchvision import io, transforms +from torchvision.transforms import InterpolationMode + + +MAX_RATIO = 200 +SPATIAL_MERGE_SIZE = 2 +IMAGE_MIN_TOKEN_NUM = 4 +IMAGE_MAX_TOKEN_NUM = 16384 +VIDEO_MIN_TOKEN_NUM = 128 +VIDEO_MAX_TOKEN_NUM = 768 + +FPS = 2.0 +FRAME_FACTOR = 2 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 768 +MAX_NUM_WORKERS_FETCH_VIDEO = 8 + +MODEL_SEQ_LEN = int(float(os.environ.get('MODEL_SEQ_LEN', 128000))) +logger = logging.getLogger(__name__) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize(height: int, width: int, factor: int, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None) -> Tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + max_pixels = max_pixels if max_pixels is not None else (IMAGE_MAX_TOKEN_NUM * factor ** 2) + min_pixels = min_pixels if min_pixels is not None else (IMAGE_MIN_TOKEN_NUM * factor ** 2) + assert max_pixels >= min_pixels, "The max_pixels of image must be greater than or equal to min_pixels." + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def to_rgb(pil_image: Image.Image) -> Image.Image: + if pil_image.mode == 'RGBA': + white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) + white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask + return white_background + else: + return pil_image.convert("RGB") + + +def _get_petrel_client(): + """Lazy initialization of petrel client for OSS/Ceph storage. + + Uses PETREL_CONF_PATH environment variable if set. + """ + global _petrel_client + if '_petrel_client' not in globals() or _petrel_client is None: + try: + import os + from petrel_client.client import Client + conf_path = os.environ.get("PETREL_CONF_PATH") + if conf_path: + _petrel_client = Client(conf_path=conf_path) + logger.info(f"Initialized petrel client with config: {conf_path}") + else: + _petrel_client = Client() + logger.info("Initialized petrel client with default config") + except ImportError: + logger.warning("petrel_client not available, hceph: paths will not work") + _petrel_client = None + return _petrel_client + +_petrel_client = None + + +def fetch_image(ele: Dict[str, Union[str, Image.Image]], image_patch_size: int = 14) -> Image.Image: + if "image" in ele: + image = ele["image"] + else: + image = ele["image_url"] + + image_obj = None + patch_factor = int(image_patch_size * SPATIAL_MERGE_SIZE) + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + with requests.get(image, stream=True) as response: + response.raise_for_status() + with BytesIO(response.content) as bio: + image_obj = copy.deepcopy(Image.open(bio)) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + with BytesIO(data) as bio: + image_obj = copy.deepcopy(Image.open(bio)) + elif image.startswith("hceph:") or image.startswith("s3://"): + # Handle OSS/Ceph storage paths + client = _get_petrel_client() + if client is None: + raise ValueError(f"petrel_client not available, cannot load image from {image}") + # Remove hceph: prefix if present + # oss_path = image[6:] if image.startswith("hceph:") else image + oss_path = image + try: + img_bytes = client.get(oss_path) + with BytesIO(img_bytes) as bio: + image_obj = copy.deepcopy(Image.open(bio)) + except Exception as e: + raise ValueError(f"Failed to load image from OSS: {image}, error: {e}") + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") + image = to_rgb(image_obj) + + ## resize + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=patch_factor, + ) + else: + width, height = image.size + min_pixels = ele.get("min_pixels", IMAGE_MIN_TOKEN_NUM * patch_factor ** 2) + max_pixels = ele.get("max_pixels", IMAGE_MAX_TOKEN_NUM * patch_factor ** 2) + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + return image + + +def smart_nframes( + ele: Dict[str, Any], + total_frames: int, + video_fps: Union[int, float], +) -> int: + """calculate the number of frames for video used for model inputs. + + Args: + ele (dict): a dict contains the configuration of video. + support either `fps` or `nframes`: + - nframes: the number of frames to extract for model inputs. + - fps: the fps to extract frames for model inputs. + - min_frames: the minimum number of frames of the video, only used when fps is provided. + - max_frames: the maximum number of frames of the video, only used when fps is provided. + total_frames (int): the original total number of frames of the video. + video_fps (int | float): the original fps of the video. + + Raises: + ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. + + Returns: + int: the number of frames for video used for model inputs. + """ + assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) + else: + fps = ele.get("fps", FPS) + min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) + max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) + nframes = total_frames / video_fps * fps + if nframes > total_frames: + logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") + nframes = min(min(max(nframes, min_frames), max_frames), total_frames) + nframes = floor_by_factor(nframes, FRAME_FACTOR) + if not (FRAME_FACTOR <= nframes and nframes <= total_frames): + raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") + return nframes + + +def _read_video_torchvision( + ele: Dict[str, Any], +) -> Tuple[torch.Tensor, float]: + """read video using torchvision.io.read_video + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + video_path = ele["video"] + if version.parse(torchvision.__version__) < version.parse("0.19.0"): + if "http://" in video_path or "https://" in video_path: + warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") + if "file://" in video_path: + video_path = video_path[7:] + st = time.time() + video, audio, info = io.read_video( + video_path, + start_pts=ele.get("video_start", 0.0), + end_pts=ele.get("video_end", None), + pts_unit="sec", + output_format="TCHW", + ) + total_frames, video_fps = video.size(0), info["video_fps"] + logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long() + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + video = video[idx] + + video_metadata = dict( + fps=video_fps, + frames_indices=idx, + total_num_frames=total_frames, + video_backend="torchvision", + ) + return video, video_metadata, sample_fps + + +def is_decord_available() -> bool: + import importlib.util + + return importlib.util.find_spec("decord") is not None + + +def calculate_video_frame_range( + ele: Dict[str, Any], + total_frames: int, + video_fps: float, +) -> Tuple[int, int, int]: + """ + Calculate the start and end frame indices based on the given time range. + + Args: + ele (dict): A dictionary containing optional 'video_start' and 'video_end' keys (in seconds). + total_frames (int): Total number of frames in the video. + video_fps (float): Frames per second of the video. + + Returns: + tuple: A tuple containing (start_frame, end_frame, frame_count). + + Raises: + ValueError: If input parameters are invalid or the time range is inconsistent. + """ + # Validate essential parameters + if video_fps <= 0: + raise ValueError("video_fps must be a positive number") + if total_frames <= 0: + raise ValueError("total_frames must be a positive integer") + + # Get start and end time in seconds + video_start = ele.get("video_start", None) + video_end = ele.get("video_end", None) + if video_start is None and video_end is None: + return 0, total_frames - 1, total_frames + + max_duration = total_frames / video_fps + # Process start frame + if video_start is not None: + video_start_clamped = max(0.0, min(video_start, max_duration)) + start_frame = math.ceil(video_start_clamped * video_fps) + else: + start_frame = 0 + # Process end frame + if video_end is not None: + video_end_clamped = max(0.0, min(video_end, max_duration)) + end_frame = math.floor(video_end_clamped * video_fps) + end_frame = min(end_frame, total_frames - 1) + else: + end_frame = total_frames - 1 + + # Validate frame order + if start_frame >= end_frame: + raise ValueError( + f"Invalid time range: Start frame {start_frame} (at {video_start_clamped if video_start is not None else 0}s) " + f"exceeds end frame {end_frame} (at {video_end_clamped if video_end is not None else max_duration}s). " + f"Video duration: {max_duration:.2f}s ({total_frames} frames @ {video_fps}fps)" + ) + + logger.info(f"calculate video frame range: {start_frame=}, {end_frame=}, {total_frames=} from {video_start=}, {video_end=}, {video_fps=:.3f}") + return start_frame, end_frame, end_frame - start_frame + 1 + + +def _read_video_decord( + ele: Dict[str, Any], +) -> Tuple[torch.Tensor, float]: + """read video using decord.VideoReader + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + import decord + video_path = ele["video"] + st = time.time() + vr = decord.VideoReader(video_path) + total_frames, video_fps = len(vr), vr.get_avg_fps() + start_frame, end_frame, total_frames = calculate_video_frame_range( + ele, + total_frames, + video_fps, + ) + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist() + video = vr.get_batch(idx).asnumpy() + video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format + logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + + video_metadata = dict( + fps=video_fps, + frames_indices=idx, + total_num_frames=total_frames, + video_backend="decord", + ) + return video, video_metadata, sample_fps + + +def is_torchcodec_available() -> bool: + import importlib.util + + return importlib.util.find_spec("torchcodec") is not None + + +def _read_video_torchcodec( + ele: Dict[str, Any], +) -> Tuple[torch.Tensor, float]: + """read video using torchcodec.decoders.VideoDecoder + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + from torchcodec.decoders import VideoDecoder + TORCHCODEC_NUM_THREADS = int(os.environ.get('TORCHCODEC_NUM_THREADS', 8)) + logger.info(f"set TORCHCODEC_NUM_THREADS: {TORCHCODEC_NUM_THREADS}") + video_path = ele["video"] + st = time.time() + decoder = VideoDecoder(video_path, num_ffmpeg_threads=TORCHCODEC_NUM_THREADS) + video_fps = decoder.metadata.average_fps + total_frames = decoder.metadata.num_frames + start_frame, end_frame, total_frames = calculate_video_frame_range( + ele, + total_frames, + video_fps, + ) + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist() + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + video = decoder.get_frames_at(indices=idx).data + logger.info(f"torchcodec: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + + video_metadata = dict( + fps=video_fps, + frames_indices=idx, + total_num_frames=total_frames, + video_backend="torchcodec", + ) + return video, video_metadata, sample_fps + + +VIDEO_READER_BACKENDS = { + "decord": _read_video_decord, + "torchvision": _read_video_torchvision, + "torchcodec": _read_video_torchcodec, +} + +FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) + + +@lru_cache(maxsize=1) +def get_video_reader_backend() -> str: + if FORCE_QWENVL_VIDEO_READER is not None: + video_reader_backend = FORCE_QWENVL_VIDEO_READER + elif is_torchcodec_available(): + video_reader_backend = "torchcodec" + elif is_decord_available(): + video_reader_backend = "decord" + else: + video_reader_backend = "torchvision" + print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr) + return video_reader_backend + + +def fetch_video(ele: Dict[str, Any], image_patch_size: int = 14, return_video_sample_fps: bool = False, + return_video_metadata: bool = False) -> Union[torch.Tensor, List[Image.Image]]: + image_factor = image_patch_size * SPATIAL_MERGE_SIZE + VIDEO_FRAME_MIN_PIXELS = VIDEO_MIN_TOKEN_NUM * image_factor * image_factor + VIDEO_FRAME_MAX_PIXELS = VIDEO_MAX_TOKEN_NUM * image_factor * image_factor + if isinstance(ele["video"], str): + video_reader_backend = get_video_reader_backend() + try: + video, video_metadata, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele) + except Exception as e: + logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}") + video, video_metadata, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele) + else: + # The input is a list of frames + assert isinstance(ele["video"], (list, tuple)) + process_info = ele.copy() + process_info.pop("type", None) + process_info.pop("video", None) + # use ThreadPoolExecutor to parallel process frames + max_workers = min(MAX_NUM_WORKERS_FETCH_VIDEO, len(ele["video"])) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(fetch_image, {"image": video_element, **process_info}, image_factor) + for video_element in ele["video"] + ] + image_list = [future.result() for future in futures] + + nframes = ceil_by_factor(len(image_list), FRAME_FACTOR) + if len(image_list) < nframes: + image_list.extend([image_list[-1]] * (nframes - len(image_list))) + + sample_fps = ele.get("sample_fps", 2.0) + video = torch.stack([ + torch.from_numpy(np.array(image).transpose(2, 0, 1)) + for image in image_list + ]) + + # fake video metadata + raw_fps = process_info.pop("raw_fps", sample_fps) + video_metadata = dict( + fps=raw_fps, + frames_indices=[i for i in range(len(video))], + total_num_frames=(nframes / sample_fps) * raw_fps, + ) + + nframes, _, height, width = video.shape + min_pixels = ele.get("min_pixels", VIDEO_FRAME_MIN_PIXELS) + total_pixels = ele.get("total_pixels", MODEL_SEQ_LEN * image_factor * image_factor * 0.9) + max_pixels = max(min(VIDEO_FRAME_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) + max_pixels_supposed = ele.get("max_pixels", max_pixels) + if max_pixels_supposed > max_pixels: + logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].") + max_pixels = min(max_pixels_supposed, max_pixels) + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=image_factor, + ) + else: + resized_height, resized_width = smart_resize( + height, + width, + factor=image_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + video = transforms.functional.resize( + video, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() + + final_video = (video, video_metadata) if return_video_metadata else video + if return_video_sample_fps: + return final_video, sample_fps + return final_video + + +def extract_vision_info(conversations: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]) -> List[Dict[str, Any]]: + vision_infos = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ( + "image" in ele + or "image_url" in ele + or "video" in ele + or ele.get("type", "text") in ("image", "image_url", "video") + ): + vision_infos.append(ele) + return vision_infos + + +def process_vision_info( + conversations: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], + return_video_kwargs: bool = False, + return_video_metadata: bool = False, + image_patch_size: int = 14, +) -> Tuple[Optional[List[Image.Image]], Optional[List[Union[torch.Tensor, List[Image.Image]]]], Optional[Dict[str, Any]]]: + + vision_infos = extract_vision_info(conversations) + ## Read images or videos + image_inputs = [] + video_inputs = [] + video_sample_fps_list = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info, image_patch_size=image_patch_size)) + elif "video" in vision_info: + video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True, + image_patch_size=image_patch_size, return_video_metadata=return_video_metadata) + video_sample_fps_list.append(video_sample_fps) + video_inputs.append(video_input) + else: + raise ValueError("image, image_url or video should in content.") + if len(image_inputs) == 0: + image_inputs = None + if len(video_inputs) == 0: + video_inputs = None + + video_kwargs = {'do_sample_frames': False} + if not return_video_metadata: # BC for qwen2.5vl + video_kwargs.update({'fps': video_sample_fps_list}) + + if return_video_kwargs: + return image_inputs, video_inputs, video_kwargs + return image_inputs, video_inputs \ No newline at end of file diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 2a576c59b..bea7435b5 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp +from pydantic import ConfigDict from safetensors import safe_open from torch.distributed.checkpoint.state_dict import ( StateDictOptions, @@ -20,6 +21,8 @@ from torch.utils._foreach_utils import ( _device_has_foreach_support, ) +from typing_extensions import NotRequired, TypedDict +from xtuner.v1.model.utils import ModelForwardExtraLogInfo from xtuner.v1.config import FSDPConfig, OptimConfig from xtuner.v1.data_proto.sequence_context import SequenceContext @@ -47,6 +50,20 @@ class TrainStepInfo(DataBatchInfo, BatchForwardInfo): threading_lock = threading.Lock() +class LossLog(TypedDict): + __pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc] + local_loss: float + reduced_llm_loss: float + reduced_balancing_loss: NotRequired[float] + reduced_z_loss: NotRequired[float] + +class OtherLog(TypedDict): + __pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc] + maxvio: NotRequired[float] + step_consumed_tokens: int + step_consumed_img_tokens: NotRequired[float] + extra_info: ModelForwardExtraLogInfo + efficient_attn_ratio: float class CPUThreadTaskCoordinator: def __init__(self, futures, callback): @@ -171,7 +188,7 @@ def data_replicate_size(self) -> int: return self.fsdp_cfg.tp_size @torch.no_grad() - def forward_only(self, seq_ctx: SequenceContext, loss_ctx: LogProbContext): + def forward_only(self, seq_ctx: SequenceContext, loss_ctx: LogProbContext | None = None): output = self.model(seq_ctx=seq_ctx, loss_ctx=loss_ctx) return output diff --git a/xtuner/v1/engine/vision_compose_train_engine.py b/xtuner/v1/engine/vision_compose_train_engine.py new file mode 100644 index 000000000..2173ddabf --- /dev/null +++ b/xtuner/v1/engine/vision_compose_train_engine.py @@ -0,0 +1,228 @@ +from pathlib import Path +from typing import List, cast + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from transformers import AutoProcessor +from xtuner.v1.float8.float8_handler import Float8Handler +from xtuner.v1.model.base import ModelItem +from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel +from xtuner.v1.model.moe.moe import MoEModelOutputs +from xtuner.v1.model.utils import ModelForwardExtraLogInfo +from xtuner.v1.module.router import NoAuxRouterConfig +from xtuner.v1.utils import get_device, get_logger, get_torch_device_module + +from .train_engine import LossLog, OtherLog, TrainEngine + + +logger = get_logger() +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + + +class VisionComposeTrainEngine(TrainEngine): + model_cfg: BaseComposeConfig # type: ignore + model: BaseComposeModel # type: ignore + llm_float8_handler: Float8Handler | None + vision_float8_handler: Float8Handler | None + projector_float8_handler: Float8Handler | None + + def __init__( + self, + model_cfg: BaseComposeConfig, + *args, + **kwargs, + ) -> None: + self._processor = None # only for save + super().__init__(model_cfg, *args, **kwargs) # type: ignore + + def build_model(self) -> BaseComposeModel: # type: ignore + with torch.device("meta"): + model = self.model_cfg.build() + + self.llm_float8_handler = None + self.vision_float8_handler = None + self.projector_float8_handler = None + if self.model_cfg.text_config.float8_cfg is not None and self.model_cfg.text_config.float8_cfg.enable_float8: + self.llm_float8_handler = Float8Handler( + scaling_granularity_gemm=self.model_cfg.text_config.float8_cfg.scaling_granularity_gemm, + scaling_granularity_grouped_gemm=self.model_cfg.text_config.float8_cfg.scaling_granularity_grouped_gemm, + ) + if ( + self.model_cfg.vision_config.float8_cfg is not None + and self.model_cfg.vision_config.float8_cfg.enable_float8 + ): + self.vision_float8_handler = Float8Handler( + scaling_granularity_gemm=self.model_cfg.vision_config.float8_cfg.scaling_granularity_gemm, + scaling_granularity_grouped_gemm=self.model_cfg.vision_config.float8_cfg.scaling_granularity_grouped_gemm, + ) + if ( + self.model_cfg.projector_config.float8_cfg is not None + and self.model_cfg.projector_config.float8_cfg.enable_float8 + ): + self.projector_float8_handler = Float8Handler( + scaling_granularity_gemm=self.model_cfg.projector_config.float8_cfg.scaling_granularity_gemm, + scaling_granularity_grouped_gemm=self.model_cfg.projector_config.float8_cfg.scaling_granularity_grouped_gemm, + ) + + model = model.fully_shard(self.fsdp_cfg) + + if dist.get_rank() == 0: + logger.info(model) + + if self.llm_float8_handler: + self.llm_float8_handler.build_reduce_mesh( + model.language_model, cast(DeviceMesh, model.language_model.fsdp_mesh) + ) + if self.vision_float8_handler: + self.vision_float8_handler.build_reduce_mesh( + model.vision_tower, cast(DeviceMesh, model.vision_tower.fsdp_mesh) + ) + if self.projector_float8_handler: + self.projector_float8_handler.build_reduce_mesh( + model.multi_modal_projector, cast(DeviceMesh, model.multi_modal_projector.fsdp_mesh) + ) + return model + + def from_hf(self, hf_path: str | Path, strict: bool = False): + super().from_hf(hf_path, strict) + self._processor = AutoProcessor.from_pretrained(hf_path, trust_remote_code=True) + + def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): + super().save_hf(hf_dir, save_dtype) + if self._processor is not None: + self._processor.save_pretrained(hf_dir) + + # this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training + def maybe_precompute_float8_dynamic_scale_for_fsdp(self): + if self.llm_float8_handler is not None and self.llm_float8_handler.enabled: + self.llm_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.language_model) + if self.vision_float8_handler is not None and self.vision_float8_handler.enabled: + self.vision_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.vision_tower) + if self.projector_float8_handler is not None and self.projector_float8_handler.enabled: + self.projector_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.multi_modal_projector) + + def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]: + """Perform a training step with the given data batches and mesh. + + Args: + data_batches (List[Dict]): The input data batches for the training step. + """ + self.maybe_precompute_float8_dynamic_scale_for_fsdp() + + loss_log: LossLog = {} # type: ignore[typeddict-item] + other_log: OtherLog = {} # type: ignore[typeddict-item] + intra_layer_micro_batch = self.intra_layer_micro_batch + assert len(data_batches) % intra_layer_micro_batch == 0, ( + f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}" + ) + iters_per_step = self.grad_accumulation_steps(len(data_batches)) + + if self._count == 0: + logger.info(f"grad_accumulation_steps: {iters_per_step}") + self._count += 1 + + moe_need_update_bias = ( + isinstance(getattr(self.model_cfg.text_config, "router", None), NoAuxRouterConfig) + and self.model_cfg.text_config.router.router_bias_update_speed > 0 + ) + moe_need_log_maxvio = getattr(self.model_cfg.text_config, "router", None) is not None + if moe_need_log_maxvio: + tokens_per_expert_global_for_bias = torch.zeros( + self.model_cfg.text_config.num_hidden_layers - self.model_cfg.text_config.first_k_dense_replace, + self.model_cfg.text_config.n_routed_experts, + dtype=torch.int64, + device=DEVICE, + ) + + step_loss = torch.tensor(0.0, device=DEVICE) + step_llm_loss = torch.tensor(0.0, device=DEVICE) + step_balancing_loss: torch.Tensor | None = None + step_z_loss: torch.Tensor | None = None + step_consumed_tokens = torch.tensor(0, device=DEVICE) + efficient_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long) + total_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long) + + train_engine_extra_info = ModelForwardExtraLogInfo() + step_consumed_img_tokens = 0.0 + for i in range(0, len(data_batches), intra_layer_micro_batch): + data_batch = data_batches[i : i + intra_layer_micro_batch] + seq_ctx_list = [] + loss_ctx_list = [] + for data in data_batch: + seq_ctx = data["seq_ctx"] + loss_ctx = data["loss_ctx"] + seq_ctx_list.append(seq_ctx) + loss_ctx_list.append(loss_ctx) + step_consumed_tokens += seq_ctx.mask.sum() + + if seq_ctx.num_img_tokens is not None: + step_consumed_img_tokens += sum(seq_ctx.num_img_tokens) + if seq_ctx.sequence_parallel_mesh: + step_consumed_img_tokens /= seq_ctx.sequence_parallel_mesh.size() + + num_tokens = seq_ctx.cu_seq_lens_k[1:] - seq_ctx.cu_seq_lens_k[:-1] + efficient_forward_tokens += (num_tokens**2).sum() + total_forward_tokens += (num_tokens.sum()) ** 2 + + # todo: support intra_layer_micro_batch + output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0]) + # llm loss has been global averaged + llm_loss = output["loss"] + step_llm_loss += llm_loss.detach().clone() + + loss = llm_loss + if "extra_info" in output: + train_engine_extra_info.append(output["extra_info"]) + + if "balancing_loss" in output: + output = cast(MoEModelOutputs, output) + loss = loss + output["balancing_loss"] / iters_per_step + step_balancing_loss = ( + output["balancing_loss"] + if step_balancing_loss is None + else step_balancing_loss + output["balancing_loss"] + ) + if "z_loss" in output: + output = cast(MoEModelOutputs, output) + loss = loss + output["z_loss"] / iters_per_step + step_z_loss = output["z_loss"] if step_z_loss is None else step_z_loss + output["z_loss"] + + if moe_need_log_maxvio: + output = cast(MoEModelOutputs, output) + assert "tokens_per_expert_global" in output, "tokens_per_expert_global is required for bias update." + tokens_per_expert_global_for_bias += output["tokens_per_expert_global"] + + del output + loss.backward() + step_loss += loss.detach().clone() + + if moe_need_log_maxvio: + avg_count_load = tokens_per_expert_global_for_bias.float().mean(1) + max_load_i, _ = torch.max(tokens_per_expert_global_for_bias, dim=1) + maxvio_all_layers = (max_load_i - avg_count_load) / avg_count_load + maxvio = maxvio_all_layers.mean() + if moe_need_update_bias: + self.model.language_model.update_bias(tokens_per_expert_global_for_bias, avg_count_load) # type: ignore + other_log["maxvio"] = maxvio.item() + + reduced_llm_loss = step_llm_loss + dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size())) + + loss_log["local_loss"] = step_loss.item() + loss_log["reduced_llm_loss"] = reduced_llm_loss.item() + if step_balancing_loss is not None: + reduced_balancing_loss = step_balancing_loss + dist.all_reduce(reduced_balancing_loss.div_(dist.get_world_size())) + loss_log["reduced_balancing_loss"] = reduced_balancing_loss.item() + if step_z_loss is not None: + reduced_z_loss = step_z_loss + dist.all_reduce(reduced_z_loss.div_(dist.get_world_size())) + loss_log["reduced_z_loss"] = reduced_z_loss.item() + other_log["step_consumed_tokens"] = cast(int, step_consumed_tokens.item()) + other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment] + other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item() + other_log["step_consumed_img_tokens"] = step_consumed_img_tokens + return loss_log, other_log diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index 8ae297b6f..36855eb81 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod -from typing import Annotated, Any, Literal, TypeVar +from typing import Annotated, Any, Generic, Literal, TypeVar import torch import torch.distributed as dist @@ -128,8 +128,9 @@ def build(self, *args, **kwargs) -> "BaseLossContext": # NOTE: Self type for BaseLossContext subclasses (F-bounded polymorphism) _BaseLossContextT = TypeVar("_BaseLossContextT", bound="BaseLossContext") +LossContextInputItem = TypeVar("LossContextInputItem") -class BaseLossContext(nn.Module, ABC): +class BaseLossContext(nn.Module, ABC, Generic[LossContextInputItem]): def __init__(self, loss_cfg: BaseLossConfig, loss_kwargs: BaseLossKwargs): # LossContext需要负责几个功能: # 1. sequence parallel, 借助LossKwargs.sp_split 实现 diff --git a/xtuner/v1/rl/dpo/__init__.py b/xtuner/v1/rl/dpo/__init__.py new file mode 100644 index 000000000..5db1f763f --- /dev/null +++ b/xtuner/v1/rl/dpo/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .loss import DPOLossConfig, DPOLossContext, DPOLossContextInputItem + + +__all__ = [ + "DPOLossConfig", + "DPOLossContext", + "DPOLossContextInputItem", +] diff --git a/xtuner/v1/rl/dpo/loss.py b/xtuner/v1/rl/dpo/loss.py new file mode 100644 index 000000000..d1af401b2 --- /dev/null +++ b/xtuner/v1/rl/dpo/loss.py @@ -0,0 +1,629 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +DPO (Direct Preference Optimization) Loss Implementation for XTuner v1. + +This module supports multiple DPO variants and loss combinations: +- sigmoid: Standard DPO loss for preference learning +- bco_pair: Binary Classifier Optimization for absolute quality +- sft: Supervised Fine-Tuning loss to maintain generation quality +- hinge, ipo, robust, nca_pair, sppo_hard: Other DPO variants + +Reference: +- DPO: https://arxiv.org/abs/2305.18290 +- BCO: https://arxiv.org/abs/2404.04656 +- IPO: https://arxiv.org/abs/2310.12036 +""" +from typing import Any, Literal, cast + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from pydantic import BaseModel, ConfigDict, Field +from torch.distributed.device_mesh import DeviceMesh +from typing_extensions import Self + +from xtuner.v1.loss import BaseLossContext, BaseLossKwargs +from xtuner.v1.loss.base_loss_ctx import BaseLossConfig +from xtuner.v1.utils import get_logger + +from ..utils import gather_logprobs, sp_split + + +logger = get_logger() + + +class RunningMoments: + """Running mean tracker for BCO loss.""" + + def __init__(self): + self.mean = 0.0 + self.count = 0 + + def update(self, value: torch.Tensor): + """Update running mean with new value.""" + value = value.detach().float().mean().item() + self.count += 1 + self.mean = self.mean + (value - self.mean) / self.count + + +class DPOLossContextInputItem(BaseModel): + """Input item for DPO loss context. + + This class handles preference pair data (chosen/rejected responses). + + Args: + chosen_shifted_labels (torch.Tensor): Shifted labels for chosen responses. + rejected_shifted_labels (torch.Tensor): Shifted labels for rejected responses. + ref_chosen_logprobs (torch.Tensor | None): Reference model log probs for chosen. + ref_rejected_logprobs (torch.Tensor | None): Reference model log probs for rejected. + """ + + model_config = ConfigDict( + title="DPOLossContextInputItem", extra="forbid", arbitrary_types_allowed=True + ) + chosen_shifted_labels: torch.Tensor + rejected_shifted_labels: torch.Tensor + ref_chosen_logprobs: torch.Tensor | None = None + ref_rejected_logprobs: torch.Tensor | None = None + + def sp_split(self, sp_mesh: DeviceMesh) -> Self: + chosen_shifted_labels = sp_split( + self.chosen_shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100 + ) + rejected_shifted_labels = sp_split( + self.rejected_shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100 + ) + return type(self)( + chosen_shifted_labels=chosen_shifted_labels, + rejected_shifted_labels=rejected_shifted_labels, + ref_chosen_logprobs=self.ref_chosen_logprobs, + ref_rejected_logprobs=self.ref_rejected_logprobs, + ) + + def to(self, device: torch.device | str) -> Self: + self.chosen_shifted_labels = self.chosen_shifted_labels.to(device) + self.rejected_shifted_labels = self.rejected_shifted_labels.to(device) + if self.ref_chosen_logprobs is not None: + self.ref_chosen_logprobs = self.ref_chosen_logprobs.to(device) + if self.ref_rejected_logprobs is not None: + self.ref_rejected_logprobs = self.ref_rejected_logprobs.to(device) + return self + + +class DPOLossKwargs(BaseLossKwargs): + """Keyword arguments for DPO loss computation. + + Args: + shifted_labels (torch.Tensor): Concatenated shifted labels [chosen; rejected]. + chosen_mask (torch.Tensor): Mask for chosen tokens. + rejected_mask (torch.Tensor): Mask for rejected tokens. + ref_chosen_logprobs (torch.Tensor | None): Reference log probs for chosen. + ref_rejected_logprobs (torch.Tensor | None): Reference log probs for rejected. + loss_weight (torch.Tensor): Weight for each token in loss computation. + num_chosen_tokens (int): Number of valid tokens in chosen responses. + num_rejected_tokens (int): Number of valid tokens in rejected responses. + """ + + chosen_mask: torch.Tensor + rejected_mask: torch.Tensor + ref_chosen_logprobs: torch.Tensor | None = None + ref_rejected_logprobs: torch.Tensor | None = None + loss_weight: torch.Tensor + num_chosen_tokens: int = 0 + num_rejected_tokens: int = 0 + + +class DPOLossConfig(BaseLossConfig): + """Configuration for DPO (Direct Preference Optimization) loss. + + DPO can combine multiple loss types with configurable weights: + - sigmoid: Standard DPO loss + - bco_pair: Binary Classifier Optimization loss + - sft: Supervised Fine-Tuning loss on chosen responses + + Args: + loss_types (list[str]): List of loss types to combine. + Supported: ["sigmoid", "bco_pair", "sft", "hinge", "ipo", "robust"] + loss_weights (list[float] | None): Weights for each loss type. + If None, all weights are set to 1.0. + beta (float): Temperature parameter for DPO/BCO losses. Defaults to 0.1. + label_smoothing (float): Label smoothing for DPO loss. Defaults to 0.0. + reference_free (bool): Whether to use reference-free mode. Defaults to False. + use_average_log_prob (bool): Whether to normalize log probs by sequence length. + Defaults to False. + + Example: + >>> # Standard DPO configuration + >>> config = DPOLossConfig( + ... loss_types=["sigmoid"], + ... loss_weights=[1.0], + ... beta=0.1, + ... ) + >>> # MPO-style configuration (multiple loss types) + >>> config = DPOLossConfig( + ... loss_types=["sigmoid", "bco_pair", "sft"], + ... loss_weights=[0.8, 0.2, 1.0], + ... beta=0.1, + ... ) + """ + + loss_types: list[Literal["sigmoid", "bco_pair", "sft", "hinge", "ipo", "robust", "nca_pair", "sppo_hard"]] = Field( + default=["sigmoid"], + description="List of loss types to combine", + ) + loss_weights: list[float] | None = Field( + default=None, + description="Weights for each loss type. If None, all weights are 1.0", + ) + beta: float = Field( + default=0.1, + description="Temperature parameter for DPO loss", + ) + label_smoothing: float = Field( + default=0.0, + description="Label smoothing for robust DPO variants", + ) + reference_free: bool = Field( + default=False, + description="Whether to use reference-free mode", + ) + use_average_log_prob: bool = Field( + default=False, + description="Whether to normalize log probs by sequence length (used in IPO)", + ) + + @property + def loss_ctx_cls(self) -> type["DPOLossContext"]: + return DPOLossContext + + def model_post_init(self, __context: Any) -> None: + """Validate and set default loss weights.""" + if self.loss_weights is None: + self.loss_weights = [1.0] * len(self.loss_types) + elif len(self.loss_weights) != len(self.loss_types): + raise ValueError( + f"Length of loss_weights ({len(self.loss_weights)}) must match " + f"length of loss_types ({len(self.loss_types)})" + ) + + +class DPOLossContext(BaseLossContext[DPOLossContextInputItem]): + """DPO loss context for preference alignment training. + + This class implements the loss computation for Direct Preference Optimization + and its variants (including MPO-style multi-loss combinations). + """ + + loss_cfg: DPOLossConfig + loss_kwargs: DPOLossKwargs + + def __init__(self, loss_cfg: DPOLossConfig, loss_kwargs: DPOLossKwargs): + super().__init__(loss_cfg, loss_kwargs) + # Running moments for BCO loss + self.running = RunningMoments() + + @staticmethod + def build_batches( # type: ignore[override] + loss_ctx_list: list["DPOLossContext"], *args: Any, **kwargs: Any + ) -> list["DPOLossContext"]: + del args, kwargs + batch_size = len(loss_ctx_list) + for loss_ctx in loss_ctx_list: + loss_ctx._batch_size = batch_size + return loss_ctx_list + + @classmethod + def build_batches_loss_kwargs( + cls, + data_batches: list[DPOLossContextInputItem], + loss_cfg: DPOLossConfig, + cu_seq_lens_list: list[torch.Tensor] | None = None, + sp_mesh: DeviceMesh | None = None, + ) -> list[DPOLossKwargs]: + """Build loss kwargs for each batch. + + This method processes the input data batches and prepares the loss kwargs + for the DPO loss computation. + """ + batches_loss_kwargs = [] + + # Compute global token count for loss normalization + total_chosen_tokens = 0 + total_rejected_tokens = 0 + for item in data_batches: + total_chosen_tokens += (item.chosen_shifted_labels != loss_cfg.ignore_idx).sum() + total_rejected_tokens += (item.rejected_shifted_labels != loss_cfg.ignore_idx).sum() + + total_chosen_tokens = cast(torch.Tensor, total_chosen_tokens) + total_rejected_tokens = cast(torch.Tensor, total_rejected_tokens) + + if dist.is_initialized(): + dist.all_reduce(total_chosen_tokens, op=dist.ReduceOp.SUM) + dist.all_reduce(total_rejected_tokens, op=dist.ReduceOp.SUM) + + global_tokens = total_chosen_tokens + total_rejected_tokens + if global_tokens == 0: + logger.warning( + "Global tokens is 0, which may lead to division by zero in loss weight calculation." + ) + global_tokens = global_tokens + 1 + + for item in data_batches: + # Concatenate chosen and rejected labels + # Shape: [1, chosen_len + rejected_len] + shifted_labels = torch.cat( + [item.chosen_shifted_labels, item.rejected_shifted_labels], dim=1 + ) + + # Create masks + chosen_len = item.chosen_shifted_labels.shape[1] + rejected_len = item.rejected_shifted_labels.shape[1] + total_len = chosen_len + rejected_len + + chosen_mask = torch.zeros(1, total_len, device=shifted_labels.device, dtype=torch.bool) + rejected_mask = torch.zeros(1, total_len, device=shifted_labels.device, dtype=torch.bool) + chosen_mask[:, :chosen_len] = item.chosen_shifted_labels != loss_cfg.ignore_idx + rejected_mask[:, chosen_len:] = item.rejected_shifted_labels != loss_cfg.ignore_idx + + # Compute loss weight + loss_weight = torch.zeros_like(shifted_labels, dtype=torch.float32) + loss_weight[chosen_mask] = 1.0 / global_tokens.float() + loss_weight[rejected_mask] = 1.0 / global_tokens.float() + + num_chosen_tokens = chosen_mask.sum().item() + num_rejected_tokens = rejected_mask.sum().item() + + loss_kwargs = DPOLossKwargs( + shifted_labels=shifted_labels, + chosen_mask=chosen_mask, + rejected_mask=rejected_mask, + ref_chosen_logprobs=item.ref_chosen_logprobs, + ref_rejected_logprobs=item.ref_rejected_logprobs, + loss_weight=loss_weight, + num_chosen_tokens=int(num_chosen_tokens), + num_rejected_tokens=int(num_rejected_tokens), + ) + batches_loss_kwargs.append(loss_kwargs) + + return batches_loss_kwargs + + def _compute_logprobs( + self, + logits: torch.Tensor, + labels: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + """Compute log probabilities for the given labels. + + Args: + logits: Model output logits [batch, seq_len, vocab_size] + labels: Target labels [batch, seq_len] + mask: Valid token mask [batch, seq_len] + + Returns: + Sum of log probabilities for valid tokens [batch] + """ + # Gather log probs for target tokens + logprobs = gather_logprobs(logits, labels) # [batch, seq_len] + logprobs = logprobs * mask.float() + + if self.loss_cfg.use_average_log_prob: + # Normalize by sequence length (used in IPO) + return logprobs.sum(dim=-1) / mask.sum(dim=-1).clamp(min=1) + else: + return logprobs.sum(dim=-1) + + def _dpo_loss_sigmoid( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ) -> torch.Tensor: + """Compute standard DPO sigmoid loss.""" + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.loss_cfg.reference_free: + ref_logratios = torch.zeros_like(pi_logratios) + else: + ref_logratios = ref_chosen_logps - ref_rejected_logps + + logits = pi_logratios - ref_logratios + + loss = ( + -F.logsigmoid(self.loss_cfg.beta * logits) * (1 - self.loss_cfg.label_smoothing) + - F.logsigmoid(-self.loss_cfg.beta * logits) * self.loss_cfg.label_smoothing + ) + return loss + + def _dpo_loss_robust( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ) -> torch.Tensor: + """Compute robust DPO loss.""" + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.loss_cfg.reference_free: + ref_logratios = torch.zeros_like(pi_logratios) + else: + ref_logratios = ref_chosen_logps - ref_rejected_logps + + logits = pi_logratios - ref_logratios + + loss = ( + -F.logsigmoid(self.loss_cfg.beta * logits) * (1 - self.loss_cfg.label_smoothing) + + F.logsigmoid(-self.loss_cfg.beta * logits) * self.loss_cfg.label_smoothing + ) / (1 - 2 * self.loss_cfg.label_smoothing + 1e-8) + return loss + + def _dpo_loss_hinge( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ) -> torch.Tensor: + """Compute hinge loss (SLiC style).""" + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.loss_cfg.reference_free: + ref_logratios = torch.zeros_like(pi_logratios) + else: + ref_logratios = ref_chosen_logps - ref_rejected_logps + + logits = pi_logratios - ref_logratios + return torch.relu(1 - self.loss_cfg.beta * logits) + + def _dpo_loss_ipo( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ) -> torch.Tensor: + """Compute IPO loss.""" + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.loss_cfg.reference_free: + ref_logratios = torch.zeros_like(pi_logratios) + else: + ref_logratios = ref_chosen_logps - ref_rejected_logps + + logits = pi_logratios - ref_logratios + return (logits - 1 / (2 * self.loss_cfg.beta)) ** 2 + + def _bco_pair_loss( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ) -> torch.Tensor: + """Compute BCO (Binary Classifier Optimization) pairwise loss. + + BCO optimizes for absolute quality rather than relative preference. + """ + chosen_logratios = policy_chosen_logps - ref_chosen_logps + rejected_logratios = policy_rejected_logps - ref_rejected_logps + + chosen_rewards = self.loss_cfg.beta * chosen_logratios + rejected_rewards = self.loss_cfg.beta * rejected_logratios + + # Update running mean + rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() + self.running.update(rewards) + delta = self.running.mean + + loss = -F.logsigmoid(chosen_rewards - delta) - F.logsigmoid(-(rejected_rewards - delta)) + return loss + + def _nca_pair_loss( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ) -> torch.Tensor: + """Compute NCA pairwise loss.""" + chosen_rewards = (policy_chosen_logps - ref_chosen_logps) * self.loss_cfg.beta + rejected_rewards = (policy_rejected_logps - ref_rejected_logps) * self.loss_cfg.beta + + loss = ( + -F.logsigmoid(chosen_rewards) + - 0.5 * F.logsigmoid(-chosen_rewards) + - 0.5 * F.logsigmoid(-rejected_rewards) + ) + return loss + + def _sppo_hard_loss( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ref_chosen_logps: torch.Tensor, + ref_rejected_logps: torch.Tensor, + ) -> torch.Tensor: + """Compute SPPO hard loss.""" + a = policy_chosen_logps - ref_chosen_logps + b = policy_rejected_logps - ref_rejected_logps + return (a - 0.5 / self.loss_cfg.beta) ** 2 + (b + 0.5 / self.loss_cfg.beta) ** 2 + + def _sft_loss( + self, + logits: torch.Tensor, + labels: torch.Tensor, + mask: torch.Tensor, + loss_weight: torch.Tensor, + ) -> torch.Tensor: + """Compute SFT (Supervised Fine-Tuning) loss on chosen responses. + + This maintains generation quality during preference optimization. + """ + # Only compute loss on chosen tokens (mask is for chosen part) + sft_logits = logits[:, :mask.shape[1]] + sft_labels = labels[:, :mask.shape[1]] + sft_loss_weight = loss_weight[:, :mask.shape[1]] + + # Cross entropy loss + vocab_size = sft_logits.shape[-1] + sft_loss = F.cross_entropy( + sft_logits.reshape(-1, vocab_size), + sft_labels.reshape(-1), + reduction="none", + ignore_index=self.loss_cfg.ignore_idx, + ).reshape(sft_labels.shape) + + # Apply loss weight and sum + return (sft_loss * sft_loss_weight * mask.float()).sum() + + def loss_fn( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None, + loss_kwargs: DPOLossKwargs, + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + """Compute combined DPO loss. + + This method computes all configured loss types and combines them + with their respective weights. + """ + # Compute logits + logits = F.linear(hidden_states, head_weight, head_bias) + logits = logits.float() + + shifted_labels = loss_kwargs.shifted_labels + chosen_mask = loss_kwargs.chosen_mask + rejected_mask = loss_kwargs.rejected_mask + loss_weight = loss_kwargs.loss_weight + + # Compute per-token log probabilities + all_logprobs = gather_logprobs(logits, shifted_labels) + + # Split into chosen and rejected parts + chosen_len = chosen_mask.shape[1] + if chosen_mask.any(): + # Get chosen logprobs - using the chosen part of the sequence + chosen_logprobs = all_logprobs * chosen_mask.float() + if self.loss_cfg.use_average_log_prob: + policy_chosen_logps = chosen_logprobs.sum(dim=-1) / chosen_mask.sum(dim=-1).clamp(min=1) + else: + policy_chosen_logps = chosen_logprobs.sum(dim=-1) + else: + policy_chosen_logps = torch.zeros(1, device=logits.device) + + if rejected_mask.any(): + # Get rejected logprobs - using the rejected part of the sequence + rejected_logprobs = all_logprobs * rejected_mask.float() + if self.loss_cfg.use_average_log_prob: + policy_rejected_logps = rejected_logprobs.sum(dim=-1) / rejected_mask.sum(dim=-1).clamp(min=1) + else: + policy_rejected_logps = rejected_logprobs.sum(dim=-1) + else: + policy_rejected_logps = torch.zeros(1, device=logits.device) + + # Get reference logprobs + ref_chosen_logps = loss_kwargs.ref_chosen_logprobs + ref_rejected_logps = loss_kwargs.ref_rejected_logprobs + + if ref_chosen_logps is None: + ref_chosen_logps = torch.zeros_like(policy_chosen_logps) + if ref_rejected_logps is None: + ref_rejected_logps = torch.zeros_like(policy_rejected_logps) + + # Compute combined loss + total_loss = torch.tensor(0.0, device=logits.device, dtype=logits.dtype) + extra_info = {} + + for loss_type, weight in zip(self.loss_cfg.loss_types, self.loss_cfg.loss_weights): + if loss_type == "sigmoid": + loss = self._dpo_loss_sigmoid( + policy_chosen_logps, policy_rejected_logps, + ref_chosen_logps, ref_rejected_logps + ) + loss = loss.mean() * weight + extra_info["dpo_sigmoid_loss"] = loss.detach() + + elif loss_type == "robust": + loss = self._dpo_loss_robust( + policy_chosen_logps, policy_rejected_logps, + ref_chosen_logps, ref_rejected_logps + ) + loss = loss.mean() * weight + extra_info["dpo_robust_loss"] = loss.detach() + + elif loss_type == "hinge": + loss = self._dpo_loss_hinge( + policy_chosen_logps, policy_rejected_logps, + ref_chosen_logps, ref_rejected_logps + ) + loss = loss.mean() * weight + extra_info["dpo_hinge_loss"] = loss.detach() + + elif loss_type == "ipo": + loss = self._dpo_loss_ipo( + policy_chosen_logps, policy_rejected_logps, + ref_chosen_logps, ref_rejected_logps + ) + loss = loss.mean() * weight + extra_info["dpo_ipo_loss"] = loss.detach() + + elif loss_type == "bco_pair": + loss = self._bco_pair_loss( + policy_chosen_logps, policy_rejected_logps, + ref_chosen_logps, ref_rejected_logps + ) + loss = loss.mean() * weight + extra_info["bco_pair_loss"] = loss.detach() + + elif loss_type == "nca_pair": + loss = self._nca_pair_loss( + policy_chosen_logps, policy_rejected_logps, + ref_chosen_logps, ref_rejected_logps + ) + loss = loss.mean() * weight + extra_info["nca_pair_loss"] = loss.detach() + + elif loss_type == "sppo_hard": + loss = self._sppo_hard_loss( + policy_chosen_logps, policy_rejected_logps, + ref_chosen_logps, ref_rejected_logps + ) + loss = loss.mean() * weight + extra_info["sppo_hard_loss"] = loss.detach() + + elif loss_type == "sft": + # SFT loss only on chosen responses + chosen_len_actual = chosen_mask.shape[1] + # Note: We need to handle the chosen part separately + chosen_labels = shifted_labels[:, :chosen_len_actual] + chosen_logits = logits[:, :chosen_len_actual] + chosen_loss_weight = loss_weight[:, :chosen_len_actual] + + loss = self._sft_loss( + chosen_logits, + chosen_labels, + chosen_mask[:, :chosen_len_actual], + chosen_loss_weight, + ) + loss = loss * weight + extra_info["sft_loss"] = loss.detach() + + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + total_loss = total_loss + loss + + # Compute rewards for logging + chosen_rewards = self.loss_cfg.beta * (policy_chosen_logps - ref_chosen_logps).detach() + rejected_rewards = self.loss_cfg.beta * (policy_rejected_logps - ref_rejected_logps).detach() + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + extra_info.update({ + "chosen_rewards": chosen_rewards.mean(), + "rejected_rewards": rejected_rewards.mean(), + "reward_margin": (chosen_rewards - rejected_rewards).mean(), + "reward_accuracy": reward_accuracies.mean(), + "policy_chosen_logps": policy_chosen_logps.mean().detach(), + "policy_rejected_logps": policy_rejected_logps.mean().detach(), + }) + + return total_loss, (logits, extra_info) diff --git a/xtuner/v1/rl/utils.py b/xtuner/v1/rl/utils.py index 8da313a36..90ac52ff3 100644 --- a/xtuner/v1/rl/utils.py +++ b/xtuner/v1/rl/utils.py @@ -1,11 +1,23 @@ import atexit import signal import subprocess +from typing import Any import torch.nn.functional as F +from torch.distributed.device_mesh import DeviceMesh +from xtuner.v1.data_proto.utils import pad_to_multiple_of, split_for_sequence_parallel from xtuner.v1.utils.logger import get_logger +def sp_split( + tensor, + sp_mesh: DeviceMesh, + split_dim: int, + padding_value: Any, +): + tensor = pad_to_multiple_of(tensor, padding_value, sp_mesh.size(), split_dim) + tensor = split_for_sequence_parallel(tensor, dim=split_dim, sp_mesh=sp_mesh) + return tensor def gather_logprobs(logits, shifted_labels): logprobs = F.log_softmax(logits, dim=-1) diff --git a/xtuner/v1/train/cli/dpo.py b/xtuner/v1/train/cli/dpo.py new file mode 100644 index 000000000..82f5b22c2 --- /dev/null +++ b/xtuner/v1/train/cli/dpo.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +CLI entry point for DPO (Direct Preference Optimization) training. + +This module provides a command-line interface for DPO training, similar to +the SFT and RL training interfaces. + +Usage: + torchrun --nproc_per_node=8 xtuner/v1/train/cli/dpo.py --config config.py +""" + +from pathlib import Path +from typing import Annotated + +import torch.distributed as dist +from cyclopts import App, Parameter +from cyclopts.group import Group + +from xtuner.v1.train.dpo_trainer import DPOTrainer +from xtuner.v1.utils import Config + +import torch._dynamo +torch._dynamo.config.disable = True + + + +app = App( + help="XTuner's entry point for DPO (Direct Preference Optimization) training.", +) + + +@app.default() +def main( + *, + config: Annotated[Path, Parameter(group=Group("config-path", sort_key=0))], +): + """Run DPO training with the given configuration file. + + The config file should export a 'trainer' variable of type DPOTrainerConfig. + + Example config file: + from xtuner.v1.train.dpo_trainer import DPOTrainerConfig + trainer = DPOTrainerConfig(...) + + Args: + config: Path to the DPO training configuration file. + """ + from xtuner.v1.train.dpo_trainer import DPOTrainerConfig + + # Load configuration + cfg = Config.fromfile(config) + # The config file should export a 'trainer' key (same pattern as rl.py) + if "trainer" in cfg: + trainer_cfg = cfg["trainer"] + elif "config" in cfg: + trainer_cfg = cfg["config"] + elif "trainer_config" in cfg: + trainer_cfg = cfg["trainer_config"] + else: + raise ValueError( + "Config file must export a 'trainer' variable of type DPOTrainerConfig. " + "Example: trainer = DPOTrainerConfig(...)" + ) + # Validate config type + if not isinstance(trainer_cfg, DPOTrainerConfig): + raise TypeError( + f"Expected DPOTrainerConfig, got {type(trainer_cfg).__name__}. " + "Please ensure your config file exports: trainer = DPOTrainerConfig(...)" + ) + + # Create trainer from config + trainer = DPOTrainer.from_config(trainer_cfg) + + # Run training + trainer.fit() + + # Clean up + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + app(exit_on_error=False) diff --git a/xtuner/v1/train/dpo_trainer.py b/xtuner/v1/train/dpo_trainer.py new file mode 100644 index 000000000..4d8eb22a4 --- /dev/null +++ b/xtuner/v1/train/dpo_trainer.py @@ -0,0 +1,1059 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +DPO (Direct Preference Optimization) Trainer for XTuner v1. + +This trainer supports offline preference learning with multiple loss types: +- sigmoid: Standard DPO loss +- bco_pair: Binary Classifier Optimization +- sft: Supervised Fine-Tuning loss +- hinge, ipo, robust: Other DPO variants + +For MPO (Mixed Preference Optimization), use loss_types=["sigmoid", "bco_pair", "sft"] +with appropriate loss_weights. +""" +import json +import os +import gc +from copy import deepcopy +from datetime import datetime +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, cast + +import torch +import torch.distributed as dist +from mmengine import load +from mmengine.dist import get_rank, init_dist +from mmengine.runner import set_random_seed +from pydantic import BaseModel, ConfigDict, model_validator +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +from typing_extensions import Self + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from xtuner.v1.config import FSDPConfig, OptimConfig, LRConfig +from xtuner.v1.data_proto.sequence_context import SequenceContext +from xtuner.v1.data_proto.utils import split_for_sequence_parallel, pad_to_multiple_of +from xtuner.v1.datasets import qwen3_vl_dpo_collator, DPOColateItem +from xtuner.v1.datasets.config import DataloaderConfig +from xtuner.v1.engine.train_engine import TrainEngine +from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.rl.dpo import DPOLossConfig, DPOLossContext, DPOLossContextInputItem +from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, record_git_info +from xtuner.v1.utils.device import get_device, get_torch_device_module +from torch.distributed.device_mesh import init_device_mesh + +from .trainer import ExpHistory, ExpInfo, GitInfo, XTunerMeta + + +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() +logger = get_logger() + + +class DPOTrainerConfig(BaseModel): + """Configuration for DPO Trainer. + + Args: + model_cfg: Model architecture configuration. + optim_cfg: Optimizer configuration. + loss_cfg: DPO loss configuration. + lr_cfg: Learning rate scheduler configuration. + fsdp_cfg: FSDP configuration for distributed training. + load_from: Path to load model from. + ref_load_from: Path to load reference model from (optional). + tokenizer_path: Path to tokenizer. + work_dir: Working directory for outputs. + total_epochs: Total number of training epochs. + global_batch_size: Global batch size across all devices. + per_device_batch_size: Batch size per device. + gradient_accumulation_steps: Number of gradient accumulation steps. + max_length: Maximum sequence length. + save_interval: Interval for saving checkpoints. + eval_interval: Interval for evaluation. + log_interval: Interval for logging. + seed: Random seed. + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + model_cfg: TransformerConfig | BaseComposeConfig + optim_cfg: OptimConfig + loss_cfg: DPOLossConfig + lr_cfg: LRConfig + fsdp_cfg: FSDPConfig + load_from: str | Path + ref_load_from: str | Path | None = None + tokenizer_path: str | Path + work_dir: Path | str = Path("./work_dir") + log_dir: Path | str | None = None + total_epochs: int = 3 + global_batch_size: int = 128 + per_device_batch_size: int = 4 + gradient_accumulation_steps: int = 1 + max_length: int = 4096 + save_interval: int | None = None + eval_interval: int | None = None + log_interval: int = 10 + seed: int = 42 + auto_resume: bool = False + freeze_ref_model: bool = True + num_workers: int = 4 + use_vlm_collator: bool = False # Use qwen3_vl_dpo_collator for VLM + sp_size: int = 1 # Sequence parallel size + + # Dataloader configuration (xtuner v1 style, optional) + dataloader_cfg: DataloaderConfig | None = None + + @model_validator(mode="after") + def _convert_paths(self): + if isinstance(self.work_dir, str): + self.work_dir = Path(self.work_dir) + if isinstance(self.log_dir, str): + self.log_dir = Path(self.log_dir) + return self + + +class DPOTrainer: + """DPO Trainer for offline preference learning. + + This trainer implements Direct Preference Optimization and its variants + for offline preference alignment. It supports: + - Standard DPO (sigmoid loss) + - MPO (Mixed Preference Optimization with multiple loss types) + - IPO, BCO, and other variants + + Args: + config: DPOTrainerConfig instance. + train_dataset: Training dataset with preference pairs. + eval_dataset: Evaluation dataset (optional). + collate_fn: Custom collate function (optional). + + Example: + >>> from functools import partial + >>> from xtuner.v1.datasets import ( + ... PreferenceJsonlDataset, + ... PreferenceTokenizeFunction, + ... ) + >>> + >>> # Prepare dataset + >>> tokenize_fn = PreferenceTokenizeFunction(tokenizer, max_length=4096) + >>> dataset = PreferenceJsonlDataset("data.jsonl", tokenize_fn=tokenize_fn) + >>> collator = partial(dpo_llm_collator, pack_max_length=4096, padding_token_idx=0) + >>> + >>> # Create trainer + >>> config = DPOTrainerConfig( + ... model_cfg=model_cfg, + ... optim_cfg=AdamWConfig(lr=5e-7), + ... loss_cfg=DPOLossConfig( + ... loss_types=["sigmoid", "bco_pair", "sft"], + ... loss_weights=[0.8, 0.2, 1.0], + ... ), + ... lr_cfg=LRConfig(lr_type="cosine"), + ... fsdp_cfg=FSDPConfig(), + ... load_from="Qwen/Qwen3-VL-8B-Instruct", + ... tokenizer_path="Qwen/Qwen3-VL-8B-Instruct", + ... ) + >>> trainer = DPOTrainer(config, dataset, collate_fn=collator) + >>> trainer.fit() + """ + + META_PATH = ".xtuner_dpo" + + def __init__( + self, + config: DPOTrainerConfig, + dataloader_cfg: DataloaderConfig, + ): + """Initialize DPO Trainer. + + This follows the same pattern as the SFT Trainer - accepting dataloader_cfg + and building the dataloader internally using dataloader_cfg.build(). + + Args: + config: DPO trainer configuration. + dataloader_cfg: Dataloader configuration containing dataset configs. + """ + self.config = config + self._dataloader_config = dataloader_cfg + + self._cur_step = 0 + self._cur_epoch = 0 + + # Initialize distributed if needed + if not dist.is_initialized(): + init_dist("pytorch") + + # Initialize + self._set_deterministic() + self._set_random_seed(config.seed) + + # Setup work directory + if isinstance(config.work_dir, str): + config.work_dir = Path(config.work_dir) + self._work_dir = config.work_dir + + if get_rank() == 0: + self._work_dir.mkdir(parents=True, exist_ok=True) + + # Synchronize before continuing + if dist.is_initialized(): + dist.barrier() + + # Initialize meta for experiment tracking + self._meta = self._init_xtuner_meta(self._work_dir, config.auto_resume) + + # Setup logging + log_dir = config.log_dir or self.exp_dir + if isinstance(log_dir, str): + log_dir = Path(log_dir) + self.logger = self._init_logger(log_dir) + + # Load tokenizer + self.logger.info(f"Loading tokenizer from {config.tokenizer_path}") + self.tokenizer = AutoTokenizer.from_pretrained( + config.tokenizer_path, trust_remote_code=True + ) + + # Initialize training engine (model + optimizer) + # Use VisionComposeTrainEngine for VLM (compose) models, TrainEngine for text-only models + self.logger.info(f"Initializing training engine from {config.load_from}") + if isinstance(config.model_cfg, BaseComposeConfig): + self.train_engine = VisionComposeTrainEngine( + model_cfg=config.model_cfg, + optim_cfg=config.optim_cfg, + fsdp_cfg=config.fsdp_cfg, + ) + else: + self.train_engine = TrainEngine( + model_cfg=config.model_cfg, + optim_cfg=config.optim_cfg, + fsdp_cfg=config.fsdp_cfg, + ) + + # Load model weights using from_hf (same as TrainEngine API) + self.logger.info(f"Loading model weights from {config.load_from}") + self.train_engine.from_hf(str(config.load_from)) + + # Resolve dataloader config with tokenizer (set pad_token_id etc.) + self._sp_size = config.sp_size + self._resolve_dataloader_config(dataloader_cfg) + + # Initialize data mesh (same as SFT Trainer) + tp_size = config.fsdp_cfg.tp_size if config.fsdp_cfg else 1 + sp_size = config.sp_size + self.data_mesh = self._init_data_mesh(tp_size, sp_size, config.fsdp_cfg) + self.sp_mesh = self.data_mesh["sp"] + + # Build dataloader using xtuner v1 pattern (same as SFT Trainer) + self.logger.info("Building dataloader from dataloader_cfg...") + micro_batch_size = config.per_device_batch_size + + # Get dp_mesh from data_mesh + dp_mesh = self.data_mesh["dp"] + + self._dataloader = dataloader_cfg.build( + tokenizer=self.tokenizer, + dp_mesh=dp_mesh, + global_batch_size=config.global_batch_size, + micro_batch_size=micro_batch_size, + seed=config.seed, + ) + + # Evaluation dataloader not supported yet in xtuner v1 pattern + self.eval_dataloader = None + + # Build learning rate scheduler + self._build_lr_scheduler() + + # Initialize reference model if needed + self.ref_engine = None + if not config.loss_cfg.reference_free: + self._init_reference_model() + + # Calculate training steps + self._total_steps = len(self._dataloader) * config.total_epochs + + # Save config + if get_rank() == 0: + config_path = log_dir / "dpo_trainer_config.json" + with config_path.open("w") as f: + f.write(config.model_dump_json(indent=2)) + + self.logger.info(f"DPO Trainer initialized") + self.logger.info(f" Total epochs: {config.total_epochs}") + self.logger.info(f" Total steps: {self._total_steps}") + self.logger.info(f" Loss types: {config.loss_cfg.loss_types}") + self.logger.info(f" Loss weights: {config.loss_cfg.loss_weights}") + self.logger.info(f" Reference free: {config.loss_cfg.reference_free}") + self.logger.info(f" Sequence parallel size: {config.sp_size}") + + def _resolve_dataloader_config(self, dataloader_cfg: DataloaderConfig): + """Resolve dataloader config conflicts, similar to SFT Trainer.""" + if hasattr(self.tokenizer, "pad_token_id"): + pad_token_id = self.tokenizer.pad_token_id + else: + pad_token_id = 0 + + if dataloader_cfg.pad_token_id is None: + dataloader_cfg.pad_token_id = pad_token_id + elif dataloader_cfg.pad_token_id != pad_token_id: + self.logger.warning( + f"Dataloader pad_token_id {dataloader_cfg.pad_token_id} is different from tokenizer " + f"pad_token_id {pad_token_id}. Using tokenizer pad_token_id {pad_token_id}." + ) + dataloader_cfg.pad_token_id = pad_token_id + + # Check sequence parallel constraints + if self._sp_size > 1: + if dataloader_cfg.pack_to_max_length is False: + self.logger.warning( + "pack_to_max_length must be True when using sequence parallel. " + "Setting pack_to_max_length to True." + ) + dataloader_cfg.pack_to_max_length = True + + def _init_data_mesh(self, tp_size: int, sp_size: int, fsdp_cfg: FSDPConfig | None): + """Initialize data mesh for distributed training, same as SFT Trainer.""" + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if world_size % tp_size != 0: + raise ValueError( + f"Found tp_size {tp_size}, world_size {world_size}. " + "Tensor parallel size must be a divisor of world size." + ) + + if world_size % sp_size != 0: + raise ValueError( + f"Found sp_size {sp_size}, world_size {world_size}. " + "Sequence parallel size must be a divisor of world size." + ) + + if world_size % (tp_size * sp_size) != 0: + raise ValueError( + f"Found tp_size {tp_size}, sp_size {sp_size}, world_size {world_size}. " + "`tp_size * sp_size` must be a divisor of world size." + ) + + dp_size = world_size // (tp_size * sp_size) + + # Use CPU for device mesh if not cpu_offload, else use DEVICE + cpu_offload = fsdp_cfg.cpu_offload if fsdp_cfg else False + device = str(DEVICE) if not cpu_offload else "cpu" + + data_mesh = init_device_mesh( + device, + (dp_size, sp_size, tp_size), + mesh_dim_names=("dp", "sp", "tp"), + ) + return data_mesh + + @classmethod + def from_config(cls, config: DPOTrainerConfig) -> "DPOTrainer": + """Create a DPOTrainer instance from a DPOTrainerConfig. + + This method follows the same pattern as SFT Trainer.from_config, + passing the dataloader_cfg to __init__ which handles building internally. + + Args: + config: DPOTrainerConfig instance containing all configuration parameters. + + Returns: + DPOTrainer instance initialized with the provided config. + + Example: + >>> from xtuner.v1.utils import Config + >>> cfg = Config.fromfile("dpo_config.py") + >>> trainer = DPOTrainer.from_config(cfg["trainer"]) + >>> trainer.fit() + """ + # Validate config has dataloader_cfg + if config.dataloader_cfg is None: + raise ValueError( + "DPOTrainerConfig.dataloader_cfg is required when using from_config(). " + "Please configure dataloader_cfg with dataset_config_list." + ) + + # Create trainer - __init__ handles all building internally + return cls( + config=config, + dataloader_cfg=config.dataloader_cfg, + ) + + def _build_lr_scheduler(self): + """Build learning rate scheduler.""" + from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LinearLR, SequentialLR + + # Calculate total OPTIMIZER steps (not batch steps!) + # lr_scheduler.step() is called once per gradient_accumulation_steps batches + steps_per_epoch = len(self._dataloader) + total_batches = steps_per_epoch * self.config.total_epochs + total_steps = total_batches // self.config.gradient_accumulation_steps + + self.logger.info(f"LR Scheduler: total_batches={total_batches}, " + f"gradient_accumulation_steps={self.config.gradient_accumulation_steps}, " + f"total_optimizer_steps={total_steps}") + + lr_cfg = self.config.lr_cfg + optimizer = self.train_engine.optimizer + + # Calculate warmup steps + if lr_cfg.warmup_ratio < 1: + warmup_steps = int(lr_cfg.warmup_ratio * total_steps) + else: + warmup_steps = int(lr_cfg.warmup_ratio) + + self.logger.info(f"LR Scheduler: warmup_ratio={lr_cfg.warmup_ratio}, warmup_steps={warmup_steps}") + + # Warmup function - linear warmup from 0 to base_lr (same as SFT Trainer) + def warmup_fn(x): + return x / warmup_steps if x < warmup_steps else 1 + + warmup_scheduler = LambdaLR(optimizer, warmup_fn) + + # Main scheduler + main_steps = total_steps - warmup_steps + if lr_cfg.lr_type == "linear": + scheduler = LinearLR( + optimizer, + start_factor=1.0, + end_factor=lr_cfg.lr_min / lr_cfg.lr if hasattr(lr_cfg, 'lr') else 0.01, + total_iters=main_steps, + ) + elif lr_cfg.lr_type == "cosine": + scheduler = CosineAnnealingLR(optimizer, T_max=main_steps, eta_min=lr_cfg.lr_min) + elif lr_cfg.lr_type == "constant": + scheduler = LambdaLR(optimizer, lambda x: 1.0) + else: + raise ValueError(f"Unsupported lr type: {lr_cfg.lr_type}") + + self.lr_scheduler = SequentialLR( + optimizer=optimizer, + schedulers=[warmup_scheduler, scheduler], + milestones=[warmup_steps], + ) + + def _init_reference_model(self): + """Initialize reference model for KL divergence computation.""" + ref_load_from = self.config.ref_load_from or self.config.load_from + + self.logger.info(f"Initializing reference model from {ref_load_from}") + + # Create a separate engine for reference model + # Use VisionComposeTrainEngine for VLM models, TrainEngine for text-only models + ref_fsdp_cfg = deepcopy(self.config.fsdp_cfg) + + if isinstance(self.config.model_cfg, BaseComposeConfig): + self.ref_engine = VisionComposeTrainEngine( + model_cfg=self.config.model_cfg, + optim_cfg=self.config.optim_cfg, + fsdp_cfg=ref_fsdp_cfg, + ) + else: + self.ref_engine = TrainEngine( + model_cfg=self.config.model_cfg, + optim_cfg=self.config.optim_cfg, + fsdp_cfg=ref_fsdp_cfg, + ) + self.ref_engine.from_hf(str(ref_load_from)) + + # Freeze reference model + if self.config.freeze_ref_model: + for param in self.ref_engine.model.parameters(): + param.requires_grad = False + self.ref_engine.model.eval() + + # _setup_dataloader removed - dataloader is now built in __init__ using dataloader_cfg.build() + + def _compute_ref_logprobs( + self, + chosen_seq_ctx: SequenceContext, + rejected_seq_ctx: SequenceContext, + chosen_labels: torch.Tensor, + rejected_labels: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute reference model log probabilities.""" + if self.ref_engine is None: + return None, None + + def _get_field(out: Any, key: str): + if isinstance(out, dict): + return out[key] + return getattr(out, key) + + with torch.no_grad(): + # Forward chosen through reference model + ref_chosen_output = self.ref_engine.forward_only(chosen_seq_ctx) + ref_chosen_logits = _get_field(ref_chosen_output, "logits") + + # Forward rejected through reference model + ref_rejected_output = self.ref_engine.forward_only(rejected_seq_ctx) + ref_rejected_logits = _get_field(ref_rejected_output, "logits") + + # Compute log probs + # NOTE: _gather_logprobs returns token-level logprobs [B, L] + ref_chosen_token_logprobs = self._gather_logprobs(ref_chosen_logits, chosen_labels) + ref_rejected_token_logprobs = self._gather_logprobs(ref_rejected_logits, rejected_labels) + + # Compute masks + chosen_mask = (chosen_labels != -100) + rejected_mask = (rejected_labels != -100) + + # Sum local log probs and counts + local_chosen_sum = (ref_chosen_token_logprobs * chosen_mask.float()).sum(dim=-1) + local_rejected_sum = (ref_rejected_token_logprobs * rejected_mask.float()).sum(dim=-1) + local_chosen_count = chosen_mask.sum(dim=-1).float() + local_rejected_count = rejected_mask.sum(dim=-1).float() + + # Aggregate across SP ranks if using sequence parallelism + if self.sp_mesh.size() > 1: + dist.all_reduce(local_chosen_sum, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + dist.all_reduce(local_rejected_sum, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + dist.all_reduce(local_chosen_count, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + dist.all_reduce(local_rejected_count, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + + # Compute final log probs (average or sum based on config) + if self.config.loss_cfg.use_average_log_prob: + ref_chosen_logprobs = local_chosen_sum / local_chosen_count.clamp(min=1) + ref_rejected_logprobs = local_rejected_sum / local_rejected_count.clamp(min=1) + else: + ref_chosen_logprobs = local_chosen_sum + ref_rejected_logprobs = local_rejected_sum + + return ref_chosen_logprobs, ref_rejected_logprobs + + def _gather_logprobs( + self, logits: torch.Tensor, labels: torch.Tensor + ) -> torch.Tensor: + """Gather log probabilities for the given labels.""" + import torch.nn.functional as F + + # Shift logits and labels for causal LM + # logits: [batch, seq_len, vocab_size] + # labels: [batch, seq_len] (already shifted in collator) + logprobs = F.log_softmax(logits, dim=-1) + + # Gather log probs for target tokens + # Handle -100 (ignore index) by clipping + gathered = logprobs.gather( + dim=-1, index=labels.clip(min=0).unsqueeze(-1) + ).squeeze(-1) + + return gathered + + def _train_step(self, batch: List[DPOColateItem]) -> Dict[str, float]: + """Perform a single training step.""" + total_loss = torch.tensor(0.0, device=DEVICE) + all_metrics = {} + + def _get_field(out: Any, key: str): + if isinstance(out, dict): + return out[key] + return getattr(out, key) + + for item in batch: + # Extract data from DPOColateItem + chosen_seq_ctx = item["chosen_seq_ctx"] + rejected_seq_ctx = item["rejected_seq_ctx"] + chosen_labels = item["chosen_shifted_labels"] + rejected_labels = item["rejected_shifted_labels"] + + # Move to device + chosen_seq_ctx = chosen_seq_ctx.to(DEVICE) + rejected_seq_ctx = rejected_seq_ctx.to(DEVICE) + chosen_labels = chosen_labels.to(DEVICE) + rejected_labels = rejected_labels.to(DEVICE) + + # Apply sequence parallel split if enabled + if self.sp_mesh.size() > 1: + sp_size = self.sp_mesh.size() + + chosen_seq_ctx = chosen_seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) + rejected_seq_ctx = rejected_seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) + chosen_labels = pad_to_multiple_of(chosen_labels, -100, sp_size, dim=1) + chosen_labels = split_for_sequence_parallel(chosen_labels, dim=1, sp_mesh=self.sp_mesh) + rejected_labels = pad_to_multiple_of(rejected_labels, -100, sp_size, dim=1) + rejected_labels = split_for_sequence_parallel(rejected_labels, dim=1, sp_mesh=self.sp_mesh) + + # Compute reference log probs if needed + ref_chosen_logprobs = item.get("ref_chosen_logps") + ref_rejected_logprobs = item.get("ref_rejected_logps") + + if ref_chosen_logprobs is None and self.ref_engine is not None: + ref_chosen_logprobs, ref_rejected_logprobs = self._compute_ref_logprobs( + chosen_seq_ctx, rejected_seq_ctx, chosen_labels, rejected_labels + ) + + # Create DPO input item + # NOTE: chosen_labels and rejected_labels are ALREADY SP-split above, + # so we should NOT call dpo_input.sp_split() again (double split bug) + dpo_input = DPOLossContextInputItem( + chosen_shifted_labels=chosen_labels, + rejected_shifted_labels=rejected_labels, + ref_chosen_logprobs=ref_chosen_logprobs, + ref_rejected_logprobs=ref_rejected_logprobs, + ) + + # Build loss kwargs + loss_kwargs_list = DPOLossContext.build_batches_loss_kwargs( + [dpo_input], self.config.loss_cfg, + sp_mesh=self.sp_mesh if self.sp_mesh.size() > 1 else None, + ) + + # Forward chosen and rejected separately (policy) + chosen_output = self.train_engine.model(seq_ctx=chosen_seq_ctx, loss_ctx=None) + rejected_output = self.train_engine.model(seq_ctx=rejected_seq_ctx, loss_ctx=None) + + chosen_logits = _get_field(chosen_output, "logits") + rejected_logits = _get_field(rejected_output, "logits") + + # Concat along seq_len to match DPOLossKwargs.shifted_labels layout + logits = torch.cat([chosen_logits, rejected_logits], dim=1).float() + + loss_kwargs = loss_kwargs_list[0] + shifted_labels = loss_kwargs.shifted_labels + chosen_mask = loss_kwargs.chosen_mask + rejected_mask = loss_kwargs.rejected_mask + loss_weight = loss_kwargs.loss_weight + + # Compute loss using DPOLossContext's logprob math (logits already computed) + from xtuner.v1.rl.utils import gather_logprobs as _gather_logprobs + + all_logprobs = _gather_logprobs(logits, shifted_labels) # [B, L] + + # Compute local log probs (sum over local tokens) + chosen_logprobs = all_logprobs * chosen_mask.float() + rejected_logprobs = all_logprobs * rejected_mask.float() + + # Sum local log probs + local_chosen_sum = chosen_logprobs.sum(dim=-1) + local_rejected_sum = rejected_logprobs.sum(dim=-1) + local_chosen_count = chosen_mask.sum(dim=-1).float() + local_rejected_count = rejected_mask.sum(dim=-1).float() + + # Aggregate across SP ranks if using sequence parallelism + if self.sp_mesh.size() > 1: + # All-reduce to get global sum of log probs across SP group + dist.all_reduce(local_chosen_sum, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + dist.all_reduce(local_rejected_sum, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + dist.all_reduce(local_chosen_count, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + dist.all_reduce(local_rejected_count, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + + # Compute final log probs (average or sum based on config) + if self.config.loss_cfg.use_average_log_prob: + policy_chosen_logps = local_chosen_sum / local_chosen_count.clamp(min=1) + policy_rejected_logps = local_rejected_sum / local_rejected_count.clamp(min=1) + else: + policy_chosen_logps = local_chosen_sum + policy_rejected_logps = local_rejected_sum + + ref_chosen_logps = loss_kwargs.ref_chosen_logprobs + ref_rejected_logps = loss_kwargs.ref_rejected_logprobs + if ref_chosen_logps is None: + ref_chosen_logps = torch.zeros_like(policy_chosen_logps) + if ref_rejected_logps is None: + ref_rejected_logps = torch.zeros_like(policy_rejected_logps) + + loss_ctx = DPOLossContext(self.config.loss_cfg, loss_kwargs) + loss = torch.tensor(0.0, device=logits.device, dtype=logits.dtype) + extra_info: dict[str, Any] = {} + + for loss_type, weight in zip(self.config.loss_cfg.loss_types, self.config.loss_cfg.loss_weights): + if loss_type == "sigmoid": + _l = loss_ctx._dpo_loss_sigmoid( # type: ignore[attr-defined] + policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps + ).mean() * weight + extra_info["dpo_sigmoid_loss"] = _l.detach() + loss = loss + _l + elif loss_type == "robust": + _l = loss_ctx._dpo_loss_robust( # type: ignore[attr-defined] + policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps + ).mean() * weight + extra_info["dpo_robust_loss"] = _l.detach() + loss = loss + _l + elif loss_type == "hinge": + _l = loss_ctx._dpo_loss_hinge( # type: ignore[attr-defined] + policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps + ).mean() * weight + extra_info["dpo_hinge_loss"] = _l.detach() + loss = loss + _l + elif loss_type == "ipo": + _l = loss_ctx._dpo_loss_ipo( # type: ignore[attr-defined] + policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps + ).mean() * weight + extra_info["dpo_ipo_loss"] = _l.detach() + loss = loss + _l + elif loss_type == "bco_pair": + _l = loss_ctx._bco_pair_loss( # type: ignore[attr-defined] + policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps + ).mean() * weight + extra_info["bco_pair_loss"] = _l.detach() + loss = loss + _l + elif loss_type == "nca_pair": + _l = loss_ctx._nca_pair_loss( # type: ignore[attr-defined] + policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps + ).mean() * weight + extra_info["nca_pair_loss"] = _l.detach() + loss = loss + _l + elif loss_type == "sppo_hard": + _l = loss_ctx._sppo_hard_loss( # type: ignore[attr-defined] + policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps + ).mean() * weight + extra_info["sppo_hard_loss"] = _l.detach() + loss = loss + _l + elif loss_type == "sft": + # SFT loss only on chosen part + _l = loss_ctx._sft_loss( # type: ignore[attr-defined] + logits=logits[:, : chosen_mask.shape[1]], + labels=shifted_labels[:, : chosen_mask.shape[1]], + mask=chosen_mask[:, : chosen_mask.shape[1]], + loss_weight=loss_weight[:, : chosen_mask.shape[1]], + ) * weight + # Aggregate SFT loss across SP ranks (each rank only has partial tokens) + if self.sp_mesh.size() > 1: + dist.all_reduce(_l, op=dist.ReduceOp.SUM, group=self.sp_mesh.get_group()) + extra_info["sft_loss"] = _l.detach() + loss = loss + _l + else: + raise ValueError(f"Unsupported loss_type: {loss_type}") + + total_loss = total_loss + loss + + # Collect metrics + for k, v in extra_info.items(): + if k not in all_metrics: + all_metrics[k] = [] + if isinstance(v, torch.Tensor): + all_metrics[k].append(v.detach()) + else: + all_metrics[k].append(v) + + # Average loss over batch + total_loss = total_loss / len(batch) + + # Backward pass + total_loss.backward() + + # Gradient accumulation + if (self._cur_step + 1) % self.config.gradient_accumulation_steps == 0: + self.train_engine.optimizer.step() + self.lr_scheduler.step() + self.train_engine.optimizer.zero_grad() + + # Prepare metrics + metrics = {"loss": total_loss.item()} + for k, v_list in all_metrics.items(): + if isinstance(v_list[0], torch.Tensor): + metrics[k] = torch.stack(v_list).mean().item() + elif isinstance(v_list[0], (int, float)): + metrics[k] = sum(v_list) / len(v_list) + + return metrics + + def fit(self): + """Run the DPO training loop.""" + self.logger.info("Starting DPO training") + + for epoch in range(self.config.total_epochs): + self._cur_epoch = epoch + + # Set epoch for distributed sampler + if hasattr(self._dataloader, "set_epoch"): + self._dataloader.set_epoch(epoch) + elif hasattr(self._dataloader.sampler, "set_epoch"): + self._dataloader.sampler.set_epoch(epoch) + + self.logger.info(f"Epoch {epoch + 1}/{self.config.total_epochs}") + + # Training loop + self.train_engine.model.train() + epoch_metrics = [] + + progress_bar = tqdm( + self._dataloader, + desc=f"Epoch {epoch + 1}", + disable=get_rank() != 0, + ) + + for batch_idx, batch in enumerate(progress_bar): + metrics = self._train_step(batch) + epoch_metrics.append(metrics) + self._cur_step += 1 + + # Logging + if self._cur_step % self.config.log_interval == 0: + avg_metrics = self._average_metrics(epoch_metrics[-self.config.log_interval:]) + avg_metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + # Format lr with scientific notation, other metrics with 4 decimal places + log_parts = [] + for k, v in avg_metrics.items(): + if k == "lr": + log_parts.append(f"{k}={v:.2e}") + else: + log_parts.append(f"{k}={v:.4f}") + log_str = f"Step {self._cur_step}: " + ", ".join(log_parts) + self.logger.info(log_str) + progress_bar.set_postfix(avg_metrics) + + # Save checkpoint + if self.config.save_interval and self._cur_step % self.config.save_interval == 0: + self._save_checkpoint() + + # Evaluation + if ( + self.config.eval_interval + and self._cur_step % self.config.eval_interval == 0 + and self.eval_dataloader is not None + ): + eval_metrics = self._evaluate() + self.logger.info(f"Evaluation: {eval_metrics}") + + # Garbage collection for memory management (same as SFT Trainer) + if self._cur_step % 50 == 0: + gc.collect() + + # End of epoch + avg_epoch_metrics = self._average_metrics(epoch_metrics) + self.logger.info(f"Epoch {epoch + 1} average metrics: {avg_epoch_metrics}") + + # Save at end of epoch + self._save_checkpoint() + + self.logger.info("Training completed!") + + def _evaluate(self) -> Dict[str, float]: + """Run evaluation.""" + self.train_engine.model.eval() + eval_metrics = [] + + with torch.no_grad(): + for batch in tqdm(self.eval_dataloader, desc="Evaluating", disable=get_rank() != 0): + metrics = self._eval_step(batch) + eval_metrics.append(metrics) + + self.train_engine.model.train() + return self._average_metrics(eval_metrics) + + def _eval_step(self, batch: List[DPOColateItem]) -> Dict[str, float]: + """Perform a single evaluation step.""" + total_loss = torch.tensor(0.0, device=DEVICE) + all_metrics = {} + + for item in batch: + chosen_seq_ctx = item["chosen_seq_ctx"].to(DEVICE) + rejected_seq_ctx = item["rejected_seq_ctx"].to(DEVICE) + chosen_labels = item["chosen_shifted_labels"].to(DEVICE) + rejected_labels = item["rejected_shifted_labels"].to(DEVICE) + + # Apply sequence parallel split if enabled + if self.sp_mesh.size() > 1: + sp_size = self.sp_mesh.size() + chosen_seq_ctx = chosen_seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) + rejected_seq_ctx = rejected_seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) + chosen_labels = pad_to_multiple_of(chosen_labels, -100, sp_size, dim=1) + chosen_labels = split_for_sequence_parallel(chosen_labels, dim=1, sp_mesh=self.sp_mesh) + rejected_labels = pad_to_multiple_of(rejected_labels, -100, sp_size, dim=1) + rejected_labels = split_for_sequence_parallel(rejected_labels, dim=1, sp_mesh=self.sp_mesh) + + ref_chosen_logprobs = item.get("ref_chosen_logps") + ref_rejected_logprobs = item.get("ref_rejected_logps") + + if ref_chosen_logprobs is None and self.ref_engine is not None: + ref_chosen_logprobs, ref_rejected_logprobs = self._compute_ref_logprobs( + chosen_seq_ctx, rejected_seq_ctx, chosen_labels, rejected_labels + ) + + dpo_input = DPOLossContextInputItem( + chosen_shifted_labels=chosen_labels, + rejected_shifted_labels=rejected_labels, + ref_chosen_logprobs=ref_chosen_logprobs, + ref_rejected_logprobs=ref_rejected_logprobs, + ) + + loss_kwargs_list = DPOLossContext.build_batches_loss_kwargs( + [dpo_input], self.config.loss_cfg + ) + + chosen_output = self.train_engine.forward_only(chosen_seq_ctx) + rejected_output = self.train_engine.forward_only(rejected_seq_ctx) + + hidden_states = torch.cat([chosen_output.hidden_states, rejected_output.hidden_states], dim=1) + + lm_head = self.train_engine.model.get_output_embeddings() + head_weight = lm_head.weight + head_bias = getattr(lm_head, "bias", None) + + loss_ctx = DPOLossContext(self.config.loss_cfg, loss_kwargs_list[0]) + loss, (_, extra_info) = loss_ctx.loss_fn( + hidden_states, head_weight, head_bias, loss_kwargs_list[0] + ) + + total_loss = total_loss + loss + + for k, v in extra_info.items(): + if k not in all_metrics: + all_metrics[k] = [] + if isinstance(v, torch.Tensor): + all_metrics[k].append(v.detach()) + else: + all_metrics[k].append(v) + + total_loss = total_loss / len(batch) + + metrics = {"eval_loss": total_loss.item()} + for k, v_list in all_metrics.items(): + if isinstance(v_list[0], torch.Tensor): + metrics[f"eval_{k}"] = torch.stack(v_list).mean().item() + + return metrics + + def _average_metrics(self, metrics_list: List[Dict[str, float]]) -> Dict[str, float]: + """Average metrics over a list.""" + if not metrics_list: + return {} + + avg_metrics = {} + for key in metrics_list[0].keys(): + values = [m[key] for m in metrics_list if key in m] + avg_metrics[key] = sum(values) / len(values) + + return avg_metrics + + def _save_checkpoint(self): + """Save model checkpoint.""" + save_path = self.exp_dir / f"checkpoint-{self._cur_step}" + self.logger.info(f"Saving checkpoint to {save_path}") + + if get_rank() == 0: + save_path.mkdir(parents=True, exist_ok=True) + + # Synchronize before saving + if dist.is_initialized(): + dist.barrier() + + # Save model + self.train_engine.save_hf(str(save_path)) + + # Save tokenizer + if get_rank() == 0: + if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + self.tokenizer.save_pretrained(str(save_path)) + + # Update meta + if get_rank() == 0: + self._meta.latest_exp.hf_checkpoint_list.append(str(save_path)) + + # Save meta + meta_path = self._work_dir / self.META_PATH + with meta_path.open("w") as f: + f.write(self._meta.model_dump_json(indent=2)) + + if dist.is_initialized(): + dist.barrier() + + def _init_logger(self, work_dir: Path): + """Initialize logger.""" + return get_logger(log_dir=work_dir, tag="DPOTrainer") + + def _set_deterministic(self): + """Set deterministic mode.""" + if XTUNER_DETERMINISTIC: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True, warn_only=True) + + def _set_random_seed(self, seed: int): + """Set random seed.""" + set_random_seed(seed) + + def _init_xtuner_meta(self, work_dir: Path, resume: bool) -> XTunerMeta: + """Initialize experiment metadata. + + This follows the same pattern as RLTrainer._init_xtuner_meta. + """ + if not work_dir.exists(): + work_dir.mkdir(parents=True, exist_ok=True) + + meta_path = work_dir / self.META_PATH + if not meta_path.exists(): + meta = XTunerMeta(exps=[]) + with open(meta_path, "w") as f: + f.write(meta.model_dump_json(indent=2)) + + meta = cast(XTunerMeta, XTunerMeta.model_validate(load(meta_path, file_format="json"))) + + resume = resume and bool(meta.exps) + + if resume: + # Resume from existing experiment + latest_exp = meta.exps[-1] + latest_exp_history = latest_exp.history[-1] + + begin = cast(int, latest_exp_history.get("end") or latest_exp_history["begin"]) + exp_dir = Path(latest_exp.exp_dir) + git_dir = exp_dir / f"git-info-begin-{begin}" + + if not git_dir.exists(): + git_dir.mkdir(parents=True, exist_ok=True) + + staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" + + commit = record_git_info(staged_path, unstaged_path) + git_info = GitInfo( + commit=commit, + staged=str(staged_path), + unstaged=str(unstaged_path), + ) + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + new_exp_history = ExpHistory( + begin=begin, + timestamp=timestamp, + git_info=git_info, + ) + latest_exp.history.append(new_exp_history) + else: + # Start new experiment + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + exp_dir = work_dir / timestamp + git_dir = Path(f"{exp_dir}/git-info-begin-0") + + if not git_dir.exists(): + git_dir.mkdir(parents=True, exist_ok=True) + + staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff" + commit = record_git_info(staged_path, unstaged_path) + git_info = GitInfo( + commit=commit, + staged=str(staged_path), + unstaged=str(unstaged_path), + ) + + new_history = ExpHistory( + begin=0, + timestamp=timestamp, + git_info=git_info, + ) + new_exp = ExpInfo(history=[new_history], exp_dir=str(exp_dir)) + meta.exps.append(new_exp) + + return meta + + @property + def work_dir(self) -> Path: + return self._work_dir + + @property + def exp_dir(self) -> Path: + return Path(self._meta.latest_exp.exp_dir) + + @property + def cur_step(self) -> int: + return self._cur_step + + @property + def cur_epoch(self) -> int: + return self._cur_epoch +# [XTuner][2026-01-12 07:36:21][WARNING] Failed to process inputs: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'images', using text-only \ No newline at end of file From b1c8b5f71bbb74312273440caebfe9e9375985ea Mon Sep 17 00:00:00 2001 From: LarryLeeee Date: Mon, 23 Mar 2026 15:50:31 +0800 Subject: [PATCH 2/2] feat: support Mixed Preference Optimization (MPO) --- examples/v1/config/mpo_qwen3_vl_8B.py | 79 +++++++++---------------- examples/v1/scripts/run_mpo_qwen3_vl.sh | 2 +- 2 files changed, 28 insertions(+), 53 deletions(-) diff --git a/examples/v1/config/mpo_qwen3_vl_8B.py b/examples/v1/config/mpo_qwen3_vl_8B.py index 66a44f917..e5680f626 100755 --- a/examples/v1/config/mpo_qwen3_vl_8B.py +++ b/examples/v1/config/mpo_qwen3_vl_8B.py @@ -9,7 +9,7 @@ - bco_pair: Binary Classifier Optimization for absolute quality - sft: Supervised Fine-Tuning loss to maintain generation quality -For MPO (Mixed Preference Optimization), use: +For MPO (Mixed Preference Optimization), use:(as used in the MPO paper) loss_types=["sigmoid", "bco_pair", "sft"] loss_weights=[0.8, 0.2, 1.0] @@ -18,9 +18,10 @@ export WORK_DIR=/path/to/work_dir export MODEL_PATH=/path/to/model export META_DATA_PATH=/path/to/meta.json - + export CEPH_CONFIG=/path/to/ceph.conf + export TOKENIZER_CACHE_DIR=/path/to/tokenizer_cache_dir # Run with torchrun - torchrun --nproc_per_node=8 xtuner/v1/train/cli/dpo.py --config dpo_qwen3_vl_8B.py + torchrun --nproc_per_node=8 xtuner/v1/train/cli/dpo.py --config mpo_qwen3_vl_8B.py """ import json @@ -32,22 +33,17 @@ from xtuner.v1.model import Qwen3VLDense8BConfig from xtuner.v1.rl.dpo import DPOLossConfig from xtuner.v1.train.dpo_trainer import DPOTrainerConfig +import os + +ceph_config = os.environ["CEPH_CONFIG"] +meta_data_path = os.environ["META_DATA_PATH"] +model_path = os.environ["MODEL_PATH"] +work_dir = os.environ["WORK_DIR"] +tokenizer_cache_dir = os.environ["TOKENIZER_CACHE_DIR"] -# ============================================================================ -# 路径配置 (Path Configuration) -# ============================================================================ -ceph_config = "/mnt/shared-storage-user/lisongze/iv3/xtuner/dpo_config/petreloss.conf" -meta_data_path = "/mnt/shared-storage-user/lisongze/iv3/xtuner/dpo_config/MMPR.json" -model_path = "/mnt/shared-storage-user/lisongze/iv3/xtuner_no/Qwen3-VL-8B-Instruct" #"/mnt/shared-storage-user/lisongze/cache/sp-mla-hf-800" -work_dir = "/mnt/shared-storage-user/iv3/mpo/xtuner_saved_model/qwen3vl-8B-mpo-sp-mmpr-fix-bug" -tokenizer_cache_dir = "/mnt/shared-storage-user/iv3/mpo/xtuner_tokenizer_cache/qwen3vl-8B-mpo-sp-mmpr-fix-bug" - -# ============================================================================ -# Training Settings (训练超参数) -# ============================================================================ -# global_batch_size = num_gpus × per_device_batch_size × gradient_accumulation_steps -# 8 GPUs × 4 gradient_accumulation_steps × 1 per_device_batch_size x 2 sp_size = 64 +# basic settings +# global_batch_size = num_gpus × per_device_batch_size × gradient_accumulation_steps x sp_size total_epochs = 1 global_batch_size = 64 # suppose 256 per_device_batch_size = 1 @@ -61,20 +57,14 @@ # Learning rate settings lr = 5e-6 # Lower LR for DPO # Paper: cosine decay with minimum learning rate 0 -lr_min = 0#1e-8 +lr_min = 0 # Paper: linear warmup for first 5% of total training steps warmup_ratio = 0.05 weight_decay = 0.05 -# ============================================================================ -# 1. Model Configuration -# ============================================================================ -# Freeze vision encoder to prevent catastrophic forgetting -model_cfg = Qwen3VLDense8BConfig()#freeze_vision=True, freeze_projector=True) +model_cfg = Qwen3VLDense8BConfig() -# ============================================================================ -# 2. DPO Loss Configuration -# ============================================================================ +# DPO Loss Configuratio # Option 1: Standard DPO (sigmoid only) # loss_cfg = DPOLossConfig( # loss_types=["sigmoid"], @@ -82,8 +72,7 @@ # beta=0.1, # ) -# Option 2: MPO (Mixed Preference Optimization) -# Combines DPO, BCO, and SFT losses +# MPO (Mixed Preference Optimization) - combines DPO, BCO, and SFT losses loss_cfg = DPOLossConfig( loss_types=["sigmoid", "bco_pair", "sft"], loss_weights=[0.8, 0.2, 1.0], @@ -96,9 +85,7 @@ ignore_idx=-100, ) -# ============================================================================ -# 3. Dataset Configuration - 逐个 JSON 加载 (参考 sft_internvl3.5_8B_config_tiny.py) -# ============================================================================ +# Dataset Configuration - refer to sft_internvl3.5_8B_config_tiny.py) oss_loader_cfg = OSSLoaderConfig(backend_kwargs={"conf_path": ceph_config}) ds_collections = json.loads(open(meta_data_path).read()) dataset_config = [] @@ -109,7 +96,7 @@ anno_path=_data['annotation'], media_root=_data.get('media_root', ''), sample_ratio=_data.get('sample_ratio', 1.0), - class_name='VLMPreferenceJsonlDataset', # 使用偏好数据集类 + class_name='VLMPreferenceJsonlDataset', # use preference dataset class enable_sequential_sampler=True, cache_tag='cache_tags_dpo_v1', cache_dir=tokenizer_cache_dir, @@ -124,7 +111,7 @@ chosen_key="chosen", rejected_key="rejected", images_key="images", - add_eos_token=True, # 使用父类定义的字段名 + add_eos_token=True, # use parent class defined field names system_message=_data.get('system_message', None), hash=_data.get('hash', None), ), @@ -135,38 +122,26 @@ dataset_config_list=dataset_config, pack_max_length=pack_max_length, pack_to_max_length=True,# must set to True if using sp_size>1 - pack_level="none", # DPO 不需要 packing,每个样本独立处理 - collator="qwen3_vl_dpo_collator", # 使用 DPO collator + pack_level="none", + collator="qwen3_vl_dpo_collator", # use DPO collator num_workers=num_workers, - group_by_length=False, # pack_level=none 时必须为 False + group_by_length=False, # must be False if pack_level=none ) -# ============================================================================ -# 4. Optimizer and Learning Rate -# ============================================================================ +# Optimizer and Learning Rate optim_cfg = AdamWConfig(lr=lr, weight_decay=weight_decay, foreach=False) lr_cfg = LRConfig(lr_type="cosine", warmup_ratio=warmup_ratio, lr_min=lr_min) -# ============================================================================ -# 5. FSDP Configuration (内存优化版) -# ============================================================================ +# FSDP Configuration fsdp_cfg = FSDPConfig( - # Gradient checkpointing: 1.0 = 完全启用,用计算换内存(最重要的内存优化) recompute_ratio=1.0, - # Vision 模块也启用 gradient checkpointing vision_recompute_ratio=1.0, - # 前向传播后重新分片参数,节省内存 reshard_after_forward=True, - # 关闭 RNG 状态保存,节省少量内存(不影响训练质量,只影响精确复现) checkpoint_preserve_rng_state=False, - # CPU offload:将优化器状态卸载到 CPU(会降速但省显存) - # cpu_offload=True, - # 关闭 torch.compile:VLM 动态 shape 不适合 compile,且编译时额外耗内存 torch_compile=True, ) -# ============================================================================ -# 6. DPO Trainer Configuration (export as 'trainer' for CLI compatibility) -# ============================================================================ + +# DPO Trainer Configuration trainer = DPOTrainerConfig( model_cfg=model_cfg, optim_cfg=optim_cfg, diff --git a/examples/v1/scripts/run_mpo_qwen3_vl.sh b/examples/v1/scripts/run_mpo_qwen3_vl.sh index c0cbd0c33..5c32a6466 100755 --- a/examples/v1/scripts/run_mpo_qwen3_vl.sh +++ b/examples/v1/scripts/run_mpo_qwen3_vl.sh @@ -9,7 +9,7 @@ export HF_HOME="$(pwd)/" export TORCHDYNAMO_VERBOSE=1 MASTER_PORT=20500 -config_file="/mnt/shared-storage-user/lisongze/xtuner/examples/v1/config/mpo_qwen3_vl_8B.py" +config_file="xtuner/examples/v1/config/mpo_qwen3_vl_8B.py" # NODE_COUNT=1 # NODE_RANK=0 # MASTER_ADDR=127.0.0.1