Skip to content
51 changes: 46 additions & 5 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from forge.data.rewards import MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.metrics import record_episode_sample, record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
Expand All @@ -47,10 +47,13 @@ class Episode:
request_len: int
response_len: int
target: Any | None = None
request: str | None = None
response: str | None = None
# Processed data
completion: Completion | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
reward_breakdown: dict[str, float] | None = None
advantage: float | None = None

@property
Expand All @@ -73,6 +76,32 @@ def response_tensor(self) -> torch.Tensor:
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor

def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
"""Convert episode to dict, optionally excluding specified fields."""
result = {
"episode_id": self.episode_id,
"policy_version": self.policy_version,
"prompt": self.request,
"response": self.response,
"target": str(self.target),
"reward": self.reward,
"advantage": self.advantage,
"request_len": self.request_len,
"response_len": self.response_len,
"pad_id": self.pad_id,
"ref_logprobs": self.ref_logprobs,
"completion": self.completion,
}

if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
result.update(self.reward_breakdown)

if exclude:
for key in exclude:
result.pop(key, None)

return result


# Represents the group (G) of episodes in GRPO
Group = list[Episode]
Expand Down Expand Up @@ -144,8 +173,11 @@ class RewardActor(ForgeActor):
reward_functions: list[Callable]

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
async def evaluate_response(
self, prompt: str, response: str, target: str
) -> (dict[str, float], float):
total_rewards = 0.0
reward_breakdown = {} # reward breakdown by function
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_rewards += reward
Expand All @@ -154,6 +186,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
reward_fn_name = getattr(
reward_fn, "__name__", reward_fn.__class__.__name__
)
reward_breakdown[reward_fn_name] = reward
# per function reward
record_metric(
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
Expand Down Expand Up @@ -183,8 +216,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
Reduce.SUM,
)

avg_reward = total_rewards / len(self.reward_functions)
return avg_reward
avg_reward: float = total_rewards / len(self.reward_functions)
return reward_breakdown, avg_reward


@dataclass
Expand Down Expand Up @@ -381,9 +414,14 @@ async def continuous_rollouts():
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
request=prompt,
response=response.text,
completion=response,
)
episode.reward = await reward_actor.evaluate_response.route(
(
episode.reward_breakdown,
episode.reward,
) = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)
episodes.append(episode)
Expand All @@ -407,6 +445,9 @@ async def continuous_rollouts():
for episode, advantage in zip(episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.call_one(episode)
record_episode_sample(
"main_samples/continuous_rollouts/sample_table", episode
)

rollout_count += 1
record_metric(
Expand Down
4 changes: 4 additions & 0 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
MetricAccumulator,
MetricCollector,
MinAccumulator,
record_episode_sample,
record_metric,
Reduce,
reduce_metrics_states,
SampleAccumulator,
StdAccumulator,
SumAccumulator,
WandbBackend,
Expand All @@ -35,6 +37,7 @@
# Main API functions
"record_metric",
"reduce_metrics_states",
"record_episode_sample",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
Expand Down Expand Up @@ -64,4 +67,5 @@
"MaxAccumulator",
"MinAccumulator",
"StdAccumulator",
"SampleAccumulator",
]
14 changes: 13 additions & 1 deletion src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LoggerBackend,
LoggingMode,
MetricCollector,
Reduce,
reduce_metrics_states,
)

Expand Down Expand Up @@ -432,9 +433,20 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
# Reduce metrics from states
reduced_metrics = reduce_metrics_states(all_local_states)

# Split into scalar metrics and sample metrics
scalar_metrics = [
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
]
sample_metrics = [
m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
]

# Log to global backends
for backend_name, backend in self.global_logger_backends.items():
await backend.log_batch(reduced_metrics, global_step)
if scalar_metrics:
await backend.log_batch(scalar_metrics, global_step)
if sample_metrics:
await backend.log_samples(sample_metrics, global_step)

@endpoint
async def has_fetcher(self, proc_id: str) -> bool:
Expand Down
Loading
Loading