@@ -195,7 +195,7 @@ def record_episode_sample(table_name: str, episode):
195195 table_name (str): logging prefix (e.g. "rollout/sample").
196196 episode (Episode): episode object with filled attributes.
197197 """
198- sample = episode .to_dict ()
198+ sample = episode .to_dict (exclude = [ "ref_logprobs" , "completion" ] )
199199 record_metric (table_name , sample , Reduce .SAMPLE )
200200
201201
@@ -675,9 +675,7 @@ def push(self, metric: Metric) -> None:
675675 for backend in self .per_rank_no_reduce_backends :
676676
677677 if metric .reduction == Reduce .SAMPLE :
678- # Wrap singleton Metric into expected {key: [list_of_dicts]} format
679- sample = {metric .key : [metric .value ]}
680- asyncio .create_task (backend .log_samples (sample , self .global_step ))
678+ asyncio .create_task (backend .log_samples ([metric ], self .global_step ))
681679 else :
682680 backend .log_stream (metric = metric , global_step = self .global_step )
683681
@@ -883,11 +881,12 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
883881 async def finish (self ) -> None :
884882 pass
885883
886- async def log_samples (self , samples : Dict [ str , List [dict ] ], step : int ) -> None :
884+ async def log_samples (self , samples : List [Metric ], step : int ) -> None :
887885 """Pretty-print sample-level logs to console."""
888886
889887 logger .info (f"========== SAMPLE LOGS STEP { step } ==========" )
890- for table_name , table_rows in samples .items ():
888+ for sample in samples :
889+ table_name , table_rows = sample .key , sample .value
891890 logger .info (f"[{ table_name } ] ({ len (table_rows )} samples)" )
892891 logger .info (json .dumps (table_rows , indent = 2 , ensure_ascii = False ))
893892 logger .info ("==============================================\n " )
@@ -1039,14 +1038,15 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
10391038 # note: here we dont use step since wandb keeps only the latest value for each step
10401039 self .run .log (log_data )
10411040
1042- async def log_samples (self , samples : Dict [ str , List [dict ] ], step : int ) -> None :
1041+ async def log_samples (self , samples : List [Metric ], step : int ) -> None :
10431042 """Log sample-level data incrementally to persistent WandB Tables."""
10441043 import wandb
10451044
10461045 if not self .run :
10471046 return
10481047
1049- for table_name , table_rows in samples .items ():
1048+ for sample in samples :
1049+ table_name , table_rows = sample .key , sample .value
10501050 if not table_rows :
10511051 continue
10521052
0 commit comments