-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add RLHF guide and dummy demo with Keras/JAX #2117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF). It includes: - \`examples/rl/rlhf_dummy_demo.py\`: A Python script demonstrating a simple RLHF loop with a dummy environment, a policy model, and a reward model, using Keras with the JAX backend. - \`examples/rl/md/rlhf_dummy_demo.md\`: A Markdown guide explaining the RLHF concept and the implementation details of the demo script. - \`examples/rl/README.md\`: A new README for the RL examples section, now including the RLHF demo. Note: The Python demo script (\`rlhf_dummy_demo.py\`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Yasir! just keep the .py file for now. once it is approved we can generate the .ipynb files
examples/rl/rlhf_dummy_demo.py
Outdated
policy_model_params["non_trainable"], | ||
state_input | ||
) | ||
actual_predictions_tensor = predictions_tuple[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we assuming batch size is 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The code as written assumes a batch size of 1 for all model inputs and gradient calculations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a simplification for the demo's clarity and to manage complexity, especially since REINFORCE-style updates can be done on single trajectories. A more advanced setup would definitely use batching across multiple episodes or from a replay buffer for stability and efficiency.
Does that make sense in the context of this simplified demo?
examples/rl/rlhf_dummy_demo.py
Outdated
episode_policy_losses.append(current_policy_loss) | ||
policy_grads_step = policy_grads_dict_step["trainable"] | ||
# Accumulate policy gradients | ||
for i, grad in enumerate(policy_grads_step): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For potentially improving performance
policy_grads_accum = jax.tree_map(lambda acc, new: acc + new if new is not None else acc, policy_grads_accum, policy_grads_step)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have refactored policy gradients and reward gradients accumulations using jax.tree_map
.
examples/rl/md/rlhf_dummy_demo.md
Outdated
actual_predictions_tensor = predictions_tuple[0] | ||
action_probs = actual_predictions_tensor[0] | ||
log_prob = jnp.log(action_probs[action] + 1e-7) | ||
return -log_prob * predicted_reward_value_stopped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this predicted_reward_value is just R(s,a), then it's using the immediate predicted reward.
it's a very naive and generally ineffective way to train a policy. Thee log_prob should be multiplied by the cumulative discounted future reward (Return, G_t).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commit c212593 addresses this.
…rewards. This commit refines the RLHF demo example (`examples/rl/rlhf_dummy_demo.py`) to use discounted cumulative actual rewards (G_t) for policy gradient calculations, aligning it with the REINFORCE algorithm. Changes include: - Added a `calculate_discounted_returns` helper function. - Modified the `rlhf_training_loop` to collect trajectories (states, actions, rewards) and compute G_t for each step at the end of an episode. - Updated the policy loss function to use these G_t values instead of immediate predicted rewards. - The reward model training logic remains focused on predicting immediate rewards based on simulated human feedback (environment reward in this demo). - Updated the corresponding RLHF guide (`examples/rl/md/rlhf_dummy_demo.md`) to explain these changes and provide updated code snippets. The timeout issues with the script in the development environment persist, but the code now better reflects a standard policy gradient approach.
the .md files are automatically generated. SO you might want to move the explanation content part to .py |
Moving the explanation piece to .py from .md
I have deleted the .md files and added relevant documentation pieces to .py file |
This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF).
It includes:
Note: The Python demo script (`rlhf_dummy_demo.py`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.