Skip to content

Commit ab59105

Browse files
committed
log_samples take list of metirc
1 parent 2efca40 commit ab59105

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

apps/grpo/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
8989
"request_len": self.request_len,
9090
"response_len": self.response_len,
9191
"pad_id": self.pad_id,
92+
"ref_logprobs": self.ref_logprobs,
93+
"completion": self.completion,
9294
}
9395

94-
if self.reward_breakdown is not None:
96+
if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
9597
result.update(self.reward_breakdown)
9698

9799
if exclude:

src/forge/observability/metric_actors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
438438
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
439439
]
440440
sample_metrics = {
441-
m.key: m.value for m in reduced_metrics if m.reduction == Reduce.SAMPLE
441+
m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
442442
}
443443

444444
# Log to global backends

src/forge/observability/metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def record_episode_sample(table_name: str, episode):
195195
table_name (str): logging prefix (e.g. "rollout/sample").
196196
episode (Episode): episode object with filled attributes.
197197
"""
198-
sample = episode.to_dict()
198+
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
199199
record_metric(table_name, sample, Reduce.SAMPLE)
200200

201201

@@ -675,9 +675,7 @@ def push(self, metric: Metric) -> None:
675675
for backend in self.per_rank_no_reduce_backends:
676676

677677
if metric.reduction == Reduce.SAMPLE:
678-
# Wrap singleton Metric into expected {key: [list_of_dicts]} format
679-
sample = {metric.key: [metric.value]}
680-
asyncio.create_task(backend.log_samples(sample, self.global_step))
678+
asyncio.create_task(backend.log_samples([metric], self.global_step))
681679
else:
682680
backend.log_stream(metric=metric, global_step=self.global_step)
683681

@@ -883,11 +881,12 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
883881
async def finish(self) -> None:
884882
pass
885883

886-
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
884+
async def log_samples(self, samples: List[Metric], step: int) -> None:
887885
"""Pretty-print sample-level logs to console."""
888886

889887
logger.info(f"========== SAMPLE LOGS STEP {step} ==========")
890-
for table_name, table_rows in samples.items():
888+
for sample in samples:
889+
table_name, table_rows = sample.key, sample.value
891890
logger.info(f"[{table_name}] ({len(table_rows)} samples)")
892891
logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False))
893892
logger.info("==============================================\n")
@@ -1039,14 +1038,15 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
10391038
# note: here we dont use step since wandb keeps only the latest value for each step
10401039
self.run.log(log_data)
10411040

1042-
async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
1041+
async def log_samples(self, samples: List[Metric], step: int) -> None:
10431042
"""Log sample-level data incrementally to persistent WandB Tables."""
10441043
import wandb
10451044

10461045
if not self.run:
10471046
return
10481047

1049-
for table_name, table_rows in samples.items():
1048+
for sample in samples:
1049+
table_name, table_rows = sample.key, sample.value
10501050
if not table_rows:
10511051
continue
10521052

0 commit comments

Comments
 (0)