Skip to content

Commit 981e845

Browse files
committed
debug
1 parent e535fb5 commit 981e845

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

src/forge/observability/metrics.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,25 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
141141
list[Metric]: List of reduced metrics
142142
143143
Example:
144-
states = [
145-
{"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
146-
{"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
144+
>>> states = [
145+
... {
146+
... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
147+
... "reward/sample": {
148+
... "reduction_type": "sample",
149+
... "samples": [{"episode_id": 1, "reward": 0.5}],
150+
... },
151+
... },
152+
... {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
153+
... ]
154+
>>> reduce_metrics_states(states)
155+
[
156+
Metric(key='loss', value=2.0, reduction=Reduce.MEAN),
157+
Metric(
158+
key='reward/sample',
159+
value=[{'episode_id': 1, 'reward': 0.5}],
160+
reduction=Reduce.SAMPLE,
161+
)
147162
]
148-
reduce_metrics_states(states)
149-
>>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)]
150163
151164
Raises:
152165
ValueError: on mismatched reduction types for the same metric key.
@@ -649,7 +662,6 @@ def push(self, metric: Metric) -> None:
649662

650663
# For PER_RANK_NO_REDUCE backends: stream without reduce
651664
for backend in self.per_rank_no_reduce_backends:
652-
653665
if metric.reduction == Reduce.SAMPLE:
654666
asyncio.create_task(backend.log_samples([metric], self.global_step))
655667
else:
@@ -712,11 +724,9 @@ async def flush(
712724
scalar_metrics = [
713725
m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE
714726
]
715-
sample_metrics = {
716-
m.key: m.value
717-
for m in metrics_for_backends
718-
if m.reduction == Reduce.SAMPLE
719-
}
727+
sample_metrics = [
728+
m for m in metrics_for_backends if m.reduction == Reduce.SAMPLE
729+
]
720730

721731
for backend in self.per_rank_reduce_backends:
722732
if scalar_metrics:
@@ -1026,6 +1036,10 @@ async def log_samples(self, samples: List[Metric], step: int) -> None:
10261036
if not table_rows:
10271037
continue
10281038

1039+
# Convert to list if single sample. This happens when logging stream
1040+
if isinstance(table_rows, dict):
1041+
table_rows = [table_rows]
1042+
10291043
# If table doesn't exist yet, create it in INCREMENTAL mode
10301044
if table_name not in self._tables:
10311045
# Collect all unique columns from all rows

0 commit comments

Comments
 (0)