Skip to content

Commit

Permalink
Cleanup MJX tests and migrate them to put_model/put_data.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588304729
Change-Id: I58075383e6eed64ae00cea305a568eba6465b0b0
  • Loading branch information
erikfrey authored and copybara-github committed Dec 6, 2023
1 parent 899ba4b commit b3ccf67
Show file tree
Hide file tree
Showing 24 changed files with 456 additions and 878 deletions.
8 changes: 8 additions & 0 deletions mjx/mujoco/mjx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@

# pylint:disable=g-importing-member
from mujoco.mjx._src.collision_driver import collision
from mujoco.mjx._src.constraint import count_constraints
from mujoco.mjx._src.constraint import make_constraint
from mujoco.mjx._src.device import device_get_into
from mujoco.mjx._src.device import device_put
from mujoco.mjx._src.forward import euler
from mujoco.mjx._src.forward import forward
from mujoco.mjx._src.forward import fwd_acceleration
from mujoco.mjx._src.forward import fwd_actuation
from mujoco.mjx._src.forward import fwd_position
from mujoco.mjx._src.forward import fwd_velocity
from mujoco.mjx._src.forward import rungekutta4
from mujoco.mjx._src.forward import step
from mujoco.mjx._src.io import get_data
from mujoco.mjx._src.io import make_data
Expand All @@ -34,4 +41,5 @@
from mujoco.mjx._src.smooth import mul_m
from mujoco.mjx._src.smooth import rne
from mujoco.mjx._src.smooth import transmission
from mujoco.mjx._src.solver import solve
from mujoco.mjx._src.types import *
36 changes: 14 additions & 22 deletions mjx/mujoco/mjx/_src/collision_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def _collide(
mjcf: str, assets: Optional[Dict[str, str]] = None
) -> Tuple[mujoco.MjModel, mujoco.MjData, Model, Data]:
m = mujoco.MjModel.from_xml_string(mjcf, assets or {})
mx = mjx.device_put(m)
mx = mjx.put_model(m)
d = mujoco.MjData(m)
dx = mjx.device_put(d)
dx = mjx.put_data(m, d)

mujoco.mj_step(m, d)
collision_jit_fn = jax.jit(mjx.collision)
Expand Down Expand Up @@ -418,9 +418,9 @@ def test_filter_self_collision(self):
def test_filter_parent_child(self):
"""Tests that parent-child collisions get filtered."""
m = mujoco.MjModel.from_xml_string(self._PARENT_CHILD)
mx = mjx.device_put(m)
mx = mjx.put_model(m)
d = mujoco.MjData(m)
dx = mjx.device_put(d)
dx = mjx.put_data(m, d)

mujoco.mj_step(m, d)
collision_jit_fn = jax.jit(mjx.collision)
Expand All @@ -435,9 +435,9 @@ def test_disable_filter_parent_child(self):
"""Tests that filterparent flag disables parent-child filtering."""
m = mujoco.MjModel.from_xml_string(self._PARENT_CHILD)
m.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_FILTERPARENT
mx = mjx.device_put(m)
mx = mjx.put_model(m)
d = mujoco.MjData(m)
dx = mjx.device_put(d)
dx = mjx.put_data(m, d)

mujoco.mj_step(m, d)
collision_jit_fn = jax.jit(mjx.collision)
Expand All @@ -454,22 +454,14 @@ class NconTest(parameterized.TestCase):
"""Tests ncon."""

def test_ncon(self):
m = test_util.load_test_file('ant.xml')
d = mujoco.MjData(m)
d.qpos[2] = 0.0

mx = mjx.device_put(m)
ncon = collision_driver.ncon(mx)
self.assertEqual(ncon, 4)
m = test_util.load_test_file('constraints.xml')
ncon = collision_driver.ncon(m)
self.assertEqual(ncon, 16)

def test_disable_contact(self):
m = test_util.load_test_file('ant.xml')
d = mujoco.MjData(m)
d.qpos[2] = 0.0

m.opt.disableflags = m.opt.disableflags | DisableBit.CONTACT
mx = mjx.device_put(m)
ncon = collision_driver.ncon(mx)
m = test_util.load_test_file('constraints.xml')
m.opt.disableflags |= DisableBit.CONTACT
ncon = collision_driver.ncon(m)
self.assertEqual(ncon, 0)


Expand Down Expand Up @@ -500,12 +492,12 @@ class TopKContactTest(absltest.TestCase):

def test_top_k_contacts(self):
m = mujoco.MjModel.from_xml_string(self._CAPSULES)
mx_top_k = mjx.device_put(m)
mx_top_k = mjx.put_model(m)
mx_all = mx_top_k.replace(
nnumeric=0, name_numericadr=np.array([]), numeric_data=np.array([])
)
d = mujoco.MjData(m)
dx = mjx.device_put(d)
dx = mjx.put_data(m, d)

collision_jit_fn = jax.jit(mjx.collision)
kinematics_jit_fn = jax.jit(mjx.kinematics)
Expand Down
173 changes: 53 additions & 120 deletions mjx/mujoco/mjx/_src/constraint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,159 +15,92 @@
"""Tests for constraint functions."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jp
import mujoco
from mujoco import mjx
from mujoco.mjx._src import constraint
from mujoco.mjx._src import test_util
# pylint: disable=g-importing-member
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import SolverType
# pylint: enable=g-importing-member
import numpy as np


def _assert_eq(a, b, name, step, fname, atol=5e-3, rtol=5e-3):
err_msg = f'mismatch: {name} at step {step} in {fname}'
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=atol, rtol=rtol)
# tolerance for difference between MuJoCo and MJX constraint calculations,
# mostly due to float precision
_TOLERANCE = 5e-5


class ConstraintTest(parameterized.TestCase):
def _assert_eq(a, b, name):
tol = _TOLERANCE * 10 # avoid test noise
err_msg = f'mismatch: {name}'
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)

@parameterized.parameters(enumerate(test_util.TEST_FILES))
def test_constraints(self, seed, fname):
"""Test constraints."""
np.random.seed(seed)

# exclude convex.xml since convex contacts are not exactly equivalent
if fname == 'convex.xml':
return
def _assert_attr_eq(a, b, attr):
_assert_eq(getattr(a, attr), getattr(b, attr), attr)

m = test_util.load_test_file(fname)
d = mujoco.MjData(m)
mx = mjx.device_put(m)
dx = mjx.make_data(mx)

forward_jit_fn = jax.jit(mjx.forward)

# give the system a little kick to ensure we have non-identity rotations
d.qvel = np.random.random(m.nv)
for i in range(100):
dx = dx.replace(qpos=jax.device_put(d.qpos), qvel=jax.device_put(d.qvel))
mujoco.mj_step(m, d)
dx = forward_jit_fn(mx, dx)

nnz_filter = dx.efc_J.any(axis=1)

mj_efc_j = d.efc_J.reshape((-1, m.nv))
mjx_efc_j = dx.efc_J[nnz_filter]
_assert_eq(mj_efc_j, mjx_efc_j, 'efc_J', i, fname)

mjx_efc_d = dx.efc_D[nnz_filter]
_assert_eq(d.efc_D, mjx_efc_d, 'efc_D', i, fname)

mjx_efc_aref = dx.efc_aref[nnz_filter]
_assert_eq(d.efc_aref, mjx_efc_aref, 'efc_aref', i, fname)

mjx_efc_frictionloss = dx.efc_frictionloss[nnz_filter]
_assert_eq(
d.efc_frictionloss,
mjx_efc_frictionloss,
'efc_frictionloss',
i,
fname,
)

_JNT_RANGE = """
<mujoco>
<worldbody>
<body pos="0 0 1">
<joint type="slide" axis="1 0 0" range="-1.8 1.8" solreflimit=".08 1"
damping="5e-4"/>
<geom type="box" size="0.2 0.15 0.1" mass="1"/>
<body>
<joint axis="0 1 0" damping="2e-6"/>
<geom type="capsule" fromto="0 0 0 0 0 1" size="0.045" mass=".1"/>
</body>
</body>
</worldbody>
</mujoco>
"""

def test_jnt_range(self):
"""Tests that mixed joint ranges are respected."""
# TODO(robotics-simulation): also test ball
m = mujoco.MjModel.from_xml_string(self._JNT_RANGE)
m.opt.solver = SolverType.CG.value
d = mujoco.MjData(m)
d.qpos = np.array([2.0, 15.0])

mx = mjx.device_put(m)
dx = mjx.device_put(d)
efc = jax.jit(constraint._instantiate_limit_slide_hinge)(mx, dx)
class ConstraintTest(absltest.TestCase):

# first joint is outside the joint range
np.testing.assert_array_almost_equal(efc.J[0, 0], -1.0)
def test_constraints(self):
"""Test constraints."""
m = test_util.load_test_file('constraints.xml')
d = mujoco.MjData(m)
mujoco.mj_step(m, d, 100) # at 100 steps mix of active/inactive constraints
mujoco.mj_forward(m, d)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)

# second joint has no range, so only one efc row
self.assertEqual(efc.J.shape[0], 1)
dx = mjx.make_constraint(mx, dx)
nnz = dx.efc_J.any(axis=1)
_assert_eq(d.efc_J, dx.efc_J[nnz].reshape(-1), 'efc_J')
_assert_eq(d.efc_D, dx.efc_D[nnz], 'efc_D')
_assert_eq(d.efc_aref, dx.efc_aref[nnz], 'efc_aref')
_assert_eq(d.efc_frictionloss, dx.efc_frictionloss[nnz], 'efc_frictionloss')

def test_disable_refsafe(self):
m = test_util.load_test_file('ant.xml')
m = test_util.load_test_file('constraints.xml')

timeconst = m.opt.timestep / 4.0 # timeconst < 2 * timestep
solimp = jp.array([timeconst, 1.0])
solref = jp.array([0.8, 0.99, 0.001, 0.2, 2])
pos = jp.ones(3)

m.opt.disableflags = m.opt.disableflags | DisableBit.REFSAFE
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.REFSAFE
mx = mjx.device_put(m)
k, *_ = constraint._kbi(mx, solimp, solref, pos)
self.assertEqual(k, 1 / (0.99**2 * timeconst**2))

m.opt.disableflags = m.opt.disableflags & ~DisableBit.REFSAFE
mx = mjx.device_put(m)
k, *_ = constraint._kbi(mx, solimp, solref, pos)
self.assertEqual(k, 1 / (0.99**2 * (2 * m.opt.timestep) ** 2))

def test_disableconstraint(self):
m = test_util.load_test_file('ant.xml')
d = mujoco.MjData(m)

m.opt.disableflags = m.opt.disableflags | DisableBit.CONSTRAINT
mx, dx = mjx.device_put(m), mjx.device_put(d)
dx = constraint.make_constraint(mx, dx)
def test_disable_constraint(self):
m = test_util.load_test_file('constraints.xml')
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONSTRAINT
ne, nf, nl, nc = mjx.count_constraints(m)
self.assertEqual(ne, 0)
self.assertEqual(nf, 0)
self.assertEqual(nl, 0)
self.assertEqual(nc, 0)
dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m))
self.assertEqual(dx.efc_J.shape[0], 0)

def test_disable_equality(self):
m = test_util.load_test_file('equality.xml')
d = mujoco.MjData(m)

m.opt.disableflags = m.opt.disableflags | DisableBit.EQUALITY
mx, dx = mjx.device_put(m), mjx.device_put(d)
dx = constraint.make_constraint(mx, dx)
self.assertEqual(dx.efc_J.shape[0], 0)
m = test_util.load_test_file('constraints.xml')
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.EQUALITY
ne, nf, nl, nc = mjx.count_constraints(m)
self.assertEqual(ne, 0)
self.assertEqual(nf, 0)
self.assertEqual(nl, 2)
self.assertEqual(nc, 64)
dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m))
self.assertEqual(dx.efc_J.shape[0], 66) # only joint range, contact

def test_disable_contact(self):
m = test_util.load_test_file('ant.xml')
d = mujoco.MjData(m)
d.qpos[2] = 0.0
mujoco.mj_forward(m, d)

m.opt.disableflags = m.opt.disableflags & ~DisableBit.CONTACT
mx, dx = mjx.device_put(m), mjx.device_put(d)
dx = dx.tree_replace(
{'contact.frame': dx.contact.frame.reshape((-1, 3, 3))}
)
efc = constraint._instantiate_contact(mx, dx)
self.assertIsNotNone(efc)

m.opt.disableflags = m.opt.disableflags | DisableBit.CONTACT
mx, dx = mjx.device_put(m), mjx.device_put(d)
efc = constraint._instantiate_contact(mx, dx)
self.assertIsNone(efc)
m = test_util.load_test_file('constraints.xml')
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONTACT
ne, nf, nl, nc = mjx.count_constraints(m)
self.assertEqual(ne, 10)
self.assertEqual(nf, 0)
self.assertEqual(nl, 2)
self.assertEqual(nc, 0)
dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m))
self.assertEqual(dx.efc_J.shape[0], 12) # only joint range, limit


if __name__ == '__main__':
Expand Down
9 changes: 9 additions & 0 deletions mjx/mujoco/mjx/_src/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def device_put(value):
Returns:
on-device MJX struct reflecting the input value
"""
warnings.warn(
'device_put is deprecated, use put_model and put_data instead',
category=DeprecationWarning,
)

clz = _TYPE_MAP.get(type(value))
if clz is None:
raise NotImplementedError(f'{type(value)} is not supported for device_put.')
Expand Down Expand Up @@ -242,6 +247,10 @@ def device_get_into(result, value):
Raises:
RuntimeError: if result length doesn't match data batch size
"""
warnings.warn(
'device_get_into is deprecated, use get_data instead',
category=DeprecationWarning,
)

value = jax.device_get(value)

Expand Down
10 changes: 5 additions & 5 deletions mjx/mujoco/mjx/_src/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,31 +130,31 @@ def test_cone(self):
mjx.device_put(m)

def test_trn(self):
m = test_util.load_test_file('ant.xml')
m = test_util.load_test_file('pendula.xml')
m.actuator_trntype[0] = mujoco.mjtTrn.mjTRN_SITE
with self.assertRaises(NotImplementedError):
mjx.device_put(m)

def test_dyn(self):
m = test_util.load_test_file('ant.xml')
m = test_util.load_test_file('pendula.xml')
m.actuator_dyntype[0] = mujoco.mjtDyn.mjDYN_MUSCLE
with self.assertRaises(NotImplementedError):
mjx.device_put(m)

def test_gain(self):
m = test_util.load_test_file('ant.xml')
m = test_util.load_test_file('pendula.xml')
m.actuator_gaintype[0] = mujoco.mjtGain.mjGAIN_MUSCLE
with self.assertRaises(NotImplementedError):
mjx.device_put(m)

def test_bias(self):
m = test_util.load_test_file('ant.xml')
m = test_util.load_test_file('pendula.xml')
m.actuator_gaintype[0] = mujoco.mjtGain.mjGAIN_MUSCLE
with self.assertRaises(NotImplementedError):
mjx.device_put(m)

def test_condim(self):
m = test_util.load_test_file('ant.xml')
m = test_util.load_test_file('constraints.xml')
for i in [1, 4, 6]:
m.geom_condim[0] = i
with self.assertRaises(NotImplementedError):
Expand Down
Loading

0 comments on commit b3ccf67

Please sign in to comment.