Skip to content

Conversation

@michel-aractingi
Copy link
Collaborator

@michel-aractingi michel-aractingi commented Jan 12, 2026

Type / Scope

  • Type: Refactor
  • Scope: train, utils, policies/sarm

Summary / Motivation

The training script (lerobot_train.py) was filled with RABC specific code in #2639. The training script should be as policy-agnostic as possible. This PR introduces a generic SampleWeighter abstraction that decouples sample loss weighting logic from the training script.

Sample weighting is a general training technique where individual samples are assigned weights to influence their contribution to the loss. This enables techniques like RA-BC, importance sampling, curriculum learning, and quality-based filtering. By abstracting this, the training script remains clean while still supporting these advanced training strategies.

What changed

  • src/lerobot/configs/train.py: Removed RABC-specific config parameters (use_rabc, rabc_progress_path, rabc_kappa, rabc_epsilon, rabc_head_mode) and replaced with a single generic sample_weighting: SampleWeightingConfig | None field.

  • src/lerobot/scripts/lerobot_train.py:

    • Renamed rabc_weights_provider parameter to sample_weighter
    • Replaced RABC-specific weight computation and logging with generic SampleWeighter interface calls
    • Removed policy-specific initialization logic (now handled by factory function)
  • src/lerobot/utils/sample_weighting.py (NEW):

    • SampleWeighter ABC defining the interface for sample weighting strategies
    • SampleWeightingConfig dataclass for generic configuration
    • make_sample_weighter() factory function that keeps policy-specific initialization out of the training script
    • UniformWeighter as a no-op baseline implementation
  • src/lerobot/policies/sarm/rabc.py:

    • RABCWeights now inherits from SampleWeighter ABC
    • Added proper type hints throughout
    • Updated docstrings to reference the sample weighting infrastructure
    • Standardized stats keys (raw_mean_weightmean_weight)
  • tests/utils/test_sample_weighting.py (NEW): Comprehensive test coverage for the sample weighting infrastructure including config tests, UniformWeighter tests, factory function tests, and integration tests with RABCWeights.

How was this tested

  • Tests added: tests/utils/test_sample_weighting.py with 18 test cases covering:

    • SampleWeightingConfig initialization and validation
    • UniformWeighter functionality (batch size determination, device placement)
    • make_sample_weighter factory behavior (error handling, type dispatch)
    • Integration with RABCWeights
  • Manual checks: Verified the refactored code maintains the same training behavior

How to run locally (reviewer)

  • Run the new tests:

    pytest tests/utils/test_sample_weighting.py -v
  • Example config usage (for RA-BC training):

    lerobot-train \
      --policy.type=act \
      --dataset.repo_id=my-dataset \
      --sample_weighting.type=rabc \
      --sample_weighting.progress_path=hf://datasets/my-dataset/sarm_progress.parquet

Checklist (required before merge)

  • Linting/formatting run (pre-commit run -a)
  • All tests pass locally (pytest)
  • Documentation updated
  • CI is green

Reviewer notes

  • The key design decision was using an ABC (Abstract Base Class) over a Protocol. ABC was chosen for explicit inheritance and clearer error messages when methods aren't implemented.
  • SampleWeightingConfig contains some RABC-specific fields (progress_path, head_mode, kappa) for convenience, but the extra_params dict allows future weighting strategies to pass custom parameters.
  • The UniformWeighter is intentionally simple and serves as a baseline/no-op option.
  • Breaking change: Old config parameters (use_rabc, rabc_*) are removed. Users must migrate to the new sample_weighting config structure.

Migration Guide

Before (old config):

TrainPipelineConfig(
    use_rabc=True,
    rabc_progress_path="path/to/progress.parquet",
    rabc_kappa=0.01,
    rabc_epsilon=1e-6,
    rabc_head_mode="sparse",
)

After (new config):

from lerobot.utils.sample_weighting import SampleWeightingConfig

TrainPipelineConfig(
    sample_weighting=SampleWeightingConfig(
        type="rabc",
        progress_path="path/to/progress.parquet",
        kappa=0.01,
        epsilon=1e-6,
        head_mode="sparse",
    ),
)

Copilot AI review requested due to automatic review settings January 12, 2026 10:50
@github-actions github-actions bot added policies Items related to robot policies configuration Problems with configuration files or settings labels Jan 12, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors RABC (Reward-Aligned Behavior Cloning) specific training code into a generic sample weighting abstraction. The refactoring improves code organization by introducing a protocol-based design pattern that allows different sample weighting strategies to be plugged into the training pipeline without modifying the core training logic.

Changes:

  • Introduced a generic SampleWeighter protocol and factory function in a new sample_weighting.py module
  • Refactored lerobot_train.py to use the generic sample weighting interface instead of RABC-specific code
  • Updated configuration from RABC-specific fields to a more flexible SampleWeightingConfig

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.

File Description
src/lerobot/utils/sample_weighting.py New module providing generic sample weighting abstraction with Protocol, config dataclass, factory function, and UniformWeighter implementation
src/lerobot/scripts/lerobot_train.py Refactored to use generic sample_weighter instead of rabc_weights_provider; updated parameter names, logging, and added defensive null checks
src/lerobot/policies/sarm/rabc.py Improved documentation, added type hints, cleaned up code (removed shebang), renamed variables for clarity, and updated to work with generic interface
src/lerobot/configs/train.py Replaced RABC-specific configuration fields with generic sample_weighting field using SampleWeightingConfig
Comments suppressed due to low confidence (1)

src/lerobot/policies/sarm/rabc.py:28

  • The arXiv URL format in the documentation appears incorrect. arXiv URLs typically follow the format https://arxiv.org/abs/YYMM.NNNNN where YYMM represents year and month, and NNNNN is a sequence number. The URL "https://arxiv.org/abs/2509.25358" suggests a paper from September 2025 with ID 25358, which seems unusual. Please verify this is the correct arXiv identifier for the SARM paper. If this is a placeholder, consider using a comment to indicate it should be updated once the paper is published.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.grad_norm = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grad_norm handling change adds a hasattr check for the "item" method. While this defensive coding handles edge cases, it's worth noting that both accelerator.clip_grad_norm_ and torch.nn.utils.clip_grad_norm_ should always return a tensor (or float in older PyTorch versions). This might be unnecessary defensive code that masks underlying issues. Consider adding a comment explaining the specific scenario where grad_norm might not have an "item" method, or verify if this is truly needed.

Suggested change
train_metrics.grad_norm = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
train_metrics.grad_norm = float(grad_norm)

Copilot uses AI. Check for mistakes.
if sample_weights is not None:
# Use per-sample loss for weighted training
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
per_sample_loss, output_dict = policy.forward(batch, reduction="none") # type: ignore[call-arg]
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 103 states "Policies supporting sample weighting must implement forward(batch, reduction='none')" but there's no validation or clear error message if a policy doesn't support this. When a policy that doesn't support the reduction parameter is used with sample weighting, the type: ignore comment will suppress the type error, but the runtime error might be unclear. Consider catching the TypeError and providing a more helpful error message explaining which policies support sample weighting, or document the required interface more explicitly.

Suggested change
per_sample_loss, output_dict = policy.forward(batch, reduction="none") # type: ignore[call-arg]
try:
per_sample_loss, output_dict = policy.forward( # type: ignore[call-arg]
batch,
reduction="none",
)
except TypeError as e:
raise TypeError(
"Sample weighting is enabled, but the current policy's `forward` method does not "
"accept a `reduction` keyword argument. Policies supporting sample weighting must "
"implement `forward(batch, reduction: str = 'none')`. Disable sample weighting or "
"use a policy that implements this interface."
) from e

Copilot uses AI. Check for mistakes.
@github-actions github-actions bot added the tests Problems with test coverage, failures, or improvements to testing label Jan 12, 2026
@michel-aractingi michel-aractingi changed the title Refactor/lerobot train rabc Refactor(lerobot_train) Abstract Sample Weighting from RABC-Specific Implementation Jan 14, 2026
@github-actions github-actions bot added the documentation Improvements or fixes to the project’s docs label Jan 14, 2026
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

configuration Problems with configuration files or settings documentation Improvements or fixes to the project’s docs policies Items related to robot policies tests Problems with test coverage, failures, or improvements to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants