Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
gpu_memory_utilization=0.8,
context_length=max_response_length + max_prompt_length,
enable_return_routed_experts=(enable_return_routed_experts == "1"),
rollout_max_batch_size_per_instance=1024
rollout_max_batch_size_per_instance=2048
)

# 3. judger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
gpu_memory_utilization=0.8,
context_length=max_response_length + max_prompt_length,
enable_return_routed_experts=(enable_return_routed_experts == "1"),
rollout_max_batch_size_per_instance=1024
rollout_max_batch_size_per_instance=2048
)

# 3. judger
Expand Down
14 changes: 9 additions & 5 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,18 @@ class RolloutState(CacheObj, BaseModel):
sample_params: SampleParams = SampleParams()

# --- InferEngine 输出 ---
# 每一次推理引擎的实际输出, 在rollout worker中被覆盖写
response: str | None = None
response_ids: list[int] | None = None
logprobs: list[float] | None = None
routed_experts: list[int] | RayObjectRef | None = None
finish_reason: str | None = None
# response_mask: 记录response_ids中哪个token算loss, 与response_ids长度相同,每轮rollout在 agent_loop.generate 中覆盖写
response_mask: list[int] | None = None
# response_rollout_steps:记录 response_ids 中每个 token 是在哪个 rollout_step 生成的,与 response_ids 长度相同,每轮rollout在agent_loop中后处理中覆盖写
response_rollout_steps: list[int] | None = None
# 记录该样本过期程度,即最先生成的token与当前的训练步数的差值,数值越大表示越过期,在 agent_loop 中后处理中覆盖写
seq_staleness: int = 0

# --- Judger 输出 ---
reward: dict[str, Any] | None = None
Expand All @@ -98,13 +105,10 @@ class RolloutState(CacheObj, BaseModel):
task_name: str | None = None
status: Status = Status.INIT
error_msg: str | None = None
seq_staleness: int = 0
response_mask: list[int] | None = None # response_ids的长度
response_rollout_steps: list[int] | None = None # 记录 response_ids 中每个 token 是在哪个 rollout_step 生成的
extra_fields: dict[str, Any] = {}

@field_serializer("routed_experts")
def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | None:
def _serialize_routed_experts(self, value: torch.Tensor | RayObjectRef | None) -> list | None:
"""Dump 时跳过 ray.ObjectRef,序列化为 None,避免 PydanticSerializationError。"""
if value is None:
return None
Expand All @@ -117,7 +121,7 @@ def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> l
pass
if type(value).__name__ == "ObjectRef" and "ray" in getattr(type(value), "__module__", ""):
return None
return value # list[int]
return value.tolist()


def update_status_from_finish_reason(finish_reason: str | None) -> Status:
Expand Down
51 changes: 51 additions & 0 deletions xtuner/v1/rl/agent_loop/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
import time

import ray

from xtuner.v1.data_proto import RolloutState, Status, update_seq_staleness
from xtuner.v1.utils import get_logger


logger = get_logger()


def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef | None) -> list[int] | None:
if routed_experts is None:
return None
if isinstance(routed_experts, ray.ObjectRef):
routed_experts = ray.get(routed_experts)
if hasattr(routed_experts, "tolist"):
routed_experts = routed_experts.tolist()
assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}"
return routed_experts


class PartialRolloutHandler:
"""Handle preprocessing and postprocessing for partial rollout
continuation."""
Expand Down Expand Up @@ -47,6 +62,7 @@ def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool =
"response": rollout_state.response or "",
"logprobs": rollout_state.logprobs or [],
"response_mask": rollout_state.response_mask or [],
"routed_experts": rollout_state.routed_experts,
}
return rollout_state

Expand All @@ -63,5 +79,40 @@ def postprocess(self, rollout_state: RolloutState, rollout_step: int) -> Rollout
rollout_state.response = history_dict.get("response", "") + (rollout_state.response or "")
rollout_state.logprobs = history_dict.get("logprobs", []) + (rollout_state.logprobs or [])
rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or [])
history_routed_experts_ref = history_dict.get("routed_experts")
cur_routed_experts_ref = rollout_state.routed_experts
if history_routed_experts_ref is not None and cur_routed_experts_ref is not None:
start_time = time.time()
history_routed_experts = _resolve_routed_experts(history_dict.get("routed_experts"))
cur_routed_experts = _resolve_routed_experts(rollout_state.routed_experts)
cur_routed_experts_len = len(cur_routed_experts)
history_routed_experts_len = len(history_routed_experts)
assert history_routed_experts_len - 1 <= cur_routed_experts_len, (
f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}"
)
cur_routed_experts = cur_routed_experts[history_routed_experts_len:]
concat_routed_experts = history_routed_experts + cur_routed_experts

prompt_ids = rollout_state.prompt_ids or []
response_ids = rollout_state.response_ids or []
expect_tokens_num = len(prompt_ids) + len(response_ids) - 1
assert len(concat_routed_experts) == expect_tokens_num, (
f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected tokens num: {expect_tokens_num}"
)
logger.info(
f"[PartialRolloutHandler] Postprocess rollout {rollout_state.uid}: "
f"concat routed_experts len={len(concat_routed_experts)} "
f"(history={history_routed_experts_len}, new={cur_routed_experts_len}), "
f"prompt={len(prompt_ids)}, response={len(response_ids)}"
)
rollout_state.routed_experts = ray.put(concat_routed_experts)
end_time = time.time()
logger.info(
f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds"
)
elif history_routed_experts_ref is None and cur_routed_experts_ref is not None:
rollout_state.routed_experts = cur_routed_experts_ref
elif history_routed_experts_ref is not None and cur_routed_experts_ref is None:
rollout_state.routed_experts = history_routed_experts_ref

return rollout_state
Comment on lines 81 to 118
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: start_time is set before the conditionals, but all three branches (both-not-None, one-None, both-None) are timed and logged. When both are None, you're timing and logging nothing useful. Consider moving the timing and logging inside the if block where concatenation actually happens, or at least gating the log message:

if history_routed_experts is not None or cur_routed_experts is not None:
    logger.info(...)

Loading