diff --git a/examples/jaxsim_walking.ipynb b/examples/jaxsim_walking.ipynb index 39b1677..6c57ab4 100644 --- a/examples/jaxsim_walking.ipynb +++ b/examples/jaxsim_walking.ipynb @@ -59,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -115,14 +115,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# ==== Set simulation parameters ====\n", "\n", "T = 6.0\n", - "js_dt = 0.000_5" + "js_dt = 0.001" ] }, { @@ -148,7 +148,7 @@ "source": [ "# ==== Define JaxSim simulator and set initial position ====\n", "\n", - "js = JaxsimSimulator(dt=js_dt, contact_model_type=JaxsimContactModelEnum.RIGID)\n", + "js = JaxsimSimulator(dt=js_dt, contact_model_type=JaxsimContactModelEnum.RELAXED_RIGID)\n", "js.load_model(\n", " robot_model=robot_model_init,\n", " s=s_0,\n", @@ -165,7 +165,7 @@ "print(f\"Contact model in use: {js._model.contact_model}\")\n", "print(f\"Link names:\\n{js.link_names}\")\n", "print(f\"Frame names:\\n{js.frame_names}\")\n", - "print(f\"Mass: {js.total_mass*js._data.standard_gravity()} N\")" + "print(f\"Mass: {-js.total_mass*js._model.gravity} N\")" ] }, { @@ -197,11 +197,30 @@ "# Set desired quantities\n", "mpc.configure(s_init=s_0, H_b_init=H_b_0)\n", "tsid.compute_com_position()\n", - "mpc.define_test_com_traj(tsid.COM.toNumPy())\n", - "\n", + "mpc.define_test_com_traj(tsid.COM.toNumPy())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ "# Set initial robot state and plan trajectories\n", + "\n", + "tic = time.perf_counter()\n", + "\n", "js.step(dry_run=True)\n", "\n", + "step_compilation_time_s = time.perf_counter() - tic" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ "# Reading the state\n", "s_js, ds_js, tau_js = js.get_state()\n", "H_b = js.base_transform\n", @@ -215,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -245,6 +264,7 @@ " tau_tsid_log = []\n", " W_p_CoM_tsid_log = []\n", " t_log = []\n", + " wall_time_step_log = []\n", "\n", " # Define number of steps\n", " n_step_tsid_js = int(tsid.frequency / js_dt)\n", @@ -309,16 +329,20 @@ "\n", " # Step the simulator\n", " js.set_input(tau_tsid)\n", + "\n", + " tic = time.perf_counter()\n", " js.step(n_step=n_step_tsid_js)\n", + " toc = time.perf_counter() - tic\n", + "\n", " counter = counter + 1\n", "\n", " if counter == n_step_mpc_tsid:\n", " counter = 0\n", "\n", " # Stop the simulation if the robot fell down\n", - " if js._data.base_position()[2] < 0.5:\n", - " print(f\"Robot fell down at t={t:.4f}s.\")\n", - " break\n", + " # if js._data.base_position()[2] < 0.5:\n", + " # print(f\"Robot fell down at t={t:.4f}s.\")\n", + " # break\n", "\n", " # Log data\n", " # TODO transform mpc contact forces to wrenches to be compared with jaxsim ones\n", @@ -336,6 +360,7 @@ " W_p_lf_sfp_log.append(lf_sfp.transform.translation())\n", " W_p_rf_sfp_log.append(rf_sfp.transform.translation())\n", " W_p_CoM_tsid_log.append(tsid.COM.toNumPy())\n", + " wall_time_step_log.append(toc * 1e3)\n", " if contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:\n", " f_lf_js, f_rf_js = js.feet_wrench\n", " f_lf_js_log.append(f_lf_js)\n", @@ -360,6 +385,7 @@ " \"W_p_lf_sfp\": np.array(W_p_lf_sfp_log),\n", " \"W_p_rf_sfp\": np.array(W_p_rf_sfp_log),\n", " \"W_p_CoM_tsid\": np.array(W_p_CoM_tsid_log),\n", + " \"wall_time_step\": np.array(wall_time_step_log),\n", " }\n", " if contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:\n", " logs[\"f_lf_js\"] = np.array(f_lf_js_log)\n", @@ -385,19 +411,16 @@ "avg_iter_time_ms = (wall_time / (T / js_dt)) * 1000\n", "\n", "print(\n", - " f\"\\nRunning simulation took {wall_time:.2f}s for {T:.3f}s simulated time. \\nIteration avg time of {avg_iter_time_ms:.1f} ms.\"\n", - ")\n", - "print(f\"RTF: {T / wall_time * 100:.2f}%\")" + " f\"\\nSimulation done.\\nRunning simulation took {wall_time:.2f}s for {T:.3f}s simulated time.\"\n", + ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "# ==== Plot the results ====\n", - "\n", "# Extract logged variables\n", "t = logs[\"t\"]\n", "s_js = logs[\"s_js\"]\n", @@ -412,9 +435,45 @@ "W_p_lf_sfp = logs[\"W_p_lf_sfp\"]\n", "W_p_rf_sfp = logs[\"W_p_rf_sfp\"]\n", "W_p_CoM_tsid = logs[\"W_p_CoM_tsid\"]\n", + "wall_time_step = logs[\"wall_time_step\"]\n", "if js.contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:\n", " f_lf_js = logs[\"f_lf_js\"]\n", - " f_rf_js = logs[\"f_rf_js\"]\n", + " f_rf_js = logs[\"f_rf_js\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute simulator step runtime statistics and RTF\n", + "\n", + "min_step_time = np.min(wall_time_step)\n", + "max_step_time = np.max(wall_time_step)\n", + "avg_step_time = np.mean(wall_time_step)\n", + "std_step_time = np.std(wall_time_step)\n", + "total_step_time = np.sum(wall_time_step)\n", + "rtf = (T * 1e3) / total_step_time * 100\n", + "\n", + "print(\"===========================================\")\n", + "print(f\"Step compilation time: {step_compilation_time_s:.2f} s\")\n", + "print(f\"Min step time: {min_step_time:.2f} ms\")\n", + "print(f\"Max step time: {max_step_time:.2f} ms\")\n", + "print(f\"Average step time: {avg_step_time:.2f} ms\")\n", + "print(f\"Std deviation step time: {std_step_time:.2f} ms\")\n", + "print(f\"RTF: {rtf:.1f}%\")\n", + "print(\"===========================================\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ==== Plot the results ====\n", + "\n", "\n", "n_sim_steps = s_js.shape[0]\n", "s_0 = np.full_like(a=s_js, fill_value=s_0)\n", @@ -577,7 +636,7 @@ ], "metadata": { "kernelspec": { - "display_name": "comodo", + "display_name": "comododev", "language": "python", "name": "python3" }, @@ -591,7 +650,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/examples/mjx_walking.ipynb b/examples/mjx_walking.ipynb new file mode 100644 index 0000000..2bf2379 --- /dev/null +++ b/examples/mjx_walking.ipynb @@ -0,0 +1,327 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mujoco MJX with TSID and MPC Example \n", + "This examples, load a basic robot model (i.e. composed only of basic shapes), modifies the links of such a robot model by elongating the legs, define instances of the TSID (Task Based Inverse Dynamics) and Centroidal MPC controller and simulate the behavior of the robot using MJX. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Comodo import\n", + "from comodo.mujocoSimulator.mjxSimulator import MJXSimulator\n", + "from comodo.robotModel.robotModel import RobotModel\n", + "from comodo.robotModel.createUrdf import createUrdf\n", + "from comodo.centroidalMPC.centroidalMPC import CentroidalMPC\n", + "from comodo.centroidalMPC.mpcParameterTuning import MPCParameterTuning\n", + "from comodo.TSIDController.TSIDParameterTuning import TSIDParameterTuning\n", + "from comodo.TSIDController.TSIDController import TSIDController\n", + "\n", + "import jax\n", + "\n", + "# Force JAX to use CPU\n", + "jax.config.update(\"jax_default_device\", jax.devices(\"cpu\")[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# General import\n", + "import xml.etree.ElementTree as ET\n", + "import numpy as np\n", + "import tempfile\n", + "import urllib.request\n", + "import os\n", + "import pathlib" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Getting stickbot urdf file\n", + "\n", + "path = pathlib.Path.cwd() / \"models\" / \"stickbot_mjx.urdf\"\n", + "# Load the URDF file\n", + "# tree = ET.parse(urdf_robot_file.name)\n", + "tree = ET.parse(path)\n", + "root = tree.getroot()\n", + "\n", + "# Convert the XML tree to a string\n", + "robot_urdf_string_original = ET.tostring(root)\n", + "\n", + "# create_urdf_instance = createUrdf(\n", + "# original_urdf_path=urdf_robot_file.name, save_gazebo_plugin=False\n", + "# )\n", + "create_urdf_instance = createUrdf(original_urdf_path=path, save_gazebo_plugin=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Define parametric links and controlled joints\n", + "legs_link_names = [\"hip_3\", \"lower_leg\"]\n", + "joint_name_list = [\n", + " \"r_shoulder_pitch\",\n", + " \"r_shoulder_roll\",\n", + " \"r_shoulder_yaw\",\n", + " \"r_elbow\",\n", + " \"l_shoulder_pitch\",\n", + " \"l_shoulder_roll\",\n", + " \"l_shoulder_yaw\",\n", + " \"l_elbow\",\n", + " \"r_hip_pitch\",\n", + " \"r_hip_roll\",\n", + " \"r_hip_yaw\",\n", + " \"r_knee\",\n", + " \"r_ankle_pitch\",\n", + " \"r_ankle_roll\",\n", + " \"l_hip_pitch\",\n", + " \"l_hip_roll\",\n", + " \"l_hip_yaw\",\n", + " \"l_knee\",\n", + " \"l_ankle_pitch\",\n", + " \"l_ankle_roll\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the robot modifications\n", + "modifications = {}\n", + "for item in legs_link_names:\n", + " left_leg_item = \"l_\" + item\n", + " right_leg_item = \"r_\" + item\n", + " modifications.update({left_leg_item: 1.2})\n", + " modifications.update({right_leg_item: 1.2})\n", + "# Motors Parameters\n", + "Im_arms = 1e-3 * np.ones(4) # from 0-4\n", + "Im_legs = 1e-3 * np.ones(6) # from 5-10\n", + "kv_arms = 0.001 * np.ones(4) # from 11-14\n", + "kv_legs = 0.001 * np.ones(6) # from 20\n", + "\n", + "Im = np.concatenate((Im_arms, Im_arms, Im_legs, Im_legs))\n", + "kv = np.concatenate((kv_arms, kv_arms, kv_legs, kv_legs))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Modify the robot model and initialize\n", + "create_urdf_instance.modify_lengths(modifications)\n", + "urdf_robot_string = create_urdf_instance.write_urdf_to_file()\n", + "create_urdf_instance.reset_modifications()\n", + "robot_model_init = RobotModel(urdf_robot_string, \"stickBot\", joint_name_list)\n", + "s_des, xyz_rpy, H_b = robot_model_init.compute_desired_position_walking()\n", + "robot_model_init.set_foot_corner(\n", + " np.asarray([0.1, 0.05, 0.0]),\n", + " np.asarray([0.1, -0.05, 0.0]),\n", + " np.asarray([-0.1, -0.05, 0.0]),\n", + " np.asarray([-0.1, 0.05, 0.0]),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Define simulator and set initial position\n", + "mujoco_instance = MJXSimulator()\n", + "mujoco_instance.load_model(\n", + " robot_model_init, s=s_des, xyz_rpy=xyz_rpy, kv_motors=kv, Im=Im\n", + ")\n", + "s, ds, tau = mujoco_instance.get_state()\n", + "t = mujoco_instance.get_simulation_time()\n", + "H_b = mujoco_instance.get_base()\n", + "w_b = mujoco_instance.get_base_velocity()\n", + "mujoco_instance.set_visualize_robot_flag(True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the controller parameters and instantiate the controller\n", + "# Controller Parameters\n", + "tsid_parameter = TSIDParameterTuning()\n", + "mpc_parameters = MPCParameterTuning()\n", + "\n", + "\n", + "## Controller gains\n", + "# Controller Parameters\n", + "tsid_parameter = TSIDParameterTuning()\n", + "tsid_parameter.foot_tracking_task_kp_lin = 150.0\n", + "tsid_parameter.foot_tracking_task_kd_lin = 40.0\n", + "tsid_parameter.root_tracking_task_weight = np.ones(3) * 50.0\n", + "\n", + "# TSID Instance\n", + "TSID_controller_instance = TSIDController(frequency=0.01, robot_model=robot_model_init)\n", + "TSID_controller_instance.define_tasks(tsid_parameter)\n", + "TSID_controller_instance.set_state_with_base(s, ds, H_b, w_b, t)\n", + "\n", + "# MPC Instance\n", + "step_lenght = 0.1\n", + "mpc = CentroidalMPC(robot_model=robot_model_init, step_length=step_lenght)\n", + "mpc.intialize_mpc(mpc_parameters=mpc_parameters)\n", + "\n", + "# Set desired quantities\n", + "mpc.configure(s_init=s_des, H_b_init=H_b)\n", + "TSID_controller_instance.compute_com_position()\n", + "mpc.define_test_com_traj(TSID_controller_instance.COM.toNumPy())\n", + "\n", + "# Set initial robot state and plan trajectories\n", + "mujoco_instance.step(1)\n", + "\n", + "# Reading the state\n", + "s, ds, tau = mujoco_instance.get_state()\n", + "H_b = mujoco_instance.get_base()\n", + "w_b = mujoco_instance.get_base_velocity()\n", + "t = mujoco_instance.get_simulation_time()\n", + "\n", + "# MPC\n", + "mpc.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n", + "mpc.initialize_centroidal_integrator(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n", + "mpc_output = mpc.plan_trajectory()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Set loop variables\n", + "TIME_TH = 6.0\n", + "\n", + "# Define number of steps\n", + "n_step = int(\n", + " TSID_controller_instance.frequency / mujoco_instance.get_simulation_frequency()\n", + ")\n", + "n_step_mpc_tsid = int(mpc.get_frequency_seconds() / TSID_controller_instance.frequency)\n", + "\n", + "counter = 0\n", + "mpc_success = True\n", + "energy_tot = 0.0\n", + "succeded_controller = True" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulation-control loop\n", + "while t < TIME_TH:\n", + " # Reading robot state from simulator\n", + " s, ds, tau = mujoco_instance.get_state()\n", + " energy_i = np.linalg.norm(tau)\n", + " H_b = mujoco_instance.get_base()\n", + " w_b = mujoco_instance.get_base_velocity()\n", + " t = mujoco_instance.get_simulation_time()\n", + "\n", + " # Update TSID\n", + " TSID_controller_instance.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n", + "\n", + " # MPC plan\n", + " if counter == 0:\n", + " mpc.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n", + " mpc.update_references()\n", + " mpc_success = mpc.plan_trajectory()\n", + " mpc.contact_planner.advance_swing_foot_planner()\n", + " if not (mpc_success):\n", + " print(\"MPC failed\")\n", + " break\n", + "\n", + " # Reading new references\n", + " com, dcom, forces_left, forces_right, ang_mom = mpc.get_references()\n", + " left_foot, right_foot = mpc.contact_planner.get_references_swing_foot_planner()\n", + "\n", + " # Update references TSID\n", + " TSID_controller_instance.update_task_references_mpc(\n", + " com=com,\n", + " dcom=dcom,\n", + " ddcom=np.zeros(3),\n", + " left_foot_desired=left_foot,\n", + " right_foot_desired=right_foot,\n", + " s_desired=np.array(s_des),\n", + " wrenches_left=np.hstack([forces_left, np.zeros(3)]),\n", + " wrenches_right=np.hstack([forces_right, np.zeros(3)]),\n", + " )\n", + "\n", + " # Run control\n", + " succeded_controller = TSID_controller_instance.run()\n", + "\n", + " if not (succeded_controller):\n", + " print(\"Controller failed\")\n", + " break\n", + "\n", + " tau = TSID_controller_instance.get_torque()\n", + "\n", + " # Step the simulator\n", + " mujoco_instance.set_input(tau)\n", + " mujoco_instance.step_with_motors(n_step=n_step, torque=tau)\n", + " counter = counter + 1\n", + "\n", + " if counter == n_step_mpc_tsid:\n", + " counter = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Closing visualization\n", + "mujoco_instance.close_visualization()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "comododev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/models/stickbot_mjx.urdf b/examples/models/stickbot_mjx.urdf new file mode 100644 index 0000000..02799da --- /dev/null +++ b/examples/models/stickbot_mjx.urdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0893697d5654c1a54fc16586c34fb8f16cdb0d3ab1ba3d649fd183050185b7b +size 59343 diff --git a/src/comodo/jaxsimSimulator/jaxsimSimulator.py b/src/comodo/jaxsimSimulator/jaxsimSimulator.py index 4122635..b78dfe4 100644 --- a/src/comodo/jaxsimSimulator/jaxsimSimulator.py +++ b/src/comodo/jaxsimSimulator/jaxsimSimulator.py @@ -12,7 +12,7 @@ import jaxsim.rbda.contacts import numpy as np import numpy.typing as npt -from jaxsim import VelRepr, integrators +from jaxsim import VelRepr from jaxsim.mujoco import MujocoVideoRecorder from jaxsim.mujoco.loaders import UrdfToMjcf from jaxsim.mujoco.model import MujocoModelHelper @@ -66,7 +66,7 @@ def __init__( # Step aux dict. # This is used only for variable-step integrators. - self._step_aux_dict: dict[str, Any]= {} + self._step_aux_dict: dict[str, Any] = {} # Time step for the simulation self._dt: float = dt @@ -172,18 +172,6 @@ def load_model( f"Invalid contact model type: {self._contact_model_type}" ) - model = js.model.JaxSimModel.build_from_model_description( - model_description=robot_model.urdf_string, - model_name=robot_model.robot_name, - contact_model=contact_model, - time_step=self._dt, - ) - - self._model = js.model.reduce( - model=model, - considered_joints=tuple(robot_model.joint_name_list), - ) - if contact_params is None: match self._contact_model_type: case JaxsimContactModelEnum.RIGID: @@ -207,6 +195,19 @@ def load_model( f"Invalid contact model type: {self._contact_model_type}" ) + model = js.model.JaxSimModel.build_from_model_description( + model_description=robot_model.urdf_string, + model_name=robot_model.robot_name, + contact_model=contact_model, + contact_params=contact_params, + time_step=self._dt, + ) + + self._model = js.model.reduce( + model=model, + considered_joints=tuple(robot_model.joint_name_list), + ) + # Find mapping between user provided joint name list and JaxSim one user_joint_names = robot_model.joint_name_list js_joint_names = self._model.joint_names() @@ -223,7 +224,6 @@ def load_model( base_position=jnp.array(xyz_rpy[:3]), base_quaternion=jnp.array(JaxsimSimulator._RPY_to_quat(*xyz_rpy[3:])), joint_positions=jnp.array(s), - contacts_params=contact_params, ) # Initialize tau to zero @@ -306,10 +306,9 @@ def step(self, n_step: int = 1, *, dry_run=False) -> None: if self._contact_model_type is JaxsimContactModelEnum.VISCO_ELASTIC: - self._data, _ = jaxsim.rbda.contacts.visco_elastic.step( + self._data = jaxsim.rbda.contacts.visco_elastic.step( model=self._model, data=self._data, - dt=self._dt, link_forces=None, joint_force_references=self._tau, ) @@ -317,15 +316,11 @@ def step(self, n_step: int = 1, *, dry_run=False) -> None: else: # All other contact models - self._data, self._step_aux_dict = js.model.step( + self._data = js.model.step( model=self._model, data=self._data, - dt=self._dt, - link_forces=None, + # link_forces=None, joint_force_references=self._tau, - integrator_metadata=self._step_aux_dict.get( - "integrator_metadata", None - ), ) if not dry_run: @@ -351,10 +346,10 @@ def step(self, n_step: int = 1, *, dry_run=False) -> None: with self._data.switch_velocity_representation(VelRepr.Mixed): - self._link_contact_forces = js.model.link_contact_forces( + self._link_contact_forces = js.contact_model.link_contact_forces( model=self._model, data=self._data, - joint_force_references=self._tau, + joint_torques=self._tau, ) def get_state(self) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]: @@ -374,8 +369,8 @@ def get_state(self) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]: self._is_initialized ), "Simulator is not initialized, call load_model first." - s = np.array(self._data.joint_positions())[self._to_user] - s_dot = np.array(self._data.joint_velocities())[self._to_user] + s = np.array(self._data.joint_positions)[self._to_user] + s_dot = np.array(self._data.joint_velocities)[self._to_user] tau = np.array(self._tau)[self._to_user] return s, s_dot, tau @@ -405,7 +400,7 @@ def base_transform(self) -> npt.NDArray: assert ( self._is_initialized ), "Simulator is not initialized, call load_model first." - return np.array(self._data.base_transform()) + return np.array(self._data.base_transform) @property def base_velocity(self) -> npt.NDArray: @@ -413,7 +408,7 @@ def base_velocity(self) -> npt.NDArray: self._is_initialized ), "Simulator is not initialized, call load_model first." with self._data.switch_velocity_representation(VelRepr.Mixed): - return np.array(self._data.base_velocity()) + return np.array(self._data.base_velocity) @property def simulation_time(self) -> float: @@ -492,7 +487,7 @@ def update_contact_model_parameters( self._is_initialized ), "Simulator is not initialized, call load_model first." - self._data = self._data.replace(contacts_params=params) + self._model = self._model.replace(contacts_params=params) # ==== Private methods ==== @@ -528,26 +523,26 @@ def _render(self) -> None: self._handle = self._viz.open_viewer() self._mj_model_helper.set_base_position( - position=np.array(self._data.base_position()), + position=np.array(self._data.base_position), ) self._mj_model_helper.set_base_orientation( - orientation=np.array(self._data.base_orientation()), + orientation=np.array(self._data.base_orientation), ) self._mj_model_helper.set_joint_positions( - positions=np.array(self._data.joint_positions()), + positions=np.array(self._data.joint_positions), joint_names=self._model.joint_names(), ) self._viz.sync(viewer=self._handle) def _record_frame(self) -> None: self._mj_model_helper.set_base_position( - position=np.array(self._data.base_position()), + position=np.array(self._data.base_position), ) self._mj_model_helper.set_base_orientation( - orientation=np.array(self._data.base_orientation()), + orientation=np.array(self._data.base_orientation), ) self._mj_model_helper.set_joint_positions( - positions=np.array(self._data.joint_positions()), + positions=np.array(self._data.joint_positions), joint_names=self._model.joint_names(), ) diff --git a/src/comodo/mujocoSimulator/mjxSimulator.py b/src/comodo/mujocoSimulator/mjxSimulator.py new file mode 100644 index 0000000..3c2f120 --- /dev/null +++ b/src/comodo/mujocoSimulator/mjxSimulator.py @@ -0,0 +1,341 @@ +import copy +import math + +import casadi as cs +import jax +import mediapy as media +import mujoco +import mujoco_viewer +import numpy as np +from mujoco import mjx + +from comodo.abstractClasses.simulator import Simulator + + +class MJXSimulator(Simulator): + def __init__(self) -> None: + self.desired_pos = None + self.postion_control = False + self.compute_misalignment_gravity_fun() + self.jit_step = jax.jit(mjx.step) + self.framerate = 30 + super().__init__() + + def load_model(self, robot_model, s, xyz_rpy, kv_motors=None, Im=None): + self.robot_model = robot_model + + mujoco_xml = robot_model.get_mujoco_model() + + # Load mujoco model and data + mujoco_model = mujoco.MjModel.from_xml_string(mujoco_xml) + mujoco_data = mujoco.MjData(mujoco_model) + self.mj_model = mujoco_model + self.mj_data = mujoco_data + + # Put the model and data on the accelerator device(s) and get the corresponding MJX model and data + self.model = mjx.put_model(mujoco_model) + self.data = mjx.put_data(mujoco_model, mujoco_data) + + self.create_mapping_vector_from_mujoco() + self.create_mapping_vector_to_mujoco() + # mjx.mj_forward(self.model, self.data) + self.set_joint_vector_in_mujoco(s) + self.set_base_pose_in_mujoco(xyz_rpy=xyz_rpy) + # mjx.mj_forward(self.model, self.data) + self.visualize_robot_flag = False + + self.Im = Im if Im is not None else np.zeros(self.robot_model.NDoF) + self.kv_motors = ( + kv_motors if kv_motors is not None else np.zeros(self.robot_model.NDoF) + ) + self.H_left_foot = copy.deepcopy(self.robot_model.H_left_foot) + self.H_right_foot = copy.deepcopy(self.robot_model.H_right_foot) + self.H_left_foot_num = None + self.H_right_foot_num = None + self.mass = self.robot_model.get_total_mass() + + def get_contact_status(self): + left_wrench, rigth_wrench = self.get_feet_wrench() + left_foot_contact = left_wrench[2] > 0.1 * self.mass + right_foot_contact = rigth_wrench[2] > 0.1 * self.mass + return left_foot_contact, right_foot_contact + + def set_visualize_robot_flag(self, visualize_robot): + self.visualize_robot_flag = visualize_robot + if self.visualize_robot_flag: + # self.viewer = mujoco_viewer.MujocoViewer(self.model, self.data) + self.renderer = mujoco.Renderer(self.mj_model) + self.frames = [] + + def set_base_pose_in_mujoco(self, xyz_rpy): + base_xyz_quat = np.zeros(7) + base_xyz_quat[:3] = xyz_rpy[:3] + base_xyz_quat[3:] = self.RPY_to_quat(xyz_rpy[3], xyz_rpy[4], xyz_rpy[5]) + base_xyz_quat[2] = base_xyz_quat[2] + # self.data.qpos[:7] = base_xyz_quat + self.data = self.data.replace(qpos=self.data.qpos.at[:7].set(base_xyz_quat)) + + def set_joint_vector_in_mujoco(self, pos): + pos_muj = self.convert_vector_to_mujoco(pos) + indexes_joint = self.model.jnt_qposadr[1:] + for i in range(self.robot_model.NDoF): + self.data = self.data.replace( + qpos=self.data.qpos.at[indexes_joint[i]].set(pos_muj[i]) + ) + + def set_input(self, input): + input_muj = self.convert_vector_to_mujoco(input) + self.data = self.data.replace(ctrl=input_muj) + np.copyto(self.data.ctrl, input_muj) + + def set_position_input(self, pos): + pos_muj = self.convert_vector_to_mujoco(pos) + self.desired_pos = pos_muj + self.postion_control = True + + def create_mapping_vector_to_mujoco(self): + # This function creates the to_mujoco map + self.to_mujoco = [] + for mujoco_joint in self.robot_model.mujoco_joint_order: + try: + index = self.robot_model.joint_name_list.index(mujoco_joint) + self.to_mujoco.append(index) + except ValueError: + raise ValueError( + f"Mujoco joint '{mujoco_joint}' not found in joint list." + ) + + def create_mapping_vector_from_mujoco(self): + # This function creates the to_mujoco map + self.from_mujoco = [] + for joint in self.robot_model.joint_name_list: + try: + index = self.robot_model.mujoco_joint_order.index(joint) + self.from_mujoco.append(index) + except ValueError: + raise ValueError( + f"Joint name list joint '{joint}' not found in mujoco list." + ) + + def convert_vector_to_mujoco(self, array_in): + out_muj = np.asarray( + [array_in[self.to_mujoco[item]] for item in range(self.robot_model.NDoF)] + ) + return out_muj + + def convert_from_mujoco(self, array_muj): + out_classic = np.asarray( + [array_muj[self.from_mujoco[item]] for item in range(self.robot_model.NDoF)] + ) + return out_classic + + def step(self, n_step=1, visualize=True): + if self.postion_control: + for _ in range(n_step): + s, s_dot, tau = self.get_state(use_mujoco_convention=True) + kp_muj = self.convert_vector_to_mujoco( + self.robot_model.kp_position_control + ) + kd_muj = self.convert_vector_to_mujoco( + self.robot_model.kd_position_control + ) + ctrl = kp_muj * (self.desired_pos - s) - kd_muj * s_dot + self.data.ctrl = ctrl + np.copyto(self.data.ctrl, ctrl) + self.data = self.jit_step(self.model, self.data) + # mjx.mj_step1(self.model, self.data) + # mjx.mj_forward(self.model, self.data) + else: + # Step the simulation of n_step iterations + for _ in range(n_step): + self.data = self.jit_step(self.model, self.data) + # mjx.mj_step1(self.model, self.data) + # mjx.mj_forward(self.model, self.data) + + if len(self.frames) < self.data.time * self.framerate: + self.visualize_robot() + # if self.visualize_robot_flag: + # self.viewer.render() + + def step_with_motors(self, n_step, torque): + indexes_joint_acceleration = self.model.jnt_dofadr[1:] + s_dot_dot = self.data.qacc[indexes_joint_acceleration[0] :] + for _ in range(n_step): + indexes_joint_acceleration = self.model.jnt_dofadr[1:] + s_dot_dot = self.data.qacc[indexes_joint_acceleration[0] :] + s_dot = self.data.qvel[indexes_joint_acceleration[0] :] + input = np.asarray( + [ + self.Im[self.to_mujoco[item]] * s_dot_dot[item] + + self.kv_motors[self.to_mujoco[item]] * s_dot[item] + + torque[self.to_mujoco[item]] + for item in range(self.robot_model.NDoF) + ] + ) + + self.set_input(input) + self.step(n_step=1, visualize=False) + # if self.visualize_robot_flag: + # self.viewer.render() + + def compute_misalignment_gravity_fun(self): + H = cs.SX.sym("H", 4, 4) + theta = cs.SX.sym("theta") + theta = cs.dot([0, 0, 1], H[:3, 2]) - 1 + error = cs.Function("error", [H], [theta]) + self.error_mis = error + + def check_feet_status(self, s, H_b): + left_foot_pose = self.robot_model.H_left_foot(H_b, s) + rigth_foot_pose = self.robot_model.H_right_foot(H_b, s) + left_foot_z = left_foot_pose[2, 3] + rigth_foot_z = rigth_foot_pose[2, 3] + left_foot_contact = not (left_foot_z > 0.1) + rigth_foot_contact = not (rigth_foot_z > 0.1) + misalignment_left = self.error_mis(left_foot_pose) + misalignment_rigth = self.error_mis(rigth_foot_pose) + left_foot_condition = abs(left_foot_contact * misalignment_left) + rigth_foot_condition = abs(rigth_foot_contact * misalignment_rigth) + misalignment_error = left_foot_condition + rigth_foot_condition + if ( + abs(left_foot_contact * misalignment_left) > 0.02 + or abs(rigth_foot_contact * misalignment_rigth) > 0.02 + ): + return False, misalignment_error + + return True, misalignment_error + + def get_feet_wrench(self): + left_foot_wrench = np.zeros(6) + rigth_foot_wrench = np.zeros(6) + s, s_dot, tau = self.get_state() + H_b = self.get_base() + self.H_left_foot_num = np.array(self.H_left_foot(H_b, s)) + self.H_right_foot_num = np.array(self.H_right_foot(H_b, s)) + for i in range(self.data.ncon): + contact = self.data.contact[i] + c_array = np.zeros(6, dtype=np.float64) + mujoco.mj_contactForce(self.model, self.data, i, c_array) + name_contact = mujoco.mj_id2name( + self.model, mujoco.mjtObj.mjOBJ_GEOM, int(contact.geom[1]) + ) + w_H_contact = np.eye(4) + w_H_contact[:3, :3] = contact.frame.reshape(3, 3) + w_H_contact[:3, 3] = contact.pos + if ( + name_contact == self.robot_model.right_foot_rear_ct + or name_contact == self.robot_model.right_foot_front_ct + ): + RF_H_contact = np.linalg.inv(self.H_right_foot_num) @ w_H_contact + wrench_RF = self.compute_resulting_wrench(RF_H_contact, c_array) + rigth_foot_wrench[:] += wrench_RF.reshape(6) + elif ( + name_contact == self.robot_model.left_foot_front_ct + or name_contact == self.robot_model.left_foot_rear_ct + ): + LF_H_contact = np.linalg.inv(self.H_left_foot_num) @ w_H_contact + wrench_LF = self.compute_resulting_wrench(LF_H_contact, c_array) + left_foot_wrench[:] += wrench_LF.reshape(6) + return (left_foot_wrench, rigth_foot_wrench) + + def compute_resulting_wrench(self, b_H_a, force_torque_a): + p = b_H_a[:3, 3] + R = b_H_a[:3, :3] + adjoint_matrix = np.zeros([6, 6]) + adjoint_matrix[:3, :3] = R + adjoint_matrix[3:, :3] = np.cross(p, R) + adjoint_matrix[3:, 3:] = R + force_torque_b = adjoint_matrix @ force_torque_a.reshape(6, 1) + return force_torque_b + + # note that for mujoco the ordering is w,x,y,z + def get_base(self): + indexes_joint = self.model.jnt_qposadr[1:] + # Extract quaternion components + w, x, y, z = self.data.qpos[3 : indexes_joint[0]] + + # Calculate rotation matrix + rot_mat = np.array( + [ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w, + 0, + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w, + 0, + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y, + 0, + ], + [0, 0, 0, 1], + ] + ) + + # Set up transformation matrix + trans_mat = np.eye(4) + trans_mat[:3, :3] = rot_mat[:3, :3] + trans_mat[:3, 3] = self.data.qpos[:3] + # Return transformation matrix + return trans_mat + + def get_base_velocity(self): + indexes_joint_velocities = self.model.jnt_dofadr[1:] + return self.data.qvel[: indexes_joint_velocities[0]] + + def get_state(self, use_mujoco_convention=False): + indexes_joint = self.model.jnt_qposadr[1:] + indexes_joint_velocities = self.model.jnt_dofadr[1:] + s = self.data.qpos[indexes_joint[0] :] + s_dot = self.data.qvel[indexes_joint_velocities[0] :] + tau = self.data.ctrl + if use_mujoco_convention: + return s, s_dot, tau + s_out = self.convert_from_mujoco(s) + s_dot_out = self.convert_from_mujoco(s_dot) + tau_out = self.convert_from_mujoco(tau) + return s_out, s_dot_out, tau_out + + def close(self): + if self.visualize_robot_flag: + self.viewer.close() + + def visualize_robot(self): + # self.viewer.render() + mj_data = mjx.get_data(self.mj_model, self.data) + self.renderer.update_scene(mj_data) + pixels = self.renderer.render() + self.frames.append(pixels) + + def get_simulation_time(self): + return self.data.time + + def get_simulation_frequency(self): + return self.model.opt.timestep + + def RPY_to_quat(self, roll, pitch, yaw): + cr = math.cos(roll / 2) + cp = math.cos(pitch / 2) + cy = math.cos(yaw / 2) + sr = math.sin(roll / 2) + sp = math.sin(pitch / 2) + sy = math.sin(yaw / 2) + + qw = cr * cp * cy + sr * sp * sy + qx = sr * cp * cy - cr * sp * sy + qy = cr * sp * cy + sr * cp * sy + qz = cr * cp * sy - sr * sp * cy + + return [qw, qx, qy, qz] + + def close_visualization(self): + if self.visualize_robot_flag: + # self.viewer.close() + media.show_video(self.frames, fps=self.framerate)