Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
14a7a4d
Add multitask diffusion transformer policy
brysonjones Nov 13, 2025
ab97d5c
Merge branch 'main' into feature/add-multitask-dit
brysonjones Nov 13, 2025
8b9fada
expand the observation encoder to support differnt size encoders for …
brysonjones Nov 21, 2025
34499cb
Merge branch 'main' into feature/add-multitask-dit
brysonjones Nov 29, 2025
a0d5a08
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 9, 2025
46ebcc2
add RoPE attention module as this is shown to help training dynamics …
brysonjones Dec 9, 2025
22714af
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 10, 2025
55e19ff
update readme and citations for multitask dit policy
brysonjones Dec 10, 2025
adabb37
remove dino vision encoder and simplify text and vision encoders by r…
brysonjones Dec 10, 2025
6f85601
adjust factory comment
brysonjones Dec 10, 2025
cdacc09
update docstring for multitask dit policy processor file
brysonjones Dec 10, 2025
103230c
simplify config for multitask dit by merging and flattening everythin…
brysonjones Dec 10, 2025
b92dc82
add references to the modeling file comments
brysonjones Dec 10, 2025
3b2a4f5
merge all modules files into the main modeling file
brysonjones Dec 10, 2025
3a16a00
add torch.no_grad decorators
brysonjones Dec 10, 2025
5524a0d
split up select action return statement
brysonjones Dec 10, 2025
10cfc17
remove redundant asserts
brysonjones Dec 10, 2025
f1ac454
add tutorial to training with multi_task_dit
brysonjones Dec 10, 2025
d49d339
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 10, 2025
ba968e8
fix bugs when testing on hardware
brysonjones Dec 11, 2025
86e0ee7
remove environment state conditioning
brysonjones Dec 11, 2025
67b1a9e
update typo in test instruction comment
brysonjones Dec 11, 2025
56dbeed
add processor tests to multitask dit tests
brysonjones Dec 11, 2025
9b47c5f
move policy to top of file
brysonjones Dec 11, 2025
c398a14
use constants for indexing into batches and remove env state references
brysonjones Dec 11, 2025
f3823e8
remove the base classes since we don't need to be able to extend
brysonjones Dec 11, 2025
dd4ef13
fix nit formatting in generate actions fcn
brysonjones Dec 11, 2025
43c335d
reformat and clean up tutorial for multitask dit policy
brysonjones Dec 11, 2025
8e3a1e8
add more descriptions and depth to multitask dit tutorial
brysonjones Dec 11, 2025
1f74982
note origins of each training objective
brysonjones Dec 11, 2025
51dfee4
rename config param for multiple vision encoders
brysonjones Dec 11, 2025
71f359c
refactor code to perform task tokenization in the processor instead o…
brysonjones Dec 11, 2025
e4a1b27
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 14, 2025
a632dd3
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 15, 2025
4eda54c
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 15, 2025
23382c0
add multitask dit to toc for docs
brysonjones Dec 15, 2025
534e143
add conditional transformers import to match all other policies that …
brysonjones Dec 15, 2025
afe2c4d
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 16, 2025
1e049fb
add test handling for multitask dit when transformers isnt available
brysonjones Dec 16, 2025
25ecd16
skip tests without transformers
brysonjones Dec 16, 2025
8a2f5aa
remove cropping of images smaller than the crop size
brysonjones Dec 16, 2025
b575632
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 17, 2025
2128dec
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 17, 2025
5b9f981
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 18, 2025
77dbc95
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 18, 2025
632c778
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 18, 2025
2b90763
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 20, 2025
d653f96
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 22, 2025
3e5f31e
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 23, 2025
2a3444a
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 23, 2025
e2b47a1
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 23, 2025
f5f9833
add kwargs arg to multitask dit constructor
brysonjones Dec 25, 2025
634e392
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 25, 2025
8755bd0
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multitask_dit
title: Multi-Task DiT
- local: walloss
title: WALL-OSS
title: "Policies"
Expand Down
340 changes: 340 additions & 0 deletions docs/source/multitask_dit.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
# Multi-Task DiT Policy

Multi-Task Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multi-task robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.

## Model Overview

The model uses:

- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation

This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.

## Installation Requirements

Multi-Task DiT Policy has additional dependencies. Install it with:

```bash
pip install lerobot[multi_task_dit]
```

This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.

## Usage

To use Multi-Task DiT in your LeRobot configuration, specify the policy type as:

```python
policy.type=multi_task_dit
```

## Training

### Basic Training Command

Here's a complete training command for training Multi-Task DiT on your dataset:

```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/multitask_dit_training \
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```

### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)

For reliable performance, start with these suggested default hyperparameters:

```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```

**Key Parameters:**

- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task

### Training Configuration Parameters

#### Objective Selection

Choose between diffusion and flow matching:

```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
--policy.clip_sample=true \ # Clip samples during denoising
--policy.clip_sample_range=1.0 # Clipping range [-x, x]

# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
```

#### Transformer Architecture

Adjust model capacity based on dataset size:

```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64

# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64

# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
```

**Positional Encoding Options:**

The model supports two positional encoding methods for action sequences:

```bash
# Rotary Position Embedding (RoPE) - default, recommended
--policy.use_rope=true \
--policy.rope_base=10000.0 # Base frequency for RoPE

# Absolute positional encoding
--policy.use_positional_encoding=true # Disables RoPE when true
```

**Other Transformer Parameters:**

```bash
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
--policy.timestep_embed_dim=256 # Timestep embedding dimension
```

#### Vision Encoder Configuration

```bash
# Use different CLIP model for more expressivity at the cost of inference time
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
--policy.vision_encoder_name=openai/clip-vit-large-patch14

# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_rgb_encoder_per_camera=true

# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```

#### Text Encoder Configuration

```bash
# Use different CLIP text encoder model
# same as vision: experiment with larger or smaller models depending on the
# complexity of your tasks and size of dataset
--policy.text_encoder_name=openai/clip-vit-large-patch14
```

#### Learning Rate Configuration

The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:

```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```

### Training Tuning Guidelines

#### 1. Flow Matching with Beta Sampling

The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)

Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).

Consider testing the flow-matching objective and evaluating performance differences for your task:

```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```

This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.

#### 2. Number of Transformer Layers

Match model capacity to your dataset size:

- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers

#### 3. `horizon` Tuning

The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:

- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`

Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.

#### 4. `n_action_steps` Sensitivity

The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:

- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery

### Inference Tuning

For faster inference, use DDIM with fewer sampling steps:

```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```

### Resuming Training

To resume training from a checkpoint:

```bash
lerobot-train \
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```

The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.

## Common Failure Modes and Debugging

Training these models can be finicky. Here are common failure modes and debugging approaches:

### Idling / No Motion

The model may "collapse" during inference, resulting in static or no motion. This can occur when:

1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.

2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.

**Debugging tips:**

- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction

### Executing the Wrong Task

Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.

**Potential causes:**

- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset

**Debugging tips:**

- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning

### Training Instability

If training loss is unstable or diverging:

- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly

## Performance Considerations

### GPU Requirements

- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc

### Batch Size Recommendations

- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)

## Example: Training on Custom Dataset

Here's a complete example training on a custom dataset:

```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true \
--wandb.project=multitask_dit
```

## References

For more details on the technical implementation and architecture, see:

- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ wallx = [
]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
Expand Down
2 changes: 2 additions & 0 deletions src/lerobot/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
Expand All @@ -27,6 +28,7 @@
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"SmolVLAConfig",
Expand Down
Loading