|
15 | 15 | """Tests for constraint functions."""
|
16 | 16 |
|
17 | 17 | from absl.testing import absltest
|
18 |
| -from absl.testing import parameterized |
19 |
| -import jax |
20 | 18 | from jax import numpy as jp
|
21 | 19 | import mujoco
|
22 | 20 | from mujoco import mjx
|
23 | 21 | from mujoco.mjx._src import constraint
|
24 | 22 | from mujoco.mjx._src import test_util
|
25 |
| -# pylint: disable=g-importing-member |
26 |
| -from mujoco.mjx._src.types import DisableBit |
27 |
| -from mujoco.mjx._src.types import SolverType |
28 |
| -# pylint: enable=g-importing-member |
29 | 23 | import numpy as np
|
30 | 24 |
|
31 | 25 |
|
32 |
| -def _assert_eq(a, b, name, step, fname, atol=5e-3, rtol=5e-3): |
33 |
| - err_msg = f'mismatch: {name} at step {step} in {fname}' |
34 |
| - np.testing.assert_allclose(a, b, err_msg=err_msg, atol=atol, rtol=rtol) |
| 26 | +# tolerance for difference between MuJoCo and MJX constraint calculations, |
| 27 | +# mostly due to float precision |
| 28 | +_TOLERANCE = 5e-5 |
35 | 29 |
|
36 | 30 |
|
37 |
| -class ConstraintTest(parameterized.TestCase): |
| 31 | +def _assert_eq(a, b, name): |
| 32 | + tol = _TOLERANCE * 10 # avoid test noise |
| 33 | + err_msg = f'mismatch: {name}' |
| 34 | + np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol) |
38 | 35 |
|
39 |
| - @parameterized.parameters(enumerate(test_util.TEST_FILES)) |
40 |
| - def test_constraints(self, seed, fname): |
41 |
| - """Test constraints.""" |
42 |
| - np.random.seed(seed) |
43 | 36 |
|
44 |
| - # exclude convex.xml since convex contacts are not exactly equivalent |
45 |
| - if fname == 'convex.xml': |
46 |
| - return |
| 37 | +def _assert_attr_eq(a, b, attr): |
| 38 | + _assert_eq(getattr(a, attr), getattr(b, attr), attr) |
47 | 39 |
|
48 |
| - m = test_util.load_test_file(fname) |
49 |
| - d = mujoco.MjData(m) |
50 |
| - mx = mjx.device_put(m) |
51 |
| - dx = mjx.make_data(mx) |
52 |
| - |
53 |
| - forward_jit_fn = jax.jit(mjx.forward) |
54 |
| - |
55 |
| - # give the system a little kick to ensure we have non-identity rotations |
56 |
| - d.qvel = np.random.random(m.nv) |
57 |
| - for i in range(100): |
58 |
| - dx = dx.replace(qpos=jax.device_put(d.qpos), qvel=jax.device_put(d.qvel)) |
59 |
| - mujoco.mj_step(m, d) |
60 |
| - dx = forward_jit_fn(mx, dx) |
61 |
| - |
62 |
| - nnz_filter = dx.efc_J.any(axis=1) |
63 |
| - |
64 |
| - mj_efc_j = d.efc_J.reshape((-1, m.nv)) |
65 |
| - mjx_efc_j = dx.efc_J[nnz_filter] |
66 |
| - _assert_eq(mj_efc_j, mjx_efc_j, 'efc_J', i, fname) |
67 |
| - |
68 |
| - mjx_efc_d = dx.efc_D[nnz_filter] |
69 |
| - _assert_eq(d.efc_D, mjx_efc_d, 'efc_D', i, fname) |
70 |
| - |
71 |
| - mjx_efc_aref = dx.efc_aref[nnz_filter] |
72 |
| - _assert_eq(d.efc_aref, mjx_efc_aref, 'efc_aref', i, fname) |
73 |
| - |
74 |
| - mjx_efc_frictionloss = dx.efc_frictionloss[nnz_filter] |
75 |
| - _assert_eq( |
76 |
| - d.efc_frictionloss, |
77 |
| - mjx_efc_frictionloss, |
78 |
| - 'efc_frictionloss', |
79 |
| - i, |
80 |
| - fname, |
81 |
| - ) |
82 |
| - |
83 |
| - _JNT_RANGE = """ |
84 |
| - <mujoco> |
85 |
| - <worldbody> |
86 |
| - <body pos="0 0 1"> |
87 |
| - <joint type="slide" axis="1 0 0" range="-1.8 1.8" solreflimit=".08 1" |
88 |
| - damping="5e-4"/> |
89 |
| - <geom type="box" size="0.2 0.15 0.1" mass="1"/> |
90 |
| - <body> |
91 |
| - <joint axis="0 1 0" damping="2e-6"/> |
92 |
| - <geom type="capsule" fromto="0 0 0 0 0 1" size="0.045" mass=".1"/> |
93 |
| - </body> |
94 |
| - </body> |
95 |
| - </worldbody> |
96 |
| - </mujoco> |
97 |
| - """ |
98 |
| - |
99 |
| - def test_jnt_range(self): |
100 |
| - """Tests that mixed joint ranges are respected.""" |
101 |
| - # TODO(robotics-simulation): also test ball |
102 |
| - m = mujoco.MjModel.from_xml_string(self._JNT_RANGE) |
103 |
| - m.opt.solver = SolverType.CG.value |
104 |
| - d = mujoco.MjData(m) |
105 |
| - d.qpos = np.array([2.0, 15.0]) |
106 | 40 |
|
107 |
| - mx = mjx.device_put(m) |
108 |
| - dx = mjx.device_put(d) |
109 |
| - efc = jax.jit(constraint._instantiate_limit_slide_hinge)(mx, dx) |
| 41 | +class ConstraintTest(absltest.TestCase): |
110 | 42 |
|
111 |
| - # first joint is outside the joint range |
112 |
| - np.testing.assert_array_almost_equal(efc.J[0, 0], -1.0) |
| 43 | + def test_constraints(self): |
| 44 | + """Test constraints.""" |
| 45 | + m = test_util.load_test_file('constraints.xml') |
| 46 | + d = mujoco.MjData(m) |
| 47 | + mujoco.mj_step(m, d, 100) # at 100 steps mix of active/inactive constraints |
| 48 | + mujoco.mj_forward(m, d) |
| 49 | + mx = mjx.put_model(m) |
| 50 | + dx = mjx.put_data(m, d) |
113 | 51 |
|
114 |
| - # second joint has no range, so only one efc row |
115 |
| - self.assertEqual(efc.J.shape[0], 1) |
| 52 | + dx = mjx.make_constraint(mx, dx) |
| 53 | + nnz = dx.efc_J.any(axis=1) |
| 54 | + _assert_eq(d.efc_J, dx.efc_J[nnz].reshape(-1), 'efc_J') |
| 55 | + _assert_eq(d.efc_D, dx.efc_D[nnz], 'efc_D') |
| 56 | + _assert_eq(d.efc_aref, dx.efc_aref[nnz], 'efc_aref') |
| 57 | + _assert_eq(d.efc_frictionloss, dx.efc_frictionloss[nnz], 'efc_frictionloss') |
116 | 58 |
|
117 | 59 | def test_disable_refsafe(self):
|
118 |
| - m = test_util.load_test_file('ant.xml') |
| 60 | + m = test_util.load_test_file('constraints.xml') |
119 | 61 |
|
120 | 62 | timeconst = m.opt.timestep / 4.0 # timeconst < 2 * timestep
|
121 | 63 | solimp = jp.array([timeconst, 1.0])
|
122 | 64 | solref = jp.array([0.8, 0.99, 0.001, 0.2, 2])
|
123 | 65 | pos = jp.ones(3)
|
124 | 66 |
|
125 |
| - m.opt.disableflags = m.opt.disableflags | DisableBit.REFSAFE |
| 67 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.REFSAFE |
126 | 68 | mx = mjx.device_put(m)
|
127 | 69 | k, *_ = constraint._kbi(mx, solimp, solref, pos)
|
128 | 70 | self.assertEqual(k, 1 / (0.99**2 * timeconst**2))
|
129 | 71 |
|
130 |
| - m.opt.disableflags = m.opt.disableflags & ~DisableBit.REFSAFE |
131 |
| - mx = mjx.device_put(m) |
132 |
| - k, *_ = constraint._kbi(mx, solimp, solref, pos) |
133 |
| - self.assertEqual(k, 1 / (0.99**2 * (2 * m.opt.timestep) ** 2)) |
134 |
| - |
135 |
| - def test_disableconstraint(self): |
136 |
| - m = test_util.load_test_file('ant.xml') |
137 |
| - d = mujoco.MjData(m) |
138 |
| - |
139 |
| - m.opt.disableflags = m.opt.disableflags | DisableBit.CONSTRAINT |
140 |
| - mx, dx = mjx.device_put(m), mjx.device_put(d) |
141 |
| - dx = constraint.make_constraint(mx, dx) |
| 72 | + def test_disable_constraint(self): |
| 73 | + m = test_util.load_test_file('constraints.xml') |
| 74 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONSTRAINT |
| 75 | + ne, nf, nl, nc = mjx.count_constraints(m) |
| 76 | + self.assertEqual(ne, 0) |
| 77 | + self.assertEqual(nf, 0) |
| 78 | + self.assertEqual(nl, 0) |
| 79 | + self.assertEqual(nc, 0) |
| 80 | + dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m)) |
142 | 81 | self.assertEqual(dx.efc_J.shape[0], 0)
|
143 | 82 |
|
144 | 83 | def test_disable_equality(self):
|
145 |
| - m = test_util.load_test_file('equality.xml') |
146 |
| - d = mujoco.MjData(m) |
147 |
| - |
148 |
| - m.opt.disableflags = m.opt.disableflags | DisableBit.EQUALITY |
149 |
| - mx, dx = mjx.device_put(m), mjx.device_put(d) |
150 |
| - dx = constraint.make_constraint(mx, dx) |
151 |
| - self.assertEqual(dx.efc_J.shape[0], 0) |
| 84 | + m = test_util.load_test_file('constraints.xml') |
| 85 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.EQUALITY |
| 86 | + ne, nf, nl, nc = mjx.count_constraints(m) |
| 87 | + self.assertEqual(ne, 0) |
| 88 | + self.assertEqual(nf, 0) |
| 89 | + self.assertEqual(nl, 2) |
| 90 | + self.assertEqual(nc, 64) |
| 91 | + dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m)) |
| 92 | + self.assertEqual(dx.efc_J.shape[0], 66) # only joint range, contact |
152 | 93 |
|
153 | 94 | def test_disable_contact(self):
|
154 |
| - m = test_util.load_test_file('ant.xml') |
155 |
| - d = mujoco.MjData(m) |
156 |
| - d.qpos[2] = 0.0 |
157 |
| - mujoco.mj_forward(m, d) |
158 |
| - |
159 |
| - m.opt.disableflags = m.opt.disableflags & ~DisableBit.CONTACT |
160 |
| - mx, dx = mjx.device_put(m), mjx.device_put(d) |
161 |
| - dx = dx.tree_replace( |
162 |
| - {'contact.frame': dx.contact.frame.reshape((-1, 3, 3))} |
163 |
| - ) |
164 |
| - efc = constraint._instantiate_contact(mx, dx) |
165 |
| - self.assertIsNotNone(efc) |
166 |
| - |
167 |
| - m.opt.disableflags = m.opt.disableflags | DisableBit.CONTACT |
168 |
| - mx, dx = mjx.device_put(m), mjx.device_put(d) |
169 |
| - efc = constraint._instantiate_contact(mx, dx) |
170 |
| - self.assertIsNone(efc) |
| 95 | + m = test_util.load_test_file('constraints.xml') |
| 96 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONTACT |
| 97 | + ne, nf, nl, nc = mjx.count_constraints(m) |
| 98 | + self.assertEqual(ne, 10) |
| 99 | + self.assertEqual(nf, 0) |
| 100 | + self.assertEqual(nl, 2) |
| 101 | + self.assertEqual(nc, 0) |
| 102 | + dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m)) |
| 103 | + self.assertEqual(dx.efc_J.shape[0], 12) # only joint range, limit |
171 | 104 |
|
172 | 105 |
|
173 | 106 | if __name__ == '__main__':
|
|
0 commit comments