Skip to content

Commit 6b9b676

Browse files
committed
STY: Use norm(), matmul and list comprehensions
1 parent 3f30ab5 commit 6b9b676

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

nibabel/tests/test_quaternions.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,27 @@ def gen_vec(dtype):
2929

3030

3131
# Example rotations
32-
eg_rots = []
33-
params = (-pi, pi, pi / 2)
34-
zs = np.arange(*params)
35-
ys = np.arange(*params)
36-
xs = np.arange(*params)
37-
for z in zs:
38-
for y in ys:
39-
for x in xs:
40-
eg_rots.append(nea.euler2mat(z, y, x))
32+
eg_rots = [
33+
nea.euler2mat(z, y, x)
34+
for z in np.arange(-pi, pi, pi / 2)
35+
for y in np.arange(-pi, pi, pi / 2)
36+
for x in np.arange(-pi, pi, pi / 2)
37+
]
38+
4139
# Example quaternions (from rotations)
42-
eg_quats = []
43-
for M in eg_rots:
44-
eg_quats.append(nq.mat2quat(M))
40+
eg_quats = [nq.mat2quat(M) for M in eg_rots]
4541
# M, quaternion pairs
4642
eg_pairs = list(zip(eg_rots, eg_quats))
4743

4844
# Set of arbitrary unit quaternions
49-
unit_quats = set()
50-
params = range(-2, 3)
51-
for w in params:
52-
for x in params:
53-
for y in params:
54-
for z in params:
55-
q = (w, x, y, z)
56-
Nq = np.sqrt(np.dot(q, q))
57-
if not Nq == 0:
58-
q = tuple([e / Nq for e in q])
59-
unit_quats.add(q)
45+
unit_quats = set(
46+
tuple(norm(np.r_[w, x, y, z]))
47+
for w in range(-2, 3)
48+
for x in range(-2, 3)
49+
for y in range(-2, 3)
50+
for z in range(-2, 3)
51+
if (w, x, y, z) != (0, 0, 0, 0)
52+
)
6053

6154

6255
def test_fillpos():
@@ -184,7 +177,7 @@ def test_norm():
184177
def test_mult(M1, q1, M2, q2):
185178
# Test that quaternion * same as matrix *
186179
q21 = nq.mult(q2, q1)
187-
assert_array_almost_equal, np.dot(M2, M1), nq.quat2mat(q21)
180+
assert_array_almost_equal, M2 @ M1, nq.quat2mat(q21)
188181

189182

190183
@pytest.mark.parametrize('M, q', eg_pairs)
@@ -205,7 +198,7 @@ def test_eye():
205198
@pytest.mark.parametrize('M, q', eg_pairs)
206199
def test_qrotate(vec, M, q):
207200
vdash = nq.rotate_vector(vec, q)
208-
vM = np.dot(M, vec)
201+
vM = M @ vec
209202
assert_array_almost_equal(vdash, vM)
210203

211204

@@ -238,6 +231,6 @@ def test_angle_axis():
238231
nq.nearly_equivalent(q, q2)
239232
aa_mat = nq.angle_axis2mat(theta, vec)
240233
assert_array_almost_equal(aa_mat, M)
241-
unit_vec = vec / np.sqrt(vec.dot(vec))
234+
unit_vec = norm(vec)
242235
aa_mat2 = nq.angle_axis2mat(theta, unit_vec, is_normalized=True)
243236
assert_array_almost_equal(aa_mat2, M)

0 commit comments

Comments
 (0)