Skip to content

Clean S* TODOs#116

Merged
DachengLi1 merged 1 commit intomainfrom
S_clean
May 18, 2025
Merged

Clean S* TODOs#116
DachengLi1 merged 1 commit intomainfrom
S_clean

Conversation

@DachengLi1
Copy link
Collaborator

No description provided.

@DachengLi1 DachengLi1 merged commit cf3b941 into main May 18, 2025
1 check passed
yllkryeziu pushed a commit to yllkryeziu/adaptive-compute-rewrite that referenced this pull request Dec 12, 2025
lru0612 pushed a commit to lru0612/SkyThought that referenced this pull request Feb 24, 2026
…+ refactor adv estimator registry to allow registration outside ray workers (#126)

# Overview
- Adds support for registering custom policy loss functions, similar to
NovaSky-AI#115,
- Refactors the policy loss to be a function in `ppo_utils.py` instead
of a (`nn.Module` in `worker.py`)
- Introduces a breaking change in renaming
`trainer.algorithm.ppo_loss_type` to
`trainer.algorithm.policy_loss_type`
- Addresses Issue NovaSky-AI#116 by creating a new `BaseFunctionRegistry` class
that uses a [named
actor](https://docs.ray.io/en/latest/ray-core/actors/named-actors.html)
to support the following pattern:

```python
# Example of custom policy loss: "simple_baseline"
def compute_simple_baseline_policy_loss(
    log_probs: torch.Tensor,
    ...
):
    return torch.randn(1, device=log_probs.device), 0.0

# Register the custom policy loss - outside of the ray worker
PolicyLossRegistry.register("simple_baseline", compute_simple_baseline_policy_loss)


@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
    exp = BasePPOExp(cfg)
    exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
    # validate the arguments
    validate_cfg(cfg)

    initialize_ray(cfg)

    ray.get(skyrl_entrypoint.remote(cfg))
```
this change was necessary for `PolicyLossRegistry` to be accessible,
since the worker `actor_loss_fn` attribute is set in `init_model` within
the `worker` actor, which is a ray actor created from within the
skyrl_entrypoint ray task (and registering within the entrypoint
wouldn't propagate down another layer).
- updates AdvantageEstimatorRegistry to extend the same
`BaseFunctionRegistry` class


Example runs:
Custom advantage (mean of reward)
<img width="956" height="326" alt="image"
src="https://github.com/user-attachments/assets/1b7222bc-fbb9-49b1-876d-265b71201087"
/>

Custom policy loss (reinforce - just (-logprobs * advantages).mean())
<img width="939" height="330" alt="image"
src="https://github.com/user-attachments/assets/cbed7ef5-b3e7-4e32-beba-b52b80879f47"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant