diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 5d37bf8e0..398e0572d 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -546,7 +546,7 @@ def _rollout_step(self, rollout_idx: int, step_timer_dict: dict) -> RolloutInfo: rollout_info: RolloutInfo = { "data_groups": dataflow_result["data_groups"], - "multimodal_train_infos": dataflow_result.get("multimodal_train_infos", None), + "multimodal_train_infos": dataflow_result.get("mm_train_infos", None), "task_time": dataflow_result.get("metrics", {}), "replay_buffer_info": ray.get(self._rollout_dataflow.get_replaybuffer_status.remote()), }