-
Notifications
You must be signed in to change notification settings - Fork 3.4k
perf(policies): Make ACT policy compatible with torch.compile
#2159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
perf(policies): Make ACT policy compatible with torch.compile
#2159
Conversation
Adds `benchmark_inference_compile_lerobot.py` from https://gist.github.com/AdilZouitine/3574664e4cf71605986b49e9148d29ab.
…pile graph break Removed .item() calls from loss_dict in forward() to avoid breaking the torch.compile computation graph. The tensor-to-scalar conversion is now handled in the training script instead.
In the inference benchmark in `benchmark_inference_compile_lerobot.py`, the `test_correctness` method failed to properly compare compiled `ACTPolicy` inference results. This was due to `policy_original` and `policy_compiled` sharing the same `_action_queue` object. Previously, the call order was: 1. `policy_original.reset()` 2. `policy_compiled.reset()` 3. `policy_original.select_action()` 4. `policy_compiled.select_action()` Because the `_action_queue` is shared, `policy_original.select_action()` would run inference (`predict_action_chunk`) and extend the queue with `n_action_steps` actions. `policy_compiled.select_action()` would then find a non-empty queue and simply pop an action, bypassing its own compiled inference logic. This commit reorders the calls to: 1. `policy_original.reset()` 2. `policy_original.select_action()` 3. `policy_compiled.reset()` 4. `policy_compiled.select_action()` This change ensures that `policy_compiled.reset()` clears the shared queue *after* the original policy's action selection. Consequently, `policy_compiled.select_action()` finds an empty queue and executes its own compiled inference, allowing for a correct comparison. With this fix, the compiled `ACTPolicy` inference check within `test_correctness` now passes, validating that the compiled inference output matches the original.
- Added import copy. - Use copy.deepcopy(policy) before torch.compile.
- Introduced `self.fullgraph` attribute in `TorchCompileBenchmark`. - Pass `fullgraph=self.fullgraph` when calling `torch.compile`. - Added CLI argument `--fullgraph` to enable full graph compilation, raising errors if graph breaks.
- Added `--matmul-precision` argument with choices: `highest`, `high`, `medium`. - Applied only when CUDA device is selected. - Allows benchmarking with different float32 matmul precision settings.
- Add `--disable-cudnn-tf32` CLI argument to disallow the use of TensorFloat-32 tensor cores in cuDNN convolutions (CUDA only). - Apply `torch.backends.cudnn.allow_tf32 = False` when the argument is used.
- Add `--disable-dropout` CLI argument to set dropout rate to 0 in policies. - Apply the argument by setting `cfg.dropout = 0.0` if the policy config has a dropout attribute.
Remove the conditional `if args.fullgraph` check and assign `benchmark.fullgraph` directly from `args.fullgraph`. This ensures the benchmark always reflects the CLI flag.
- Add `compile_mode` to `TorchCompileBenchmark` and expose it through the command-line argument `--compile-mode`, supporting both `default` and `reduce-overhead` modes. - Update the benchmark compilation strategy by compiling `forward` and `select_action` individually instead of compiling the entire model, improving control over compilation behavior and inference performance. - Extend `ACTConfig` with `compile_model` and `compile_mode` to support optional model compilation through configuration. - Update `ACTPolicy` to conditionally compile `forward` and `select_action` during initialization when `compile_model` is enabled in the policy configuration.
4046ece to
5496a42
Compare
torch.compiletorch.compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR makes the ACT policy compatible with torch.compile by removing .item() calls from the forward method (which cause graph breaks) and delegating scalar extraction to the training script. It also adds optional torch.compile support via configuration flags and includes a comprehensive benchmark script.
Key Changes:
- Removed
.item()calls from ACT policy'sforward()method to avoid graph breaks during compilation - Modified training script to handle tensor-to-scalar conversion for loss dictionaries
- Added
compile_modelandcompile_modeconfiguration options to ACTConfig - Introduced benchmark script for evaluating torch.compile performance
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/lerobot/policies/act/modeling_act.py |
Removed .item() calls from loss dict; added conditional torch.compile of forward and select_action methods |
src/lerobot/policies/act/configuration_act.py |
Added compile_model and compile_mode configuration fields with documentation |
src/lerobot/scripts/lerobot_train.py |
Added dictionary comprehension to convert tensor values to scalars after policy forward pass |
benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py |
New comprehensive benchmark script with compile options and performance testing |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Let accelerator handle mixed precision | ||
| with accelerator.autocast(): | ||
| loss, output_dict = policy.forward(batch) | ||
| output_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()} |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line will crash when output_dict is None (e.g., for DiffusionPolicy which returns None). The code should check if output_dict is not None before attempting to call .items() on it. Consider: output_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()} if output_dict is not None else {}
| output_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()} | |
| output_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()} if output_dict is not None else {} |
| dropout: Dropout to use in the transformer layers (see code for details). | ||
| kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective | ||
| is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. | ||
| compile_model: Enables compiling with `torch.compile` for faster policy training and inference. |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation states this parameter enables torch.compile for faster training and inference, but the implementation in modeling_act.py shows that it compiles both forward() and select_action() methods. The docstring should clarify which specific methods are compiled to avoid confusion.
| compile_model: Enables compiling with `torch.compile` for faster policy training and inference. | |
| compile_model: Enables compiling with `torch.compile` for faster policy training and inference. This compiles both the `forward()` and `select_action()` methods. |
benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py
Outdated
Show resolved
Hide resolved
…robot.py Co-authored-by: Copilot <[email protected]> Signed-off-by: HUANG TZU-CHUN <[email protected]>
1. What this does
1.1 Changes to make the ACT policy compile-compatible
1.1.1 Modifications to fix graph breaks
TorchDynamo emitted graph-break warnings due to the use of
Tensor.item()inside theforward()method ofACTPolicy. Since.item()performs Python-side scalar extraction and disrupts graph capture, these conversions were removed from theforward()method. The model now returns loss tensors directly in loss_dict; scalar extraction is deferred to the training script (lerobot_train.py).The warnings observed were:
Corresponding code changes:
ACTPolicy.forwardinsrc/lerobot/policies/act/modeling_act.pyupdate_policyinsrc/lerobot/scripts/lerobot_train.pywith accelerator.autocast(): loss, output_dict = policy.forward(batch) + output_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()}1.1.2 Add optional torch.compile support to ACTPolicy
This update introduces optional compilation support for
ACTPolicyusing PyTorch’storch.compile. Two new arguments for compilation control were added, consistent with PI0 and PI0.5 policies:--policy.compile_model: Enables or disables compilation of the policy model.--policy.compile_mode: Specifies the Torch compile mode to use.In
ACTConfig(src/lerobot/policies/act/configuration_act.py), the following fields were added:During initialization of
ACTPolicy(src/lerobot/policies/act/modeling_act.py), compilation is applied conditionally based on the configuration:1.2 Changes of benchmark
The benchmark script
benchmarks/policies_compilation/benchmark_inference_compile_lerobot.pymodified from the script provided in issue #2061 and the following changes were made:--compile-mode: ["default", "reduce-overhead"] Torch compile mode to use.--fullgraph: If set, compile the entire model as a single graph and raise an error if graph breaks.--disable-dropout: If set, disable dropout layers by setting their dropout rate to 0.--matmul-precision: ["highest", "high", "medium"] Set float32 matmul precision (only applies when device is cuda).--disable-cudnn-tf32: Disallow the use of TensorFloat-32 tensor cores in cuDNN convolutions (only applies when device is CUDA).2. How it was tested
2.1 Environment and testing command
The environment used for testing and benchmarking is as follows:
Tests and benchmarks were executed using the following command with different combinations of command-line arguments:
The benchmark was performed using the following combinations of command-line arguments. Both
defaultandreduce-overheadcompile modes were tested separately:2.2 Baseline and final benchmark reports
2.2.1 Compile mode:
default2.2.2 Compile mode:
reduce-overhead3. How to checkout & try (for the reviewer)
Testing and benchmarking can be performed using the following command with different combinations of command-line arguments: