Skip to content

support r3 for moe model#1605

Open
YanhuiDua wants to merge 3 commits intoInternLM:rl_designfrom
YanhuiDua:support_moe_r3
Open

support r3 for moe model#1605
YanhuiDua wants to merge 3 commits intoInternLM:rl_designfrom
YanhuiDua:support_moe_r3

Conversation

@YanhuiDua
Copy link
Collaborator

No description provided.

@YanhuiDua
Copy link
Collaborator Author

@claude review

"response": rollout_state.response or "",
"logprobs": rollout_state.logprobs or [],
"response_mask": rollout_state.response_mask or [],
"routed_experts": rollout_state.routed_experts or None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Bug (Critical): rollout_state.routed_experts or None will crash with RuntimeError: Boolean value of Tensor with more than one element is ambiguous when routed_experts is a multi-element torch.Tensor.

Since the type of routed_experts was changed to torch.Tensor | RayObjectRef | None, the or operator will attempt to evaluate the truthiness of the tensor, which is undefined for multi-element tensors in PyTorch.

Suggested change
"routed_experts": rollout_state.routed_experts or None,
"routed_experts": rollout_state.routed_experts if rollout_state.routed_experts is not None else None,

Or more simply: "routed_experts": rollout_state.routed_experts, — since None is already a valid value.

Comment on lines +88 to +92
cur_routed_experts_shape = cur_routed_experts.shape
history_routed_experts_len = history_routed_experts.shape[0]
assert history_routed_experts_len - 1 <= cur_routed_experts_shape[0], (
f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape}"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: The assertion history_routed_experts_len - 1 <= cur_routed_experts_shape[0] allows history_len to be up to cur_shape[0] + 1, which means the slice cur_routed_experts[history_routed_experts_len:, :, :] could yield an empty tensor. The - 1 in the assertion is unexplained and could mask off-by-one errors.

Consider either:

  1. Adding a comment explaining why - 1 is correct (what scenario causes history to be one token longer than current?), or
  2. Using a stricter assertion: assert history_routed_experts_len <= cur_routed_experts_shape[0]

Comment on lines +99 to +104
expect_tokens_num = len(prompt_ids) + len(response_ids) - 1
assert concat_routed_experts.shape[0] == expect_tokens_num, (
f"After concatenation, routed_experts shape: {concat_routed_experts.shape}, expected tokens num: {expect_tokens_num}"
)
logger.info(
f"[PartialRolloutHandler] After postprocess, rollout {rollout_state.uid} routed_experts shape: {concat_routed_experts.shape}, prompt len: {len(prompt_ids)}, response len: {len(response_ids)}, total tokens: {len(tokens)}, history routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape[0]}, concatenated routed_experts shape: {concat_routed_experts.shape}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: This log message is very verbose and contains duplicate information — concat_routed_experts.shape is logged twice (as both "routed_experts shape" and "concatenated routed_experts shape"). Consider trimming it:

logger.info(
    f"[PartialRolloutHandler] Postprocess rollout {rollout_state.uid}: "
    f"concat routed_experts {concat_routed_experts.shape} "
    f"(history={history_routed_experts.shape[0]}, new={cur_routed_experts_shape[0]}), "
    f"prompt={len(prompt_ids)}, response={len(response_ids)}"
)

Comment on lines 83 to 115
rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or [])
history_routed_experts = _resolve_routed_experts(history_dict.get("routed_experts"))
cur_routed_experts = _resolve_routed_experts(rollout_state.routed_experts)
start_time = time.time()
if history_routed_experts is not None and cur_routed_experts is not None:
cur_routed_experts_shape = cur_routed_experts.shape
history_routed_experts_len = history_routed_experts.shape[0]
assert history_routed_experts_len - 1 <= cur_routed_experts_shape[0], (
f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape}"
)
cur_routed_experts = cur_routed_experts[history_routed_experts_len:, :, :]
concat_routed_experts = torch.cat([history_routed_experts, cur_routed_experts], dim=0)

prompt_ids = rollout_state.prompt_ids or []
response_ids = rollout_state.response_ids or []
tokens = rollout_state.tokens or []
expect_tokens_num = len(prompt_ids) + len(response_ids) - 1
assert concat_routed_experts.shape[0] == expect_tokens_num, (
f"After concatenation, routed_experts shape: {concat_routed_experts.shape}, expected tokens num: {expect_tokens_num}"
)
logger.info(
f"[PartialRolloutHandler] After postprocess, rollout {rollout_state.uid} routed_experts shape: {concat_routed_experts.shape}, prompt len: {len(prompt_ids)}, response len: {len(response_ids)}, total tokens: {len(tokens)}, history routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape[0]}, concatenated routed_experts shape: {concat_routed_experts.shape}"
)
rollout_state.routed_experts = concat_routed_experts
elif history_routed_experts is None and cur_routed_experts is not None:
rollout_state.routed_experts = cur_routed_experts
elif history_routed_experts is not None and cur_routed_experts is None:
rollout_state.routed_experts = history_routed_experts
end_time = time.time()
logger.info(
f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds"
)
return rollout_state
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: start_time is set before the conditionals, but all three branches (both-not-None, one-None, both-None) are timed and logged. When both are None, you're timing and logging nothing useful. Consider moving the timing and logging inside the if block where concatenation actually happens, or at least gating the log message:

if history_routed_experts is not None or cur_routed_experts is not None:
    logger.info(...)

@claude
Copy link

claude bot commented Mar 20, 2026

Claude: ## Review Summary

This PR adds support for R3 (partial rollout continuation) for MoE models by handling routed_experts tensor concatenation across rollout steps, and renames config files to be model-agnostic.

Issues

Critical

  • utils.py:67: routed_experts or None will crash with RuntimeError when routed_experts is a multi-element torch.Tensor — use explicit is not None check instead.

Warning

  • utils.py:88-92: The - 1 in the assertion history_len - 1 <= cur_shape[0] is unexplained and could mask off-by-one errors. Needs a comment or stricter check.

Nit

  • Renamed file rl_dapo_math_aysnc_filter.py still has the "aysnc" typo — good opportunity to fix it to "async".
  • Log message at utils.py:99-104 logs concat_routed_experts.shape twice redundantly.
  • Timing/logging at utils.py:83-115 runs even when both expert tensors are None.

Verdict

REQUEST_CHANGES — the or None tensor truthiness bug needs to be fixed before merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant