Skip to content

Commit 8820be8

Browse files
committed
Add plot vector tests
1 parent 35ccf82 commit 8820be8

File tree

1 file changed

+258
-0
lines changed

1 file changed

+258
-0
lines changed

tests/test_unit/test_plot.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
5+
from movement.plot import vector
6+
7+
8+
@pytest.fixture
9+
def sample_data():
10+
"""Sample data for plot testing.
11+
12+
Data has three keypoints (left, centre, right) for one
13+
individual that moves in a straight line along the y-axis with a
14+
constant x-coordinate.
15+
16+
"""
17+
time_steps = 4
18+
individuals = ["individual_0"]
19+
keypoints = ["left", "centre", "right"]
20+
space = ["x", "y"]
21+
positions = {
22+
"left": {"x": -1, "y": np.arange(time_steps)},
23+
"centre": {"x": 0, "y": np.arange(time_steps)},
24+
"right": {"x": 1, "y": np.arange(time_steps)},
25+
}
26+
27+
time = np.arange(time_steps)
28+
position_data = np.zeros(
29+
(time_steps, len(space), len(keypoints), len(individuals))
30+
)
31+
32+
# Create x and y coordinates arrays
33+
x_coords = np.array([positions[key]["x"] for key in keypoints])
34+
y_coords = np.array([positions[key]["y"] for key in keypoints])
35+
36+
for i, _ in enumerate(keypoints):
37+
position_data[:, 0, i, 0] = x_coords[i] # x-coordinates
38+
position_data[:, 1, i, 0] = y_coords[i] # y-coordinates
39+
40+
confidence_data = np.full(
41+
(time_steps, len(keypoints), len(individuals)), 0.90
42+
)
43+
44+
ds = xr.Dataset(
45+
{
46+
"position": (
47+
["time", "space", "keypoints", "individuals"],
48+
position_data,
49+
),
50+
"confidence": (
51+
["time", "keypoints", "individuals"],
52+
confidence_data,
53+
),
54+
},
55+
coords={
56+
"time": time,
57+
"space": space,
58+
"keypoints": keypoints,
59+
"individuals": individuals,
60+
},
61+
)
62+
return ds
63+
64+
65+
@pytest.fixture
66+
def sample_data_quiver1():
67+
"""Sample data for plot testing.
68+
69+
Data has three keypoints (left, centre, right) for one
70+
individual that moves in a straight line along the y-axis with a
71+
constant x-coordinate.
72+
73+
"""
74+
time_steps = 4
75+
individuals = ["individual_0"]
76+
keypoints = ["left", "centre", "right"]
77+
space = ["x", "y"]
78+
positions = {
79+
"left": {"x": -1, "y": np.arange(time_steps)},
80+
"centre": {"x": 0, "y": np.arange(time_steps) + 1},
81+
"right": {"x": 1, "y": np.arange(time_steps)},
82+
}
83+
84+
time = np.arange(time_steps)
85+
position_data = np.zeros(
86+
(time_steps, len(space), len(keypoints), len(individuals))
87+
)
88+
89+
# Create x and y coordinates arrays
90+
x_coords = np.array([positions[key]["x"] for key in keypoints])
91+
y_coords = np.array([positions[key]["y"] for key in keypoints])
92+
93+
for i, _ in enumerate(keypoints):
94+
position_data[:, 0, i, 0] = x_coords[i] # x-coordinates
95+
position_data[:, 1, i, 0] = y_coords[i] # y-coordinates
96+
97+
confidence_data = np.full(
98+
(time_steps, len(keypoints), len(individuals)), 0.90
99+
)
100+
101+
ds = xr.Dataset(
102+
{
103+
"position": (
104+
["time", "space", "keypoints", "individuals"],
105+
position_data,
106+
),
107+
"confidence": (
108+
["time", "keypoints", "individuals"],
109+
confidence_data,
110+
),
111+
},
112+
coords={
113+
"time": time,
114+
"space": space,
115+
"keypoints": keypoints,
116+
"individuals": individuals,
117+
},
118+
)
119+
return ds
120+
121+
122+
@pytest.fixture
123+
def sample_data_quiver2():
124+
"""Sample data for plot testing.
125+
126+
Data has three keypoints (left, centre, right) for one
127+
individual that moves in a straight line along the y-axis with a
128+
constant x-coordinate.
129+
130+
"""
131+
time_steps = 4
132+
individuals = ["individual_0"]
133+
keypoints = ["left1", "right1", "left2", "right2"]
134+
space = ["x", "y"]
135+
positions = {
136+
"left1": {"x": -1, "y": np.arange(time_steps)},
137+
"right1": {"x": 1, "y": np.arange(time_steps)},
138+
"left2": {"x": -1, "y": np.arange(time_steps)},
139+
"right2": {"x": 1, "y": np.arange(time_steps)},
140+
}
141+
142+
time = np.arange(time_steps)
143+
position_data = np.zeros(
144+
(time_steps, len(space), len(keypoints), len(individuals))
145+
)
146+
147+
# Create x and y coordinates arrays
148+
x_coords = np.array([positions[key]["x"] for key in keypoints])
149+
y_coords = np.array([positions[key]["y"] for key in keypoints])
150+
151+
for i, _ in enumerate(keypoints):
152+
position_data[:, 0, i, 0] = x_coords[i] # x-coordinates
153+
position_data[:, 1, i, 0] = y_coords[i] # y-coordinates
154+
155+
confidence_data = np.full(
156+
(time_steps, len(keypoints), len(individuals)), 0.90
157+
)
158+
159+
ds = xr.Dataset(
160+
{
161+
"position": (
162+
["time", "space", "keypoints", "individuals"],
163+
position_data,
164+
),
165+
"confidence": (
166+
["time", "keypoints", "individuals"],
167+
confidence_data,
168+
),
169+
},
170+
coords={
171+
"time": time,
172+
"space": space,
173+
"keypoints": keypoints,
174+
"individuals": individuals,
175+
},
176+
)
177+
return ds
178+
179+
180+
def test_vector_no_quiver(sample_data):
181+
"""Test midpoint between left and right keypoints."""
182+
vector_fig = vector(
183+
sample_data,
184+
reference_points=["left", "right"],
185+
vector_point="centre",
186+
)
187+
188+
quiver = vector_fig.axes[0].collections[-1]
189+
190+
# Extract the X, Y, U, V data
191+
x = quiver.X
192+
y = quiver.Y
193+
u = quiver.U
194+
v = quiver.V
195+
196+
expected_x = np.array([0.0, 0.0, 0.0, 0.0])
197+
expected_y = np.array([0.0, 1.0, 2.0, 3.0])
198+
expected_u = np.array([0.0, 0.0, 0.0, 0.0])
199+
expected_v = np.array([0.0, 0.0, 0.0, 0.0])
200+
201+
assert np.allclose(x, expected_x)
202+
assert np.allclose(y, expected_y)
203+
assert np.allclose(u, expected_u)
204+
assert np.allclose(v, expected_v)
205+
206+
207+
def test_vector_quiver2(sample_data_quiver2):
208+
"""Test midpoint between left and right keypoints."""
209+
vector_fig = vector(
210+
sample_data_quiver2,
211+
reference_points=["left1", "right1"],
212+
vector_point="right2",
213+
)
214+
215+
quiver = vector_fig.axes[0].collections[-1]
216+
217+
# Extract the X, Y, U, V data
218+
x = quiver.X
219+
y = quiver.Y
220+
u = quiver.U
221+
v = quiver.V
222+
223+
expected_x = np.array([0.0, 0.0, 0.0, 0.0])
224+
expected_y = np.array([0.0, 1.0, 2.0, 3.0])
225+
expected_u = np.array([1.0, 1.0, 1.0, 1.0])
226+
expected_v = np.array([0.0, 0.0, 0.0, 0.0])
227+
228+
assert np.allclose(x, expected_x)
229+
assert np.allclose(y, expected_y)
230+
assert np.allclose(u, expected_u)
231+
assert np.allclose(v, expected_v)
232+
233+
234+
def test_vector_quiver1(sample_data_quiver1):
235+
"""Test midpoint between left and right keypoints."""
236+
vector_fig = vector(
237+
sample_data_quiver1,
238+
reference_points=["left", "right"],
239+
vector_point="centre",
240+
)
241+
242+
quiver = vector_fig.axes[0].collections[-1]
243+
244+
# Extract the X, Y, U, V data
245+
x = quiver.X
246+
y = quiver.Y
247+
u = quiver.U
248+
v = quiver.V
249+
250+
expected_x = np.array([0.0, 0.0, 0.0, 0.0])
251+
expected_y = np.array([0.0, 1.0, 2.0, 3.0])
252+
expected_u = np.array([0.0, 0.0, 0.0, 0.0])
253+
expected_v = np.array([1.0, 1.0, 1.0, 1.0])
254+
255+
assert np.allclose(x, expected_x)
256+
assert np.allclose(y, expected_y)
257+
assert np.allclose(u, expected_u)
258+
assert np.allclose(v, expected_v)

0 commit comments

Comments
 (0)