Skip to content

Commit c24cbaa

Browse files
authored
Merge branch 'main' into feature/add-multitask-dit
2 parents 0e7bfa5 + 1d86c9b commit c24cbaa

File tree

15 files changed

+3214
-5
lines changed

15 files changed

+3214
-5
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
title: SmolVLA
3838
- local: pi0
3939
title: π₀ (Pi0)
40+
- local: pi0fast
41+
title: π₀-FAST (Pi0Fast)
4042
- local: pi05
4143
title: π₀.₅ (Pi05)
4244
- local: groot

docs/source/pi0fast.mdx

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# π₀-FAST (Pi0-FAST)
2+
3+
π₀-FAST is a **Vision-Language-Action model for general robot control** that uses autoregressive next-token prediction to model continuous robot actions.
4+
5+
## Model Overview
6+
7+
π₀-FAST combines the power of Vision-Language Models with a novel action tokenization approach called **FAST (Frequency-space Action Sequence Tokenization)**. This enables training autoregressive VLAs on highly dexterous tasks that are impossible with standard binning-based discretization, while training **up to 5x faster** than diffusion-based approaches like π₀.
8+
9+
### Why FAST?
10+
11+
Standard approaches for robot action tokenization use simple per-dimension, per-timestep binning schemes. While passable for simple behaviors, this rapidly breaks down for complex and dexterous skills that require precision and high-frequency control.
12+
13+
FAST solves this by compressing action sequences using signal processing techniques, resulting in a dense sequence of action tokens that can be predicted autoregressively—just like language tokens.
14+
15+
### How FAST Tokenization Works
16+
17+
The FAST tokenizer compresses action sequences through the following steps:
18+
19+
1. **Normalize**: Take a continuous action chunk of shape `(H, D)` where `H` is the horizon and `D` is the action dimension. Normalize using one of the supported normalization methods (Quantiles recommended to handle outliers).
20+
21+
2. **Discrete Cosine Transform (DCT)**: Apply DCT (via scipy) to each action dimension separately. DCT is a compression algorithm commonly used in image and audio codecs (JPEG, MP3).
22+
23+
3. **Quantization**: Round and remove insignificant coefficients for each action dimension, producing a sparse frequency matrix.
24+
25+
4. **Flatten**: Flatten the matrix into a 1D vector, with low-frequency components first.
26+
27+
5. **Byte Pair Encoding (BPE)**: Train a BPE tokenizer to compress the DCT coefficients into dense action tokens, typically achieving **10x compression** over prior tokenization approaches.
28+
29+
This approach can transform **any existing VLM** into a VLA by training it to predict these FAST tokens.
30+
31+
## Installation Requirements
32+
33+
1. Install LeRobot by following our [Installation Guide](./installation).
34+
2. Install π₀-FAST dependencies by running:
35+
36+
```bash
37+
pip install -e ".[pi]"
38+
```
39+
40+
> [!NOTE]
41+
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
42+
>
43+
> This will be solved in the next patch release
44+
45+
## Training a Custom FAST Tokenizer
46+
47+
You have two options for the FAST tokenizer:
48+
49+
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
50+
51+
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
52+
53+
### Training Your Own Tokenizer
54+
55+
```bash
56+
python src/lerobot/policies/pi0_fast/train_fast_tokenizer.py \
57+
--repo_id "user/my-lerobot-dataset" \
58+
--action_horizon 10 \
59+
--encoded_dims "0:6" \
60+
--vocab_size 1024 \
61+
--scale 10.0 \
62+
--normalization_mode QUANTILES \
63+
--output_dir "./my_fast_tokenizer" \
64+
--push_to_hub \
65+
--hub_repo_id "username/my-action-tokenizer"
66+
```
67+
68+
### Key Tokenizer Parameters
69+
70+
| Parameter | Description | Default |
71+
| ---------------------- | --------------------------------------------------------------------------------- | ------------ |
72+
| `--repo_id` | LeRobot dataset repository ID | Required |
73+
| `--action_horizon` | Number of future actions in each chunk | `10` |
74+
| `--encoded_dims` | Comma-separated dimension ranges to encode (e.g., `"0:6,7:23"`) | `"0:6,7:23"` |
75+
| `--vocab_size` | BPE vocabulary size | `1024` |
76+
| `--scale` | DCT scaling factor for quantization | `10.0` |
77+
| `--normalization_mode` | Normalization mode (`MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`, `IDENTITY`) | `QUANTILES` |
78+
| `--sample_fraction` | Fraction of chunks to sample per episode | `0.1` |
79+
80+
## Usage
81+
82+
To use π₀-FAST in LeRobot, specify the policy type as:
83+
84+
```python
85+
policy.type=pi0_fast
86+
```
87+
88+
## Training
89+
90+
For training π₀-FAST, you can use the LeRobot training script:
91+
92+
```bash
93+
python src/lerobot/scripts/lerobot_train.py \
94+
--dataset.repo_id=your_dataset \
95+
--policy.type=pi0_fast \
96+
--output_dir=./outputs/pi0fast_training \
97+
--job_name=pi0fast_training \
98+
--policy.pretrained_path=lerobot/pi0_fast_base \
99+
--policy.dtype=bfloat16 \
100+
--policy.gradient_checkpointing=true \
101+
--policy.chunk_size=10 \
102+
--policy.n_action_steps=10 \
103+
--policy.max_action_tokens=256 \
104+
--steps=100000 \
105+
--batch_size=4 \
106+
--policy.device=cuda
107+
```
108+
109+
### Key Training Parameters
110+
111+
| Parameter | Description | Default |
112+
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
113+
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
114+
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
115+
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
116+
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
117+
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
118+
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
119+
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
120+
121+
## Inference
122+
123+
### KV-Caching for Fast Inference
124+
125+
π₀-FAST supports **KV-caching**, a widely used optimization in LLM inference. This caches the key-value pairs from the attention mechanism, avoiding redundant computation during autoregressive decoding.
126+
127+
```python
128+
# KV-caching is enabled by default
129+
policy.use_kv_cache=true
130+
```
131+
132+
### Inference Example
133+
134+
```python
135+
from lerobot.policies.pi0_fast import PI0FastPolicy, PI0FastConfig
136+
137+
# Load the policy
138+
policy = PI0FastPolicy.from_pretrained("your-model-path")
139+
140+
# During inference
141+
actions = policy.predict_action_chunk(batch)
142+
```
143+
144+
## Model Architecture
145+
146+
π₀-FAST uses a PaliGemma-based architecture:
147+
148+
- **Vision Encoder**: SigLIP vision tower for image understanding
149+
- **Language Model**: Gemma 2B for processing language instructions and predicting action tokens
150+
151+
The model takes images, text instructions, and robot state as input, and outputs discrete FAST tokens that are decoded back to continuous actions.
152+
153+
## Configuration Options
154+
155+
| Parameter | Description | Default |
156+
| -------------------- | ----------------------------------------------- | ---------- |
157+
| `paligemma_variant` | VLM backbone variant (`gemma_300m`, `gemma_2b`) | `gemma_2b` |
158+
| `max_state_dim` | Maximum state vector dimension (padded) | `32` |
159+
| `max_action_dim` | Maximum action vector dimension (padded) | `32` |
160+
| `temperature` | Sampling temperature (0.0 for greedy) | `0.0` |
161+
| `max_decoding_steps` | Maximum decoding steps | `256` |
162+
| `use_kv_cache` | Enable KV caching for faster inference | `true` |
163+
164+
## Comparison with π₀
165+
166+
| Feature | π₀ | π₀-FAST |
167+
| --------------------- | ------------------------- | ---------------------------- |
168+
| Action Representation | Flow Matching (Diffusion) | Autoregressive Tokens (FAST) |
169+
| Training Speed | 1x | **5x faster** |
170+
| Dexterity | High | High |
171+
| Inference Method | Iterative Denoising | Autoregressive Decoding |
172+
| KV-Caching | N/A | Supported |
173+
174+
## License
175+
176+
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
177+
178+
## References
179+
180+
- [FAST: Efficient Robot Action Tokenization](https://www.physicalintelligence.company/research/fast) - Physical Intelligence Blog
181+
- [OpenPI Repository](https://github.com/Physical-Intelligence/openpi) - Original implementation
182+
- [FAST Tokenizer on Hugging Face](https://huggingface.co/physical-intelligence/fast) - Pre-trained tokenizer

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ wallx = [
127127
"torchdiffeq==0.2.5",
128128
"qwen_vl_utils==0.0.11"
129129
]
130-
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
130+
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
131131
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
132132
multi_task_dit = ["lerobot[transformers-dep]"]
133133
groot = [

src/lerobot/policies/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .groot.configuration_groot import GrootConfig as GrootConfig
1818
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
1919
from .pi0.configuration_pi0 import PI0Config as PI0Config
20+
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
2021
from .pi05.configuration_pi05 import PI05Config as PI05Config
2122
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
2223
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
@@ -31,6 +32,7 @@
3132
"MultiTaskDiTConfig",
3233
"PI0Config",
3334
"PI05Config",
35+
"PI0FastConfig",
3436
"SmolVLAConfig",
3537
"SARMConfig",
3638
"TDMPCConfig",

src/lerobot/policies/factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
9595
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
9696

9797
return PI0Policy
98+
elif name == "pi0_fast":
99+
from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
100+
101+
return PI0FastPolicy
98102
elif name == "pi05":
99103
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
100104

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from .configuration_pi0_fast import PI0FastConfig
18+
from .modeling_pi0_fast import PI0FastPolicy
19+
from .processor_pi0_fast import make_pi0_fast_pre_post_processors
20+
21+
__all__ = ["PI0FastConfig", "PI0FastPolicy", "make_pi0_fast_pre_post_processors"]

0 commit comments

Comments
 (0)