diff --git a/examples/jaxsim_walking.ipynb b/examples/jaxsim_walking.ipynb
index 39b1677..6c57ab4 100644
--- a/examples/jaxsim_walking.ipynb
+++ b/examples/jaxsim_walking.ipynb
@@ -59,7 +59,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -115,14 +115,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
     "# ==== Set simulation parameters ====\n",
     "\n",
     "T = 6.0\n",
-    "js_dt = 0.000_5"
+    "js_dt = 0.001"
    ]
   },
   {
@@ -148,7 +148,7 @@
    "source": [
     "# ==== Define JaxSim simulator and set initial position ====\n",
     "\n",
-    "js = JaxsimSimulator(dt=js_dt, contact_model_type=JaxsimContactModelEnum.RIGID)\n",
+    "js = JaxsimSimulator(dt=js_dt, contact_model_type=JaxsimContactModelEnum.RELAXED_RIGID)\n",
     "js.load_model(\n",
     "    robot_model=robot_model_init,\n",
     "    s=s_0,\n",
@@ -165,7 +165,7 @@
     "print(f\"Contact model in use: {js._model.contact_model}\")\n",
     "print(f\"Link names:\\n{js.link_names}\")\n",
     "print(f\"Frame names:\\n{js.frame_names}\")\n",
-    "print(f\"Mass: {js.total_mass*js._data.standard_gravity()} N\")"
+    "print(f\"Mass: {-js.total_mass*js._model.gravity} N\")"
    ]
   },
   {
@@ -197,11 +197,30 @@
     "# Set desired quantities\n",
     "mpc.configure(s_init=s_0, H_b_init=H_b_0)\n",
     "tsid.compute_com_position()\n",
-    "mpc.define_test_com_traj(tsid.COM.toNumPy())\n",
-    "\n",
+    "mpc.define_test_com_traj(tsid.COM.toNumPy())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
     "# Set initial robot state  and plan trajectories\n",
+    "\n",
+    "tic = time.perf_counter()\n",
+    "\n",
     "js.step(dry_run=True)\n",
     "\n",
+    "step_compilation_time_s = time.perf_counter() - tic"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
     "# Reading the state\n",
     "s_js, ds_js, tau_js = js.get_state()\n",
     "H_b = js.base_transform\n",
@@ -215,7 +234,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -245,6 +264,7 @@
     "    tau_tsid_log = []\n",
     "    W_p_CoM_tsid_log = []\n",
     "    t_log = []\n",
+    "    wall_time_step_log = []\n",
     "\n",
     "    # Define number of steps\n",
     "    n_step_tsid_js = int(tsid.frequency / js_dt)\n",
@@ -309,16 +329,20 @@
     "\n",
     "            # Step the simulator\n",
     "            js.set_input(tau_tsid)\n",
+    "\n",
+    "            tic = time.perf_counter()\n",
     "            js.step(n_step=n_step_tsid_js)\n",
+    "            toc = time.perf_counter() - tic\n",
+    "\n",
     "            counter = counter + 1\n",
     "\n",
     "            if counter == n_step_mpc_tsid:\n",
     "                counter = 0\n",
     "\n",
     "            # Stop the simulation if the robot fell down\n",
-    "            if js._data.base_position()[2] < 0.5:\n",
-    "                print(f\"Robot fell down at t={t:.4f}s.\")\n",
-    "                break\n",
+    "            # if js._data.base_position()[2] < 0.5:\n",
+    "            #     print(f\"Robot fell down at t={t:.4f}s.\")\n",
+    "            #     break\n",
     "\n",
     "            # Log data\n",
     "            # TODO transform mpc contact forces to wrenches to be compared with jaxsim ones\n",
@@ -336,6 +360,7 @@
     "            W_p_lf_sfp_log.append(lf_sfp.transform.translation())\n",
     "            W_p_rf_sfp_log.append(rf_sfp.transform.translation())\n",
     "            W_p_CoM_tsid_log.append(tsid.COM.toNumPy())\n",
+    "            wall_time_step_log.append(toc * 1e3)\n",
     "            if contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:\n",
     "                f_lf_js, f_rf_js = js.feet_wrench\n",
     "                f_lf_js_log.append(f_lf_js)\n",
@@ -360,6 +385,7 @@
     "        \"W_p_lf_sfp\": np.array(W_p_lf_sfp_log),\n",
     "        \"W_p_rf_sfp\": np.array(W_p_rf_sfp_log),\n",
     "        \"W_p_CoM_tsid\": np.array(W_p_CoM_tsid_log),\n",
+    "        \"wall_time_step\": np.array(wall_time_step_log),\n",
     "    }\n",
     "    if contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:\n",
     "        logs[\"f_lf_js\"] = np.array(f_lf_js_log)\n",
@@ -385,19 +411,16 @@
     "avg_iter_time_ms = (wall_time / (T / js_dt)) * 1000\n",
     "\n",
     "print(\n",
-    "    f\"\\nRunning simulation took {wall_time:.2f}s for {T:.3f}s simulated time. \\nIteration avg time of {avg_iter_time_ms:.1f} ms.\"\n",
-    ")\n",
-    "print(f\"RTF: {T / wall_time * 100:.2f}%\")"
+    "    f\"\\nSimulation done.\\nRunning simulation took {wall_time:.2f}s for {T:.3f}s simulated time.\"\n",
+    ")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [],
    "source": [
-    "# ==== Plot the results ====\n",
-    "\n",
     "# Extract logged variables\n",
     "t = logs[\"t\"]\n",
     "s_js = logs[\"s_js\"]\n",
@@ -412,9 +435,45 @@
     "W_p_lf_sfp = logs[\"W_p_lf_sfp\"]\n",
     "W_p_rf_sfp = logs[\"W_p_rf_sfp\"]\n",
     "W_p_CoM_tsid = logs[\"W_p_CoM_tsid\"]\n",
+    "wall_time_step = logs[\"wall_time_step\"]\n",
     "if js.contact_model_type != JaxsimContactModelEnum.VISCO_ELASTIC:\n",
     "    f_lf_js = logs[\"f_lf_js\"]\n",
-    "    f_rf_js = logs[\"f_rf_js\"]\n",
+    "    f_rf_js = logs[\"f_rf_js\"]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Compute simulator step runtime statistics and RTF\n",
+    "\n",
+    "min_step_time = np.min(wall_time_step)\n",
+    "max_step_time = np.max(wall_time_step)\n",
+    "avg_step_time = np.mean(wall_time_step)\n",
+    "std_step_time = np.std(wall_time_step)\n",
+    "total_step_time = np.sum(wall_time_step)\n",
+    "rtf = (T * 1e3) / total_step_time * 100\n",
+    "\n",
+    "print(\"===========================================\")\n",
+    "print(f\"Step compilation time: {step_compilation_time_s:.2f} s\")\n",
+    "print(f\"Min step time: {min_step_time:.2f} ms\")\n",
+    "print(f\"Max step time: {max_step_time:.2f} ms\")\n",
+    "print(f\"Average step time: {avg_step_time:.2f} ms\")\n",
+    "print(f\"Std deviation step time: {std_step_time:.2f} ms\")\n",
+    "print(f\"RTF: {rtf:.1f}%\")\n",
+    "print(\"===========================================\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ==== Plot the results ====\n",
+    "\n",
     "\n",
     "n_sim_steps = s_js.shape[0]\n",
     "s_0 = np.full_like(a=s_js, fill_value=s_0)\n",
@@ -577,7 +636,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "comodo",
+   "display_name": "comododev",
    "language": "python",
    "name": "python3"
   },
@@ -591,7 +650,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.13"
+   "version": "3.10.15"
   }
  },
  "nbformat": 4,
diff --git a/examples/mjx_walking.ipynb b/examples/mjx_walking.ipynb
new file mode 100644
index 0000000..2bf2379
--- /dev/null
+++ b/examples/mjx_walking.ipynb
@@ -0,0 +1,327 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Mujoco MJX with TSID and MPC Example \n",
+    "This examples, load a basic robot model (i.e. composed only of basic shapes), modifies the links of such a robot model by elongating the legs, define instances of the TSID (Task Based Inverse Dynamics) and Centroidal MPC  controller and simulate the behavior of the robot using MJX.  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Comodo import\n",
+    "from comodo.mujocoSimulator.mjxSimulator import MJXSimulator\n",
+    "from comodo.robotModel.robotModel import RobotModel\n",
+    "from comodo.robotModel.createUrdf import createUrdf\n",
+    "from comodo.centroidalMPC.centroidalMPC import CentroidalMPC\n",
+    "from comodo.centroidalMPC.mpcParameterTuning import MPCParameterTuning\n",
+    "from comodo.TSIDController.TSIDParameterTuning import TSIDParameterTuning\n",
+    "from comodo.TSIDController.TSIDController import TSIDController\n",
+    "\n",
+    "import jax\n",
+    "\n",
+    "# Force JAX to use CPU\n",
+    "jax.config.update(\"jax_default_device\", jax.devices(\"cpu\")[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# General  import\n",
+    "import xml.etree.ElementTree as ET\n",
+    "import numpy as np\n",
+    "import tempfile\n",
+    "import urllib.request\n",
+    "import os\n",
+    "import pathlib"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Getting stickbot urdf file\n",
+    "\n",
+    "path = pathlib.Path.cwd() / \"models\" / \"stickbot_mjx.urdf\"\n",
+    "# Load the URDF file\n",
+    "# tree = ET.parse(urdf_robot_file.name)\n",
+    "tree = ET.parse(path)\n",
+    "root = tree.getroot()\n",
+    "\n",
+    "# Convert the XML tree to a string\n",
+    "robot_urdf_string_original = ET.tostring(root)\n",
+    "\n",
+    "# create_urdf_instance = createUrdf(\n",
+    "#     original_urdf_path=urdf_robot_file.name, save_gazebo_plugin=False\n",
+    "# )\n",
+    "create_urdf_instance = createUrdf(original_urdf_path=path, save_gazebo_plugin=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define parametric links and controlled joints\n",
+    "legs_link_names = [\"hip_3\", \"lower_leg\"]\n",
+    "joint_name_list = [\n",
+    "    \"r_shoulder_pitch\",\n",
+    "    \"r_shoulder_roll\",\n",
+    "    \"r_shoulder_yaw\",\n",
+    "    \"r_elbow\",\n",
+    "    \"l_shoulder_pitch\",\n",
+    "    \"l_shoulder_roll\",\n",
+    "    \"l_shoulder_yaw\",\n",
+    "    \"l_elbow\",\n",
+    "    \"r_hip_pitch\",\n",
+    "    \"r_hip_roll\",\n",
+    "    \"r_hip_yaw\",\n",
+    "    \"r_knee\",\n",
+    "    \"r_ankle_pitch\",\n",
+    "    \"r_ankle_roll\",\n",
+    "    \"l_hip_pitch\",\n",
+    "    \"l_hip_roll\",\n",
+    "    \"l_hip_yaw\",\n",
+    "    \"l_knee\",\n",
+    "    \"l_ankle_pitch\",\n",
+    "    \"l_ankle_roll\",\n",
+    "]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define the robot modifications\n",
+    "modifications = {}\n",
+    "for item in legs_link_names:\n",
+    "    left_leg_item = \"l_\" + item\n",
+    "    right_leg_item = \"r_\" + item\n",
+    "    modifications.update({left_leg_item: 1.2})\n",
+    "    modifications.update({right_leg_item: 1.2})\n",
+    "# Motors Parameters\n",
+    "Im_arms = 1e-3 * np.ones(4)  # from 0-4\n",
+    "Im_legs = 1e-3 * np.ones(6)  # from 5-10\n",
+    "kv_arms = 0.001 * np.ones(4)  # from 11-14\n",
+    "kv_legs = 0.001 * np.ones(6)  # from 20\n",
+    "\n",
+    "Im = np.concatenate((Im_arms, Im_arms, Im_legs, Im_legs))\n",
+    "kv = np.concatenate((kv_arms, kv_arms, kv_legs, kv_legs))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Modify the robot model and initialize\n",
+    "create_urdf_instance.modify_lengths(modifications)\n",
+    "urdf_robot_string = create_urdf_instance.write_urdf_to_file()\n",
+    "create_urdf_instance.reset_modifications()\n",
+    "robot_model_init = RobotModel(urdf_robot_string, \"stickBot\", joint_name_list)\n",
+    "s_des, xyz_rpy, H_b = robot_model_init.compute_desired_position_walking()\n",
+    "robot_model_init.set_foot_corner(\n",
+    "    np.asarray([0.1, 0.05, 0.0]),\n",
+    "    np.asarray([0.1, -0.05, 0.0]),\n",
+    "    np.asarray([-0.1, -0.05, 0.0]),\n",
+    "    np.asarray([-0.1, 0.05, 0.0]),\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define simulator and set initial position\n",
+    "mujoco_instance = MJXSimulator()\n",
+    "mujoco_instance.load_model(\n",
+    "    robot_model_init, s=s_des, xyz_rpy=xyz_rpy, kv_motors=kv, Im=Im\n",
+    ")\n",
+    "s, ds, tau = mujoco_instance.get_state()\n",
+    "t = mujoco_instance.get_simulation_time()\n",
+    "H_b = mujoco_instance.get_base()\n",
+    "w_b = mujoco_instance.get_base_velocity()\n",
+    "mujoco_instance.set_visualize_robot_flag(True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define the controller parameters  and instantiate the controller\n",
+    "# Controller Parameters\n",
+    "tsid_parameter = TSIDParameterTuning()\n",
+    "mpc_parameters = MPCParameterTuning()\n",
+    "\n",
+    "\n",
+    "## Controller gains\n",
+    "# Controller Parameters\n",
+    "tsid_parameter = TSIDParameterTuning()\n",
+    "tsid_parameter.foot_tracking_task_kp_lin = 150.0\n",
+    "tsid_parameter.foot_tracking_task_kd_lin = 40.0\n",
+    "tsid_parameter.root_tracking_task_weight = np.ones(3) * 50.0\n",
+    "\n",
+    "# TSID Instance\n",
+    "TSID_controller_instance = TSIDController(frequency=0.01, robot_model=robot_model_init)\n",
+    "TSID_controller_instance.define_tasks(tsid_parameter)\n",
+    "TSID_controller_instance.set_state_with_base(s, ds, H_b, w_b, t)\n",
+    "\n",
+    "# MPC Instance\n",
+    "step_lenght = 0.1\n",
+    "mpc = CentroidalMPC(robot_model=robot_model_init, step_length=step_lenght)\n",
+    "mpc.intialize_mpc(mpc_parameters=mpc_parameters)\n",
+    "\n",
+    "# Set desired quantities\n",
+    "mpc.configure(s_init=s_des, H_b_init=H_b)\n",
+    "TSID_controller_instance.compute_com_position()\n",
+    "mpc.define_test_com_traj(TSID_controller_instance.COM.toNumPy())\n",
+    "\n",
+    "# Set initial robot state  and plan trajectories\n",
+    "mujoco_instance.step(1)\n",
+    "\n",
+    "# Reading the state\n",
+    "s, ds, tau = mujoco_instance.get_state()\n",
+    "H_b = mujoco_instance.get_base()\n",
+    "w_b = mujoco_instance.get_base_velocity()\n",
+    "t = mujoco_instance.get_simulation_time()\n",
+    "\n",
+    "# MPC\n",
+    "mpc.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n",
+    "mpc.initialize_centroidal_integrator(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n",
+    "mpc_output = mpc.plan_trajectory()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Set loop variables\n",
+    "TIME_TH = 6.0\n",
+    "\n",
+    "# Define number of steps\n",
+    "n_step = int(\n",
+    "    TSID_controller_instance.frequency / mujoco_instance.get_simulation_frequency()\n",
+    ")\n",
+    "n_step_mpc_tsid = int(mpc.get_frequency_seconds() / TSID_controller_instance.frequency)\n",
+    "\n",
+    "counter = 0\n",
+    "mpc_success = True\n",
+    "energy_tot = 0.0\n",
+    "succeded_controller = True"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Simulation-control loop\n",
+    "while t < TIME_TH:\n",
+    "    # Reading robot state from simulator\n",
+    "    s, ds, tau = mujoco_instance.get_state()\n",
+    "    energy_i = np.linalg.norm(tau)\n",
+    "    H_b = mujoco_instance.get_base()\n",
+    "    w_b = mujoco_instance.get_base_velocity()\n",
+    "    t = mujoco_instance.get_simulation_time()\n",
+    "\n",
+    "    # Update TSID\n",
+    "    TSID_controller_instance.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n",
+    "\n",
+    "    # MPC plan\n",
+    "    if counter == 0:\n",
+    "        mpc.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)\n",
+    "        mpc.update_references()\n",
+    "        mpc_success = mpc.plan_trajectory()\n",
+    "        mpc.contact_planner.advance_swing_foot_planner()\n",
+    "        if not (mpc_success):\n",
+    "            print(\"MPC failed\")\n",
+    "            break\n",
+    "\n",
+    "    # Reading new references\n",
+    "    com, dcom, forces_left, forces_right, ang_mom = mpc.get_references()\n",
+    "    left_foot, right_foot = mpc.contact_planner.get_references_swing_foot_planner()\n",
+    "\n",
+    "    # Update references TSID\n",
+    "    TSID_controller_instance.update_task_references_mpc(\n",
+    "        com=com,\n",
+    "        dcom=dcom,\n",
+    "        ddcom=np.zeros(3),\n",
+    "        left_foot_desired=left_foot,\n",
+    "        right_foot_desired=right_foot,\n",
+    "        s_desired=np.array(s_des),\n",
+    "        wrenches_left=np.hstack([forces_left, np.zeros(3)]),\n",
+    "        wrenches_right=np.hstack([forces_right, np.zeros(3)]),\n",
+    "    )\n",
+    "\n",
+    "    # Run control\n",
+    "    succeded_controller = TSID_controller_instance.run()\n",
+    "\n",
+    "    if not (succeded_controller):\n",
+    "        print(\"Controller failed\")\n",
+    "        break\n",
+    "\n",
+    "    tau = TSID_controller_instance.get_torque()\n",
+    "\n",
+    "    # Step the simulator\n",
+    "    mujoco_instance.set_input(tau)\n",
+    "    mujoco_instance.step_with_motors(n_step=n_step, torque=tau)\n",
+    "    counter = counter + 1\n",
+    "\n",
+    "    if counter == n_step_mpc_tsid:\n",
+    "        counter = 0"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Closing visualization\n",
+    "mujoco_instance.close_visualization()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "comododev",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.15"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/models/stickbot_mjx.urdf b/examples/models/stickbot_mjx.urdf
new file mode 100644
index 0000000..02799da
--- /dev/null
+++ b/examples/models/stickbot_mjx.urdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0893697d5654c1a54fc16586c34fb8f16cdb0d3ab1ba3d649fd183050185b7b
+size 59343
diff --git a/src/comodo/jaxsimSimulator/jaxsimSimulator.py b/src/comodo/jaxsimSimulator/jaxsimSimulator.py
index 4122635..b78dfe4 100644
--- a/src/comodo/jaxsimSimulator/jaxsimSimulator.py
+++ b/src/comodo/jaxsimSimulator/jaxsimSimulator.py
@@ -12,7 +12,7 @@
 import jaxsim.rbda.contacts
 import numpy as np
 import numpy.typing as npt
-from jaxsim import VelRepr, integrators
+from jaxsim import VelRepr
 from jaxsim.mujoco import MujocoVideoRecorder
 from jaxsim.mujoco.loaders import UrdfToMjcf
 from jaxsim.mujoco.model import MujocoModelHelper
@@ -66,7 +66,7 @@ def __init__(
 
         # Step aux dict.
         # This is used only for variable-step integrators.
-        self._step_aux_dict: dict[str, Any]= {}
+        self._step_aux_dict: dict[str, Any] = {}
 
         # Time step for the simulation
         self._dt: float = dt
@@ -172,18 +172,6 @@ def load_model(
                     f"Invalid contact model type: {self._contact_model_type}"
                 )
 
-        model = js.model.JaxSimModel.build_from_model_description(
-            model_description=robot_model.urdf_string,
-            model_name=robot_model.robot_name,
-            contact_model=contact_model,
-            time_step=self._dt,
-        )
-
-        self._model = js.model.reduce(
-            model=model,
-            considered_joints=tuple(robot_model.joint_name_list),
-        )
-
         if contact_params is None:
             match self._contact_model_type:
                 case JaxsimContactModelEnum.RIGID:
@@ -207,6 +195,19 @@ def load_model(
                         f"Invalid contact model type: {self._contact_model_type}"
                     )
 
+        model = js.model.JaxSimModel.build_from_model_description(
+            model_description=robot_model.urdf_string,
+            model_name=robot_model.robot_name,
+            contact_model=contact_model,
+            contact_params=contact_params,
+            time_step=self._dt,
+        )
+
+        self._model = js.model.reduce(
+            model=model,
+            considered_joints=tuple(robot_model.joint_name_list),
+        )
+
         # Find mapping between user provided joint name list and JaxSim one
         user_joint_names = robot_model.joint_name_list
         js_joint_names = self._model.joint_names()
@@ -223,7 +224,6 @@ def load_model(
             base_position=jnp.array(xyz_rpy[:3]),
             base_quaternion=jnp.array(JaxsimSimulator._RPY_to_quat(*xyz_rpy[3:])),
             joint_positions=jnp.array(s),
-            contacts_params=contact_params,
         )
 
         # Initialize tau to zero
@@ -306,10 +306,9 @@ def step(self, n_step: int = 1, *, dry_run=False) -> None:
 
             if self._contact_model_type is JaxsimContactModelEnum.VISCO_ELASTIC:
 
-                self._data, _ = jaxsim.rbda.contacts.visco_elastic.step(
+                self._data = jaxsim.rbda.contacts.visco_elastic.step(
                     model=self._model,
                     data=self._data,
-                    dt=self._dt,
                     link_forces=None,
                     joint_force_references=self._tau,
                 )
@@ -317,15 +316,11 @@ def step(self, n_step: int = 1, *, dry_run=False) -> None:
             else:
 
                 # All other contact models
-                self._data, self._step_aux_dict = js.model.step(
+                self._data = js.model.step(
                     model=self._model,
                     data=self._data,
-                    dt=self._dt,
-                    link_forces=None,
+                    # link_forces=None,
                     joint_force_references=self._tau,
-                    integrator_metadata=self._step_aux_dict.get(
-                        "integrator_metadata", None
-                    ),
                 )
 
             if not dry_run:
@@ -351,10 +346,10 @@ def step(self, n_step: int = 1, *, dry_run=False) -> None:
 
             with self._data.switch_velocity_representation(VelRepr.Mixed):
 
-                self._link_contact_forces = js.model.link_contact_forces(
+                self._link_contact_forces = js.contact_model.link_contact_forces(
                     model=self._model,
                     data=self._data,
-                    joint_force_references=self._tau,
+                    joint_torques=self._tau,
                 )
 
     def get_state(self) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
@@ -374,8 +369,8 @@ def get_state(self) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
             self._is_initialized
         ), "Simulator is not initialized, call load_model first."
 
-        s = np.array(self._data.joint_positions())[self._to_user]
-        s_dot = np.array(self._data.joint_velocities())[self._to_user]
+        s = np.array(self._data.joint_positions)[self._to_user]
+        s_dot = np.array(self._data.joint_velocities)[self._to_user]
         tau = np.array(self._tau)[self._to_user]
 
         return s, s_dot, tau
@@ -405,7 +400,7 @@ def base_transform(self) -> npt.NDArray:
         assert (
             self._is_initialized
         ), "Simulator is not initialized, call load_model first."
-        return np.array(self._data.base_transform())
+        return np.array(self._data.base_transform)
 
     @property
     def base_velocity(self) -> npt.NDArray:
@@ -413,7 +408,7 @@ def base_velocity(self) -> npt.NDArray:
             self._is_initialized
         ), "Simulator is not initialized, call load_model first."
         with self._data.switch_velocity_representation(VelRepr.Mixed):
-            return np.array(self._data.base_velocity())
+            return np.array(self._data.base_velocity)
 
     @property
     def simulation_time(self) -> float:
@@ -492,7 +487,7 @@ def update_contact_model_parameters(
             self._is_initialized
         ), "Simulator is not initialized, call load_model first."
 
-        self._data = self._data.replace(contacts_params=params)
+        self._model = self._model.replace(contacts_params=params)
 
     # ==== Private methods ====
 
@@ -528,26 +523,26 @@ def _render(self) -> None:
             self._handle = self._viz.open_viewer()
 
         self._mj_model_helper.set_base_position(
-            position=np.array(self._data.base_position()),
+            position=np.array(self._data.base_position),
         )
         self._mj_model_helper.set_base_orientation(
-            orientation=np.array(self._data.base_orientation()),
+            orientation=np.array(self._data.base_orientation),
         )
         self._mj_model_helper.set_joint_positions(
-            positions=np.array(self._data.joint_positions()),
+            positions=np.array(self._data.joint_positions),
             joint_names=self._model.joint_names(),
         )
         self._viz.sync(viewer=self._handle)
 
     def _record_frame(self) -> None:
         self._mj_model_helper.set_base_position(
-            position=np.array(self._data.base_position()),
+            position=np.array(self._data.base_position),
         )
         self._mj_model_helper.set_base_orientation(
-            orientation=np.array(self._data.base_orientation()),
+            orientation=np.array(self._data.base_orientation),
         )
         self._mj_model_helper.set_joint_positions(
-            positions=np.array(self._data.joint_positions()),
+            positions=np.array(self._data.joint_positions),
             joint_names=self._model.joint_names(),
         )
 
diff --git a/src/comodo/mujocoSimulator/mjxSimulator.py b/src/comodo/mujocoSimulator/mjxSimulator.py
new file mode 100644
index 0000000..3c2f120
--- /dev/null
+++ b/src/comodo/mujocoSimulator/mjxSimulator.py
@@ -0,0 +1,341 @@
+import copy
+import math
+
+import casadi as cs
+import jax
+import mediapy as media
+import mujoco
+import mujoco_viewer
+import numpy as np
+from mujoco import mjx
+
+from comodo.abstractClasses.simulator import Simulator
+
+
+class MJXSimulator(Simulator):
+    def __init__(self) -> None:
+        self.desired_pos = None
+        self.postion_control = False
+        self.compute_misalignment_gravity_fun()
+        self.jit_step = jax.jit(mjx.step)
+        self.framerate = 30
+        super().__init__()
+
+    def load_model(self, robot_model, s, xyz_rpy, kv_motors=None, Im=None):
+        self.robot_model = robot_model
+
+        mujoco_xml = robot_model.get_mujoco_model()
+
+        # Load mujoco model and data
+        mujoco_model = mujoco.MjModel.from_xml_string(mujoco_xml)
+        mujoco_data = mujoco.MjData(mujoco_model)
+        self.mj_model = mujoco_model
+        self.mj_data = mujoco_data
+
+        # Put the model and data on the accelerator device(s) and get the corresponding MJX model and data
+        self.model = mjx.put_model(mujoco_model)
+        self.data = mjx.put_data(mujoco_model, mujoco_data)
+
+        self.create_mapping_vector_from_mujoco()
+        self.create_mapping_vector_to_mujoco()
+        # mjx.mj_forward(self.model, self.data)
+        self.set_joint_vector_in_mujoco(s)
+        self.set_base_pose_in_mujoco(xyz_rpy=xyz_rpy)
+        # mjx.mj_forward(self.model, self.data)
+        self.visualize_robot_flag = False
+
+        self.Im = Im if Im is not None else np.zeros(self.robot_model.NDoF)
+        self.kv_motors = (
+            kv_motors if kv_motors is not None else np.zeros(self.robot_model.NDoF)
+        )
+        self.H_left_foot = copy.deepcopy(self.robot_model.H_left_foot)
+        self.H_right_foot = copy.deepcopy(self.robot_model.H_right_foot)
+        self.H_left_foot_num = None
+        self.H_right_foot_num = None
+        self.mass = self.robot_model.get_total_mass()
+
+    def get_contact_status(self):
+        left_wrench, rigth_wrench = self.get_feet_wrench()
+        left_foot_contact = left_wrench[2] > 0.1 * self.mass
+        right_foot_contact = rigth_wrench[2] > 0.1 * self.mass
+        return left_foot_contact, right_foot_contact
+
+    def set_visualize_robot_flag(self, visualize_robot):
+        self.visualize_robot_flag = visualize_robot
+        if self.visualize_robot_flag:
+            # self.viewer = mujoco_viewer.MujocoViewer(self.model, self.data)
+            self.renderer = mujoco.Renderer(self.mj_model)
+            self.frames = []
+
+    def set_base_pose_in_mujoco(self, xyz_rpy):
+        base_xyz_quat = np.zeros(7)
+        base_xyz_quat[:3] = xyz_rpy[:3]
+        base_xyz_quat[3:] = self.RPY_to_quat(xyz_rpy[3], xyz_rpy[4], xyz_rpy[5])
+        base_xyz_quat[2] = base_xyz_quat[2]
+        # self.data.qpos[:7] = base_xyz_quat
+        self.data = self.data.replace(qpos=self.data.qpos.at[:7].set(base_xyz_quat))
+
+    def set_joint_vector_in_mujoco(self, pos):
+        pos_muj = self.convert_vector_to_mujoco(pos)
+        indexes_joint = self.model.jnt_qposadr[1:]
+        for i in range(self.robot_model.NDoF):
+            self.data = self.data.replace(
+                qpos=self.data.qpos.at[indexes_joint[i]].set(pos_muj[i])
+            )
+
+    def set_input(self, input):
+        input_muj = self.convert_vector_to_mujoco(input)
+        self.data = self.data.replace(ctrl=input_muj)
+        np.copyto(self.data.ctrl, input_muj)
+
+    def set_position_input(self, pos):
+        pos_muj = self.convert_vector_to_mujoco(pos)
+        self.desired_pos = pos_muj
+        self.postion_control = True
+
+    def create_mapping_vector_to_mujoco(self):
+        # This function creates the to_mujoco map
+        self.to_mujoco = []
+        for mujoco_joint in self.robot_model.mujoco_joint_order:
+            try:
+                index = self.robot_model.joint_name_list.index(mujoco_joint)
+                self.to_mujoco.append(index)
+            except ValueError:
+                raise ValueError(
+                    f"Mujoco joint '{mujoco_joint}' not found in joint list."
+                )
+
+    def create_mapping_vector_from_mujoco(self):
+        # This function creates the to_mujoco map
+        self.from_mujoco = []
+        for joint in self.robot_model.joint_name_list:
+            try:
+                index = self.robot_model.mujoco_joint_order.index(joint)
+                self.from_mujoco.append(index)
+            except ValueError:
+                raise ValueError(
+                    f"Joint name list  joint '{joint}' not found in mujoco list."
+                )
+
+    def convert_vector_to_mujoco(self, array_in):
+        out_muj = np.asarray(
+            [array_in[self.to_mujoco[item]] for item in range(self.robot_model.NDoF)]
+        )
+        return out_muj
+
+    def convert_from_mujoco(self, array_muj):
+        out_classic = np.asarray(
+            [array_muj[self.from_mujoco[item]] for item in range(self.robot_model.NDoF)]
+        )
+        return out_classic
+
+    def step(self, n_step=1, visualize=True):
+        if self.postion_control:
+            for _ in range(n_step):
+                s, s_dot, tau = self.get_state(use_mujoco_convention=True)
+                kp_muj = self.convert_vector_to_mujoco(
+                    self.robot_model.kp_position_control
+                )
+                kd_muj = self.convert_vector_to_mujoco(
+                    self.robot_model.kd_position_control
+                )
+                ctrl = kp_muj * (self.desired_pos - s) - kd_muj * s_dot
+                self.data.ctrl = ctrl
+                np.copyto(self.data.ctrl, ctrl)
+                self.data = self.jit_step(self.model, self.data)
+                # mjx.mj_step1(self.model, self.data)
+                # mjx.mj_forward(self.model, self.data)
+        else:
+            # Step the simulation of n_step iterations
+            for _ in range(n_step):
+                self.data = self.jit_step(self.model, self.data)
+                # mjx.mj_step1(self.model, self.data)
+                # mjx.mj_forward(self.model, self.data)
+
+        if len(self.frames) < self.data.time * self.framerate:
+            self.visualize_robot()
+        # if self.visualize_robot_flag:
+        #     self.viewer.render()
+
+    def step_with_motors(self, n_step, torque):
+        indexes_joint_acceleration = self.model.jnt_dofadr[1:]
+        s_dot_dot = self.data.qacc[indexes_joint_acceleration[0] :]
+        for _ in range(n_step):
+            indexes_joint_acceleration = self.model.jnt_dofadr[1:]
+            s_dot_dot = self.data.qacc[indexes_joint_acceleration[0] :]
+            s_dot = self.data.qvel[indexes_joint_acceleration[0] :]
+            input = np.asarray(
+                [
+                    self.Im[self.to_mujoco[item]] * s_dot_dot[item]
+                    + self.kv_motors[self.to_mujoco[item]] * s_dot[item]
+                    + torque[self.to_mujoco[item]]
+                    for item in range(self.robot_model.NDoF)
+                ]
+            )
+
+            self.set_input(input)
+            self.step(n_step=1, visualize=False)
+        # if self.visualize_robot_flag:
+        #     self.viewer.render()
+
+    def compute_misalignment_gravity_fun(self):
+        H = cs.SX.sym("H", 4, 4)
+        theta = cs.SX.sym("theta")
+        theta = cs.dot([0, 0, 1], H[:3, 2]) - 1
+        error = cs.Function("error", [H], [theta])
+        self.error_mis = error
+
+    def check_feet_status(self, s, H_b):
+        left_foot_pose = self.robot_model.H_left_foot(H_b, s)
+        rigth_foot_pose = self.robot_model.H_right_foot(H_b, s)
+        left_foot_z = left_foot_pose[2, 3]
+        rigth_foot_z = rigth_foot_pose[2, 3]
+        left_foot_contact = not (left_foot_z > 0.1)
+        rigth_foot_contact = not (rigth_foot_z > 0.1)
+        misalignment_left = self.error_mis(left_foot_pose)
+        misalignment_rigth = self.error_mis(rigth_foot_pose)
+        left_foot_condition = abs(left_foot_contact * misalignment_left)
+        rigth_foot_condition = abs(rigth_foot_contact * misalignment_rigth)
+        misalignment_error = left_foot_condition + rigth_foot_condition
+        if (
+            abs(left_foot_contact * misalignment_left) > 0.02
+            or abs(rigth_foot_contact * misalignment_rigth) > 0.02
+        ):
+            return False, misalignment_error
+
+        return True, misalignment_error
+
+    def get_feet_wrench(self):
+        left_foot_wrench = np.zeros(6)
+        rigth_foot_wrench = np.zeros(6)
+        s, s_dot, tau = self.get_state()
+        H_b = self.get_base()
+        self.H_left_foot_num = np.array(self.H_left_foot(H_b, s))
+        self.H_right_foot_num = np.array(self.H_right_foot(H_b, s))
+        for i in range(self.data.ncon):
+            contact = self.data.contact[i]
+            c_array = np.zeros(6, dtype=np.float64)
+            mujoco.mj_contactForce(self.model, self.data, i, c_array)
+            name_contact = mujoco.mj_id2name(
+                self.model, mujoco.mjtObj.mjOBJ_GEOM, int(contact.geom[1])
+            )
+            w_H_contact = np.eye(4)
+            w_H_contact[:3, :3] = contact.frame.reshape(3, 3)
+            w_H_contact[:3, 3] = contact.pos
+            if (
+                name_contact == self.robot_model.right_foot_rear_ct
+                or name_contact == self.robot_model.right_foot_front_ct
+            ):
+                RF_H_contact = np.linalg.inv(self.H_right_foot_num) @ w_H_contact
+                wrench_RF = self.compute_resulting_wrench(RF_H_contact, c_array)
+                rigth_foot_wrench[:] += wrench_RF.reshape(6)
+            elif (
+                name_contact == self.robot_model.left_foot_front_ct
+                or name_contact == self.robot_model.left_foot_rear_ct
+            ):
+                LF_H_contact = np.linalg.inv(self.H_left_foot_num) @ w_H_contact
+                wrench_LF = self.compute_resulting_wrench(LF_H_contact, c_array)
+                left_foot_wrench[:] += wrench_LF.reshape(6)
+        return (left_foot_wrench, rigth_foot_wrench)
+
+    def compute_resulting_wrench(self, b_H_a, force_torque_a):
+        p = b_H_a[:3, 3]
+        R = b_H_a[:3, :3]
+        adjoint_matrix = np.zeros([6, 6])
+        adjoint_matrix[:3, :3] = R
+        adjoint_matrix[3:, :3] = np.cross(p, R)
+        adjoint_matrix[3:, 3:] = R
+        force_torque_b = adjoint_matrix @ force_torque_a.reshape(6, 1)
+        return force_torque_b
+
+    # note that for mujoco the ordering is w,x,y,z
+    def get_base(self):
+        indexes_joint = self.model.jnt_qposadr[1:]
+        # Extract quaternion components
+        w, x, y, z = self.data.qpos[3 : indexes_joint[0]]
+
+        # Calculate rotation matrix
+        rot_mat = np.array(
+            [
+                [
+                    1 - 2 * y * y - 2 * z * z,
+                    2 * x * y - 2 * z * w,
+                    2 * x * z + 2 * y * w,
+                    0,
+                ],
+                [
+                    2 * x * y + 2 * z * w,
+                    1 - 2 * x * x - 2 * z * z,
+                    2 * y * z - 2 * x * w,
+                    0,
+                ],
+                [
+                    2 * x * z - 2 * y * w,
+                    2 * y * z + 2 * x * w,
+                    1 - 2 * x * x - 2 * y * y,
+                    0,
+                ],
+                [0, 0, 0, 1],
+            ]
+        )
+
+        # Set up transformation matrix
+        trans_mat = np.eye(4)
+        trans_mat[:3, :3] = rot_mat[:3, :3]
+        trans_mat[:3, 3] = self.data.qpos[:3]
+        # Return transformation matrix
+        return trans_mat
+
+    def get_base_velocity(self):
+        indexes_joint_velocities = self.model.jnt_dofadr[1:]
+        return self.data.qvel[: indexes_joint_velocities[0]]
+
+    def get_state(self, use_mujoco_convention=False):
+        indexes_joint = self.model.jnt_qposadr[1:]
+        indexes_joint_velocities = self.model.jnt_dofadr[1:]
+        s = self.data.qpos[indexes_joint[0] :]
+        s_dot = self.data.qvel[indexes_joint_velocities[0] :]
+        tau = self.data.ctrl
+        if use_mujoco_convention:
+            return s, s_dot, tau
+        s_out = self.convert_from_mujoco(s)
+        s_dot_out = self.convert_from_mujoco(s_dot)
+        tau_out = self.convert_from_mujoco(tau)
+        return s_out, s_dot_out, tau_out
+
+    def close(self):
+        if self.visualize_robot_flag:
+            self.viewer.close()
+
+    def visualize_robot(self):
+        # self.viewer.render()
+        mj_data = mjx.get_data(self.mj_model, self.data)
+        self.renderer.update_scene(mj_data)
+        pixels = self.renderer.render()
+        self.frames.append(pixels)
+
+    def get_simulation_time(self):
+        return self.data.time
+
+    def get_simulation_frequency(self):
+        return self.model.opt.timestep
+
+    def RPY_to_quat(self, roll, pitch, yaw):
+        cr = math.cos(roll / 2)
+        cp = math.cos(pitch / 2)
+        cy = math.cos(yaw / 2)
+        sr = math.sin(roll / 2)
+        sp = math.sin(pitch / 2)
+        sy = math.sin(yaw / 2)
+
+        qw = cr * cp * cy + sr * sp * sy
+        qx = sr * cp * cy - cr * sp * sy
+        qy = cr * sp * cy + sr * cp * sy
+        qz = cr * cp * sy - sr * sp * cy
+
+        return [qw, qx, qy, qz]
+
+    def close_visualization(self):
+        if self.visualize_robot_flag:
+            # self.viewer.close()
+            media.show_video(self.frames, fps=self.framerate)