diff --git a/examples/v1/config/rl_qwen25_7B_dapo.py b/examples/v1/config/rl_dapo_math.py similarity index 99% rename from examples/v1/config/rl_qwen25_7B_dapo.py rename to examples/v1/config/rl_dapo_math.py index 96e94795c..500ec31bd 100644 --- a/examples/v1/config/rl_qwen25_7B_dapo.py +++ b/examples/v1/config/rl_dapo_math.py @@ -57,7 +57,7 @@ gpu_memory_utilization=0.8, context_length=max_response_length + max_prompt_length, enable_return_routed_experts=(enable_return_routed_experts == "1"), - rollout_max_batch_size_per_instance=1024 + rollout_max_batch_size_per_instance=2048 ) # 3. judger diff --git a/examples/v1/config/rl_qwen25_7B_dapo_async.py b/examples/v1/config/rl_dapo_math_async.py similarity index 99% rename from examples/v1/config/rl_qwen25_7B_dapo_async.py rename to examples/v1/config/rl_dapo_math_async.py index 0958dfec8..3beeea21e 100644 --- a/examples/v1/config/rl_qwen25_7B_dapo_async.py +++ b/examples/v1/config/rl_dapo_math_async.py @@ -57,7 +57,7 @@ gpu_memory_utilization=0.8, context_length=max_response_length + max_prompt_length, enable_return_routed_experts=(enable_return_routed_experts == "1"), - rollout_max_batch_size_per_instance=1024 + rollout_max_batch_size_per_instance=2048 ) # 3. judger diff --git a/examples/v1/config/rl_qwen25_7B_dapo_aysnc_filter.py b/examples/v1/config/rl_dapo_math_aysnc_filter.py similarity index 100% rename from examples/v1/config/rl_qwen25_7B_dapo_aysnc_filter.py rename to examples/v1/config/rl_dapo_math_aysnc_filter.py diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index e8009eb34..4aa93d91c 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -84,11 +84,18 @@ class RolloutState(CacheObj, BaseModel): sample_params: SampleParams = SampleParams() # --- InferEngine 输出 --- + # 每一次推理引擎的实际输出, 在rollout worker中被覆盖写 response: str | None = None response_ids: list[int] | None = None logprobs: list[float] | None = None routed_experts: list[int] | RayObjectRef | None = None finish_reason: str | None = None + # response_mask: 记录response_ids中哪个token算loss, 与response_ids长度相同,每轮rollout在 agent_loop.generate 中覆盖写 + response_mask: list[int] | None = None + # response_rollout_steps:记录 response_ids 中每个 token 是在哪个 rollout_step 生成的,与 response_ids 长度相同,每轮rollout在agent_loop中后处理中覆盖写 + response_rollout_steps: list[int] | None = None + # 记录该样本过期程度,即最先生成的token与当前的训练步数的差值,数值越大表示越过期,在 agent_loop 中后处理中覆盖写 + seq_staleness: int = 0 # --- Judger 输出 --- reward: dict[str, Any] | None = None @@ -98,13 +105,10 @@ class RolloutState(CacheObj, BaseModel): task_name: str | None = None status: Status = Status.INIT error_msg: str | None = None - seq_staleness: int = 0 - response_mask: list[int] | None = None # response_ids的长度 - response_rollout_steps: list[int] | None = None # 记录 response_ids 中每个 token 是在哪个 rollout_step 生成的 extra_fields: dict[str, Any] = {} @field_serializer("routed_experts") - def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | None: + def _serialize_routed_experts(self, value: torch.Tensor | RayObjectRef | None) -> list | None: """Dump 时跳过 ray.ObjectRef,序列化为 None,避免 PydanticSerializationError。""" if value is None: return None @@ -117,7 +121,7 @@ def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> l pass if type(value).__name__ == "ObjectRef" and "ray" in getattr(type(value), "__module__", ""): return None - return value # list[int] + return value.tolist() def update_status_from_finish_reason(finish_reason: str | None) -> Status: diff --git a/xtuner/v1/rl/agent_loop/utils.py b/xtuner/v1/rl/agent_loop/utils.py index b6054e043..4606bd889 100644 --- a/xtuner/v1/rl/agent_loop/utils.py +++ b/xtuner/v1/rl/agent_loop/utils.py @@ -1,3 +1,7 @@ +import time + +import ray + from xtuner.v1.data_proto import RolloutState, Status, update_seq_staleness from xtuner.v1.utils import get_logger @@ -5,6 +9,17 @@ logger = get_logger() +def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef | None) -> list[int] | None: + if routed_experts is None: + return None + if isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.get(routed_experts) + if hasattr(routed_experts, "tolist"): + routed_experts = routed_experts.tolist() + assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" + return routed_experts + + class PartialRolloutHandler: """Handle preprocessing and postprocessing for partial rollout continuation.""" @@ -47,6 +62,7 @@ def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool = "response": rollout_state.response or "", "logprobs": rollout_state.logprobs or [], "response_mask": rollout_state.response_mask or [], + "routed_experts": rollout_state.routed_experts, } return rollout_state @@ -63,5 +79,40 @@ def postprocess(self, rollout_state: RolloutState, rollout_step: int) -> Rollout rollout_state.response = history_dict.get("response", "") + (rollout_state.response or "") rollout_state.logprobs = history_dict.get("logprobs", []) + (rollout_state.logprobs or []) rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or []) + history_routed_experts_ref = history_dict.get("routed_experts") + cur_routed_experts_ref = rollout_state.routed_experts + if history_routed_experts_ref is not None and cur_routed_experts_ref is not None: + start_time = time.time() + history_routed_experts = _resolve_routed_experts(history_dict.get("routed_experts")) + cur_routed_experts = _resolve_routed_experts(rollout_state.routed_experts) + cur_routed_experts_len = len(cur_routed_experts) + history_routed_experts_len = len(history_routed_experts) + assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( + f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}" + ) + cur_routed_experts = cur_routed_experts[history_routed_experts_len:] + concat_routed_experts = history_routed_experts + cur_routed_experts + + prompt_ids = rollout_state.prompt_ids or [] + response_ids = rollout_state.response_ids or [] + expect_tokens_num = len(prompt_ids) + len(response_ids) - 1 + assert len(concat_routed_experts) == expect_tokens_num, ( + f"After concatenation, routed_experts len: {len(concat_routed_experts)}, expected tokens num: {expect_tokens_num}" + ) + logger.info( + f"[PartialRolloutHandler] Postprocess rollout {rollout_state.uid}: " + f"concat routed_experts len={len(concat_routed_experts)} " + f"(history={history_routed_experts_len}, new={cur_routed_experts_len}), " + f"prompt={len(prompt_ids)}, response={len(response_ids)}" + ) + rollout_state.routed_experts = ray.put(concat_routed_experts) + end_time = time.time() + logger.info( + f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" + ) + elif history_routed_experts_ref is None and cur_routed_experts_ref is not None: + rollout_state.routed_experts = cur_routed_experts_ref + elif history_routed_experts_ref is not None and cur_routed_experts_ref is None: + rollout_state.routed_experts = history_routed_experts_ref return rollout_state