|
15 | 15 | """Tests for constraint functions.""" |
16 | 16 |
|
17 | 17 | from absl.testing import absltest |
18 | | -from absl.testing import parameterized |
19 | | -import jax |
20 | 18 | from jax import numpy as jp |
21 | 19 | import mujoco |
22 | 20 | from mujoco import mjx |
23 | 21 | from mujoco.mjx._src import constraint |
24 | 22 | from mujoco.mjx._src import test_util |
25 | | -# pylint: disable=g-importing-member |
26 | | -from mujoco.mjx._src.types import DisableBit |
27 | | -from mujoco.mjx._src.types import SolverType |
28 | | -# pylint: enable=g-importing-member |
29 | 23 | import numpy as np |
30 | 24 |
|
31 | 25 |
|
32 | | -def _assert_eq(a, b, name, step, fname, atol=5e-3, rtol=5e-3): |
33 | | - err_msg = f'mismatch: {name} at step {step} in {fname}' |
34 | | - np.testing.assert_allclose(a, b, err_msg=err_msg, atol=atol, rtol=rtol) |
| 26 | +# tolerance for difference between MuJoCo and MJX constraint calculations, |
| 27 | +# mostly due to float precision |
| 28 | +_TOLERANCE = 5e-5 |
35 | 29 |
|
36 | 30 |
|
37 | | -class ConstraintTest(parameterized.TestCase): |
| 31 | +def _assert_eq(a, b, name): |
| 32 | + tol = _TOLERANCE * 10 # avoid test noise |
| 33 | + err_msg = f'mismatch: {name}' |
| 34 | + np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol) |
38 | 35 |
|
39 | | - @parameterized.parameters(enumerate(test_util.TEST_FILES)) |
40 | | - def test_constraints(self, seed, fname): |
41 | | - """Test constraints.""" |
42 | | - np.random.seed(seed) |
43 | 36 |
|
44 | | - # exclude convex.xml since convex contacts are not exactly equivalent |
45 | | - if fname == 'convex.xml': |
46 | | - return |
| 37 | +def _assert_attr_eq(a, b, attr): |
| 38 | + _assert_eq(getattr(a, attr), getattr(b, attr), attr) |
47 | 39 |
|
48 | | - m = test_util.load_test_file(fname) |
49 | | - d = mujoco.MjData(m) |
50 | | - mx = mjx.device_put(m) |
51 | | - dx = mjx.make_data(mx) |
52 | | - |
53 | | - forward_jit_fn = jax.jit(mjx.forward) |
54 | | - |
55 | | - # give the system a little kick to ensure we have non-identity rotations |
56 | | - d.qvel = np.random.random(m.nv) |
57 | | - for i in range(100): |
58 | | - dx = dx.replace(qpos=jax.device_put(d.qpos), qvel=jax.device_put(d.qvel)) |
59 | | - mujoco.mj_step(m, d) |
60 | | - dx = forward_jit_fn(mx, dx) |
61 | | - |
62 | | - nnz_filter = dx.efc_J.any(axis=1) |
63 | | - |
64 | | - mj_efc_j = d.efc_J.reshape((-1, m.nv)) |
65 | | - mjx_efc_j = dx.efc_J[nnz_filter] |
66 | | - _assert_eq(mj_efc_j, mjx_efc_j, 'efc_J', i, fname) |
67 | | - |
68 | | - mjx_efc_d = dx.efc_D[nnz_filter] |
69 | | - _assert_eq(d.efc_D, mjx_efc_d, 'efc_D', i, fname) |
70 | | - |
71 | | - mjx_efc_aref = dx.efc_aref[nnz_filter] |
72 | | - _assert_eq(d.efc_aref, mjx_efc_aref, 'efc_aref', i, fname) |
73 | | - |
74 | | - mjx_efc_frictionloss = dx.efc_frictionloss[nnz_filter] |
75 | | - _assert_eq( |
76 | | - d.efc_frictionloss, |
77 | | - mjx_efc_frictionloss, |
78 | | - 'efc_frictionloss', |
79 | | - i, |
80 | | - fname, |
81 | | - ) |
82 | | - |
83 | | - _JNT_RANGE = """ |
84 | | - <mujoco> |
85 | | - <worldbody> |
86 | | - <body pos="0 0 1"> |
87 | | - <joint type="slide" axis="1 0 0" range="-1.8 1.8" solreflimit=".08 1" |
88 | | - damping="5e-4"/> |
89 | | - <geom type="box" size="0.2 0.15 0.1" mass="1"/> |
90 | | - <body> |
91 | | - <joint axis="0 1 0" damping="2e-6"/> |
92 | | - <geom type="capsule" fromto="0 0 0 0 0 1" size="0.045" mass=".1"/> |
93 | | - </body> |
94 | | - </body> |
95 | | - </worldbody> |
96 | | - </mujoco> |
97 | | - """ |
98 | | - |
99 | | - def test_jnt_range(self): |
100 | | - """Tests that mixed joint ranges are respected.""" |
101 | | - # TODO(robotics-simulation): also test ball |
102 | | - m = mujoco.MjModel.from_xml_string(self._JNT_RANGE) |
103 | | - m.opt.solver = SolverType.CG.value |
104 | | - d = mujoco.MjData(m) |
105 | | - d.qpos = np.array([2.0, 15.0]) |
106 | 40 |
|
107 | | - mx = mjx.device_put(m) |
108 | | - dx = mjx.device_put(d) |
109 | | - efc = jax.jit(constraint._instantiate_limit_slide_hinge)(mx, dx) |
| 41 | +class ConstraintTest(absltest.TestCase): |
110 | 42 |
|
111 | | - # first joint is outside the joint range |
112 | | - np.testing.assert_array_almost_equal(efc.J[0, 0], -1.0) |
| 43 | + def test_constraints(self): |
| 44 | + """Test constraints.""" |
| 45 | + m = test_util.load_test_file('constraints.xml') |
| 46 | + d = mujoco.MjData(m) |
| 47 | + mujoco.mj_step(m, d, 100) # at 100 steps mix of active/inactive constraints |
| 48 | + mujoco.mj_forward(m, d) |
| 49 | + mx = mjx.put_model(m) |
| 50 | + dx = mjx.put_data(m, d) |
113 | 51 |
|
114 | | - # second joint has no range, so only one efc row |
115 | | - self.assertEqual(efc.J.shape[0], 1) |
| 52 | + dx = mjx.make_constraint(mx, dx) |
| 53 | + nnz = dx.efc_J.any(axis=1) |
| 54 | + _assert_eq(d.efc_J, dx.efc_J[nnz].reshape(-1), 'efc_J') |
| 55 | + _assert_eq(d.efc_D, dx.efc_D[nnz], 'efc_D') |
| 56 | + _assert_eq(d.efc_aref, dx.efc_aref[nnz], 'efc_aref') |
| 57 | + _assert_eq(d.efc_frictionloss, dx.efc_frictionloss[nnz], 'efc_frictionloss') |
116 | 58 |
|
117 | 59 | def test_disable_refsafe(self): |
118 | | - m = test_util.load_test_file('ant.xml') |
| 60 | + m = test_util.load_test_file('constraints.xml') |
119 | 61 |
|
120 | 62 | timeconst = m.opt.timestep / 4.0 # timeconst < 2 * timestep |
121 | 63 | solimp = jp.array([timeconst, 1.0]) |
122 | 64 | solref = jp.array([0.8, 0.99, 0.001, 0.2, 2]) |
123 | 65 | pos = jp.ones(3) |
124 | 66 |
|
125 | | - m.opt.disableflags = m.opt.disableflags | DisableBit.REFSAFE |
| 67 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.REFSAFE |
126 | 68 | mx = mjx.device_put(m) |
127 | 69 | k, *_ = constraint._kbi(mx, solimp, solref, pos) |
128 | 70 | self.assertEqual(k, 1 / (0.99**2 * timeconst**2)) |
129 | 71 |
|
130 | | - m.opt.disableflags = m.opt.disableflags & ~DisableBit.REFSAFE |
131 | | - mx = mjx.device_put(m) |
132 | | - k, *_ = constraint._kbi(mx, solimp, solref, pos) |
133 | | - self.assertEqual(k, 1 / (0.99**2 * (2 * m.opt.timestep) ** 2)) |
134 | | - |
135 | | - def test_disableconstraint(self): |
136 | | - m = test_util.load_test_file('ant.xml') |
137 | | - d = mujoco.MjData(m) |
138 | | - |
139 | | - m.opt.disableflags = m.opt.disableflags | DisableBit.CONSTRAINT |
140 | | - mx, dx = mjx.device_put(m), mjx.device_put(d) |
141 | | - dx = constraint.make_constraint(mx, dx) |
| 72 | + def test_disable_constraint(self): |
| 73 | + m = test_util.load_test_file('constraints.xml') |
| 74 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONSTRAINT |
| 75 | + ne, nf, nl, nc = mjx.count_constraints(m) |
| 76 | + self.assertEqual(ne, 0) |
| 77 | + self.assertEqual(nf, 0) |
| 78 | + self.assertEqual(nl, 0) |
| 79 | + self.assertEqual(nc, 0) |
| 80 | + dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m)) |
142 | 81 | self.assertEqual(dx.efc_J.shape[0], 0) |
143 | 82 |
|
144 | 83 | def test_disable_equality(self): |
145 | | - m = test_util.load_test_file('equality.xml') |
146 | | - d = mujoco.MjData(m) |
147 | | - |
148 | | - m.opt.disableflags = m.opt.disableflags | DisableBit.EQUALITY |
149 | | - mx, dx = mjx.device_put(m), mjx.device_put(d) |
150 | | - dx = constraint.make_constraint(mx, dx) |
151 | | - self.assertEqual(dx.efc_J.shape[0], 0) |
| 84 | + m = test_util.load_test_file('constraints.xml') |
| 85 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.EQUALITY |
| 86 | + ne, nf, nl, nc = mjx.count_constraints(m) |
| 87 | + self.assertEqual(ne, 0) |
| 88 | + self.assertEqual(nf, 0) |
| 89 | + self.assertEqual(nl, 2) |
| 90 | + self.assertEqual(nc, 64) |
| 91 | + dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m)) |
| 92 | + self.assertEqual(dx.efc_J.shape[0], 66) # only joint range, contact |
152 | 93 |
|
153 | 94 | def test_disable_contact(self): |
154 | | - m = test_util.load_test_file('ant.xml') |
155 | | - d = mujoco.MjData(m) |
156 | | - d.qpos[2] = 0.0 |
157 | | - mujoco.mj_forward(m, d) |
158 | | - |
159 | | - m.opt.disableflags = m.opt.disableflags & ~DisableBit.CONTACT |
160 | | - mx, dx = mjx.device_put(m), mjx.device_put(d) |
161 | | - dx = dx.tree_replace( |
162 | | - {'contact.frame': dx.contact.frame.reshape((-1, 3, 3))} |
163 | | - ) |
164 | | - efc = constraint._instantiate_contact(mx, dx) |
165 | | - self.assertIsNotNone(efc) |
166 | | - |
167 | | - m.opt.disableflags = m.opt.disableflags | DisableBit.CONTACT |
168 | | - mx, dx = mjx.device_put(m), mjx.device_put(d) |
169 | | - efc = constraint._instantiate_contact(mx, dx) |
170 | | - self.assertIsNone(efc) |
| 95 | + m = test_util.load_test_file('constraints.xml') |
| 96 | + m.opt.disableflags = m.opt.disableflags | mjx.DisableBit.CONTACT |
| 97 | + ne, nf, nl, nc = mjx.count_constraints(m) |
| 98 | + self.assertEqual(ne, 10) |
| 99 | + self.assertEqual(nf, 0) |
| 100 | + self.assertEqual(nl, 2) |
| 101 | + self.assertEqual(nc, 0) |
| 102 | + dx = constraint.make_constraint(mjx.put_model(m), mjx.make_data(m)) |
| 103 | + self.assertEqual(dx.efc_J.shape[0], 12) # only joint range, limit |
171 | 104 |
|
172 | 105 |
|
173 | 106 | if __name__ == '__main__': |
|
0 commit comments