diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 7766b34725..b59d15734c 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -59,6 +59,8 @@
title: Environments from the Hub
- local: envhub_leisaac
title: Control & Train Robots in Sim (LeIsaac)
+ - local: envhub_isaaclab_arena
+ title: NVIDIA IsaacLab Arena Environments
- local: libero
title: Using Libero
- local: metaworld
diff --git a/docs/source/envhub_isaaclab_arena.mdx b/docs/source/envhub_isaaclab_arena.mdx
new file mode 100644
index 0000000000..f0f1c33ad0
--- /dev/null
+++ b/docs/source/envhub_isaaclab_arena.mdx
@@ -0,0 +1,395 @@
+# NVIDIA IsaacLab Arena & LeRobot
+
+LeRobot EnvHub now supports **GPU-accelerated simulation** with IsaacLab Arena for policy evaluation at scale.
+Train and evaluate imitation learning policies with high-fidelity simulation — all integrated into the LeRobot ecosystem.
+
+
+
+[IsaacLab Arena](https://github.com/isaac-sim/IsaacLab-Arena) integrates with NVIDIA IsaacLab to provide:
+
+- 🤖 **Humanoid embodiments**: GR1, G1, Galileo with various configurations
+- 🎯 **Manipulation & loco-manipulation tasks**: Microwave opening, pick-and-place, button pressing
+- ⚡ **GPU-accelerated rollouts**: Parallel environment execution on NVIDIA GPUs
+- 📦 **LeRobot-compatible datasets**: Ready for training with PI0, SmolVLA, ACT, Diffusion policies
+- 🔄 **EnvHub integration**: Load environments from HuggingFace Hub with one line
+
+## Available Environments
+
+The following environments are currently available in IsaacLab Arena:
+
+| Preview | Environment | Description |
+| :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------- | :------------------------------------------------------------------------------------------------------------------- |
+|
| `gr1_microwave` | Reach out to the microwave and open it. |
+|
| `galileo_pnp` | Pick objects and place in target location |
+|
| `g1_locomanip_pnp` | Pick up the brown box from the shelf, and place it into the blue bin on the table located at the right of the shelf. |
+|
| `kitchen_pnp` | Kitchen object manipulation tasks |
+|
| `press_button` | Locate and press button |
+
+## Quick Start
+
+### Load Environment from Hub
+
+```python
+from lerobot.envs.factory import make_env
+
+# Load IsaacLab Arena environment from the Hub
+envs_dict = make_env(
+ "nvkartik/isaaclab-arena-envs",
+ n_envs=4,
+ trust_remote_code=True,
+ environment="gr1_microwave",
+ embodiment="gr1_pink",
+ headless=True,
+ enable_cameras=True,
+)
+
+# Access the environment
+suite_name = next(iter(envs_dict))
+env = envs_dict[suite_name][0]
+
+# Run a simple episode
+obs, info = env.reset()
+for _ in range(100):
+ action = env.action_space.sample()
+ obs, reward, terminated, truncated, info = env.step(action)
+ if terminated.any() or truncated.any():
+ obs, info = env.reset()
+
+env.close()
+```
+
+### Using the Native Configuration
+
+For CLI integration and streamlined configuration, use the `isaaclab_arena` environment type:
+
+```bash
+# Evaluate a trained policy
+lerobot-eval \
+ --policy.path=nvkartik/smolvla-arena-gr1-microwave \
+ --env.type=isaaclab_arena \
+ --env.environment=gr1_microwave \
+ --env.embodiment=gr1_pink \
+ --env.num_envs=1 \
+ --env.headless=true \
+ --env.enable_cameras=true \
+ --policy.device=cuda
+```
+
+## Installation
+
+### Prerequisites
+
+- NVIDIA GPU with CUDA support
+- NVIDIA driver compatible with IsaacSim 5.1.0
+- Linux (Ubuntu 22.04 recommended)
+
+### Setup
+
+```bash
+# 1. Create conda environment
+conda create -y -n lerobot-arena python=3.11
+conda activate lerobot-arena
+conda install -y -c conda-forge ffmpeg=7.1.1
+
+# 2. Install Isaac Sim 5.1.0
+pip install "isaacsim[all,extscache]==5.1.0" --extra-index-url https://pypi.nvidia.com
+
+# Accept NVIDIA EULA (required)
+export ACCEPT_EULA=Y
+export PRIVACY_CONSENT=Y
+
+# 3. Install IsaacLab 2.3.0
+git clone https://github.com/isaac-sim/IsaacLab.git
+cd IsaacLab
+git checkout v2.3.0
+./isaaclab.sh -i
+cd ..
+
+# 4. Install LeRobot
+git clone https://github.com/huggingface/lerobot.git
+cd lerobot
+pip install -e ".[pi]"
+cd ..
+
+# 5. Install IsaacLab Arena
+git clone git@github.com:isaac-sim/IsaacLab-Arena.git
+cd IsaacLab-Arena
+pip install -e .
+cd ..
+
+# 6. Install additional dependencies
+pip install qpsolvers==4.8.1 numpy==1.26.0
+
+# 7. (Optional) Setup Weights & Biases
+pip install wandb
+wandb login
+```
+
+## Training Policies
+
+### Using Pre-collected Datasets
+
+IsaacLab Arena datasets are available on HuggingFace Hub:
+
+| Dataset | Description | Frames |
+| :---------------------------------------------------------------------------------------------------------- | :------------------------- | :----- |
+| [Arena-GR1-Manipulation-Task](https://huggingface.co/datasets/nvkartik/Arena-GR1-Manipulation-Task) | GR1 microwave manipulation | ~4K |
+| [Arena-G1-Loco-Manipulation-Task](https://huggingface.co/datasets/nvkartik/Arena-G1-Loco-Manipulation-Task) | G1 loco-manipulation | ~4K |
+
+### Train PI0.5 Policy
+
+```bash
+lerobot-train \
+ --policy.type=pi05 \
+ --dataset.repo_id=nvkartik/Arena-GR1-Manipulation-Task \
+ --rename_map='{"observation.images.robot_pov_cam":"observation.images.camera1"}' \
+ --policy.empty_cameras=2 \
+ --policy.max_state_dim=64 \
+ --policy.max_action_dim=64 \
+ --batch_size=16 \
+ --steps=20000 \
+ --output_dir=outputs/train/pi05-arena \
+ --policy.device=cuda \
+ --wandb.enable=true \
+ --save_freq=500 \
+ --log_freq=50
+```
+
+### Train SmolVLA Policy
+
+```bash
+lerobot-train \
+ --policy.type=smolvla \
+ --dataset.repo_id=nvkartik/Arena-GR1-Manipulation-Task \
+ --rename_map='{"observation.images.robot_pov_cam":"observation.images.camera1"}' \
+ --batch_size=8 \
+ --steps=10000 \
+ --output_dir=outputs/train/smolvla-arena \
+ --policy.device=cuda
+```
+
+## Evaluating Policies
+
+### Pre-trained Policies
+
+The following trained policies are available:
+
+| Policy | Architecture | Task | Link |
+| :-------------------------- | :----------- | :------------ | :------------------------------------------------------------------------- |
+| pi05-arena-gr1-microwave | PI0.5 | GR1 Microwave | [HuggingFace](https://huggingface.co/nvkartik/pi05-arena-gr1-microwave) |
+| smolvla-arena-gr1-microwave | SmolVLA | GR1 Microwave | [HuggingFace](https://huggingface.co/nvkartik/smolvla-arena-gr1-microwave) |
+
+### Evaluate SmolVLA
+
+```bash
+lerobot-eval \
+ --policy.path=nvkartik/smolvla-arena-gr1-microwave \
+ --env.type=isaaclab_arena \
+ --rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \
+ --env.environment=gr1_microwave \
+ --env.embodiment=gr1_pink \
+ --env.object=mustard_bottle \
+ --env.num_envs=1 \
+ --env.headless=true \
+ --policy.device=cuda \
+ --env.enable_cameras=true \
+ --env.video=true \
+ --env.video_length=10 \
+ --env.video_interval=15
+```
+
+### Evaluate PI0.5
+
+PI0.5 requires disabling torch compile for evaluation:
+
+```bash
+TORCH_COMPILE_DISABLE=1 TORCHINDUCTOR_DISABLE=1 lerobot-eval \
+ --policy.path=nvkartik/pi05-arena-gr1-microwave \
+ --env.type=isaaclab_arena \
+ --rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \
+ --env.environment=gr1_microwave \
+ --env.embodiment=gr1_pink \
+ --env.object=mustard_bottle \
+ --env.num_envs=1 \
+ --env.headless=true \
+ --policy.device=cuda \
+ --env.enable_cameras=true \
+ --env.video=true \
+ --env.video_length=15 \
+ --env.video_interval=15
+```
+
+## Environment Configuration
+
+### Full Configuration Options
+
+```python
+from lerobot.envs.configs import IsaaclabArenaEnv
+
+config = IsaaclabArenaEnv(
+ # Environment selection
+ environment="gr1_microwave", # Task environment
+ embodiment="gr1_pink", # Robot embodiment
+ object="power_drill", # Object to manipulate
+
+ # Simulation settings
+ num_envs=4, # Number of parallel environments
+ episode_length=300, # Max steps per episode
+ headless=True, # Run without GUI
+ device="cuda:0", # GPU device
+ seed=42, # Random seed
+
+ # Observation configuration
+ state_keys="robot_joint_pos", # State observation keys (comma-separated)
+ camera_keys="robot_pov_cam_rgb", # Camera observation keys (comma-separated)
+ state_dim=54, # Expected state dimension
+ action_dim=36, # Expected action dimension
+ camera_height=512, # Camera image height
+ camera_width=512, # Camera image width
+ enable_cameras=True, # Enable camera observations
+
+ # Video recording
+ video=False, # Enable video recording
+ video_length=100, # Frames per video
+ video_interval=200, # Steps between recordings
+
+ # Advanced
+ mimic=False, # Enable mimic mode
+ teleop_device=None, # Teleoperation device
+ disable_fabric=False, # Disable fabric optimization
+ enable_pinocchio=True, # Enable Pinocchio for IK
+)
+```
+
+## Zero-Agent Environment Test
+
+Test environment loading without a trained policy:
+
+```python
+import logging
+from dataclasses import asdict, dataclass
+from pprint import pformat
+import torch
+import tqdm
+from lerobot import envs
+from lerobot.configs import parser
+from lerobot.envs.configs import IsaaclabArenaEnv
+
+@dataclass
+class ArenaConfig:
+ env: envs.EnvConfig
+
+@parser.wrap()
+def main(cfg: ArenaConfig):
+ """Run zero action rollout for IsaacLab Arena environment."""
+ logging.info(pformat(asdict(cfg)))
+
+ from lerobot.envs.factory import make_env
+
+ env_kwargs = asdict(cfg.env)
+ env_kwargs.pop("features", None)
+ env_kwargs.pop("features_map", None)
+
+ env_dict = make_env(
+ "nvkartik/isaaclab-arena-envs",
+ n_envs=cfg.env.num_envs,
+ trust_remote_code=True,
+ **env_kwargs,
+ )
+ env = next(iter(env_dict.values()))[0]
+ env.reset()
+
+ for _ in tqdm.tqdm(range(cfg.env.episode_length)):
+ with torch.inference_mode():
+ action_dim = env.action_space.shape[-1]
+ actions = torch.zeros((env.num_envs, action_dim), device=env.device)
+ obs, rewards, terminated, truncated, info = env.step(actions)
+
+if __name__ == "__main__":
+ main()
+```
+
+Run with:
+
+```bash
+python test_env.py \
+ --env.type=isaaclab_arena \
+ --env.environment=galileo_pnp \
+ --env.embodiment=gr1_pink \
+ --env.object=cracker_box \
+ --env.num_envs=4 \
+ --env.enable_cameras=true \
+ --env.seed=1000
+```
+
+### Vector Environment Wrapper
+
+IsaacLab uses GPU-batched execution (all environments run on GPU simultaneously). The `IsaacLabVectorEnvWrapper` provides VectorEnv compatibility:
+
+```python
+class IsaacLabVectorEnvWrapper:
+ """Wrapper adapting IsaacLab batched GPU env to VectorEnv interface."""
+
+ @property
+ def num_envs(self) -> int:
+ return self._num_envs
+
+ def reset(self, *, seed=None, options=None):
+ # Handle seed list → single seed for IsaacLab
+ ...
+
+ def step(self, actions):
+ # Convert actions to GPU tensors, execute, return numpy
+ ...
+
+ def render_all(self) -> list[np.ndarray]:
+ # Return list of RGB frames for video recording
+ ...
+
+ ...
+```
+
+## Troubleshooting
+
+### "CUDA out of memory"
+
+Reduce `num_envs` or use a GPU with more VRAM:
+
+```bash
+--env.num_envs=1
+```
+
+### "EULA not accepted"
+
+Set environment variables before running:
+
+```bash
+export ACCEPT_EULA=Y
+export PRIVACY_CONSENT=Y
+```
+
+### Video recording not working
+
+Enable cameras when running headless:
+
+```bash
+--env.video=true --env.enable_cameras=true --env.headless=true
+```
+
+### Policy output dimension mismatch
+
+E.g. ensure `action_dim` matches your policy:
+
+```bash
+--env.action_dim=36
+```
+
+## See Also
+
+- [EnvHub Documentation](./envhub.mdx) - General EnvHub usage
+- [IsaacLab Arena GitHub](https://github.com/isaac-sim/IsaacLab-Arena)
+- [IsaacLab Documentation](https://isaac-sim.github.io/IsaacLab/)
diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py
index 2f085da560..64ef154ff5 100644
--- a/src/lerobot/configs/eval.py
+++ b/src/lerobot/configs/eval.py
@@ -16,6 +16,7 @@
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
+from typing import Any
from lerobot import envs, policies # noqa: F401
from lerobot.configs import parser
@@ -30,7 +31,7 @@ class EvalPipelineConfig:
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch
# (useful for debugging). This argument is mutually exclusive with `--config`.
- env: envs.EnvConfig
+ env: envs.EnvConfig # | str
eval: EvalConfig = field(default_factory=EvalConfig)
policy: PreTrainedConfig | None = None
output_dir: Path | None = None
@@ -38,6 +39,10 @@ class EvalPipelineConfig:
seed: int | None = 1000
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
+ # Additional kwargs to pass to hub environments (e.g., config_path, config_overrides, custom params)
+ env_kwargs: dict[str, Any] = field(default_factory=dict)
+ # Explicit consent to execute remote code from the Hub (required for hub environments).
+ trust_remote_code: bool = False
def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -52,13 +57,21 @@ def __post_init__(self) -> None:
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
+ # Parse env_kwargs from CLI (e.g., --env_kwargs.headless=true)
+ env_kwargs_overrides = parser.get_cli_overrides("env_kwargs")
+ if env_kwargs_overrides:
+ for arg in env_kwargs_overrides:
+ # arg format: "--key=value"
+ key, value = arg.removeprefix("--").split("=", 1)
+ self.env_kwargs[key] = value
+
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
else:
- self.job_name = (
- f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
- )
+ # Added for hub environments
+ env_type = self.env.type if isinstance(self.env, envs.EnvConfig) else self.env.split("/")[-1]
+ self.job_name = f"{env_type}_{self.policy.type if self.policy is not None else 'scratch'}"
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py
index d767b6e8cc..42407649f6 100644
--- a/src/lerobot/envs/__init__.py
+++ b/src/lerobot/envs/__init__.py
@@ -13,3 +13,4 @@
# limitations under the License.
from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401
+from .isaaclab import IsaacLabEnvWrapper # noqa: F401
diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py
index 4323f33166..b119608796 100644
--- a/src/lerobot/envs/configs.py
+++ b/src/lerobot/envs/configs.py
@@ -47,6 +47,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
features_map: dict[str, str] = field(default_factory=dict)
max_parallel_tasks: int = 1
disable_env_checker: bool = True
+ hub_path: str | None = None
@property
def type(self) -> str:
@@ -368,3 +369,61 @@ def gym_kwargs(self) -> dict:
"obs_type": self.obs_type,
"render_mode": self.render_mode,
}
+
+
+@EnvConfig.register_subclass("isaaclab_arena")
+@dataclass
+class IsaaclabArenaEnv(EnvConfig):
+ hub_path: str = "nvkartik/isaaclab-arena-envs"
+ episode_length: int = 300
+ # num_envs: int = 1
+ embodiment: str | None = "gr1_pink"
+ object: str | None = "power_drill"
+ mimic: bool = False
+ teleop_device: str | None = None
+ seed: int | None = 42
+ device: str | None = "cuda:0"
+ disable_fabric: bool = False
+ enable_cameras: bool = False
+ headless: bool = False
+ enable_pinocchio: bool = True
+ environment: str | None = "gr1_microwave"
+ task: str | None = "Reach out to the microwave and open it."
+ state_dim: int = 54
+ action_dim: int = 36
+ camera_height: int = 512
+ camera_width: int = 512
+ video: bool = False
+ video_length: int = 100
+ video_interval: int = 200
+ # Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos"
+ state_keys: str = "robot_joint_pos"
+ # Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb"
+ # Set to None or "" for environments without cameras
+ camera_keys: str | None = None
+ features: dict[str, PolicyFeature] = field(default_factory=dict)
+ features_map: dict[str, str] = field(default_factory=dict)
+
+ def __post_init__(self):
+ # Set action feature
+ self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
+ self.features_map[ACTION] = ACTION
+
+ # Set state feature
+ self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,))
+ self.features_map[OBS_STATE] = OBS_STATE
+
+ # Add camera features for each camera key
+ if self.enable_cameras and self.camera_keys:
+ for cam_key in self.camera_keys.split(","):
+ cam_key = cam_key.strip()
+ if cam_key:
+ self.features[cam_key] = PolicyFeature(
+ type=FeatureType.VISUAL,
+ shape=(self.camera_height, self.camera_width, 3),
+ )
+ self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}"
+
+ @property
+ def gym_kwargs(self) -> dict:
+ return {}
diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py
index b39cfee718..c83005a9ec 100644
--- a/src/lerobot/envs/factory.py
+++ b/src/lerobot/envs/factory.py
@@ -20,11 +20,11 @@
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
-from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
+from lerobot.envs.configs import AlohaEnv, EnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
-from lerobot.processor.env_processor import LiberoProcessorStep
+from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -73,6 +73,26 @@ def make_env_pre_post_processors(
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
+ # For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
+ if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
+ # Parse comma-separated keys (handle None for state-based policies)
+ if env_cfg.state_keys:
+ state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
+ else:
+ state_keys = ()
+ if env_cfg.camera_keys:
+ camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
+ else:
+ camera_keys = ()
+ if not state_keys and not camera_keys:
+ raise ValueError("At least one of state_keys or camera_keys must be specified.")
+ preprocessor_steps.append(
+ IsaaclabArenaProcessorStep(
+ state_keys=state_keys,
+ camera_keys=camera_keys,
+ )
+ )
+
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
@@ -85,6 +105,7 @@ def make_env(
use_async_envs: bool = False,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
+ **kwargs,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config or Hub reference.
@@ -98,7 +119,8 @@ def make_env(
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
Default False — must be set to True to import/exec hub `env.py`.
-
+ **kwargs: Additional keyword arguments passed to the hub environment's `make_env` function.
+ Useful for passing custom configurations like `config_path`, `config_overrides`, etc.
Raises:
ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not installed
@@ -113,18 +135,32 @@ def make_env(
# if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
# simplified: only support hub-provided `make_env`
if isinstance(cfg, str):
+ hub_path: str | None = cfg
+ else:
+ hub_path = cfg.hub_path
+
+ # If hub_path is set, download and call hub-provided `make_env`
+ if hub_path:
# _download_hub_file will raise the same RuntimeError if trust_remote_code is False
- repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
+ repo_id, file_path, local_file, revision = _download_hub_file(
+ hub_path, trust_remote_code, hub_cache_dir
+ )
# import and surface clear import errors
module = _import_hub_module(local_file, repo_id)
# call the hub-provided make_env
- raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
+ env_cfg = None if isinstance(cfg, str) else cfg
+ raw_result = _call_make_env(
+ module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg, **kwargs
+ )
# normalize the return into {suite: {task_id: vec_env}}
return _normalize_hub_result(raw_result)
+ # At this point, cfg must be an EnvConfig (not a string) since hub_path would have been set otherwise
+ assert not isinstance(cfg, str), "cfg should be EnvConfig at this point"
+
if n_envs < 1:
raise ValueError("`n_envs` must be at least 1")
diff --git a/src/lerobot/envs/isaaclab.py b/src/lerobot/envs/isaaclab.py
new file mode 100644
index 0000000000..8d34fda410
--- /dev/null
+++ b/src/lerobot/envs/isaaclab.py
@@ -0,0 +1,223 @@
+from __future__ import annotations
+
+import atexit
+import logging
+import os
+import signal
+from contextlib import suppress
+from typing import Any
+
+import gymnasium as gym
+import numpy as np
+import torch
+
+from lerobot.utils.errors import IsaacLabArenaError
+
+
+def cleanup_isaaclab(env, simulation_app) -> None:
+ """Cleanup IsaacLab env and simulation app resources."""
+ # Ignore signals during cleanup to prevent interruption
+ old_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN)
+ old_sigterm = signal.signal(signal.SIGTERM, signal.SIG_IGN)
+ try:
+ with suppress(Exception):
+ if env is not None:
+ env.close()
+ with suppress(Exception):
+ if simulation_app is not None:
+ simulation_app.app.close()
+ finally:
+ # Restore signal handlers
+ signal.signal(signal.SIGINT, old_sigint)
+ signal.signal(signal.SIGTERM, old_sigterm)
+
+
+class IsaacLabEnvWrapper(gym.vector.AsyncVectorEnv):
+ """Wrapper adapting IsaacLab batched GPU env to AsyncVectorEnv.
+ IsaacLab handles vectorization internally on GPU. We inherit from
+ AsyncVectorEnv for compatibility with LeRobot."""
+
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
+ _cleanup_in_progress = False # Class-level flag for re-entrant protection
+
+ def __init__(
+ self,
+ env,
+ episode_length: int = 500,
+ task: str | None = None,
+ render_mode: str | None = "rgb_array",
+ simulation_app=None,
+ ):
+ self._env = env
+ self._num_envs = env.num_envs
+ self._episode_length = episode_length
+ self._closed = False
+ self.render_mode = render_mode
+ self._simulation_app = simulation_app
+
+ self.observation_space = env.observation_space
+ self.action_space = env.action_space
+ self.single_observation_space = env.observation_space
+ self.single_action_space = env.action_space
+ self.task = task
+
+ if hasattr(env, "metadata") and env.metadata:
+ self.metadata = {**self.metadata, **env.metadata}
+
+ # Register cleanup handlers
+ atexit.register(self._cleanup)
+ signal.signal(signal.SIGINT, self._signal_handler)
+ signal.signal(signal.SIGTERM, self._signal_handler)
+
+ def _signal_handler(self, signum, frame):
+ if IsaacLabEnvWrapper._cleanup_in_progress:
+ return # Prevent re-entrant cleanup
+ IsaacLabEnvWrapper._cleanup_in_progress = True
+ logging.info(f"Received signal {signum}, cleaning up...")
+ self._cleanup()
+ # Exit without raising to avoid propagating through callbacks
+ os._exit(0)
+
+ def _check_closed(self):
+ if self._closed:
+ raise IsaacLabArenaError()
+
+ @property
+ def unwrapped(self):
+ return self
+
+ @property
+ def num_envs(self) -> int:
+ return self._num_envs
+
+ @property
+ def _max_episode_steps(self) -> int:
+ return self._episode_length
+
+ @property
+ def device(self) -> str:
+ return getattr(self._env, "device", "cpu")
+
+ def reset(
+ self,
+ *,
+ seed: int | list[int] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ self._check_closed()
+ if isinstance(seed, (list, tuple, range)):
+ seed = seed[0] if len(seed) > 0 else None
+
+ obs, info = self._env.reset(seed=seed, options=options)
+
+ if "final_info" not in info:
+ zeros = np.zeros(self._num_envs, dtype=bool)
+ info["final_info"] = {"is_success": zeros}
+
+ return obs, info
+
+ def step(
+ self, actions: np.ndarray | torch.Tensor
+ ) -> tuple[dict, np.ndarray, np.ndarray, np.ndarray, dict]:
+ self._check_closed()
+ if isinstance(actions, np.ndarray):
+ actions = torch.from_numpy(actions).to(self._env.device)
+
+ obs, reward, terminated, truncated, info = self._env.step(actions)
+
+ # Convert to numpy for gym compatibility
+ reward = reward.cpu().numpy().astype(np.float32)
+ terminated = terminated.cpu().numpy().astype(bool)
+ truncated = truncated.cpu().numpy().astype(bool)
+
+ is_success = self._get_success(terminated, truncated)
+ info["final_info"] = {"is_success": is_success}
+
+ return obs, reward, terminated, truncated, info
+
+ def _get_success(self, terminated: np.ndarray, truncated: np.ndarray) -> np.ndarray:
+ is_success = np.zeros(self._num_envs, dtype=bool)
+
+ if not hasattr(self._env, "termination_manager"):
+ return is_success & (terminated | truncated)
+
+ term_manager = self._env.termination_manager
+ if not hasattr(term_manager, "get_term"):
+ return is_success & (terminated | truncated)
+
+ success_tensor = term_manager.get_term("success")
+ if success_tensor is None:
+ return is_success & (terminated | truncated)
+
+ is_success = success_tensor.cpu().numpy().astype(bool)
+
+ return is_success & (terminated | truncated)
+
+ def call(self, method_name: str, *args, **kwargs) -> list[Any]:
+ if method_name == "_max_episode_steps":
+ return [self._episode_length] * self._num_envs
+ if method_name == "task":
+ return [self.task] * self._num_envs
+ if method_name == "render":
+ return self.render_all()
+
+ if hasattr(self._env, method_name):
+ attr = getattr(self._env, method_name)
+ result = attr(*args, **kwargs) if callable(attr) else attr
+ if isinstance(result, list):
+ return result
+ return [result] * self._num_envs
+
+ raise AttributeError(f"IsaacLab-Arena has no method/attribute '{method_name}'")
+
+ def render_all(self) -> list[np.ndarray]:
+ self._check_closed()
+ frames = self.render()
+ if frames is None:
+ placeholder = np.zeros((480, 640, 3), dtype=np.uint8)
+ return [placeholder] * self._num_envs
+
+ if frames.ndim == 4:
+ return [frames[i] for i in range(min(len(frames), self._num_envs))]
+
+ return [np.zeros((480, 640, 3), dtype=np.uint8)] * self._num_envs
+
+ def render(self) -> np.ndarray | None:
+ """Render all environments and return list of frames."""
+ self._check_closed()
+ if self.render_mode != "rgb_array":
+ return None
+
+ frames = self._env.render() if hasattr(self._env, "render") else None
+ if frames is None:
+ return None
+
+ if isinstance(frames, torch.Tensor):
+ frames = frames.cpu().numpy()
+
+ return frames[0] if frames.ndim == 4 else frames
+
+ def _cleanup(self) -> None:
+ if self._closed:
+ return
+ self._closed = True
+ IsaacLabEnvWrapper._cleanup_in_progress = True
+ logging.info("Cleaning up IsaacLab Arena environment...")
+ cleanup_isaaclab(self._env, self._simulation_app)
+
+ def close(self) -> None:
+ self._cleanup()
+
+ @property
+ def envs(self) -> list[IsaacLabEnvWrapper]:
+ return [self] * self._num_envs
+
+ def __del__(self):
+ self._cleanup()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._cleanup()
+ return False
diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py
index 8d0f249221..082df83319 100644
--- a/src/lerobot/envs/utils.py
+++ b/src/lerobot/envs/utils.py
@@ -98,6 +98,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
+
+ # Handle IsaacLab Arena format: observations have 'policy' and 'camera_obs' keys
+ if "policy" in observations:
+ return_observations[f"{OBS_STR}.policy"] = observations["policy"]
+
+ if "camera_obs" in observations:
+ return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
+
return return_observations
@@ -302,16 +310,22 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any:
return module
-def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
+def _call_make_env(
+ module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None, **kwargs: Any
+) -> Any:
"""
Ensure module exposes make_env and call it.
"""
if not hasattr(module, "make_env"):
raise AttributeError(
- f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
+ f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool, cfg: EnvConfig | None, **kwargs)`."
)
entry_fn = module.make_env
- return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
+ # Only pass cfg if it's not None (i.e., when an EnvConfig was provided, not a string hub ID)
+ if cfg is not None:
+ return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg, **kwargs)
+ else:
+ return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, **kwargs)
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py
index b1872b0328..bc69d9f95a 100644
--- a/src/lerobot/processor/env_processor.py
+++ b/src/lerobot/processor/env_processor.py
@@ -18,7 +18,7 @@
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
+from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -152,3 +152,78 @@ def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
result[mask] = axis * angle.unsqueeze(1)
return result
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
+class IsaaclabArenaProcessorStep(ObservationProcessorStep):
+ """
+ Processes IsaacLab Arena observations into LeRobot format.
+
+ **State Processing:**
+ - Extracts state components from obs["policy"] based on `state_keys`.
+ - Concatenates into a flat vector mapped to "observation.state".
+
+ **Image Processing:**
+ - Extracts images from obs["camera_obs"] based on `camera_keys`.
+ - Converts from (B, H, W, C) uint8 to (B, C, H, W) float32 [0, 1].
+ - Maps to "observation.images.".
+ """
+
+ # Configurable from IsaacLabEnv config / cli args: --env.state_keys="robot_joint_pos,left_eef_pos"
+ state_keys: tuple[str, ...] = ("robot_joint_pos",)
+
+ # Configurable from IsaacLabEnv config / cli args: --env.camera_keys="robot_pov_cam_rgb"
+ camera_keys: tuple[str, ...] = ("robot_pov_cam_rgb",)
+
+ def _process_observation(self, observation):
+ """
+ Processes both image and policy state observations from IsaacLab Arena.
+ """
+ processed_obs = {}
+
+ if f"{OBS_STR}.camera_obs" in observation:
+ camera_obs = observation[f"{OBS_STR}.camera_obs"]
+
+ for cam_name, img in camera_obs.items():
+ if cam_name not in self.camera_keys:
+ continue
+
+ img = img.permute(0, 3, 1, 2).contiguous()
+ if img.dtype == torch.uint8:
+ img = img.float() / 255.0
+ elif img.dtype != torch.float32:
+ img = img.float()
+
+ processed_obs[f"{OBS_IMAGES}.{cam_name}"] = img
+
+ # Process policy state -> observation.state
+ if f"{OBS_STR}.policy" in observation:
+ policy_obs = observation[f"{OBS_STR}.policy"]
+
+ # Collect state components in order
+ state_components = []
+ for key in self.state_keys:
+ if key in policy_obs:
+ component = policy_obs[key]
+ # Flatten extra dims: (B, N, M) -> (B, N*M)
+ if component.dim() > 2:
+ batch_size = component.shape[0]
+ component = component.view(batch_size, -1)
+ state_components.append(component)
+
+ if state_components:
+ state = torch.cat(state_components, dim=-1)
+ state = state.float()
+ processed_obs[OBS_STATE] = state
+
+ return processed_obs
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """Not used for policy evaluation."""
+ return features
+
+ def observation(self, observation):
+ return self._process_observation(observation)
diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py
index d23b9d083a..0014c9694d 100644
--- a/src/lerobot/scripts/lerobot_eval.py
+++ b/src/lerobot/scripts/lerobot_eval.py
@@ -43,6 +43,17 @@
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
+You can also evaluate a model on a Hub environment with custom kwargs:
+```
+lerobot-eval \
+ --policy.path=HF_USER/HF_REPO \
+ --env=HF_USER/HF_REPO \
+ --eval.batch_size=1 \
+ --eval.n_episodes=10 \
+ --env_kwargs.environment=env_A \
+ --env_kwargs.embodiment=emb_B \
+```
+
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
"""
@@ -509,7 +520,13 @@ def eval_main(cfg: EvalPipelineConfig):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.")
- envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
+ envs = make_env(
+ cfg.env,
+ n_envs=cfg.eval.batch_size,
+ use_async_envs=cfg.eval.use_async_envs,
+ trust_remote_code=cfg.trust_remote_code,
+ **cfg.env_kwargs,
+ )
logging.info("Making policy.")
diff --git a/src/lerobot/utils/errors.py b/src/lerobot/utils/errors.py
index 31b73eacab..6d04accfac 100644
--- a/src/lerobot/utils/errors.py
+++ b/src/lerobot/utils/errors.py
@@ -30,3 +30,35 @@ def __init__(
):
self.message = message
super().__init__(self.message)
+
+
+class IsaacLabArenaError(RuntimeError):
+ """Base exception for IsaacLab Arena environment errors."""
+
+ def __init__(self, message: str = "IsaacLab Arena error"):
+ self.message = message
+ super().__init__(self.message)
+
+
+class IsaacLabArenaConfigError(IsaacLabArenaError):
+ """Exception raised for invalid environment configuration."""
+
+ def __init__(self, invalid: list, available: list, key_type: str = "keys"):
+ msg = f"Invalid {key_type}: {invalid}. Available: {sorted(available)}"
+ super().__init__(msg)
+ self.invalid = invalid
+ self.available = available
+
+
+class IsaacLabArenaCameraKeyError(IsaacLabArenaConfigError):
+ """Exception raised when camera_keys don't match available cameras."""
+
+ def __init__(self, invalid: list, available: list):
+ super().__init__(invalid, available, "camera_keys")
+
+
+class IsaacLabArenaStateKeyError(IsaacLabArenaConfigError):
+ """Exception raised when state_keys don't match available state terms."""
+
+ def __init__(self, invalid: list, available: list):
+ super().__init__(invalid, available, "state_keys")
diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py
index 910c275eb4..853fdb3240 100644
--- a/tests/envs/test_envs.py
+++ b/tests/envs/test_envs.py
@@ -15,6 +15,7 @@
# limitations under the License.
import importlib
from dataclasses import dataclass, field
+from unittest.mock import MagicMock
import gymnasium as gym
import numpy as np
@@ -27,6 +28,7 @@
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config
+from lerobot.envs.isaaclab import IsaacLabEnvWrapper
from lerobot.envs.utils import (
_normalize_hub_result,
_parse_hub_url,
@@ -266,3 +268,152 @@ def test_make_env_from_hub_async():
# clean up
env.close()
+
+
+# IsaacLabEnvWrapper tests (mock-based without installing IsaacLab)
+
+
+def _create_mock_isaaclab_env(num_envs: int = 2, device: str = "cpu"):
+ """Create a mock IsaacLab environment for testing."""
+ mock_env = MagicMock()
+ mock_env.num_envs = num_envs
+ mock_env.device = device
+ mock_env.observation_space = gym.spaces.Dict(
+ {"policy": gym.spaces.Box(low=-1, high=1, shape=(num_envs, 54), dtype=np.float32)}
+ )
+ mock_env.action_space = gym.spaces.Box(low=-1, high=1, shape=(36,), dtype=np.float32)
+ mock_env.metadata = {}
+ return mock_env
+
+
+def test_isaaclab_wrapper_init():
+ """Test IsaacLabEnvWrapper initialization."""
+ mock_env = _create_mock_isaaclab_env(num_envs=4)
+
+ wrapper = IsaacLabEnvWrapper(
+ mock_env,
+ episode_length=300,
+ task="Test task",
+ render_mode="rgb_array",
+ )
+
+ assert wrapper.num_envs == 4
+ assert wrapper._max_episode_steps == 300
+ assert wrapper.task == "Test task"
+ assert wrapper.render_mode == "rgb_array"
+ assert wrapper.device == "cpu"
+ assert len(wrapper.envs) == 4
+
+
+def test_isaaclab_wrapper_reset():
+ """Test IsaacLabEnvWrapper reset."""
+ mock_env = _create_mock_isaaclab_env(num_envs=2)
+ mock_obs = {"policy": torch.randn(2, 54)}
+ mock_env.reset.return_value = (mock_obs, {})
+
+ wrapper = IsaacLabEnvWrapper(mock_env, episode_length=100)
+ obs, info = wrapper.reset(seed=42)
+
+ mock_env.reset.assert_called_once_with(seed=42, options=None)
+ assert "final_info" in info
+ assert "is_success" in info["final_info"]
+ assert len(info["final_info"]["is_success"]) == 2
+
+
+def test_isaaclab_wrapper_reset_with_seed_list():
+ """Test that seed list is handled correctly (IsaacLab expects single seed)."""
+ mock_env = _create_mock_isaaclab_env(num_envs=2)
+ mock_env.reset.return_value = ({"policy": torch.randn(2, 54)}, {})
+
+ wrapper = IsaacLabEnvWrapper(mock_env)
+ wrapper.reset(seed=[42, 43, 44])
+
+ # Should extract first seed
+ mock_env.reset.assert_called_once_with(seed=42, options=None)
+
+
+def test_isaaclab_wrapper_step():
+ """Test IsaacLabEnvWrapper step."""
+ mock_env = _create_mock_isaaclab_env(num_envs=2)
+ mock_env.step.return_value = (
+ {"policy": torch.randn(2, 54)},
+ torch.tensor([0.5, 0.3]),
+ torch.tensor([False, False]),
+ torch.tensor([False, True]),
+ {},
+ )
+ # Mock termination manager
+ mock_env.termination_manager.get_term.return_value = torch.tensor([False, True])
+
+ wrapper = IsaacLabEnvWrapper(mock_env)
+ actions = np.random.randn(2, 36).astype(np.float32)
+ obs, reward, terminated, truncated, info = wrapper.step(actions)
+
+ assert reward.dtype == np.float32
+ assert terminated.dtype == bool
+ assert truncated.dtype == bool
+ assert len(reward) == 2
+ assert "final_info" in info
+ assert "is_success" in info["final_info"]
+
+
+def test_isaaclab_wrapper_call_method():
+ """Test IsaacLabEnvWrapper call method."""
+ mock_env = _create_mock_isaaclab_env(num_envs=3)
+
+ wrapper = IsaacLabEnvWrapper(mock_env, episode_length=200, task="My task")
+
+ # Test _max_episode_steps
+ result = wrapper.call("_max_episode_steps")
+ assert result == [200, 200, 200]
+
+ # Test task
+ result = wrapper.call("task")
+ assert result == ["My task", "My task", "My task"]
+
+
+def test_isaaclab_wrapper_render():
+ """Test IsaacLabEnvWrapper render."""
+ mock_env = _create_mock_isaaclab_env(num_envs=2)
+ mock_frames = torch.randint(0, 255, (2, 480, 640, 3), dtype=torch.uint8)
+ mock_env.render.return_value = mock_frames
+
+ wrapper = IsaacLabEnvWrapper(mock_env, render_mode="rgb_array")
+ frame = wrapper.render()
+
+ assert frame is not None
+ assert frame.shape == (480, 640, 3) # Returns first env frame
+
+
+def test_isaaclab_wrapper_render_all():
+ """Test IsaacLabEnvWrapper render_all."""
+ mock_env = _create_mock_isaaclab_env(num_envs=2)
+ mock_frames = torch.randint(0, 255, (2, 480, 640, 3), dtype=torch.uint8)
+ mock_env.render.return_value = mock_frames
+
+ wrapper = IsaacLabEnvWrapper(mock_env, render_mode="rgb_array")
+ frames = wrapper.render_all()
+
+ assert len(frames) == 2
+ assert all(f.shape == (480, 640, 3) for f in frames)
+
+
+def test_isaaclab_wrapper_render_none():
+ """Test render returns None when render_mode is not rgb_array."""
+ mock_env = _create_mock_isaaclab_env()
+
+ wrapper = IsaacLabEnvWrapper(mock_env, render_mode=None)
+ assert wrapper.render() is None
+
+
+def test_isaaclab_wrapper_close():
+ """Test IsaacLabEnvWrapper close."""
+ mock_env = _create_mock_isaaclab_env()
+ mock_app = MagicMock()
+
+ wrapper = IsaacLabEnvWrapper(mock_env, simulation_app=mock_app)
+ wrapper.close()
+
+ mock_env.close.assert_called_once()
+ mock_app.app.close.assert_called_once()
+ assert wrapper._closed
diff --git a/tests/processor/test_arena_processor.py b/tests/processor/test_arena_processor.py
new file mode 100644
index 0000000000..c68eb1fb43
--- /dev/null
+++ b/tests/processor/test_arena_processor.py
@@ -0,0 +1,407 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import torch
+
+from lerobot.configs.types import (
+ FeatureType,
+ PipelineFeatureType,
+ PolicyFeature,
+)
+from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
+from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
+
+# Test constants
+BATCH_SIZE = 2
+STATE_DIM = 16
+IMG_HEIGHT = 64
+IMG_WIDTH = 64
+
+# Generic test keys (not real robot keys)
+TEST_STATE_KEY = "test_state_obs"
+TEST_CAMERA_KEY = "test_rgb_cam"
+
+
+@pytest.fixture
+def processor():
+ """Default processor with test keys."""
+ return IsaaclabArenaProcessorStep(
+ state_keys=(TEST_STATE_KEY,),
+ camera_keys=(TEST_CAMERA_KEY,),
+ )
+
+
+@pytest.fixture
+def sample_observation():
+ """Sample IsaacLab Arena observation with state and camera data."""
+ return {
+ f"{OBS_STR}.policy": {
+ TEST_STATE_KEY: torch.randn(BATCH_SIZE, STATE_DIM),
+ },
+ f"{OBS_STR}.camera_obs": {
+ TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8),
+ },
+ }
+
+
+# =============================================================================
+# State Processing Tests
+# =============================================================================
+
+
+def test_state_extraction(processor, sample_observation):
+ """Test that state is extracted and converted to float32."""
+ processed = processor.observation(sample_observation)
+
+ assert OBS_STATE in processed
+ assert processed[OBS_STATE].shape == (BATCH_SIZE, STATE_DIM)
+ assert processed[OBS_STATE].dtype == torch.float32
+
+
+def test_state_concatenation_multiple_keys():
+ """Test that multiple state keys are concatenated in order."""
+ dim1, dim2 = 10, 6
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=("state_alpha", "state_beta"),
+ camera_keys=(),
+ )
+
+ obs = {
+ f"{OBS_STR}.policy": {
+ "state_alpha": torch.ones(BATCH_SIZE, dim1),
+ "state_beta": torch.ones(BATCH_SIZE, dim2) * 2,
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ state = processed[OBS_STATE]
+ assert state.shape == (BATCH_SIZE, dim1 + dim2)
+ # Verify ordering: first dim1 elements are 1s, last dim2 are 2s
+ assert torch.all(state[:, :dim1] == 1.0)
+ assert torch.all(state[:, dim1:] == 2.0)
+
+
+def test_state_flattening_higher_dims():
+ """Test that state with dim > 2 is flattened to (B, -1)."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=("multidim_state",),
+ camera_keys=(),
+ )
+
+ # Shape (B, 4, 4) -> should flatten to (B, 16)
+ obs = {
+ f"{OBS_STR}.policy": {
+ "multidim_state": torch.randn(BATCH_SIZE, 4, 4),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ assert processed[OBS_STATE].shape == (BATCH_SIZE, 16)
+
+
+def test_state_filters_to_configured_keys():
+ """Test that only configured state_keys are extracted."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=("included_key",),
+ camera_keys=(),
+ )
+
+ obs = {
+ f"{OBS_STR}.policy": {
+ "included_key": torch.randn(BATCH_SIZE, 10),
+ "excluded_key": torch.randn(BATCH_SIZE, 6), # Should be ignored
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ # Only included_key (dim 10) should be included
+ assert processed[OBS_STATE].shape == (BATCH_SIZE, 10)
+
+
+def test_missing_state_key_skipped():
+ """Test that missing state keys in observation are skipped."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=("present_key", "missing_key"),
+ camera_keys=(),
+ )
+
+ obs = {
+ f"{OBS_STR}.policy": {
+ "present_key": torch.randn(BATCH_SIZE, 10),
+ # missing_key not present
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ # Should only have present_key
+ assert processed[OBS_STATE].shape == (BATCH_SIZE, 10)
+
+
+# =============================================================================
+# Camera/Image Processing Tests
+# =============================================================================
+
+
+def test_camera_permutation_bhwc_to_bchw(processor, sample_observation):
+ """Test images are permuted from (B, H, W, C) to (B, C, H, W)."""
+ processed = processor.observation(sample_observation)
+
+ img_key = f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"
+ assert img_key in processed
+ img = processed[img_key]
+ assert img.shape == (BATCH_SIZE, 3, IMG_HEIGHT, IMG_WIDTH)
+
+
+def test_camera_uint8_to_normalized_float32(processor):
+ """Test that uint8 images are normalized to float32 [0, 1]."""
+ obs = {
+ f"{OBS_STR}.camera_obs": {
+ TEST_CAMERA_KEY: torch.full((BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), 255, dtype=torch.uint8),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"]
+ assert img.dtype == torch.float32
+ assert torch.allclose(img, torch.ones_like(img))
+
+
+def test_camera_float32_passthrough(processor):
+ """Test that float32 images are kept as float32."""
+ original_img = torch.rand(BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3, dtype=torch.float32)
+ obs = {
+ f"{OBS_STR}.camera_obs": {
+ TEST_CAMERA_KEY: original_img.clone(),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"]
+ assert img.dtype == torch.float32
+ # Values should be same (just permuted)
+ expected = original_img.permute(0, 3, 1, 2)
+ assert torch.allclose(img, expected)
+
+
+def test_camera_other_dtype_converted_to_float(processor):
+ """Test that non-uint8, non-float32 dtypes are converted to float."""
+ obs = {
+ f"{OBS_STR}.camera_obs": {
+ TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.int32),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"]
+ assert img.dtype == torch.float32
+
+
+def test_camera_filters_to_configured_keys():
+ """Test that only configured camera_keys are extracted."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=(),
+ camera_keys=("included_cam",),
+ )
+
+ obs = {
+ f"{OBS_STR}.camera_obs": {
+ "included_cam": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8),
+ "excluded_cam": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ assert f"{OBS_IMAGES}.included_cam" in processed
+ assert f"{OBS_IMAGES}.excluded_cam" not in processed
+
+
+def test_camera_key_preserved_exactly():
+ """Test that camera key name is used exactly (no suffix stripping)."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=(),
+ camera_keys=("my_cam_rgb",),
+ )
+
+ obs = {
+ f"{OBS_STR}.camera_obs": {
+ "my_cam_rgb": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ # Key should be exactly as configured, with _rgb suffix intact
+ assert f"{OBS_IMAGES}.my_cam_rgb" in processed
+ assert f"{OBS_IMAGES}.my_cam" not in processed
+
+
+# =============================================================================
+# Edge Cases & Missing Data Tests
+# =============================================================================
+
+
+def test_missing_camera_obs_section(processor):
+ """Test processor handles observation without camera_obs section."""
+ obs = {
+ f"{OBS_STR}.policy": {
+ TEST_STATE_KEY: torch.randn(BATCH_SIZE, STATE_DIM),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ assert OBS_STATE in processed
+ assert not any(k.startswith(OBS_IMAGES) for k in processed)
+
+
+def test_missing_policy_obs_section(processor):
+ """Test processor handles observation without policy section."""
+ obs = {
+ f"{OBS_STR}.camera_obs": {
+ TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ assert f"{OBS_IMAGES}.{TEST_CAMERA_KEY}" in processed
+ assert OBS_STATE not in processed
+
+
+def test_empty_observation(processor):
+ """Test processor handles empty observation dict."""
+ processed = processor.observation({})
+
+ assert OBS_STATE not in processed
+ assert not any(k.startswith(OBS_IMAGES) for k in processed)
+
+
+def test_no_matching_state_keys():
+ """Test processor when no state keys match observation."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=("nonexistent_key",),
+ camera_keys=(),
+ )
+
+ obs = {
+ f"{OBS_STR}.policy": {
+ "some_other_key": torch.randn(BATCH_SIZE, STATE_DIM),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ # No state because no keys matched
+ assert OBS_STATE not in processed
+
+
+def test_no_matching_camera_keys():
+ """Test processor when no camera keys match observation."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=(),
+ camera_keys=("nonexistent_cam",),
+ )
+
+ obs = {
+ f"{OBS_STR}.camera_obs": {
+ "some_other_cam": torch.randint(
+ 0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8
+ ),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ assert not any(k.startswith(OBS_IMAGES) for k in processed)
+
+
+# =============================================================================
+# Configuration Tests
+# =============================================================================
+
+
+def test_default_keys():
+ """Test default state_keys and camera_keys values."""
+ processor = IsaaclabArenaProcessorStep()
+
+ assert processor.state_keys == ("robot_joint_pos",)
+ assert processor.camera_keys == ("robot_pov_cam_rgb",)
+
+
+def test_custom_keys_configuration():
+ """Test processor with custom state and camera keys."""
+ processor = IsaaclabArenaProcessorStep(
+ state_keys=("pos_xyz", "quat_wxyz", "grip_val"),
+ camera_keys=("front_view", "wrist_view"),
+ )
+
+ obs = {
+ f"{OBS_STR}.policy": {
+ "pos_xyz": torch.randn(BATCH_SIZE, 3),
+ "quat_wxyz": torch.randn(BATCH_SIZE, 4),
+ "grip_val": torch.randn(BATCH_SIZE, 1),
+ },
+ f"{OBS_STR}.camera_obs": {
+ "front_view": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8),
+ "wrist_view": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8),
+ },
+ }
+
+ processed = processor.observation(obs)
+
+ # State should be concatenated: 3 + 4 + 1 = 8
+ assert processed[OBS_STATE].shape == (BATCH_SIZE, 8)
+ # Both cameras should be present
+ assert f"{OBS_IMAGES}.front_view" in processed
+ assert f"{OBS_IMAGES}.wrist_view" in processed
+
+
+# =============================================================================
+# transform_features Tests
+# =============================================================================
+
+
+def test_transform_features_passthrough(processor):
+ """Test that transform_features returns features unchanged."""
+ input_features = {
+ PipelineFeatureType.OBSERVATION: {
+ "observation.state": PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(16,),
+ ),
+ "observation.images.cam": PolicyFeature(
+ type=FeatureType.VISUAL,
+ shape=(3, 64, 64),
+ ),
+ },
+ PipelineFeatureType.ACTION: {
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
+ },
+ }
+
+ output_features = processor.transform_features(input_features)
+
+ # Should be unchanged
+ assert output_features == input_features
diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py
index 134228c05c..58a83fe69b 100644
--- a/tests/processor/test_pipeline.py
+++ b/tests/processor/test_pipeline.py
@@ -17,7 +17,7 @@
import json
import tempfile
from collections.abc import Callable
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@@ -1884,7 +1884,7 @@ class FeatureContractAddStep(ProcessorStep):
"""Adds a PolicyFeature"""
key: str = "a"
- value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,))
+ value: PolicyFeature = field(default_factory=lambda: PolicyFeature(type=FeatureType.STATE, shape=(1,)))
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition