diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7766b34725..b59d15734c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,6 +59,8 @@ title: Environments from the Hub - local: envhub_leisaac title: Control & Train Robots in Sim (LeIsaac) + - local: envhub_isaaclab_arena + title: NVIDIA IsaacLab Arena Environments - local: libero title: Using Libero - local: metaworld diff --git a/docs/source/envhub_isaaclab_arena.mdx b/docs/source/envhub_isaaclab_arena.mdx new file mode 100644 index 0000000000..f0f1c33ad0 --- /dev/null +++ b/docs/source/envhub_isaaclab_arena.mdx @@ -0,0 +1,395 @@ +# NVIDIA IsaacLab Arena & LeRobot + +LeRobot EnvHub now supports **GPU-accelerated simulation** with IsaacLab Arena for policy evaluation at scale. +Train and evaluate imitation learning policies with high-fidelity simulation — all integrated into the LeRobot ecosystem. + +IsaacLab Arena - GR1 Microwave Environment + +[IsaacLab Arena](https://github.com/isaac-sim/IsaacLab-Arena) integrates with NVIDIA IsaacLab to provide: + +- 🤖 **Humanoid embodiments**: GR1, G1, Galileo with various configurations +- 🎯 **Manipulation & loco-manipulation tasks**: Microwave opening, pick-and-place, button pressing +- ⚡ **GPU-accelerated rollouts**: Parallel environment execution on NVIDIA GPUs +- 📦 **LeRobot-compatible datasets**: Ready for training with PI0, SmolVLA, ACT, Diffusion policies +- 🔄 **EnvHub integration**: Load environments from HuggingFace Hub with one line + +## Available Environments + +The following environments are currently available in IsaacLab Arena: + +| Preview | Environment | Description | +| :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------- | :------------------------------------------------------------------------------------------------------------------- | +| GR1 Microwave | `gr1_microwave` | Reach out to the microwave and open it. | +| Galileo Pick and Place | `galileo_pnp` | Pick objects and place in target location | +| G1 Loco-manipulation | `g1_locomanip_pnp` | Pick up the brown box from the shelf, and place it into the blue bin on the table located at the right of the shelf. | +| Kitchen Pick and Place | `kitchen_pnp` | Kitchen object manipulation tasks | +| Press Button | `press_button` | Locate and press button | + +## Quick Start + +### Load Environment from Hub + +```python +from lerobot.envs.factory import make_env + +# Load IsaacLab Arena environment from the Hub +envs_dict = make_env( + "nvkartik/isaaclab-arena-envs", + n_envs=4, + trust_remote_code=True, + environment="gr1_microwave", + embodiment="gr1_pink", + headless=True, + enable_cameras=True, +) + +# Access the environment +suite_name = next(iter(envs_dict)) +env = envs_dict[suite_name][0] + +# Run a simple episode +obs, info = env.reset() +for _ in range(100): + action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(action) + if terminated.any() or truncated.any(): + obs, info = env.reset() + +env.close() +``` + +### Using the Native Configuration + +For CLI integration and streamlined configuration, use the `isaaclab_arena` environment type: + +```bash +# Evaluate a trained policy +lerobot-eval \ + --policy.path=nvkartik/smolvla-arena-gr1-microwave \ + --env.type=isaaclab_arena \ + --env.environment=gr1_microwave \ + --env.embodiment=gr1_pink \ + --env.num_envs=1 \ + --env.headless=true \ + --env.enable_cameras=true \ + --policy.device=cuda +``` + +## Installation + +### Prerequisites + +- NVIDIA GPU with CUDA support +- NVIDIA driver compatible with IsaacSim 5.1.0 +- Linux (Ubuntu 22.04 recommended) + +### Setup + +```bash +# 1. Create conda environment +conda create -y -n lerobot-arena python=3.11 +conda activate lerobot-arena +conda install -y -c conda-forge ffmpeg=7.1.1 + +# 2. Install Isaac Sim 5.1.0 +pip install "isaacsim[all,extscache]==5.1.0" --extra-index-url https://pypi.nvidia.com + +# Accept NVIDIA EULA (required) +export ACCEPT_EULA=Y +export PRIVACY_CONSENT=Y + +# 3. Install IsaacLab 2.3.0 +git clone https://github.com/isaac-sim/IsaacLab.git +cd IsaacLab +git checkout v2.3.0 +./isaaclab.sh -i +cd .. + +# 4. Install LeRobot +git clone https://github.com/huggingface/lerobot.git +cd lerobot +pip install -e ".[pi]" +cd .. + +# 5. Install IsaacLab Arena +git clone git@github.com:isaac-sim/IsaacLab-Arena.git +cd IsaacLab-Arena +pip install -e . +cd .. + +# 6. Install additional dependencies +pip install qpsolvers==4.8.1 numpy==1.26.0 + +# 7. (Optional) Setup Weights & Biases +pip install wandb +wandb login +``` + +## Training Policies + +### Using Pre-collected Datasets + +IsaacLab Arena datasets are available on HuggingFace Hub: + +| Dataset | Description | Frames | +| :---------------------------------------------------------------------------------------------------------- | :------------------------- | :----- | +| [Arena-GR1-Manipulation-Task](https://huggingface.co/datasets/nvkartik/Arena-GR1-Manipulation-Task) | GR1 microwave manipulation | ~4K | +| [Arena-G1-Loco-Manipulation-Task](https://huggingface.co/datasets/nvkartik/Arena-G1-Loco-Manipulation-Task) | G1 loco-manipulation | ~4K | + +### Train PI0.5 Policy + +```bash +lerobot-train \ + --policy.type=pi05 \ + --dataset.repo_id=nvkartik/Arena-GR1-Manipulation-Task \ + --rename_map='{"observation.images.robot_pov_cam":"observation.images.camera1"}' \ + --policy.empty_cameras=2 \ + --policy.max_state_dim=64 \ + --policy.max_action_dim=64 \ + --batch_size=16 \ + --steps=20000 \ + --output_dir=outputs/train/pi05-arena \ + --policy.device=cuda \ + --wandb.enable=true \ + --save_freq=500 \ + --log_freq=50 +``` + +### Train SmolVLA Policy + +```bash +lerobot-train \ + --policy.type=smolvla \ + --dataset.repo_id=nvkartik/Arena-GR1-Manipulation-Task \ + --rename_map='{"observation.images.robot_pov_cam":"observation.images.camera1"}' \ + --batch_size=8 \ + --steps=10000 \ + --output_dir=outputs/train/smolvla-arena \ + --policy.device=cuda +``` + +## Evaluating Policies + +### Pre-trained Policies + +The following trained policies are available: + +| Policy | Architecture | Task | Link | +| :-------------------------- | :----------- | :------------ | :------------------------------------------------------------------------- | +| pi05-arena-gr1-microwave | PI0.5 | GR1 Microwave | [HuggingFace](https://huggingface.co/nvkartik/pi05-arena-gr1-microwave) | +| smolvla-arena-gr1-microwave | SmolVLA | GR1 Microwave | [HuggingFace](https://huggingface.co/nvkartik/smolvla-arena-gr1-microwave) | + +### Evaluate SmolVLA + +```bash +lerobot-eval \ + --policy.path=nvkartik/smolvla-arena-gr1-microwave \ + --env.type=isaaclab_arena \ + --rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \ + --env.environment=gr1_microwave \ + --env.embodiment=gr1_pink \ + --env.object=mustard_bottle \ + --env.num_envs=1 \ + --env.headless=true \ + --policy.device=cuda \ + --env.enable_cameras=true \ + --env.video=true \ + --env.video_length=10 \ + --env.video_interval=15 +``` + +### Evaluate PI0.5 + +PI0.5 requires disabling torch compile for evaluation: + +```bash +TORCH_COMPILE_DISABLE=1 TORCHINDUCTOR_DISABLE=1 lerobot-eval \ + --policy.path=nvkartik/pi05-arena-gr1-microwave \ + --env.type=isaaclab_arena \ + --rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \ + --env.environment=gr1_microwave \ + --env.embodiment=gr1_pink \ + --env.object=mustard_bottle \ + --env.num_envs=1 \ + --env.headless=true \ + --policy.device=cuda \ + --env.enable_cameras=true \ + --env.video=true \ + --env.video_length=15 \ + --env.video_interval=15 +``` + +## Environment Configuration + +### Full Configuration Options + +```python +from lerobot.envs.configs import IsaaclabArenaEnv + +config = IsaaclabArenaEnv( + # Environment selection + environment="gr1_microwave", # Task environment + embodiment="gr1_pink", # Robot embodiment + object="power_drill", # Object to manipulate + + # Simulation settings + num_envs=4, # Number of parallel environments + episode_length=300, # Max steps per episode + headless=True, # Run without GUI + device="cuda:0", # GPU device + seed=42, # Random seed + + # Observation configuration + state_keys="robot_joint_pos", # State observation keys (comma-separated) + camera_keys="robot_pov_cam_rgb", # Camera observation keys (comma-separated) + state_dim=54, # Expected state dimension + action_dim=36, # Expected action dimension + camera_height=512, # Camera image height + camera_width=512, # Camera image width + enable_cameras=True, # Enable camera observations + + # Video recording + video=False, # Enable video recording + video_length=100, # Frames per video + video_interval=200, # Steps between recordings + + # Advanced + mimic=False, # Enable mimic mode + teleop_device=None, # Teleoperation device + disable_fabric=False, # Disable fabric optimization + enable_pinocchio=True, # Enable Pinocchio for IK +) +``` + +## Zero-Agent Environment Test + +Test environment loading without a trained policy: + +```python +import logging +from dataclasses import asdict, dataclass +from pprint import pformat +import torch +import tqdm +from lerobot import envs +from lerobot.configs import parser +from lerobot.envs.configs import IsaaclabArenaEnv + +@dataclass +class ArenaConfig: + env: envs.EnvConfig + +@parser.wrap() +def main(cfg: ArenaConfig): + """Run zero action rollout for IsaacLab Arena environment.""" + logging.info(pformat(asdict(cfg))) + + from lerobot.envs.factory import make_env + + env_kwargs = asdict(cfg.env) + env_kwargs.pop("features", None) + env_kwargs.pop("features_map", None) + + env_dict = make_env( + "nvkartik/isaaclab-arena-envs", + n_envs=cfg.env.num_envs, + trust_remote_code=True, + **env_kwargs, + ) + env = next(iter(env_dict.values()))[0] + env.reset() + + for _ in tqdm.tqdm(range(cfg.env.episode_length)): + with torch.inference_mode(): + action_dim = env.action_space.shape[-1] + actions = torch.zeros((env.num_envs, action_dim), device=env.device) + obs, rewards, terminated, truncated, info = env.step(actions) + +if __name__ == "__main__": + main() +``` + +Run with: + +```bash +python test_env.py \ + --env.type=isaaclab_arena \ + --env.environment=galileo_pnp \ + --env.embodiment=gr1_pink \ + --env.object=cracker_box \ + --env.num_envs=4 \ + --env.enable_cameras=true \ + --env.seed=1000 +``` + +### Vector Environment Wrapper + +IsaacLab uses GPU-batched execution (all environments run on GPU simultaneously). The `IsaacLabVectorEnvWrapper` provides VectorEnv compatibility: + +```python +class IsaacLabVectorEnvWrapper: + """Wrapper adapting IsaacLab batched GPU env to VectorEnv interface.""" + + @property + def num_envs(self) -> int: + return self._num_envs + + def reset(self, *, seed=None, options=None): + # Handle seed list → single seed for IsaacLab + ... + + def step(self, actions): + # Convert actions to GPU tensors, execute, return numpy + ... + + def render_all(self) -> list[np.ndarray]: + # Return list of RGB frames for video recording + ... + + ... +``` + +## Troubleshooting + +### "CUDA out of memory" + +Reduce `num_envs` or use a GPU with more VRAM: + +```bash +--env.num_envs=1 +``` + +### "EULA not accepted" + +Set environment variables before running: + +```bash +export ACCEPT_EULA=Y +export PRIVACY_CONSENT=Y +``` + +### Video recording not working + +Enable cameras when running headless: + +```bash +--env.video=true --env.enable_cameras=true --env.headless=true +``` + +### Policy output dimension mismatch + +E.g. ensure `action_dim` matches your policy: + +```bash +--env.action_dim=36 +``` + +## See Also + +- [EnvHub Documentation](./envhub.mdx) - General EnvHub usage +- [IsaacLab Arena GitHub](https://github.com/isaac-sim/IsaacLab-Arena) +- [IsaacLab Documentation](https://isaac-sim.github.io/IsaacLab/) diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index 2f085da560..64ef154ff5 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -16,6 +16,7 @@ from dataclasses import dataclass, field from logging import getLogger from pathlib import Path +from typing import Any from lerobot import envs, policies # noqa: F401 from lerobot.configs import parser @@ -30,7 +31,7 @@ class EvalPipelineConfig: # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch # (useful for debugging). This argument is mutually exclusive with `--config`. - env: envs.EnvConfig + env: envs.EnvConfig # | str eval: EvalConfig = field(default_factory=EvalConfig) policy: PreTrainedConfig | None = None output_dir: Path | None = None @@ -38,6 +39,10 @@ class EvalPipelineConfig: seed: int | None = 1000 # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) + # Additional kwargs to pass to hub environments (e.g., config_path, config_overrides, custom params) + env_kwargs: dict[str, Any] = field(default_factory=dict) + # Explicit consent to execute remote code from the Hub (required for hub environments). + trust_remote_code: bool = False def __post_init__(self) -> None: # HACK: We parse again the cli args here to get the pretrained path if there was one. @@ -52,13 +57,21 @@ def __post_init__(self) -> None: "No pretrained path was provided, evaluated policy will be built from scratch (random weights)." ) + # Parse env_kwargs from CLI (e.g., --env_kwargs.headless=true) + env_kwargs_overrides = parser.get_cli_overrides("env_kwargs") + if env_kwargs_overrides: + for arg in env_kwargs_overrides: + # arg format: "--key=value" + key, value = arg.removeprefix("--").split("=", 1) + self.env_kwargs[key] = value + if not self.job_name: if self.env is None: self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}" else: - self.job_name = ( - f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}" - ) + # Added for hub environments + env_type = self.env.type if isinstance(self.env, envs.EnvConfig) else self.env.split("/")[-1] + self.job_name = f"{env_type}_{self.policy.type if self.policy is not None else 'scratch'}" logger.warning(f"No job name provided, using '{self.job_name}' as job name.") diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py index d767b6e8cc..42407649f6 100644 --- a/src/lerobot/envs/__init__.py +++ b/src/lerobot/envs/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401 +from .isaaclab import IsaacLabEnvWrapper # noqa: F401 diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 4323f33166..b119608796 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -47,6 +47,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): features_map: dict[str, str] = field(default_factory=dict) max_parallel_tasks: int = 1 disable_env_checker: bool = True + hub_path: str | None = None @property def type(self) -> str: @@ -368,3 +369,61 @@ def gym_kwargs(self) -> dict: "obs_type": self.obs_type, "render_mode": self.render_mode, } + + +@EnvConfig.register_subclass("isaaclab_arena") +@dataclass +class IsaaclabArenaEnv(EnvConfig): + hub_path: str = "nvkartik/isaaclab-arena-envs" + episode_length: int = 300 + # num_envs: int = 1 + embodiment: str | None = "gr1_pink" + object: str | None = "power_drill" + mimic: bool = False + teleop_device: str | None = None + seed: int | None = 42 + device: str | None = "cuda:0" + disable_fabric: bool = False + enable_cameras: bool = False + headless: bool = False + enable_pinocchio: bool = True + environment: str | None = "gr1_microwave" + task: str | None = "Reach out to the microwave and open it." + state_dim: int = 54 + action_dim: int = 36 + camera_height: int = 512 + camera_width: int = 512 + video: bool = False + video_length: int = 100 + video_interval: int = 200 + # Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos" + state_keys: str = "robot_joint_pos" + # Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb" + # Set to None or "" for environments without cameras + camera_keys: str | None = None + features: dict[str, PolicyFeature] = field(default_factory=dict) + features_map: dict[str, str] = field(default_factory=dict) + + def __post_init__(self): + # Set action feature + self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,)) + self.features_map[ACTION] = ACTION + + # Set state feature + self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,)) + self.features_map[OBS_STATE] = OBS_STATE + + # Add camera features for each camera key + if self.enable_cameras and self.camera_keys: + for cam_key in self.camera_keys.split(","): + cam_key = cam_key.strip() + if cam_key: + self.features[cam_key] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(self.camera_height, self.camera_width, 3), + ) + self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}" + + @property + def gym_kwargs(self) -> dict: + return {} diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index b39cfee718..c83005a9ec 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -20,11 +20,11 @@ from gymnasium.envs.registration import registry as gym_registry from lerobot.configs.policies import PreTrainedConfig -from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.processor import ProcessorStep -from lerobot.processor.env_processor import LiberoProcessorStep +from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep from lerobot.processor.pipeline import PolicyProcessorPipeline @@ -73,6 +73,26 @@ def make_env_pre_post_processors( if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: preprocessor_steps.append(LiberoProcessorStep()) + # For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep + if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type: + # Parse comma-separated keys (handle None for state-based policies) + if env_cfg.state_keys: + state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip()) + else: + state_keys = () + if env_cfg.camera_keys: + camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip()) + else: + camera_keys = () + if not state_keys and not camera_keys: + raise ValueError("At least one of state_keys or camera_keys must be specified.") + preprocessor_steps.append( + IsaaclabArenaProcessorStep( + state_keys=state_keys, + camera_keys=camera_keys, + ) + ) + preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps) postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps) @@ -85,6 +105,7 @@ def make_env( use_async_envs: bool = False, hub_cache_dir: str | None = None, trust_remote_code: bool = False, + **kwargs, ) -> dict[str, dict[int, gym.vector.VectorEnv]]: """Makes a gym vector environment according to the config or Hub reference. @@ -98,7 +119,8 @@ def make_env( hub_cache_dir (str | None): Optional cache path for downloaded hub files. trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub. Default False — must be set to True to import/exec hub `env.py`. - + **kwargs: Additional keyword arguments passed to the hub environment's `make_env` function. + Useful for passing custom configurations like `config_path`, `config_overrides`, etc. Raises: ValueError: if n_envs < 1 ModuleNotFoundError: If the requested env package is not installed @@ -113,18 +135,32 @@ def make_env( # if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py") # simplified: only support hub-provided `make_env` if isinstance(cfg, str): + hub_path: str | None = cfg + else: + hub_path = cfg.hub_path + + # If hub_path is set, download and call hub-provided `make_env` + if hub_path: # _download_hub_file will raise the same RuntimeError if trust_remote_code is False - repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir) + repo_id, file_path, local_file, revision = _download_hub_file( + hub_path, trust_remote_code, hub_cache_dir + ) # import and surface clear import errors module = _import_hub_module(local_file, repo_id) # call the hub-provided make_env - raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs) + env_cfg = None if isinstance(cfg, str) else cfg + raw_result = _call_make_env( + module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg, **kwargs + ) # normalize the return into {suite: {task_id: vec_env}} return _normalize_hub_result(raw_result) + # At this point, cfg must be an EnvConfig (not a string) since hub_path would have been set otherwise + assert not isinstance(cfg, str), "cfg should be EnvConfig at this point" + if n_envs < 1: raise ValueError("`n_envs` must be at least 1") diff --git a/src/lerobot/envs/isaaclab.py b/src/lerobot/envs/isaaclab.py new file mode 100644 index 0000000000..8d34fda410 --- /dev/null +++ b/src/lerobot/envs/isaaclab.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import atexit +import logging +import os +import signal +from contextlib import suppress +from typing import Any + +import gymnasium as gym +import numpy as np +import torch + +from lerobot.utils.errors import IsaacLabArenaError + + +def cleanup_isaaclab(env, simulation_app) -> None: + """Cleanup IsaacLab env and simulation app resources.""" + # Ignore signals during cleanup to prevent interruption + old_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN) + old_sigterm = signal.signal(signal.SIGTERM, signal.SIG_IGN) + try: + with suppress(Exception): + if env is not None: + env.close() + with suppress(Exception): + if simulation_app is not None: + simulation_app.app.close() + finally: + # Restore signal handlers + signal.signal(signal.SIGINT, old_sigint) + signal.signal(signal.SIGTERM, old_sigterm) + + +class IsaacLabEnvWrapper(gym.vector.AsyncVectorEnv): + """Wrapper adapting IsaacLab batched GPU env to AsyncVectorEnv. + IsaacLab handles vectorization internally on GPU. We inherit from + AsyncVectorEnv for compatibility with LeRobot.""" + + metadata = {"render_modes": ["rgb_array"], "render_fps": 30} + _cleanup_in_progress = False # Class-level flag for re-entrant protection + + def __init__( + self, + env, + episode_length: int = 500, + task: str | None = None, + render_mode: str | None = "rgb_array", + simulation_app=None, + ): + self._env = env + self._num_envs = env.num_envs + self._episode_length = episode_length + self._closed = False + self.render_mode = render_mode + self._simulation_app = simulation_app + + self.observation_space = env.observation_space + self.action_space = env.action_space + self.single_observation_space = env.observation_space + self.single_action_space = env.action_space + self.task = task + + if hasattr(env, "metadata") and env.metadata: + self.metadata = {**self.metadata, **env.metadata} + + # Register cleanup handlers + atexit.register(self._cleanup) + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) + + def _signal_handler(self, signum, frame): + if IsaacLabEnvWrapper._cleanup_in_progress: + return # Prevent re-entrant cleanup + IsaacLabEnvWrapper._cleanup_in_progress = True + logging.info(f"Received signal {signum}, cleaning up...") + self._cleanup() + # Exit without raising to avoid propagating through callbacks + os._exit(0) + + def _check_closed(self): + if self._closed: + raise IsaacLabArenaError() + + @property + def unwrapped(self): + return self + + @property + def num_envs(self) -> int: + return self._num_envs + + @property + def _max_episode_steps(self) -> int: + return self._episode_length + + @property + def device(self) -> str: + return getattr(self._env, "device", "cpu") + + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + self._check_closed() + if isinstance(seed, (list, tuple, range)): + seed = seed[0] if len(seed) > 0 else None + + obs, info = self._env.reset(seed=seed, options=options) + + if "final_info" not in info: + zeros = np.zeros(self._num_envs, dtype=bool) + info["final_info"] = {"is_success": zeros} + + return obs, info + + def step( + self, actions: np.ndarray | torch.Tensor + ) -> tuple[dict, np.ndarray, np.ndarray, np.ndarray, dict]: + self._check_closed() + if isinstance(actions, np.ndarray): + actions = torch.from_numpy(actions).to(self._env.device) + + obs, reward, terminated, truncated, info = self._env.step(actions) + + # Convert to numpy for gym compatibility + reward = reward.cpu().numpy().astype(np.float32) + terminated = terminated.cpu().numpy().astype(bool) + truncated = truncated.cpu().numpy().astype(bool) + + is_success = self._get_success(terminated, truncated) + info["final_info"] = {"is_success": is_success} + + return obs, reward, terminated, truncated, info + + def _get_success(self, terminated: np.ndarray, truncated: np.ndarray) -> np.ndarray: + is_success = np.zeros(self._num_envs, dtype=bool) + + if not hasattr(self._env, "termination_manager"): + return is_success & (terminated | truncated) + + term_manager = self._env.termination_manager + if not hasattr(term_manager, "get_term"): + return is_success & (terminated | truncated) + + success_tensor = term_manager.get_term("success") + if success_tensor is None: + return is_success & (terminated | truncated) + + is_success = success_tensor.cpu().numpy().astype(bool) + + return is_success & (terminated | truncated) + + def call(self, method_name: str, *args, **kwargs) -> list[Any]: + if method_name == "_max_episode_steps": + return [self._episode_length] * self._num_envs + if method_name == "task": + return [self.task] * self._num_envs + if method_name == "render": + return self.render_all() + + if hasattr(self._env, method_name): + attr = getattr(self._env, method_name) + result = attr(*args, **kwargs) if callable(attr) else attr + if isinstance(result, list): + return result + return [result] * self._num_envs + + raise AttributeError(f"IsaacLab-Arena has no method/attribute '{method_name}'") + + def render_all(self) -> list[np.ndarray]: + self._check_closed() + frames = self.render() + if frames is None: + placeholder = np.zeros((480, 640, 3), dtype=np.uint8) + return [placeholder] * self._num_envs + + if frames.ndim == 4: + return [frames[i] for i in range(min(len(frames), self._num_envs))] + + return [np.zeros((480, 640, 3), dtype=np.uint8)] * self._num_envs + + def render(self) -> np.ndarray | None: + """Render all environments and return list of frames.""" + self._check_closed() + if self.render_mode != "rgb_array": + return None + + frames = self._env.render() if hasattr(self._env, "render") else None + if frames is None: + return None + + if isinstance(frames, torch.Tensor): + frames = frames.cpu().numpy() + + return frames[0] if frames.ndim == 4 else frames + + def _cleanup(self) -> None: + if self._closed: + return + self._closed = True + IsaacLabEnvWrapper._cleanup_in_progress = True + logging.info("Cleaning up IsaacLab Arena environment...") + cleanup_isaaclab(self._env, self._simulation_app) + + def close(self) -> None: + self._cleanup() + + @property + def envs(self) -> list[IsaacLabEnvWrapper]: + return [self] * self._num_envs + + def __del__(self): + self._cleanup() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._cleanup() + return False diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 8d0f249221..082df83319 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -98,6 +98,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten if "robot_state" in observations: return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"]) + + # Handle IsaacLab Arena format: observations have 'policy' and 'camera_obs' keys + if "policy" in observations: + return_observations[f"{OBS_STR}.policy"] = observations["policy"] + + if "camera_obs" in observations: + return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"] + return return_observations @@ -302,16 +310,22 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any: return module -def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any: +def _call_make_env( + module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None, **kwargs: Any +) -> Any: """ Ensure module exposes make_env and call it. """ if not hasattr(module, "make_env"): raise AttributeError( - f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`." + f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool, cfg: EnvConfig | None, **kwargs)`." ) entry_fn = module.make_env - return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs) + # Only pass cfg if it's not None (i.e., when an EnvConfig was provided, not a string hub ID) + if cfg is not None: + return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg, **kwargs) + else: + return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, **kwargs) def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]: diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index b1872b0328..bc69d9f95a 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -18,7 +18,7 @@ import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -152,3 +152,78 @@ def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor: result[mask] = axis * angle.unsqueeze(1) return result + + +@dataclass +@ProcessorStepRegistry.register(name="isaaclab_arena_processor") +class IsaaclabArenaProcessorStep(ObservationProcessorStep): + """ + Processes IsaacLab Arena observations into LeRobot format. + + **State Processing:** + - Extracts state components from obs["policy"] based on `state_keys`. + - Concatenates into a flat vector mapped to "observation.state". + + **Image Processing:** + - Extracts images from obs["camera_obs"] based on `camera_keys`. + - Converts from (B, H, W, C) uint8 to (B, C, H, W) float32 [0, 1]. + - Maps to "observation.images.". + """ + + # Configurable from IsaacLabEnv config / cli args: --env.state_keys="robot_joint_pos,left_eef_pos" + state_keys: tuple[str, ...] = ("robot_joint_pos",) + + # Configurable from IsaacLabEnv config / cli args: --env.camera_keys="robot_pov_cam_rgb" + camera_keys: tuple[str, ...] = ("robot_pov_cam_rgb",) + + def _process_observation(self, observation): + """ + Processes both image and policy state observations from IsaacLab Arena. + """ + processed_obs = {} + + if f"{OBS_STR}.camera_obs" in observation: + camera_obs = observation[f"{OBS_STR}.camera_obs"] + + for cam_name, img in camera_obs.items(): + if cam_name not in self.camera_keys: + continue + + img = img.permute(0, 3, 1, 2).contiguous() + if img.dtype == torch.uint8: + img = img.float() / 255.0 + elif img.dtype != torch.float32: + img = img.float() + + processed_obs[f"{OBS_IMAGES}.{cam_name}"] = img + + # Process policy state -> observation.state + if f"{OBS_STR}.policy" in observation: + policy_obs = observation[f"{OBS_STR}.policy"] + + # Collect state components in order + state_components = [] + for key in self.state_keys: + if key in policy_obs: + component = policy_obs[key] + # Flatten extra dims: (B, N, M) -> (B, N*M) + if component.dim() > 2: + batch_size = component.shape[0] + component = component.view(batch_size, -1) + state_components.append(component) + + if state_components: + state = torch.cat(state_components, dim=-1) + state = state.float() + processed_obs[OBS_STATE] = state + + return processed_obs + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Not used for policy evaluation.""" + return features + + def observation(self, observation): + return self._process_observation(observation) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d23b9d083a..0014c9694d 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -43,6 +43,17 @@ Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files. +You can also evaluate a model on a Hub environment with custom kwargs: +``` +lerobot-eval \ + --policy.path=HF_USER/HF_REPO \ + --env=HF_USER/HF_REPO \ + --eval.batch_size=1 \ + --eval.n_episodes=10 \ + --env_kwargs.environment=env_A \ + --env_kwargs.embodiment=emb_B \ +``` + You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py """ @@ -509,7 +520,13 @@ def eval_main(cfg: EvalPipelineConfig): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info("Making environment.") - envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + envs = make_env( + cfg.env, + n_envs=cfg.eval.batch_size, + use_async_envs=cfg.eval.use_async_envs, + trust_remote_code=cfg.trust_remote_code, + **cfg.env_kwargs, + ) logging.info("Making policy.") diff --git a/src/lerobot/utils/errors.py b/src/lerobot/utils/errors.py index 31b73eacab..6d04accfac 100644 --- a/src/lerobot/utils/errors.py +++ b/src/lerobot/utils/errors.py @@ -30,3 +30,35 @@ def __init__( ): self.message = message super().__init__(self.message) + + +class IsaacLabArenaError(RuntimeError): + """Base exception for IsaacLab Arena environment errors.""" + + def __init__(self, message: str = "IsaacLab Arena error"): + self.message = message + super().__init__(self.message) + + +class IsaacLabArenaConfigError(IsaacLabArenaError): + """Exception raised for invalid environment configuration.""" + + def __init__(self, invalid: list, available: list, key_type: str = "keys"): + msg = f"Invalid {key_type}: {invalid}. Available: {sorted(available)}" + super().__init__(msg) + self.invalid = invalid + self.available = available + + +class IsaacLabArenaCameraKeyError(IsaacLabArenaConfigError): + """Exception raised when camera_keys don't match available cameras.""" + + def __init__(self, invalid: list, available: list): + super().__init__(invalid, available, "camera_keys") + + +class IsaacLabArenaStateKeyError(IsaacLabArenaConfigError): + """Exception raised when state_keys don't match available state terms.""" + + def __init__(self, invalid: list, available: list): + super().__init__(invalid, available, "state_keys") diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 910c275eb4..853fdb3240 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -15,6 +15,7 @@ # limitations under the License. import importlib from dataclasses import dataclass, field +from unittest.mock import MagicMock import gymnasium as gym import numpy as np @@ -27,6 +28,7 @@ from lerobot.configs.types import PolicyFeature from lerobot.envs.configs import EnvConfig from lerobot.envs.factory import make_env, make_env_config +from lerobot.envs.isaaclab import IsaacLabEnvWrapper from lerobot.envs.utils import ( _normalize_hub_result, _parse_hub_url, @@ -266,3 +268,152 @@ def test_make_env_from_hub_async(): # clean up env.close() + + +# IsaacLabEnvWrapper tests (mock-based without installing IsaacLab) + + +def _create_mock_isaaclab_env(num_envs: int = 2, device: str = "cpu"): + """Create a mock IsaacLab environment for testing.""" + mock_env = MagicMock() + mock_env.num_envs = num_envs + mock_env.device = device + mock_env.observation_space = gym.spaces.Dict( + {"policy": gym.spaces.Box(low=-1, high=1, shape=(num_envs, 54), dtype=np.float32)} + ) + mock_env.action_space = gym.spaces.Box(low=-1, high=1, shape=(36,), dtype=np.float32) + mock_env.metadata = {} + return mock_env + + +def test_isaaclab_wrapper_init(): + """Test IsaacLabEnvWrapper initialization.""" + mock_env = _create_mock_isaaclab_env(num_envs=4) + + wrapper = IsaacLabEnvWrapper( + mock_env, + episode_length=300, + task="Test task", + render_mode="rgb_array", + ) + + assert wrapper.num_envs == 4 + assert wrapper._max_episode_steps == 300 + assert wrapper.task == "Test task" + assert wrapper.render_mode == "rgb_array" + assert wrapper.device == "cpu" + assert len(wrapper.envs) == 4 + + +def test_isaaclab_wrapper_reset(): + """Test IsaacLabEnvWrapper reset.""" + mock_env = _create_mock_isaaclab_env(num_envs=2) + mock_obs = {"policy": torch.randn(2, 54)} + mock_env.reset.return_value = (mock_obs, {}) + + wrapper = IsaacLabEnvWrapper(mock_env, episode_length=100) + obs, info = wrapper.reset(seed=42) + + mock_env.reset.assert_called_once_with(seed=42, options=None) + assert "final_info" in info + assert "is_success" in info["final_info"] + assert len(info["final_info"]["is_success"]) == 2 + + +def test_isaaclab_wrapper_reset_with_seed_list(): + """Test that seed list is handled correctly (IsaacLab expects single seed).""" + mock_env = _create_mock_isaaclab_env(num_envs=2) + mock_env.reset.return_value = ({"policy": torch.randn(2, 54)}, {}) + + wrapper = IsaacLabEnvWrapper(mock_env) + wrapper.reset(seed=[42, 43, 44]) + + # Should extract first seed + mock_env.reset.assert_called_once_with(seed=42, options=None) + + +def test_isaaclab_wrapper_step(): + """Test IsaacLabEnvWrapper step.""" + mock_env = _create_mock_isaaclab_env(num_envs=2) + mock_env.step.return_value = ( + {"policy": torch.randn(2, 54)}, + torch.tensor([0.5, 0.3]), + torch.tensor([False, False]), + torch.tensor([False, True]), + {}, + ) + # Mock termination manager + mock_env.termination_manager.get_term.return_value = torch.tensor([False, True]) + + wrapper = IsaacLabEnvWrapper(mock_env) + actions = np.random.randn(2, 36).astype(np.float32) + obs, reward, terminated, truncated, info = wrapper.step(actions) + + assert reward.dtype == np.float32 + assert terminated.dtype == bool + assert truncated.dtype == bool + assert len(reward) == 2 + assert "final_info" in info + assert "is_success" in info["final_info"] + + +def test_isaaclab_wrapper_call_method(): + """Test IsaacLabEnvWrapper call method.""" + mock_env = _create_mock_isaaclab_env(num_envs=3) + + wrapper = IsaacLabEnvWrapper(mock_env, episode_length=200, task="My task") + + # Test _max_episode_steps + result = wrapper.call("_max_episode_steps") + assert result == [200, 200, 200] + + # Test task + result = wrapper.call("task") + assert result == ["My task", "My task", "My task"] + + +def test_isaaclab_wrapper_render(): + """Test IsaacLabEnvWrapper render.""" + mock_env = _create_mock_isaaclab_env(num_envs=2) + mock_frames = torch.randint(0, 255, (2, 480, 640, 3), dtype=torch.uint8) + mock_env.render.return_value = mock_frames + + wrapper = IsaacLabEnvWrapper(mock_env, render_mode="rgb_array") + frame = wrapper.render() + + assert frame is not None + assert frame.shape == (480, 640, 3) # Returns first env frame + + +def test_isaaclab_wrapper_render_all(): + """Test IsaacLabEnvWrapper render_all.""" + mock_env = _create_mock_isaaclab_env(num_envs=2) + mock_frames = torch.randint(0, 255, (2, 480, 640, 3), dtype=torch.uint8) + mock_env.render.return_value = mock_frames + + wrapper = IsaacLabEnvWrapper(mock_env, render_mode="rgb_array") + frames = wrapper.render_all() + + assert len(frames) == 2 + assert all(f.shape == (480, 640, 3) for f in frames) + + +def test_isaaclab_wrapper_render_none(): + """Test render returns None when render_mode is not rgb_array.""" + mock_env = _create_mock_isaaclab_env() + + wrapper = IsaacLabEnvWrapper(mock_env, render_mode=None) + assert wrapper.render() is None + + +def test_isaaclab_wrapper_close(): + """Test IsaacLabEnvWrapper close.""" + mock_env = _create_mock_isaaclab_env() + mock_app = MagicMock() + + wrapper = IsaacLabEnvWrapper(mock_env, simulation_app=mock_app) + wrapper.close() + + mock_env.close.assert_called_once() + mock_app.app.close.assert_called_once() + assert wrapper._closed diff --git a/tests/processor/test_arena_processor.py b/tests/processor/test_arena_processor.py new file mode 100644 index 0000000000..c68eb1fb43 --- /dev/null +++ b/tests/processor/test_arena_processor.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from lerobot.configs.types import ( + FeatureType, + PipelineFeatureType, + PolicyFeature, +) +from lerobot.processor.env_processor import IsaaclabArenaProcessorStep +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR + +# Test constants +BATCH_SIZE = 2 +STATE_DIM = 16 +IMG_HEIGHT = 64 +IMG_WIDTH = 64 + +# Generic test keys (not real robot keys) +TEST_STATE_KEY = "test_state_obs" +TEST_CAMERA_KEY = "test_rgb_cam" + + +@pytest.fixture +def processor(): + """Default processor with test keys.""" + return IsaaclabArenaProcessorStep( + state_keys=(TEST_STATE_KEY,), + camera_keys=(TEST_CAMERA_KEY,), + ) + + +@pytest.fixture +def sample_observation(): + """Sample IsaacLab Arena observation with state and camera data.""" + return { + f"{OBS_STR}.policy": { + TEST_STATE_KEY: torch.randn(BATCH_SIZE, STATE_DIM), + }, + f"{OBS_STR}.camera_obs": { + TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), + }, + } + + +# ============================================================================= +# State Processing Tests +# ============================================================================= + + +def test_state_extraction(processor, sample_observation): + """Test that state is extracted and converted to float32.""" + processed = processor.observation(sample_observation) + + assert OBS_STATE in processed + assert processed[OBS_STATE].shape == (BATCH_SIZE, STATE_DIM) + assert processed[OBS_STATE].dtype == torch.float32 + + +def test_state_concatenation_multiple_keys(): + """Test that multiple state keys are concatenated in order.""" + dim1, dim2 = 10, 6 + processor = IsaaclabArenaProcessorStep( + state_keys=("state_alpha", "state_beta"), + camera_keys=(), + ) + + obs = { + f"{OBS_STR}.policy": { + "state_alpha": torch.ones(BATCH_SIZE, dim1), + "state_beta": torch.ones(BATCH_SIZE, dim2) * 2, + }, + } + + processed = processor.observation(obs) + + state = processed[OBS_STATE] + assert state.shape == (BATCH_SIZE, dim1 + dim2) + # Verify ordering: first dim1 elements are 1s, last dim2 are 2s + assert torch.all(state[:, :dim1] == 1.0) + assert torch.all(state[:, dim1:] == 2.0) + + +def test_state_flattening_higher_dims(): + """Test that state with dim > 2 is flattened to (B, -1).""" + processor = IsaaclabArenaProcessorStep( + state_keys=("multidim_state",), + camera_keys=(), + ) + + # Shape (B, 4, 4) -> should flatten to (B, 16) + obs = { + f"{OBS_STR}.policy": { + "multidim_state": torch.randn(BATCH_SIZE, 4, 4), + }, + } + + processed = processor.observation(obs) + + assert processed[OBS_STATE].shape == (BATCH_SIZE, 16) + + +def test_state_filters_to_configured_keys(): + """Test that only configured state_keys are extracted.""" + processor = IsaaclabArenaProcessorStep( + state_keys=("included_key",), + camera_keys=(), + ) + + obs = { + f"{OBS_STR}.policy": { + "included_key": torch.randn(BATCH_SIZE, 10), + "excluded_key": torch.randn(BATCH_SIZE, 6), # Should be ignored + }, + } + + processed = processor.observation(obs) + + # Only included_key (dim 10) should be included + assert processed[OBS_STATE].shape == (BATCH_SIZE, 10) + + +def test_missing_state_key_skipped(): + """Test that missing state keys in observation are skipped.""" + processor = IsaaclabArenaProcessorStep( + state_keys=("present_key", "missing_key"), + camera_keys=(), + ) + + obs = { + f"{OBS_STR}.policy": { + "present_key": torch.randn(BATCH_SIZE, 10), + # missing_key not present + }, + } + + processed = processor.observation(obs) + + # Should only have present_key + assert processed[OBS_STATE].shape == (BATCH_SIZE, 10) + + +# ============================================================================= +# Camera/Image Processing Tests +# ============================================================================= + + +def test_camera_permutation_bhwc_to_bchw(processor, sample_observation): + """Test images are permuted from (B, H, W, C) to (B, C, H, W).""" + processed = processor.observation(sample_observation) + + img_key = f"{OBS_IMAGES}.{TEST_CAMERA_KEY}" + assert img_key in processed + img = processed[img_key] + assert img.shape == (BATCH_SIZE, 3, IMG_HEIGHT, IMG_WIDTH) + + +def test_camera_uint8_to_normalized_float32(processor): + """Test that uint8 images are normalized to float32 [0, 1].""" + obs = { + f"{OBS_STR}.camera_obs": { + TEST_CAMERA_KEY: torch.full((BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), 255, dtype=torch.uint8), + }, + } + + processed = processor.observation(obs) + + img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"] + assert img.dtype == torch.float32 + assert torch.allclose(img, torch.ones_like(img)) + + +def test_camera_float32_passthrough(processor): + """Test that float32 images are kept as float32.""" + original_img = torch.rand(BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3, dtype=torch.float32) + obs = { + f"{OBS_STR}.camera_obs": { + TEST_CAMERA_KEY: original_img.clone(), + }, + } + + processed = processor.observation(obs) + + img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"] + assert img.dtype == torch.float32 + # Values should be same (just permuted) + expected = original_img.permute(0, 3, 1, 2) + assert torch.allclose(img, expected) + + +def test_camera_other_dtype_converted_to_float(processor): + """Test that non-uint8, non-float32 dtypes are converted to float.""" + obs = { + f"{OBS_STR}.camera_obs": { + TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.int32), + }, + } + + processed = processor.observation(obs) + + img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"] + assert img.dtype == torch.float32 + + +def test_camera_filters_to_configured_keys(): + """Test that only configured camera_keys are extracted.""" + processor = IsaaclabArenaProcessorStep( + state_keys=(), + camera_keys=("included_cam",), + ) + + obs = { + f"{OBS_STR}.camera_obs": { + "included_cam": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), + "excluded_cam": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), + }, + } + + processed = processor.observation(obs) + + assert f"{OBS_IMAGES}.included_cam" in processed + assert f"{OBS_IMAGES}.excluded_cam" not in processed + + +def test_camera_key_preserved_exactly(): + """Test that camera key name is used exactly (no suffix stripping).""" + processor = IsaaclabArenaProcessorStep( + state_keys=(), + camera_keys=("my_cam_rgb",), + ) + + obs = { + f"{OBS_STR}.camera_obs": { + "my_cam_rgb": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), + }, + } + + processed = processor.observation(obs) + + # Key should be exactly as configured, with _rgb suffix intact + assert f"{OBS_IMAGES}.my_cam_rgb" in processed + assert f"{OBS_IMAGES}.my_cam" not in processed + + +# ============================================================================= +# Edge Cases & Missing Data Tests +# ============================================================================= + + +def test_missing_camera_obs_section(processor): + """Test processor handles observation without camera_obs section.""" + obs = { + f"{OBS_STR}.policy": { + TEST_STATE_KEY: torch.randn(BATCH_SIZE, STATE_DIM), + }, + } + + processed = processor.observation(obs) + + assert OBS_STATE in processed + assert not any(k.startswith(OBS_IMAGES) for k in processed) + + +def test_missing_policy_obs_section(processor): + """Test processor handles observation without policy section.""" + obs = { + f"{OBS_STR}.camera_obs": { + TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), + }, + } + + processed = processor.observation(obs) + + assert f"{OBS_IMAGES}.{TEST_CAMERA_KEY}" in processed + assert OBS_STATE not in processed + + +def test_empty_observation(processor): + """Test processor handles empty observation dict.""" + processed = processor.observation({}) + + assert OBS_STATE not in processed + assert not any(k.startswith(OBS_IMAGES) for k in processed) + + +def test_no_matching_state_keys(): + """Test processor when no state keys match observation.""" + processor = IsaaclabArenaProcessorStep( + state_keys=("nonexistent_key",), + camera_keys=(), + ) + + obs = { + f"{OBS_STR}.policy": { + "some_other_key": torch.randn(BATCH_SIZE, STATE_DIM), + }, + } + + processed = processor.observation(obs) + + # No state because no keys matched + assert OBS_STATE not in processed + + +def test_no_matching_camera_keys(): + """Test processor when no camera keys match observation.""" + processor = IsaaclabArenaProcessorStep( + state_keys=(), + camera_keys=("nonexistent_cam",), + ) + + obs = { + f"{OBS_STR}.camera_obs": { + "some_other_cam": torch.randint( + 0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8 + ), + }, + } + + processed = processor.observation(obs) + + assert not any(k.startswith(OBS_IMAGES) for k in processed) + + +# ============================================================================= +# Configuration Tests +# ============================================================================= + + +def test_default_keys(): + """Test default state_keys and camera_keys values.""" + processor = IsaaclabArenaProcessorStep() + + assert processor.state_keys == ("robot_joint_pos",) + assert processor.camera_keys == ("robot_pov_cam_rgb",) + + +def test_custom_keys_configuration(): + """Test processor with custom state and camera keys.""" + processor = IsaaclabArenaProcessorStep( + state_keys=("pos_xyz", "quat_wxyz", "grip_val"), + camera_keys=("front_view", "wrist_view"), + ) + + obs = { + f"{OBS_STR}.policy": { + "pos_xyz": torch.randn(BATCH_SIZE, 3), + "quat_wxyz": torch.randn(BATCH_SIZE, 4), + "grip_val": torch.randn(BATCH_SIZE, 1), + }, + f"{OBS_STR}.camera_obs": { + "front_view": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), + "wrist_view": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), + }, + } + + processed = processor.observation(obs) + + # State should be concatenated: 3 + 4 + 1 = 8 + assert processed[OBS_STATE].shape == (BATCH_SIZE, 8) + # Both cameras should be present + assert f"{OBS_IMAGES}.front_view" in processed + assert f"{OBS_IMAGES}.wrist_view" in processed + + +# ============================================================================= +# transform_features Tests +# ============================================================================= + + +def test_transform_features_passthrough(processor): + """Test that transform_features returns features unchanged.""" + input_features = { + PipelineFeatureType.OBSERVATION: { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(16,), + ), + "observation.images.cam": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 64, 64), + ), + }, + PipelineFeatureType.ACTION: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + }, + } + + output_features = processor.transform_features(input_features) + + # Should be unchanged + assert output_features == input_features diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 134228c05c..58a83fe69b 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -17,7 +17,7 @@ import json import tempfile from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -1884,7 +1884,7 @@ class FeatureContractAddStep(ProcessorStep): """Adds a PolicyFeature""" key: str = "a" - value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,)) + value: PolicyFeature = field(default_factory=lambda: PolicyFeature(type=FeatureType.STATE, shape=(1,))) def __call__(self, transition: EnvTransition) -> EnvTransition: return transition