-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Refactor(lerobot_train) Abstract Sample Weighting from RABC-Specific Implementation
#2781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…place it with a generic samplerweight class in lerobot_train
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors RABC (Reward-Aligned Behavior Cloning) specific training code into a generic sample weighting abstraction. The refactoring improves code organization by introducing a protocol-based design pattern that allows different sample weighting strategies to be plugged into the training pipeline without modifying the core training logic.
Changes:
- Introduced a generic
SampleWeighterprotocol and factory function in a newsample_weighting.pymodule - Refactored
lerobot_train.pyto use the generic sample weighting interface instead of RABC-specific code - Updated configuration from RABC-specific fields to a more flexible
SampleWeightingConfig
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| src/lerobot/utils/sample_weighting.py | New module providing generic sample weighting abstraction with Protocol, config dataclass, factory function, and UniformWeighter implementation |
| src/lerobot/scripts/lerobot_train.py | Refactored to use generic sample_weighter instead of rabc_weights_provider; updated parameter names, logging, and added defensive null checks |
| src/lerobot/policies/sarm/rabc.py | Improved documentation, added type hints, cleaned up code (removed shebang), renamed variables for clarity, and updated to work with generic interface |
| src/lerobot/configs/train.py | Replaced RABC-specific configuration fields with generic sample_weighting field using SampleWeightingConfig |
Comments suppressed due to low confidence (1)
src/lerobot/policies/sarm/rabc.py:28
- The arXiv URL format in the documentation appears incorrect. arXiv URLs typically follow the format https://arxiv.org/abs/YYMM.NNNNN where YYMM represents year and month, and NNNNN is a sequence number. The URL "https://arxiv.org/abs/2509.25358" suggests a paper from September 2025 with ID 25358, which seems unusual. Please verify this is the correct arXiv identifier for the SARM paper. If this is a placeholder, consider using a comment to indicate it should be updated once the paper is published.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/lerobot/scripts/lerobot_train.py
Outdated
|
|
||
| train_metrics.loss = loss.item() | ||
| train_metrics.grad_norm = grad_norm.item() | ||
| train_metrics.grad_norm = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) |
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grad_norm handling change adds a hasattr check for the "item" method. While this defensive coding handles edge cases, it's worth noting that both accelerator.clip_grad_norm_ and torch.nn.utils.clip_grad_norm_ should always return a tensor (or float in older PyTorch versions). This might be unnecessary defensive code that masks underlying issues. Consider adding a comment explaining the specific scenario where grad_norm might not have an "item" method, or verify if this is truly needed.
| train_metrics.grad_norm = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) | |
| train_metrics.grad_norm = float(grad_norm) |
src/lerobot/scripts/lerobot_train.py
Outdated
| if sample_weights is not None: | ||
| # Use per-sample loss for weighted training | ||
| # Note: Policies supporting sample weighting must implement forward(batch, reduction="none") | ||
| per_sample_loss, output_dict = policy.forward(batch, reduction="none") # type: ignore[call-arg] |
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on line 103 states "Policies supporting sample weighting must implement forward(batch, reduction='none')" but there's no validation or clear error message if a policy doesn't support this. When a policy that doesn't support the reduction parameter is used with sample weighting, the type: ignore comment will suppress the type error, but the runtime error might be unclear. Consider catching the TypeError and providing a more helpful error message explaining which policies support sample weighting, or document the required interface more explicitly.
| per_sample_loss, output_dict = policy.forward(batch, reduction="none") # type: ignore[call-arg] | |
| try: | |
| per_sample_loss, output_dict = policy.forward( # type: ignore[call-arg] | |
| batch, | |
| reduction="none", | |
| ) | |
| except TypeError as e: | |
| raise TypeError( | |
| "Sample weighting is enabled, but the current policy's `forward` method does not " | |
| "accept a `reduction` keyword argument. Policies supporting sample weighting must " | |
| "implement `forward(batch, reduction: str = 'none')`. Disable sample weighting or " | |
| "use a policy that implements this interface." | |
| ) from e |
lerobot_train) Abstract Sample Weighting from RABC-Specific Implementation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Type / Scope
Summary / Motivation
The training script (
lerobot_train.py) was filled with RABC specific code in #2639. The training script should be as policy-agnostic as possible. This PR introduces a genericSampleWeighterabstraction that decouples sample loss weighting logic from the training script.Sample weighting is a general training technique where individual samples are assigned weights to influence their contribution to the loss. This enables techniques like RA-BC, importance sampling, curriculum learning, and quality-based filtering. By abstracting this, the training script remains clean while still supporting these advanced training strategies.
What changed
src/lerobot/configs/train.py: Removed RABC-specific config parameters (use_rabc,rabc_progress_path,rabc_kappa,rabc_epsilon,rabc_head_mode) and replaced with a single genericsample_weighting: SampleWeightingConfig | Nonefield.src/lerobot/scripts/lerobot_train.py:rabc_weights_providerparameter tosample_weighterSampleWeighterinterface callssrc/lerobot/utils/sample_weighting.py(NEW):SampleWeighterABC defining the interface for sample weighting strategiesSampleWeightingConfigdataclass for generic configurationmake_sample_weighter()factory function that keeps policy-specific initialization out of the training scriptUniformWeighteras a no-op baseline implementationsrc/lerobot/policies/sarm/rabc.py:RABCWeightsnow inherits fromSampleWeighterABCraw_mean_weight→mean_weight)tests/utils/test_sample_weighting.py(NEW): Comprehensive test coverage for the sample weighting infrastructure including config tests, UniformWeighter tests, factory function tests, and integration tests with RABCWeights.How was this tested
Tests added:
tests/utils/test_sample_weighting.pywith 18 test cases covering:SampleWeightingConfiginitialization and validationUniformWeighterfunctionality (batch size determination, device placement)make_sample_weighterfactory behavior (error handling, type dispatch)RABCWeightsManual checks: Verified the refactored code maintains the same training behavior
How to run locally (reviewer)
Run the new tests:
Example config usage (for RA-BC training):
Checklist (required before merge)
pre-commit run -a)pytest)Reviewer notes
SampleWeightingConfigcontains some RABC-specific fields (progress_path,head_mode,kappa) for convenience, but theextra_paramsdict allows future weighting strategies to pass custom parameters.UniformWeighteris intentionally simple and serves as a baseline/no-op option.use_rabc,rabc_*) are removed. Users must migrate to the newsample_weightingconfig structure.Migration Guide
Before (old config):
After (new config):