Conversation
|
@claude review |
xtuner/v1/rl/agent_loop/utils.py
Outdated
| "response": rollout_state.response or "", | ||
| "logprobs": rollout_state.logprobs or [], | ||
| "response_mask": rollout_state.response_mask or [], | ||
| "routed_experts": rollout_state.routed_experts or None, |
There was a problem hiding this comment.
Claude: Bug (Critical): rollout_state.routed_experts or None will crash with RuntimeError: Boolean value of Tensor with more than one element is ambiguous when routed_experts is a multi-element torch.Tensor.
Since the type of routed_experts was changed to torch.Tensor | RayObjectRef | None, the or operator will attempt to evaluate the truthiness of the tensor, which is undefined for multi-element tensors in PyTorch.
| "routed_experts": rollout_state.routed_experts or None, | |
| "routed_experts": rollout_state.routed_experts if rollout_state.routed_experts is not None else None, |
Or more simply: "routed_experts": rollout_state.routed_experts, — since None is already a valid value.
xtuner/v1/rl/agent_loop/utils.py
Outdated
| cur_routed_experts_shape = cur_routed_experts.shape | ||
| history_routed_experts_len = history_routed_experts.shape[0] | ||
| assert history_routed_experts_len - 1 <= cur_routed_experts_shape[0], ( | ||
| f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape}" | ||
| ) |
There was a problem hiding this comment.
Claude: Warning: The assertion history_routed_experts_len - 1 <= cur_routed_experts_shape[0] allows history_len to be up to cur_shape[0] + 1, which means the slice cur_routed_experts[history_routed_experts_len:, :, :] could yield an empty tensor. The - 1 in the assertion is unexplained and could mask off-by-one errors.
Consider either:
- Adding a comment explaining why
- 1is correct (what scenario causes history to be one token longer than current?), or - Using a stricter assertion:
assert history_routed_experts_len <= cur_routed_experts_shape[0]
xtuner/v1/rl/agent_loop/utils.py
Outdated
| expect_tokens_num = len(prompt_ids) + len(response_ids) - 1 | ||
| assert concat_routed_experts.shape[0] == expect_tokens_num, ( | ||
| f"After concatenation, routed_experts shape: {concat_routed_experts.shape}, expected tokens num: {expect_tokens_num}" | ||
| ) | ||
| logger.info( | ||
| f"[PartialRolloutHandler] After postprocess, rollout {rollout_state.uid} routed_experts shape: {concat_routed_experts.shape}, prompt len: {len(prompt_ids)}, response len: {len(response_ids)}, total tokens: {len(tokens)}, history routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape[0]}, concatenated routed_experts shape: {concat_routed_experts.shape}" |
There was a problem hiding this comment.
Claude: Nit: This log message is very verbose and contains duplicate information — concat_routed_experts.shape is logged twice (as both "routed_experts shape" and "concatenated routed_experts shape"). Consider trimming it:
logger.info(
f"[PartialRolloutHandler] Postprocess rollout {rollout_state.uid}: "
f"concat routed_experts {concat_routed_experts.shape} "
f"(history={history_routed_experts.shape[0]}, new={cur_routed_experts_shape[0]}), "
f"prompt={len(prompt_ids)}, response={len(response_ids)}"
)| rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or []) | ||
| history_routed_experts = _resolve_routed_experts(history_dict.get("routed_experts")) | ||
| cur_routed_experts = _resolve_routed_experts(rollout_state.routed_experts) | ||
| start_time = time.time() | ||
| if history_routed_experts is not None and cur_routed_experts is not None: | ||
| cur_routed_experts_shape = cur_routed_experts.shape | ||
| history_routed_experts_len = history_routed_experts.shape[0] | ||
| assert history_routed_experts_len - 1 <= cur_routed_experts_shape[0], ( | ||
| f"Existing routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape}" | ||
| ) | ||
| cur_routed_experts = cur_routed_experts[history_routed_experts_len:, :, :] | ||
| concat_routed_experts = torch.cat([history_routed_experts, cur_routed_experts], dim=0) | ||
|
|
||
| prompt_ids = rollout_state.prompt_ids or [] | ||
| response_ids = rollout_state.response_ids or [] | ||
| tokens = rollout_state.tokens or [] | ||
| expect_tokens_num = len(prompt_ids) + len(response_ids) - 1 | ||
| assert concat_routed_experts.shape[0] == expect_tokens_num, ( | ||
| f"After concatenation, routed_experts shape: {concat_routed_experts.shape}, expected tokens num: {expect_tokens_num}" | ||
| ) | ||
| logger.info( | ||
| f"[PartialRolloutHandler] After postprocess, rollout {rollout_state.uid} routed_experts shape: {concat_routed_experts.shape}, prompt len: {len(prompt_ids)}, response len: {len(response_ids)}, total tokens: {len(tokens)}, history routed_experts shape: {history_routed_experts.shape}, current routed_experts shape: {cur_routed_experts_shape[0]}, concatenated routed_experts shape: {concat_routed_experts.shape}" | ||
| ) | ||
| rollout_state.routed_experts = concat_routed_experts | ||
| elif history_routed_experts is None and cur_routed_experts is not None: | ||
| rollout_state.routed_experts = cur_routed_experts | ||
| elif history_routed_experts is not None and cur_routed_experts is None: | ||
| rollout_state.routed_experts = history_routed_experts | ||
| end_time = time.time() | ||
| logger.info( | ||
| f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" | ||
| ) | ||
| return rollout_state |
There was a problem hiding this comment.
Claude: Nit: start_time is set before the conditionals, but all three branches (both-not-None, one-None, both-None) are timed and logged. When both are None, you're timing and logging nothing useful. Consider moving the timing and logging inside the if block where concatenation actually happens, or at least gating the log message:
if history_routed_experts is not None or cur_routed_experts is not None:
logger.info(...)|
Claude: ## Review Summary This PR adds support for R3 (partial rollout continuation) for MoE models by handling IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES — the |
No description provided.