Skip to content

Commit

Permalink
Fixed quaternion multiply for batch and support batched quaternion2rpy (
Browse files Browse the repository at this point in the history
#351)

* Fixed quaternion multiply for batch
  • Loading branch information
iory authored Feb 8, 2024
1 parent 1004463 commit 2b72f84
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
54 changes: 38 additions & 16 deletions skrobot/coordinates/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def _wrap_axis(axis):
return convert_to_axis_vector(axis)


def to_numpy_array(arr):
if isinstance(arr, (list, tuple)):
return np.array(arr)
elif isinstance(arr, np.ndarray):
return arr
else:
raise TypeError("Input must be a list, tuple, or numpy.ndarray.")


def _check_valid_rotation(rotation):
"""Checks that the given rotation matrix is valid."""
rotation = np.array(rotation)
Expand Down Expand Up @@ -898,15 +907,28 @@ def quaternion2rpy(q):
(array([ 0. , -0. , 3.14159265]),
array([3.14159265, 3.14159265, 0. ]))
"""
roll = atan2(
2 * q[2] * q[3] + 2 * q[0] * q[1],
q[3] ** 2 - q[2] ** 2 - q[1] ** 2 + q[0] ** 2)
pitch = -asin(
2 * q[1] * q[3] - 2 * q[0] * q[2])
yaw = atan2(
2 * q[1] * q[2] + 2 * q[0] * q[3],
q[1] ** 2 + q[0] ** 2 - q[3] ** 2 - q[2] ** 2)
rpy = np.array([yaw, pitch, roll])
q = to_numpy_array(q)
if q.ndim == 1:
roll = atan2(
2 * q[2] * q[3] + 2 * q[0] * q[1],
q[3] ** 2 - q[2] ** 2 - q[1] ** 2 + q[0] ** 2)
pitch = -asin(
2 * q[1] * q[3] - 2 * q[0] * q[2])
yaw = atan2(
2 * q[1] * q[2] + 2 * q[0] * q[3],
q[1] ** 2 + q[0] ** 2 - q[3] ** 2 - q[2] ** 2)
rpy = np.array([yaw, pitch, roll])
elif q.ndim == 2:
roll = np.arctan2(
2 * q[:, 2] * q[:, 3] + 2 * q[:, 0] * q[:, 1],
q[:, 3] ** 2 - q[:, 2] ** 2 - q[:, 1] ** 2 + q[:, 0] ** 2)
pitch = -np.sin(
2 * q[:, 1] * q[:, 3] - 2 * q[:, 0] * q[:, 2])
yaw = np.arctan2(
2 * q[:, 1] * q[:, 2] + 2 * q[:, 0] * q[:, 3],
q[:, 1] ** 2 + q[:, 0] ** 2 - q[:, 3] ** 2 - q[:, 2] ** 2)
rpy = np.concatenate([yaw[:, None], pitch[:, None], roll[:, None]],
axis=1)
return rpy, np.pi - rpy


Expand Down Expand Up @@ -1201,7 +1223,8 @@ def quaternion_multiply(quaternion1, quaternion0):
x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0,
-x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0,
x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0), dtype=np.float64)
return np.transpose(np.squeeze(result))
return np.transpose(np.squeeze(result)).reshape(
quaternion1.shape[0], 4)
else:
raise ValueError

Expand Down Expand Up @@ -1351,13 +1374,12 @@ def quaternion_distance(q1, q2, absolute=False):
>>> np.rad2deg(distance)
60.00000021683236
"""
q = quaternion_multiply(
quaternion_inverse(q1), q2)
w = q[0]
q1 = to_numpy_array(q1)
q2 = to_numpy_array(q2)
dot = np.clip(np.sum(q1 * q2, q1.ndim - 1), -1, 1)
diff_theta = 2 * np.arccos(dot)
if absolute is True:
w = abs(q[0])
diff_theta = 2.0 * np.arctan2(
np.linalg.norm(q[1:]), w)
diff_theta = np.abs(diff_theta)
return diff_theta


Expand Down
5 changes: 5 additions & 0 deletions tests/skrobot_tests/coordinates_tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ def test_quaternion_distance(self):
self.assertEqual(quaternion_distance(np.ones(4), np.ones(4)),
0.0)

# batch
testing.assert_equal(
quaternion_distance(np.ones((10, 4)),
np.ones((10, 4))), np.zeros(10))

def test_quaternion_norm(self):
q = np.array([1, 0, 0, 0])
self.assertEqual(quaternion_norm(q), 1.0)
Expand Down

0 comments on commit 2b72f84

Please sign in to comment.