Skip to content

Commit 41ce88c

Browse files
committed
TEST: Check that quaternions.fillpositive does not augment unit vectors
1 parent fc9a1c1 commit 41ce88c

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

nibabel/tests/test_quaternions.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616
from .. import eulerangles as nea
1717
from .. import quaternions as nq
1818

19+
20+
def norm(vec):
21+
# Return unit vector with same orientation as input vector
22+
return vec / np.sqrt(vec @ vec)
23+
24+
25+
def gen_vec(dtype):
26+
# Generate random 3-vector in [-1, 1]^3
27+
rand = np.random.default_rng()
28+
return rand.uniform(low=-1.0, high=1.0, size=(3,)).astype(dtype)
29+
30+
1931
# Example rotations
2032
eg_rots = []
2133
params = (-pi, pi, pi / 2)
@@ -69,6 +81,53 @@ def test_fillpos():
6981
assert wxyz[0] == 0.0
7082

7183

84+
@pytest.mark.parametrize('dtype', ('f4', 'f8'))
85+
def test_fillpositive_plus_minus_epsilon(dtype):
86+
# Deterministic test for fillpositive threshold
87+
# We are trying to fill (x, y, z) with a w such that |(w, x, y, z)| == 1
88+
# If |(x, y, z)| is slightly off one, w should still be 0
89+
nptype = np.dtype(dtype).type
90+
91+
# Obviously, |(x, y, z)| == 1
92+
baseline = np.array([0, 0, 1], dtype=dtype)
93+
94+
# Obviously, |(x, y, z)| ~ 1
95+
plus = baseline * nptype(1 + np.finfo(dtype).eps)
96+
minus = baseline * nptype(1 - np.finfo(dtype).eps)
97+
98+
assert nq.fillpositive(plus)[0] == 0.0
99+
assert nq.fillpositive(minus)[0] == 0.0
100+
101+
102+
@pytest.mark.parametrize('dtype', ('f4', 'f8'))
103+
def test_fillpositive_simulated_error(dtype):
104+
# Nondeterministic test for fillpositive threshold
105+
# Create random vectors, normalize to unit length, and count on floating point
106+
# error to result in magnitudes larger/smaller than one
107+
# This is to simulate cases where a unit quaternion with w == 0 would be encoded
108+
# as xyz with small error, and we want to recover the w of 0
109+
110+
# Permit 1 epsilon per value (default, but make explicit here)
111+
w2_thresh = 3 * -np.finfo(dtype).eps
112+
113+
pos_error = neg_error = False
114+
for _ in range(50):
115+
xyz = norm(gen_vec(dtype))
116+
117+
wxyz = nq.fillpositive(xyz, w2_thresh)
118+
assert wxyz[0] == 0.0
119+
120+
# Verify that we exercise the threshold
121+
magnitude = xyz @ xyz
122+
if magnitude < 1:
123+
pos_error = True
124+
elif magnitude > 1:
125+
neg_error = True
126+
127+
assert pos_error, 'Did not encounter a case where 1 - |xyz| > 0'
128+
assert neg_error, 'Did not encounter a case where 1 - |xyz| < 0'
129+
130+
72131
def test_conjugate():
73132
# Takes sequence
74133
cq = nq.conjugate((1, 0, 0, 0))

0 commit comments

Comments
 (0)