Skip to content

Add RLHF guide and dummy demo with Keras/JAX #2117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

TrailChai
Copy link

This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF).

It includes:

  • `examples/rl/rlhf_dummy_demo.py`: A Python script demonstrating a simple RLHF loop with a dummy environment, a policy model, and a reward model, using Keras with the JAX backend.
  • `examples/rl/md/rlhf_dummy_demo.md`: A Markdown guide explaining the RLHF concept and the implementation details of the demo script.
  • `examples/rl/README.md`: A new README for the RL examples section, now including the RLHF demo.

Note: The Python demo script (`rlhf_dummy_demo.py`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.

This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF).

It includes:
- \`examples/rl/rlhf_dummy_demo.py\`: A Python script demonstrating a simple RLHF loop with a dummy environment, a policy model, and a reward model, using Keras with the JAX backend.
- \`examples/rl/md/rlhf_dummy_demo.md\`: A Markdown guide explaining the RLHF concept and the implementation details of the demo script.
- \`examples/rl/README.md\`: A new README for the RL examples section, now including the RLHF demo.

Note: The Python demo script (\`rlhf_dummy_demo.py\`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.
Copy link
Contributor

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Yasir! just keep the .py file for now. once it is approved we can generate the .ipynb files

policy_model_params["non_trainable"],
state_input
)
actual_predictions_tensor = predictions_tuple[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we assuming batch size is 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The code as written assumes a batch size of 1 for all model inputs and gradient calculations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a simplification for the demo's clarity and to manage complexity, especially since REINFORCE-style updates can be done on single trajectories. A more advanced setup would definitely use batching across multiple episodes or from a replay buffer for stability and efficiency.

Does that make sense in the context of this simplified demo?

episode_policy_losses.append(current_policy_loss)
policy_grads_step = policy_grads_dict_step["trainable"]
# Accumulate policy gradients
for i, grad in enumerate(policy_grads_step):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For potentially improving performance
policy_grads_accum = jax.tree_map(lambda acc, new: acc + new if new is not None else acc, policy_grads_accum, policy_grads_step)

Copy link
Author

@TrailChai TrailChai Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have refactored policy gradients and reward gradients accumulations using jax.tree_map.

actual_predictions_tensor = predictions_tuple[0]
action_probs = actual_predictions_tensor[0]
log_prob = jnp.log(action_probs[action] + 1e-7)
return -log_prob * predicted_reward_value_stopped
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this predicted_reward_value is just R(s,a), then it's using the immediate predicted reward.
it's a very naive and generally ineffective way to train a policy. Thee log_prob should be multiplied by the cumulative discounted future reward (Return, G_t).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit c212593 addresses this.

TrailChai and others added 7 commits June 3, 2025 11:35
…rewards.

This commit refines the RLHF demo example (`examples/rl/rlhf_dummy_demo.py`)
to use discounted cumulative actual rewards (G_t) for policy gradient
calculations, aligning it with the REINFORCE algorithm.

Changes include:
- Added a `calculate_discounted_returns` helper function.
- Modified the `rlhf_training_loop` to collect trajectories (states,
  actions, rewards) and compute G_t for each step at the end of an episode.
- Updated the policy loss function to use these G_t values instead of
  immediate predicted rewards.
- The reward model training logic remains focused on predicting immediate
  rewards based on simulated human feedback (environment reward in this demo).
- Updated the corresponding RLHF guide (`examples/rl/md/rlhf_dummy_demo.md`)
  to explain these changes and provide updated code snippets.

The timeout issues with the script in the development environment persist,
but the code now better reflects a standard policy gradient approach.
@divyashreepathihalli
Copy link
Contributor

the .md files are automatically generated. SO you might want to move the explanation content part to .py

TrailChai added 2 commits June 6, 2025 23:06
Moving the explanation piece to .py from .md
@TrailChai
Copy link
Author

I have deleted the .md files and added relevant documentation pieces to .py file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants