Skip to content

Go2 support#110

Open
aatb-ch wants to merge 8 commits intogoogle-deepmind:mainfrom
aatb-ch:go2
Open

Go2 support#110
aatb-ch wants to merge 8 commits intogoogle-deepmind:mainfrom
aatb-ch:go2

Conversation

@aatb-ch
Copy link
Copy Markdown

@aatb-ch aatb-ch commented Apr 11, 2025

This PR adds Unitree Go2 support, based off existing Go1 support. Used the Menagerie Go2 MJX model and adjusted accordingly to add correct sensors, collisions etc.

  • Joystick task tested with flat/rough terrain and feet collision only

TODO: adjust full collision mjx, not 100% sure, seems some things are missing, have to go through the mesh of Go1 and compare, then test getup/handstand before adding these tasks.

@google-cla
Copy link
Copy Markdown

google-cla bot commented Apr 11, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented Apr 16, 2025

So i've worked on the full collision mesh and examples, i have trained successfully Joystick, Handstand, Footstand and Getup. The policies need some rewards tuning but training works.

Let me know if I need to do anything else.

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented Apr 23, 2025

note: the actuator order in the mjx for Go2 does not follow the Unitree ordering for legs which is FR/FL/RR/RL, in the mjx it's FL/FR/RL/RR. just a note as simply forwarding the actions to the default order in LowCmd leads to mixing up the joints.

Should this be fixed in the MJX? or it's an implementation detail left to the driver?

@xander-2077
Copy link
Copy Markdown

Hello! Have you successfully trained Go2Getup and sim2realized it to a real robot? I found that a single training does not work like this:

python train_jax_ppo.py --env_name=Go2Getup

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented May 5, 2025

Yes I trained the Joystick policy and transfered it on a real Go2 successfully, the Getup and Handstand are straight copied from Go1, but from quick tests they did result in successful policies in sim, I didnt transfer these on the real Go2.

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented May 5, 2025

for Getup there's an issue about it for Go1 so might have a look #65

@xander-2077
Copy link
Copy Markdown

xander-2077 commented May 5, 2025

for Getup there's an issue about it for Go1 so might have a look #65

That's the key point! It's mentioned in #65 that 50M timesteps is not enough for training Go1Getup. But should I train 750M timesteps at once, or train 50M timesteps each time and repeat loading checkpoints? More importantly, the paper mentions:

  1. Train with a power termination cutoff of 400 W.
  2. Finetune with a joint velocity cost.

How should these two tricks be added to the training process?

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented May 5, 2025

that's a bit off-topic regarding this PR, I'd suggest you ask directly in the issue itself since it's pretty much the same problem, the Go1/Go2 architectures are very similar.

@kevinzakka
Copy link
Copy Markdown
Collaborator

Hi @aatb-ch thanks for the PR! I'll try to get to this after the CoRL supplemental deadline (probs end of week).

@xander-2077
Copy link
Copy Markdown

that's a bit off-topic regarding this PR, I'd suggest you ask directly in the issue itself since it's pretty much the same problem, the Go1/Go2 architectures are very similar.

Thank you! I have reproduced it after 750M timesteps training.

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented May 6, 2025

@kevinzakka super, yeah no stress just let me know once you got time if I need to change anything.

@DerSimi
Copy link
Copy Markdown

DerSimi commented Sep 2, 2025

Hi,
got the joystick environment working on the real robot, using this PR. Trained the policy with domain randomization.

https://github.com/DerSimi/unitree_go2_sim2real

But note, when the Go2 is in low state mode, which is necessary for low level control, the "sportstatemode" is not published by the robot. This means that the linear velocity used as an observation here is not available. In my code, you can see that I circumvented this by setting the 'linvel' in the observation to the current command. It's a wonder it worked at all.

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented Sep 2, 2025

even more, try zeroing out the linvel, gyro and gravity, and it still works! but yeah you dont get any state estimation from the internal sportmodestate, have to estimate it some other way. i did same as you initially but realized it's the same data as the command passed as observation anyway so it doesnt really matter to pass it again instead of linvel.

@YilangLiu
Copy link
Copy Markdown

@DerSimi May I know if you still have sim-to-real code available?

@aatb-ch
Copy link
Copy Markdown
Author

aatb-ch commented Nov 3, 2025

hi, @kevinzakka was there anything missing for merging?

@bhenriquezsoto
Copy link
Copy Markdown

bhenriquezsoto commented Dec 23, 2025

@aatb-ch do you think you could share the code that you used to train and then make inference from the policy of the model please? I've been trying to do precisely the same thing modifying the go1 files to go2 (before knowing this PR actually existed) and I've arrived at a good result in the recording part (see in https://drive.google.com/file/d/1aaLDFLu4hyKNuKUTYwuNz3HxN5GG7BK5/view?usp=sharing).

The problem is that I don't know why, but I'm trying to run it separately using the passive viewer of mujoco instead of just recording video and the policy is working pretty bad. Right now this is the code that I'm using trying to replicate the behavior. Could you tell me if you see any error (or difference) with what you used please??

This is how the robot is behavioring right now: https://drive.google.com/file/d/1J84Kv9RD5Pw7A9nEjK52qyiqF20p-2wE/view?usp=sharing

import os
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import model
from mujoco_playground.config import locomotion_params
from mujoco_playground import registry
import jax
from mujoco_playground._src.mjx_env import get_sensor_data
import mujoco
import time
import mujoco.viewer
import logging
from mujoco import mjx
from typing import List
from mujoco_playground._src.locomotion.go2 import go2_constants as consts


os.environ['JAX_PLATFORMS'] = 'gpu'
jax.config.update('jax_platform_name', 'gpu')
jax.config.update('jax_disable_jit', False)
# Remove jax debug logging
logging.getLogger('jax').setLevel(logging.WARNING)

def get_obs_sensors(model: mujoco.MjModel, data: mjx.Data) -> List:
	gyro = get_sensor_data(model, data, consts.GYRO_SENSOR)
	resized_matrix = jax.numpy.array(data.site_xmat[model.site("imu").id]).reshape((3, 3))
	gravity = resized_matrix @ jax.numpy.array([0, 0, -1])
	joint_angles = data.qpos[7:]
	joint_velocities = data.qvel[6:]
	linvel = get_sensor_data(model, data, consts.LOCAL_LINVEL_SENSOR)

	return [gyro, gravity, joint_angles, joint_velocities, linvel]


def import_model(env_name='Go2JoystickFlatTerrain', env_cfg=None):
	"""
	This function imports the PPO trained model for the task 'Go2JoystickFlatTerrain'.
	"""
	
	
	ppo_params = locomotion_params.brax_ppo_config(env_name)
	model_params = model.load_params('ppo_go2joystick_flatterrain_params_v1_ctrl_002_sim_0004_impratio_100')
	ppo = ppo_networks.make_ppo_networks(action_size=env_cfg.action_size, observation_size=env_cfg.observation_size, **ppo_params.network_factory)
	make_inference = ppo_networks.make_inference_fn(ppo)
	inference_fnTEST = make_inference(model_params, deterministic=True)

	return inference_fnTEST

if __name__ == "__main__":
	rng = jax.random.PRNGKey(0)
	env_name = 'Go2JoystickFlatTerrain'
	env_cfg = registry.get_default_config(env_name)

	m = mujoco.MjModel.from_xml_path('./xmls/scene_mjx_feetonly_flat_terrain.xml')
	d = mujoco.MjData(m)

	inference = import_model(env_name, env_cfg)
	last_action = jax.numpy.zeros(env_cfg.action_size)
	inference = jax.jit(inference)
	command = jax.numpy.array([1.0, 0.0, 0.0]) 
	counter_control = 0
	counter_init = 0

	with mujoco.viewer.launch_passive(m, d) as viewer:
		# Obtain observations
		mujoco.mj_resetData(m, d)
		# act_rng, rng = jax.random.split(rng)
		# obs_sensors = get_obs_sensors(m, d)
		# obs = jax.numpy.concatenate(obs_sensors + [last_action, command_1])
		# # Inference from command -> angle off set for each of the joints
		# ctrl, _ = inference(obs, act_rng)
		timer_control = time.time()
		motors_targets = m.keyframe("home").qpos[7:]
		while viewer.is_running:
			# print(f"Control dt: {env_cfg.ctrl_dt}s, Sim dt: {env_cfg.sim_dt}s")
			step_start = time.time()
			if (counter_control % int(env_cfg.ctrl_dt / env_cfg.sim_dt)) == 0:
				act_rng, rng = jax.random.split(rng)
				# print(f"Time elapsed for random split: {time.time() - step_start:.4f}s")
				obs_sensors = get_obs_sensors(m, d)
				# print(f"Time elapsed for getting sensors: {time.time() - step_start:.4f}s")
				obs = jax.numpy.concatenate(obs_sensors + [last_action, command])
				# print(f"Time elapsed for concatenating obs: {time.time() - step_start:.4f}s")
				# Inference from command -> angle off set for each of the joints
				ctrl, _ = inference(obs, act_rng)
				# print(f"Time elapsed for inference: {time.time() - step_start:.4f}s")
				motors_targets = m.keyframe("home").qpos[7:] + ctrl * env_cfg.action_scale
				# print(f"Time elapsed for getting motor targets: {time.time() - step_start:.4f}s")
				timer_control = time.time()
				last_action = ctrl
				print("Time used in control loop: {:.4f}s".format(time.time() - step_start))
			
			d.ctrl = motors_targets	

			time_until_next_step = env_cfg.sim_dt - (time.time() - step_start)
			# if time_until_next_step > 0:
			# 	time.sleep(time_until_next_step)
			# else:
			# 	print(f"Step took longer ({env_cfg.sim_dt - time_until_next_step:.4f}s) than sim_dt of {env_cfg.sim_dt}s")
			# 	pass

			mujoco.mj_step(m, d)
			viewer.sync()
			counter_control += 1

Thanks in advance!

@bhenriquezsoto
Copy link
Copy Markdown

Yes I trained the Joystick policy and transfered it on a real Go2 successfully, the Getup and Handstand are straight copied from Go1, but from quick tests they did result in successful policies in sim, I didnt transfer these on the real Go2.

Can I ask you what values of KP and KD did you use for the joystick policy in the sim2real please?? For some reason my policy when I use the same KP and KD

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants