-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Feature/add multitask diffusion transformer policy implementation #2545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
brysonjones
wants to merge
54
commits into
huggingface:main
Choose a base branch
from
brysonjones:feature/add-multitask-dit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
14a7a4d
Add multitask diffusion transformer policy
brysonjones ab97d5c
Merge branch 'main' into feature/add-multitask-dit
brysonjones 8b9fada
expand the observation encoder to support differnt size encoders for …
brysonjones 34499cb
Merge branch 'main' into feature/add-multitask-dit
brysonjones a0d5a08
Merge branch 'main' into feature/add-multitask-dit
brysonjones 46ebcc2
add RoPE attention module as this is shown to help training dynamics …
brysonjones 22714af
Merge branch 'main' into feature/add-multitask-dit
brysonjones 55e19ff
update readme and citations for multitask dit policy
brysonjones adabb37
remove dino vision encoder and simplify text and vision encoders by r…
brysonjones 6f85601
adjust factory comment
brysonjones cdacc09
update docstring for multitask dit policy processor file
brysonjones 103230c
simplify config for multitask dit by merging and flattening everythin…
brysonjones b92dc82
add references to the modeling file comments
brysonjones 3b2a4f5
merge all modules files into the main modeling file
brysonjones 3a16a00
add torch.no_grad decorators
brysonjones 5524a0d
split up select action return statement
brysonjones 10cfc17
remove redundant asserts
brysonjones f1ac454
add tutorial to training with multi_task_dit
brysonjones d49d339
Merge branch 'main' into feature/add-multitask-dit
brysonjones ba968e8
fix bugs when testing on hardware
brysonjones 86e0ee7
remove environment state conditioning
brysonjones 67b1a9e
update typo in test instruction comment
brysonjones 56dbeed
add processor tests to multitask dit tests
brysonjones 9b47c5f
move policy to top of file
brysonjones c398a14
use constants for indexing into batches and remove env state references
brysonjones f3823e8
remove the base classes since we don't need to be able to extend
brysonjones dd4ef13
fix nit formatting in generate actions fcn
brysonjones 43c335d
reformat and clean up tutorial for multitask dit policy
brysonjones 8e3a1e8
add more descriptions and depth to multitask dit tutorial
brysonjones 1f74982
note origins of each training objective
brysonjones 51dfee4
rename config param for multiple vision encoders
brysonjones 71f359c
refactor code to perform task tokenization in the processor instead o…
brysonjones e4a1b27
Merge branch 'main' into feature/add-multitask-dit
brysonjones a632dd3
Merge branch 'main' into feature/add-multitask-dit
brysonjones 4eda54c
Merge branch 'main' into feature/add-multitask-dit
brysonjones 23382c0
add multitask dit to toc for docs
brysonjones 534e143
add conditional transformers import to match all other policies that …
brysonjones afe2c4d
Merge branch 'main' into feature/add-multitask-dit
brysonjones 1e049fb
add test handling for multitask dit when transformers isnt available
brysonjones 25ecd16
skip tests without transformers
brysonjones 8a2f5aa
remove cropping of images smaller than the crop size
brysonjones b575632
Merge branch 'main' into feature/add-multitask-dit
brysonjones 2128dec
Merge branch 'main' into feature/add-multitask-dit
brysonjones 5b9f981
Merge branch 'main' into feature/add-multitask-dit
brysonjones 77dbc95
Merge branch 'main' into feature/add-multitask-dit
brysonjones 632c778
Merge branch 'main' into feature/add-multitask-dit
brysonjones 2b90763
Merge branch 'main' into feature/add-multitask-dit
brysonjones d653f96
Merge branch 'main' into feature/add-multitask-dit
brysonjones 3e5f31e
Merge branch 'main' into feature/add-multitask-dit
brysonjones 2a3444a
Merge branch 'main' into feature/add-multitask-dit
brysonjones e2b47a1
Merge branch 'main' into feature/add-multitask-dit
brysonjones f5f9833
add kwargs arg to multitask dit constructor
brysonjones 634e392
Merge branch 'main' into feature/add-multitask-dit
brysonjones 8755bd0
Merge branch 'main' into feature/add-multitask-dit
brysonjones File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,340 @@ | ||
| # Multi-Task DiT Policy | ||
|
|
||
| Multi-Task Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multi-task robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions. | ||
|
|
||
| ## Model Overview | ||
|
|
||
| The model uses: | ||
|
|
||
| - **CLIP Vision Encoder**: Processes RGB images from multiple camera views | ||
| - **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection) | ||
| - **Diffusion Transformer**: Predicts action sequences conditioned on observations and language | ||
| - **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation | ||
|
|
||
| This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter | ||
| VLAs, with only ~450M parameters and significantly less training. | ||
|
|
||
| ## Installation Requirements | ||
|
|
||
| Multi-Task DiT Policy has additional dependencies. Install it with: | ||
|
|
||
| ```bash | ||
| pip install lerobot[multi_task_dit] | ||
| ``` | ||
|
|
||
| This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models. | ||
|
|
||
| ## Usage | ||
|
|
||
| To use Multi-Task DiT in your LeRobot configuration, specify the policy type as: | ||
|
|
||
| ```python | ||
| policy.type=multi_task_dit | ||
| ``` | ||
|
|
||
| ## Training | ||
|
|
||
| ### Basic Training Command | ||
|
|
||
| Here's a complete training command for training Multi-Task DiT on your dataset: | ||
|
|
||
| ```bash | ||
| lerobot-train \ | ||
| --dataset.repo_id=YOUR_DATASET \ | ||
| --output_dir=./outputs/multitask_dit_training \ | ||
| --batch_size=32 \ | ||
| --steps=5000 \ | ||
| --save_freq=500 \ | ||
| --log_freq=100 \ | ||
| --policy.type=multi_task_dit \ | ||
| --policy.device=cuda \ | ||
| --policy.repo_id="HF_USER/multitask-dit-your-robot" \ | ||
| --wandb.enable=true | ||
| ``` | ||
|
|
||
| ### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency) | ||
|
|
||
| For reliable performance, start with these suggested default hyperparameters: | ||
|
|
||
| ```bash | ||
| lerobot-train \ | ||
| --dataset.repo_id=YOUR_DATASET \ | ||
| --output_dir=./outputs/mutitask_dit_training \ | ||
| --batch_size=320 \ | ||
| --steps=30000 \ | ||
| --policy.type=multi_task_dit \ | ||
| --policy.device=cuda \ | ||
| --policy.horizon=32 \ | ||
| --policy.n_action_steps=24 \ | ||
| --policy.objective=diffusion \ | ||
| --policy.noise_scheduler_type=DDPM \ | ||
| --policy.num_train_timesteps=100 \ | ||
| --policy.repo_id="HF_USER/multitask-dit-your-robot" \ | ||
| --wandb.enable=true | ||
| ``` | ||
|
|
||
| **Key Parameters:** | ||
|
|
||
| - **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics | ||
| - **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz | ||
| - **n_action_steps**: 24 - ~0.8 seconds at 30Hz | ||
| - **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor | ||
| - **Training Steps**: >30k steps recommended for a single task | ||
|
|
||
| ### Training Configuration Parameters | ||
|
|
||
| #### Objective Selection | ||
|
|
||
| Choose between diffusion and flow matching: | ||
|
|
||
| ```bash | ||
| # Diffusion objective (default) | ||
| --policy.objective=diffusion \ | ||
| --policy.noise_scheduler_type=DDPM \ # or "DDIM" | ||
| --policy.num_train_timesteps=100 \ | ||
| --policy.num_inference_steps=10 \ # For faster inference | ||
| --policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type | ||
| --policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean) | ||
| --policy.clip_sample=true \ # Clip samples during denoising | ||
| --policy.clip_sample_range=1.0 # Clipping range [-x, x] | ||
|
|
||
| # Flow matching objective | ||
| --policy.objective=flow_matching \ | ||
| --policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice | ||
| --policy.num_integration_steps=100 \ | ||
| --policy.integration_method=euler \ # or "rk4" | ||
| --policy.sigma_min=0.0 # Minimum noise in flow interpolation path | ||
| ``` | ||
|
|
||
| #### Transformer Architecture | ||
|
|
||
| Adjust model capacity based on dataset size: | ||
|
|
||
| ```bash | ||
| # Small datasets (< 100 examples) | ||
| --policy.num_layers=4 \ | ||
| --policy.hidden_dim=512 \ | ||
| --policy.num_heads=8 # should ideally be hidden_dim // 64 | ||
|
|
||
| # Medium datasets (100-5k examples) - default | ||
| --policy.num_layers=6 \ | ||
| --policy.hidden_dim=512 \ | ||
| --policy.num_heads=8 # should ideally be hidden_dim // 64 | ||
|
|
||
| # Large datasets (> 5k examples) | ||
| --policy.num_layers=8 \ | ||
| --policy.hidden_dim=512 \ | ||
| --policy.num_heads=8 # should ideally be hidden_dim // 64 | ||
| ``` | ||
|
|
||
| **Positional Encoding Options:** | ||
|
|
||
| The model supports two positional encoding methods for action sequences: | ||
|
|
||
| ```bash | ||
| # Rotary Position Embedding (RoPE) - default, recommended | ||
| --policy.use_rope=true \ | ||
| --policy.rope_base=10000.0 # Base frequency for RoPE | ||
|
|
||
| # Absolute positional encoding | ||
| --policy.use_positional_encoding=true # Disables RoPE when true | ||
| ``` | ||
|
|
||
| **Other Transformer Parameters:** | ||
|
|
||
| ```bash | ||
| --policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0) | ||
| --policy.timestep_embed_dim=256 # Timestep embedding dimension | ||
| ``` | ||
|
|
||
| #### Vision Encoder Configuration | ||
|
|
||
| ```bash | ||
| # Use different CLIP model for more expressivity at the cost of inference time | ||
| # experiment with larger or smaller models depending on the complexity of your tasks and size of dataset | ||
| --policy.vision_encoder_name=openai/clip-vit-large-patch14 | ||
|
|
||
| # Use separate vision encoder per camera | ||
| # This may be useful when cameras have significantly different characteristics, but | ||
| # be wary of increased VRAM footprint. | ||
| --policy.use_separate_rgb_encoder_per_camera=true | ||
|
|
||
| # Image preprocessing | ||
| --policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups | ||
| --policy.image_crop_shape=[224,224] \ | ||
| --policy.image_crop_is_random=true # Random during training, center at inference | ||
| ``` | ||
|
|
||
| #### Text Encoder Configuration | ||
|
|
||
| ```bash | ||
| # Use different CLIP text encoder model | ||
| # same as vision: experiment with larger or smaller models depending on the | ||
| # complexity of your tasks and size of dataset | ||
| --policy.text_encoder_name=openai/clip-vit-large-patch14 | ||
| ``` | ||
|
|
||
| #### Learning Rate Configuration | ||
|
|
||
| The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point: | ||
|
|
||
| ```bash | ||
| --policy.optimizer_lr=2e-5 \ | ||
| --policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr | ||
| ``` | ||
|
|
||
| ### Training Tuning Guidelines | ||
|
|
||
| #### 1. Flow Matching with Beta Sampling | ||
brysonjones marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331) | ||
|
|
||
| Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/). | ||
|
|
||
| Consider testing the flow-matching objective and evaluating performance differences for your task: | ||
|
|
||
| ```bash | ||
| --policy.objective=flow_matching \ | ||
| --policy.timestep_sampling_strategy=beta \ | ||
| --policy.timestep_sampling_alpha=1.5 \ | ||
| --policy.timestep_sampling_beta=1.0 \ | ||
| --policy.timestep_sampling_s=0.999 | ||
| ``` | ||
|
|
||
| This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions. | ||
|
|
||
| #### 2. Number of Transformer Layers | ||
|
|
||
| Match model capacity to your dataset size: | ||
|
|
||
| - **Small datasets** (< 100 examples): Reduce to 4 layers | ||
| - **Large datasets** (> 5k examples): Increase to 8 layers | ||
|
|
||
| #### 3. `horizon` Tuning | ||
|
|
||
| The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency: | ||
|
|
||
| - **30 Hz frequency**: `horizon=30` | ||
| - **10 Hz frequency**: `horizon=10` | ||
|
|
||
| Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions. | ||
|
|
||
| #### 4. `n_action_steps` Sensitivity | ||
|
|
||
| The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there: | ||
|
|
||
| - **Lower values**: More reactive but potentially less stable for long-horizon tasks | ||
| - **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery | ||
|
|
||
| ### Inference Tuning | ||
|
|
||
| For faster inference, use DDIM with fewer sampling steps: | ||
|
|
||
| ```bash | ||
| --policy.noise_scheduler_type=DDIM \ | ||
| --policy.num_inference_steps=10 | ||
| ``` | ||
|
|
||
| ### Resuming Training | ||
|
|
||
| To resume training from a checkpoint: | ||
|
|
||
| ```bash | ||
| lerobot-train \ | ||
| --config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \ | ||
| --resume=true | ||
| ``` | ||
|
|
||
| The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters. | ||
|
|
||
| ## Common Failure Modes and Debugging | ||
|
|
||
| Training these models can be finicky. Here are common failure modes and debugging approaches: | ||
|
|
||
| ### Idling / No Motion | ||
|
|
||
| The model may "collapse" during inference, resulting in static or no motion. This can occur when: | ||
|
|
||
| 1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex. | ||
|
|
||
| 2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough. | ||
|
|
||
| **Debugging tips:** | ||
|
|
||
| - Increase dataset size (double until you get to over 300 examples) | ||
| - Train for longer, up to 100k steps, even when the loss flatlines | ||
| - Check if the model is receiving proper language instructions or increase diversity of instruction | ||
|
|
||
| ### Executing the Wrong Task | ||
|
|
||
| Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks. | ||
|
|
||
| **Potential causes:** | ||
|
|
||
| - Language instruction ambiguity | ||
| - Insufficient task-specific training data | ||
| - Model confusion between similar tasks in the multitask dataset | ||
|
|
||
| **Debugging tips:** | ||
|
|
||
| - Verify language instruction specificity, especially if descriptions are similar between multiple tasks | ||
| - Check task distribution in your training dataset and add weighting to the failing/ignored task | ||
| - Consider task-specific fine-tuning | ||
|
|
||
| ### Training Instability | ||
|
|
||
| If training loss is unstable or diverging: | ||
|
|
||
| - Try adjusting learning rate between `1e-5` and `3e-4` | ||
| - Increase batch size if possible | ||
| - Check that your dataset normalization is correct | ||
| - Verify image preprocessing is working correctly | ||
|
|
||
| ## Performance Considerations | ||
|
|
||
| ### GPU Requirements | ||
|
|
||
| - **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance | ||
| - **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc | ||
|
|
||
| ### Batch Size Recommendations | ||
|
|
||
| - **Minimum**: 64 (less than this may result in unstable training) | ||
| - **Recommended**: 256-320 (best performance, requires larger GPU) | ||
|
|
||
| ## Example: Training on Custom Dataset | ||
|
|
||
| Here's a complete example training on a custom dataset: | ||
|
|
||
| ```bash | ||
| lerobot-train \ | ||
| --dataset.repo_id=YOUR_DATASET \ | ||
| --output_dir=./outputs/mutitask_dit_training \ | ||
| --batch_size=320 \ | ||
| --steps=30000 \ | ||
| --save_freq=1000 \ | ||
| --log_freq=100 \ | ||
| --eval_freq=1000 \ | ||
| --policy.type=multi_task_dit \ | ||
| --policy.device=cuda \ | ||
| --policy.horizon=32 \ | ||
| --policy.n_action_steps=24 \ | ||
| --policy.objective=diffusion \ | ||
| --policy.noise_scheduler_type=DDPM \ | ||
| --policy.num_layers=6 \ | ||
| --policy.hidden_dim=512 \ | ||
| --policy.vision_encoder_name=openai/clip-vit-base-patch16 \ | ||
| --policy.image_resize_shape=[320,240] \ | ||
| --policy.image_crop_shape=[224,224] \ | ||
| --policy.repo_id="HF_USER/multitask-dit-your-robot" \ | ||
| --wandb.enable=true \ | ||
| --wandb.project=multitask_dit | ||
| ``` | ||
|
|
||
| ## References | ||
|
|
||
| For more details on the technical implementation and architecture, see: | ||
|
|
||
| - [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331) | ||
| - [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/) | ||
| - [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.