Skip to content

Commit b3ccf67

Browse files
erikfreycopybara-github
authored andcommitted
Cleanup MJX tests and migrate them to put_model/put_data.
PiperOrigin-RevId: 588304729 Change-Id: I58075383e6eed64ae00cea305a568eba6465b0b0
1 parent 899ba4b commit b3ccf67

24 files changed

+456
-878
lines changed

mjx/mujoco/mjx/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616

1717
# pylint:disable=g-importing-member
1818
from mujoco.mjx._src.collision_driver import collision
19+
from mujoco.mjx._src.constraint import count_constraints
1920
from mujoco.mjx._src.constraint import make_constraint
2021
from mujoco.mjx._src.device import device_get_into
2122
from mujoco.mjx._src.device import device_put
23+
from mujoco.mjx._src.forward import euler
2224
from mujoco.mjx._src.forward import forward
25+
from mujoco.mjx._src.forward import fwd_acceleration
26+
from mujoco.mjx._src.forward import fwd_actuation
27+
from mujoco.mjx._src.forward import fwd_position
28+
from mujoco.mjx._src.forward import fwd_velocity
29+
from mujoco.mjx._src.forward import rungekutta4
2330
from mujoco.mjx._src.forward import step
2431
from mujoco.mjx._src.io import get_data
2532
from mujoco.mjx._src.io import make_data
@@ -34,4 +41,5 @@
3441
from mujoco.mjx._src.smooth import mul_m
3542
from mujoco.mjx._src.smooth import rne
3643
from mujoco.mjx._src.smooth import transmission
44+
from mujoco.mjx._src.solver import solve
3745
from mujoco.mjx._src.types import *

mjx/mujoco/mjx/_src/collision_driver_test.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def _collide(
5252
mjcf: str, assets: Optional[Dict[str, str]] = None
5353
) -> Tuple[mujoco.MjModel, mujoco.MjData, Model, Data]:
5454
m = mujoco.MjModel.from_xml_string(mjcf, assets or {})
55-
mx = mjx.device_put(m)
55+
mx = mjx.put_model(m)
5656
d = mujoco.MjData(m)
57-
dx = mjx.device_put(d)
57+
dx = mjx.put_data(m, d)
5858

5959
mujoco.mj_step(m, d)
6060
collision_jit_fn = jax.jit(mjx.collision)
@@ -418,9 +418,9 @@ def test_filter_self_collision(self):
418418
def test_filter_parent_child(self):
419419
"""Tests that parent-child collisions get filtered."""
420420
m = mujoco.MjModel.from_xml_string(self._PARENT_CHILD)
421-
mx = mjx.device_put(m)
421+
mx = mjx.put_model(m)
422422
d = mujoco.MjData(m)
423-
dx = mjx.device_put(d)
423+
dx = mjx.put_data(m, d)
424424

425425
mujoco.mj_step(m, d)
426426
collision_jit_fn = jax.jit(mjx.collision)
@@ -435,9 +435,9 @@ def test_disable_filter_parent_child(self):
435435
"""Tests that filterparent flag disables parent-child filtering."""
436436
m = mujoco.MjModel.from_xml_string(self._PARENT_CHILD)
437437
m.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_FILTERPARENT
438-
mx = mjx.device_put(m)
438+
mx = mjx.put_model(m)
439439
d = mujoco.MjData(m)
440-
dx = mjx.device_put(d)
440+
dx = mjx.put_data(m, d)
441441

442442
mujoco.mj_step(m, d)
443443
collision_jit_fn = jax.jit(mjx.collision)
@@ -454,22 +454,14 @@ class NconTest(parameterized.TestCase):
454454
"""Tests ncon."""
455455

456456
def test_ncon(self):
457-
m = test_util.load_test_file('ant.xml')
458-
d = mujoco.MjData(m)
459-
d.qpos[2] = 0.0
460-
461-
mx = mjx.device_put(m)
462-
ncon = collision_driver.ncon(mx)
463-
self.assertEqual(ncon, 4)
457+
m = test_util.load_test_file('constraints.xml')
458+
ncon = collision_driver.ncon(m)
459+
self.assertEqual(ncon, 16)
464460

465461
def test_disable_contact(self):
466-
m = test_util.load_test_file('ant.xml')
467-
d = mujoco.MjData(m)
468-
d.qpos[2] = 0.0
469-
470-
m.opt.disableflags = m.opt.disableflags | DisableBit.CONTACT
471-
mx = mjx.device_put(m)
472-
ncon = collision_driver.ncon(mx)
462+
m = test_util.load_test_file('constraints.xml')
463+
m.opt.disableflags |= DisableBit.CONTACT
464+
ncon = collision_driver.ncon(m)
473465
self.assertEqual(ncon, 0)
474466

475467

@@ -500,12 +492,12 @@ class TopKContactTest(absltest.TestCase):
500492

501493
def test_top_k_contacts(self):
502494
m = mujoco.MjModel.from_xml_string(self._CAPSULES)
503-
mx_top_k = mjx.device_put(m)
495+
mx_top_k = mjx.put_model(m)
504496
mx_all = mx_top_k.replace(
505497
nnumeric=0, name_numericadr=np.array([]), numeric_data=np.array([])
506498
)
507499
d = mujoco.MjData(m)
508-
dx = mjx.device_put(d)
500+
dx = mjx.put_data(m, d)
509501

510502
collision_jit_fn = jax.jit(mjx.collision)
511503
kinematics_jit_fn = jax.jit(mjx.kinematics)

mjx/mujoco/mjx/_src/constraint_test.py

Lines changed: 53 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -15,159 +15,92 @@
1515
"""Tests for constraint functions."""
1616

1717
from absl.testing import absltest
18-
from absl.testing import parameterized
19-
import jax
2018
from jax import numpy as jp
2119
import mujoco
2220
from mujoco import mjx
2321
from mujoco.mjx._src import constraint
2422
from mujoco.mjx._src import test_util
25-
# pylint: disable=g-importing-member
26-
from mujoco.mjx._src.types import DisableBit
27-
from mujoco.mjx._src.types import SolverType
28-
# pylint: enable=g-importing-member
2923
import numpy as np
3024

3125

32-
def _assert_eq(a, b, name, step, fname, atol=5e-3, rtol=5e-3):
33-
err_msg = f'mismatch: {name} at step {step} in {fname}'
34-
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=atol, rtol=rtol)
26+
# tolerance for difference between MuJoCo and MJX constraint calculations,
27+
# mostly due to float precision
28+
_TOLERANCE = 5e-5
3529

3630

37-
class ConstraintTest(parameterized.TestCase):
31+
def _assert_eq(a, b, name):
32+
tol = _TOLERANCE * 10 # avoid test noise
33+
err_msg = f'mismatch: {name}'
34+
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
3835

39-
@parameterized.parameters(enumerate(test_util.TEST_FILES))
40-
def test_constraints(self, seed, fname):
41-
"""Test constraints."""
42-
np.random.seed(seed)
4336

44-
# exclude convex.xml since convex contacts are not exactly equivalent
45-
if fname == 'convex.xml':
46-
return
37+
def _assert_attr_eq(a, b, attr):
38+
_assert_eq(getattr(a, attr), getattr(b, attr), attr)
4739

48-
m = test_util.load_test_file(fname)
49-
d = mujoco.MjData(m)
50-
mx = mjx.device_put(m)
51-
dx = mjx.make_data(mx)
52-
53-
forward_jit_fn = jax.jit(mjx.forward)
54-
55-
# give the system a little kick to ensure we have non-identity rotations
56-
d.qvel = np.random.random(m.nv)
57-
for i in range(100):
58-
dx = dx.replace(qpos=jax.device_put(d.qpos), qvel=jax.device_put(d.qvel))
59-
mujoco.mj_step(m, d)
60-
dx = forward_jit_fn(mx, dx)
61-
62-
nnz_filter = dx.efc_J.any(axis=1)
63-
64-
mj_efc_j = d.efc_J.reshape((-1, m.nv))
65-
mjx_efc_j = dx.efc_J[nnz_filter]
66-
_assert_eq(mj_efc_j, mjx_efc_j, 'efc_J', i, fname)
67-
68-
mjx_efc_d = dx.efc_D[nnz_filter]
69-
_assert_eq(d.efc_D, mjx_efc_d, 'efc_D', i, fname)
70-
71-
mjx_efc_aref = dx.efc_aref[nnz_filter]
72-
_assert_eq(d.efc_aref, mjx_efc_aref, 'efc_aref', i, fname)
73-
74-
mjx_efc_frictionloss = dx.efc_frictionloss[nnz_filter]
75-
_assert_eq(
76-
d.efc_frictionloss,
77-
mjx_efc_frictionloss,
78-
'efc_frictionloss',
79-
i,
80-
fname,
81-
)
82-
83-
_JNT_RANGE = """
84-
<mujoco>
85-
<worldbody>
86-
<body pos="0 0 1">
87-
<joint type="slide" axis="1 0 0" range="-1.8 1.8" solreflimit=".08 1"
88-
damping="5e-4"/>
89-
<geom type="box" size="0.2 0.15 0.1" mass="1"/>
90-
<body>
91-
<joint axis="0 1 0" damping="2e-6"/>
92-
<geom type="capsule" fromto="0 0 0 0 0 1" size="0.045" mass=".1"/>
93-
</body>
94-
</body>
95-
</worldbody>
96-
</mujoco>
97-
"""
98-
99-
def test_jnt_range(self):
100-
"""Tests that mixed joint ranges are respected."""
101-
# TODO(robotics-simulation): also test ball
102-
m = mujoco.MjModel.from_xml_string(self._JNT_RANGE)
103-
m.opt.solver = SolverType.CG.value
104-
d = mujoco.MjData(m)
105-
d.qpos = np.array([2.0, 15.0])
10640

107-
mx = mjx.device_put(m)
108-
dx = mjx.device_put(d)
109-
efc = jax.jit(constraint._instantiate_limit_slide_hinge)(mx, dx)
41+
class ConstraintTest(absltest.TestCase):
11042

111-
# first joint is outside the joint range
112-
np.testing.assert_array_almost_equal(efc.J[0, 0], -1.0)
43+
def test_constraints(self):
44+
"""Test constraints."""
45+
m = test_util.load_test_file('constraints.xml')
46+
d = mujoco.MjData(m)
47+
mujoco.mj_step(m, d, 100) # at 100 steps mix of active/inactive constraints
48+
mujoco.mj_forward(m, d)
49+
mx = mjx.put_model(m)
50+
dx = mjx.put_data(m, d)
11351

114-
# second joint has no range, so only one efc row
115-
self.assertEqual(efc.J.shape[0], 1)
52+
dx = mjx.make_constraint(mx, dx)
53+
nnz = dx.efc_J.any(axis=1)
54+
_assert_eq(d.efc_J, dx.efc_J[nnz].reshape(-1), 'efc_J')
55+
_assert_eq(d.efc_D, dx.efc_D[nnz], 'efc_D')
56+
_assert_eq(d.efc_aref, dx.efc_aref[nnz], 'efc_aref')
57+
_assert_eq(d.efc_frictionloss, dx.efc_frictionloss[nnz], 'efc_frictionloss')
11658

11759
def test_disable_refsafe(self):
118-
m = test_util.load_test_file('ant.xml')
60+
m = test_util.load_test_file('constraints.xml')
11961

12062
timeconst = m.opt.timestep / 4.0 # timeconst < 2 * timestep
12163
solimp = jp.array([timeconst, 1.0])
12264
solref = jp.array([0.8, 0.99, 0.001, 0.2, 2])
12365
pos = jp.ones(3)
12466

125-
m.opt.disableflags = m.opt.disableflags | DisableBit.REFSAFE
67+
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.REFSAFE
12668
mx = mjx.device_put(m)
12769
k, *_ = constraint._kbi(mx, solimp, solref, pos)
12870
self.assertEqual(k, 1 / (0.99**2 * timeconst**2))
12971

130-
m.opt.disableflags = m.opt.disableflags & ~DisableBit.REFSAFE
131-
mx = mjx.device_put(m)
132-
k, *_ = constraint._kbi(mx, solimp, solref, pos)
133-
self.assertEqual(k, 1 / (0.99**2 * (2 * m.opt.timestep) ** 2))
134-
135-
def test_disableconstraint(self):
136-
m = test_util.load_test_file('ant.xml')
137-
d = mujoco.MjData(m)
138-
139-
m.opt.disableflags = m.opt.disableflags | DisableBit.CONSTRAINT
140-
mx, dx = mjx.device_put(m), mjx.device_put(d)
141-
dx = constraint.make_constraint(mx, dx)
72+
def test_disable_constraint(self):
73+
m = test_util.load_test_file('constraints.xml')
74+
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONSTRAINT
75+
ne, nf, nl, nc = mjx.count_constraints(m)
76+
self.assertEqual(ne, 0)
77+
self.assertEqual(nf, 0)
78+
self.assertEqual(nl, 0)
79+
self.assertEqual(nc, 0)
80+
dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m))
14281
self.assertEqual(dx.efc_J.shape[0], 0)
14382

14483
def test_disable_equality(self):
145-
m = test_util.load_test_file('equality.xml')
146-
d = mujoco.MjData(m)
147-
148-
m.opt.disableflags = m.opt.disableflags | DisableBit.EQUALITY
149-
mx, dx = mjx.device_put(m), mjx.device_put(d)
150-
dx = constraint.make_constraint(mx, dx)
151-
self.assertEqual(dx.efc_J.shape[0], 0)
84+
m = test_util.load_test_file('constraints.xml')
85+
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.EQUALITY
86+
ne, nf, nl, nc = mjx.count_constraints(m)
87+
self.assertEqual(ne, 0)
88+
self.assertEqual(nf, 0)
89+
self.assertEqual(nl, 2)
90+
self.assertEqual(nc, 64)
91+
dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m))
92+
self.assertEqual(dx.efc_J.shape[0], 66) # only joint range, contact
15293

15394
def test_disable_contact(self):
154-
m = test_util.load_test_file('ant.xml')
155-
d = mujoco.MjData(m)
156-
d.qpos[2] = 0.0
157-
mujoco.mj_forward(m, d)
158-
159-
m.opt.disableflags = m.opt.disableflags & ~DisableBit.CONTACT
160-
mx, dx = mjx.device_put(m), mjx.device_put(d)
161-
dx = dx.tree_replace(
162-
{'contact.frame': dx.contact.frame.reshape((-1, 3, 3))}
163-
)
164-
efc = constraint._instantiate_contact(mx, dx)
165-
self.assertIsNotNone(efc)
166-
167-
m.opt.disableflags = m.opt.disableflags | DisableBit.CONTACT
168-
mx, dx = mjx.device_put(m), mjx.device_put(d)
169-
efc = constraint._instantiate_contact(mx, dx)
170-
self.assertIsNone(efc)
95+
m = test_util.load_test_file('constraints.xml')
96+
m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONTACT
97+
ne, nf, nl, nc = mjx.count_constraints(m)
98+
self.assertEqual(ne, 10)
99+
self.assertEqual(nf, 0)
100+
self.assertEqual(nl, 2)
101+
self.assertEqual(nc, 0)
102+
dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m))
103+
self.assertEqual(dx.efc_J.shape[0], 12) # only joint range, limit
171104

172105

173106
if __name__ == '__main__':

mjx/mujoco/mjx/_src/device.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def device_put(value):
184184
Returns:
185185
on-device MJX struct reflecting the input value
186186
"""
187+
warnings.warn(
188+
'device_put is deprecated, use put_model and put_data instead',
189+
category=DeprecationWarning,
190+
)
191+
187192
clz = _TYPE_MAP.get(type(value))
188193
if clz is None:
189194
raise NotImplementedError(f'{type(value)} is not supported for device_put.')
@@ -242,6 +247,10 @@ def device_get_into(result, value):
242247
Raises:
243248
RuntimeError: if result length doesn't match data batch size
244249
"""
250+
warnings.warn(
251+
'device_get_into is deprecated, use get_data instead',
252+
category=DeprecationWarning,
253+
)
245254

246255
value = jax.device_get(value)
247256

mjx/mujoco/mjx/_src/device_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,31 +130,31 @@ def test_cone(self):
130130
mjx.device_put(m)
131131

132132
def test_trn(self):
133-
m = test_util.load_test_file('ant.xml')
133+
m = test_util.load_test_file('pendula.xml')
134134
m.actuator_trntype[0] = mujoco.mjtTrn.mjTRN_SITE
135135
with self.assertRaises(NotImplementedError):
136136
mjx.device_put(m)
137137

138138
def test_dyn(self):
139-
m = test_util.load_test_file('ant.xml')
139+
m = test_util.load_test_file('pendula.xml')
140140
m.actuator_dyntype[0] = mujoco.mjtDyn.mjDYN_MUSCLE
141141
with self.assertRaises(NotImplementedError):
142142
mjx.device_put(m)
143143

144144
def test_gain(self):
145-
m = test_util.load_test_file('ant.xml')
145+
m = test_util.load_test_file('pendula.xml')
146146
m.actuator_gaintype[0] = mujoco.mjtGain.mjGAIN_MUSCLE
147147
with self.assertRaises(NotImplementedError):
148148
mjx.device_put(m)
149149

150150
def test_bias(self):
151-
m = test_util.load_test_file('ant.xml')
151+
m = test_util.load_test_file('pendula.xml')
152152
m.actuator_gaintype[0] = mujoco.mjtGain.mjGAIN_MUSCLE
153153
with self.assertRaises(NotImplementedError):
154154
mjx.device_put(m)
155155

156156
def test_condim(self):
157-
m = test_util.load_test_file('ant.xml')
157+
m = test_util.load_test_file('constraints.xml')
158158
for i in [1, 4, 6]:
159159
m.geom_condim[0] = i
160160
with self.assertRaises(NotImplementedError):

0 commit comments

Comments
 (0)