Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from xtuner.v1.config import FSDPConfig, OptimConfig
from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.loss import LogProbContext
from xtuner.v1.model.base import (
BaseModel,
BatchForwardInfo,
Expand Down Expand Up @@ -170,8 +171,8 @@ def data_replicate_size(self) -> int:
return self.fsdp_cfg.tp_size

@torch.no_grad()
def forward_only(self, seq_ctx: SequenceContext):
output = self.model(seq_ctx=seq_ctx, loss_ctx=None)
def forward_only(self, seq_ctx: SequenceContext, loss_ctx: LogProbContext):
output = self.model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)
return output

def grad_accumulation_steps(self, data_batches_len: int):
Expand Down
3 changes: 3 additions & 0 deletions xtuner/v1/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .ce_loss import CELossConfig, CELossContext
from .chunk_loss import ChunkLoss
from .moe_loss import BalancingLoss, ZLoss
from .rl_loss import LogProbConfig, LogProbContext


__all__ = [
Expand All @@ -13,6 +14,8 @@
"BaseLossConfig",
"BaseLossContext",
"BaseLossKwargs",
"LogProbConfig",
"LogProbContext",
]

import torch
Expand Down
87 changes: 87 additions & 0 deletions xtuner/v1/loss/rl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any

import torch
import torch.nn.functional as F
from torch.distributed.device_mesh import DeviceMesh

from xtuner.v1.rl.utils import gather_logprobs
from xtuner.v1.utils.device import get_device

from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs


DEVICE = get_device()


class LogProbConfig(BaseLossConfig):
@property
def loss_ctx_cls(self) -> type["LogProbContext"]:
return LogProbContext

def build(self, shifted_labels: torch.Tensor, sp_mesh: DeviceMesh | None = None) -> "LogProbContext":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def build(self, shifted_labels: torch.Tensor, sp_mesh: DeviceMesh | None = None) -> "LogProbContext":
def build(self, shifted_labels: torch.Tensor, sp_mesh: DeviceMesh | None = None) -> Self:

loss_kwargs = LogProbKwargs(shifted_labels=shifted_labels)
if sp_mesh is not None and sp_mesh.size() > 1:
loss_kwargs = loss_kwargs.sp_split(sp_mesh)
return self.loss_ctx_cls(self, loss_kwargs)


class LogProbKwargs(BaseLossKwargs):
shifted_labels: torch.Tensor


class LogProbContext(BaseLossContext):
loss_cfg: LogProbConfig
loss_kwargs: LogProbKwargs

@staticmethod
def build_batches( # type: ignore[override]
loss_ctx_list: list["LogProbContext"], *args: Any, **kwargs: Any
) -> list["LogProbContext"]:
del args, kwargs
batch_size = len(loss_ctx_list)
for loss_ctx in loss_ctx_list:
loss_ctx._batch_size = batch_size
return loss_ctx_list

def loss_fn(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: LogProbKwargs,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
logits = F.linear(hidden_states, head_weight, head_bias).float()
logprobs = gather_logprobs(logits, loss_kwargs.shifted_labels)
return logprobs, (None, {})

def chunk_mode(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: LogProbKwargs,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode"

bs, seq_len = loss_kwargs.shifted_labels.shape
logprobs = torch.zeros((bs, seq_len), device=loss_kwargs.shifted_labels.device)
for i in range(0, seq_len, self.loss_cfg.chunk_size):
hidden_states_chunk = hidden_states[:, i : i + self.loss_cfg.chunk_size, :]
logits = F.linear(hidden_states_chunk, head_weight, head_bias).float()
chunked_labels = loss_kwargs.shifted_labels[:, i : i + self.loss_cfg.chunk_size]
chunked_logprobs = gather_logprobs(logits, chunked_labels)
logprobs[:, i : i + self.loss_cfg.chunk_size] = chunked_logprobs
return logprobs, (None, {})

def forward(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward"
if self.loss_cfg.mode == "chunk":
logprobs, _ = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
else:
logprobs, _ = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
return logprobs, (None, {})
14 changes: 9 additions & 5 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from xtuner.v1.datasets.dataloader import Dataloader
from xtuner.v1.engine.train_engine import TrainEngine, TrainStepInfo
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.loss import CELossConfig
from xtuner.v1.loss import CELossConfig, LogProbConfig
from xtuner.v1.loss.ce_loss import CELossContext
from xtuner.v1.model.base import BaseModel as XtunerBaseModel
from xtuner.v1.model.base import ModelItem, TransformerConfig
Expand Down Expand Up @@ -248,6 +248,11 @@ def __init__(
self.rollout_cfg_info: dict = dict()
self.endpoints: dict[str, str] = dict()
self.endpoints["update_weights"] = "update_weights"
if worker_cfg.loss_cfg.chunk_size is not None:
mode = "chunk"
else:
mode = "eager"
self.logprob_cfg = LogProbConfig(chunk_size=worker_cfg.loss_cfg.chunk_size, mode=mode)

def _init_sft(self, worker_cfg: WorkerConfig):
self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg
Expand Down Expand Up @@ -294,7 +299,6 @@ def _build_engine(self, worker_cfg: WorkerConfig) -> TrainEngine:
fsdp_cfg=worker_cfg.fsdp_cfg,
model_cfg=worker_cfg.model_cfg,
)

if worker_cfg.load_from is not None:
engine.from_hf(worker_cfg.load_from)

Expand Down Expand Up @@ -373,9 +377,9 @@ def compute_actor_logprobs(
self._engine._maybe_precompute_float8_dynamic_scale_for_fsdp()
old_logprobs_list: list[torch.Tensor] = []
for seq_ctx, shifted_labels in zip(seq_ctx_list, shifted_labels_list):
output = self._engine.forward_only(seq_ctx=seq_ctx)
old_logprobs = gather_logprobs(output["logits"], shifted_labels)
old_logprobs_list.append(old_logprobs)
loss_ctx = self.logprob_cfg.build(shifted_labels=shifted_labels)
output = self._engine.forward_only(seq_ctx=seq_ctx, loss_ctx=loss_ctx)
old_logprobs_list.append(output["loss"])
return old_logprobs_list

def compute_ref_logprobs(
Expand Down
Loading