|
| 1 | +"""Unit tests for the plot module.""" |
| 2 | + |
1 | 3 | import numpy as np
|
2 | 4 | import pytest
|
3 | 5 | import xarray as xr
|
4 | 6 |
|
5 | 7 | from movement.plot import vector
|
6 | 8 |
|
7 | 9 |
|
8 |
| -@pytest.fixture |
9 |
| -def sample_data(): |
10 |
| - """Sample data for plot testing. |
11 |
| -
|
12 |
| - Data has three keypoints (left, centre, right) for one |
13 |
| - individual that moves in a straight line along the y-axis with a |
14 |
| - constant x-coordinate. |
15 |
| -
|
16 |
| - """ |
| 10 | +def create_sample_data(keypoints, positions): |
| 11 | + """Create sample data for testing.""" |
17 | 12 | time_steps = 4
|
18 | 13 | individuals = ["individual_0"]
|
19 |
| - keypoints = ["left", "centre", "right"] |
20 | 14 | space = ["x", "y"]
|
21 |
| - positions = { |
22 |
| - "left": {"x": -1, "y": np.arange(time_steps)}, |
23 |
| - "centre": {"x": 0, "y": np.arange(time_steps)}, |
24 |
| - "right": {"x": 1, "y": np.arange(time_steps)}, |
25 |
| - } |
26 | 15 |
|
27 | 16 | time = np.arange(time_steps)
|
28 | 17 | position_data = np.zeros(
|
29 | 18 | (time_steps, len(space), len(keypoints), len(individuals))
|
30 | 19 | )
|
31 | 20 |
|
32 |
| - # Create x and y coordinates arrays |
33 | 21 | x_coords = np.array([positions[key]["x"] for key in keypoints])
|
34 | 22 | y_coords = np.array([positions[key]["y"] for key in keypoints])
|
35 | 23 |
|
@@ -63,194 +51,86 @@ def sample_data():
|
63 | 51 |
|
64 | 52 |
|
65 | 53 | @pytest.fixture
|
66 |
| -def sample_data_quiver1(): |
| 54 | +def sample_data(): |
67 | 55 | """Sample data for plot testing.
|
68 | 56 |
|
69 |
| - Data has three keypoints (left, centre, right) for one |
70 |
| - individual that moves in a straight line along the y-axis with a |
71 |
| - constant x-coordinate. |
| 57 | + Data has six keypoints for one individual that moves in a straight line |
| 58 | + along the y-axis. All keypoints have a constant x-coordinate and move in |
| 59 | + steps of 1 along the y-axis. |
72 | 60 |
|
73 |
| - """ |
74 |
| - time_steps = 4 |
75 |
| - individuals = ["individual_0"] |
76 |
| - keypoints = ["left", "centre", "right"] |
77 |
| - space = ["x", "y"] |
78 |
| - positions = { |
79 |
| - "left": {"x": -1, "y": np.arange(time_steps)}, |
80 |
| - "centre": {"x": 0, "y": np.arange(time_steps) + 1}, |
81 |
| - "right": {"x": 1, "y": np.arange(time_steps)}, |
82 |
| - } |
83 |
| - |
84 |
| - time = np.arange(time_steps) |
85 |
| - position_data = np.zeros( |
86 |
| - (time_steps, len(space), len(keypoints), len(individuals)) |
87 |
| - ) |
88 |
| - |
89 |
| - # Create x and y coordinates arrays |
90 |
| - x_coords = np.array([positions[key]["x"] for key in keypoints]) |
91 |
| - y_coords = np.array([positions[key]["y"] for key in keypoints]) |
92 |
| - |
93 |
| - for i, _ in enumerate(keypoints): |
94 |
| - position_data[:, 0, i, 0] = x_coords[i] # x-coordinates |
95 |
| - position_data[:, 1, i, 0] = y_coords[i] # y-coordinates |
96 |
| - |
97 |
| - confidence_data = np.full( |
98 |
| - (time_steps, len(keypoints), len(individuals)), 0.90 |
99 |
| - ) |
100 |
| - |
101 |
| - ds = xr.Dataset( |
102 |
| - { |
103 |
| - "position": ( |
104 |
| - ["time", "space", "keypoints", "individuals"], |
105 |
| - position_data, |
106 |
| - ), |
107 |
| - "confidence": ( |
108 |
| - ["time", "keypoints", "individuals"], |
109 |
| - confidence_data, |
110 |
| - ), |
111 |
| - }, |
112 |
| - coords={ |
113 |
| - "time": time, |
114 |
| - "space": space, |
115 |
| - "keypoints": keypoints, |
116 |
| - "individuals": individuals, |
117 |
| - }, |
118 |
| - ) |
119 |
| - return ds |
120 |
| - |
121 |
| - |
122 |
| -@pytest.fixture |
123 |
| -def sample_data_quiver2(): |
124 |
| - """Sample data for plot testing. |
125 |
| -
|
126 |
| - Data has three keypoints (left, centre, right) for one |
127 |
| - individual that moves in a straight line along the y-axis with a |
128 |
| - constant x-coordinate. |
| 61 | + Keypoint starting positions: |
| 62 | + - left1: (-1, 0) |
| 63 | + - right1: (1, 0) |
| 64 | + - left2: (-2, 0) |
| 65 | + - right2: (2, 0) |
| 66 | + - centre0: (0, 0) |
| 67 | + - centre1: (0, 1) |
129 | 68 |
|
130 | 69 | """
|
131 |
| - time_steps = 4 |
132 |
| - individuals = ["individual_0"] |
133 |
| - keypoints = ["left1", "right1", "left2", "right2"] |
134 |
| - space = ["x", "y"] |
| 70 | + keypoints = ["left1", "right1", "left2", "right2", "centre0", "centre1"] |
135 | 71 | positions = {
|
136 |
| - "left1": {"x": -1, "y": np.arange(time_steps)}, |
137 |
| - "right1": {"x": 1, "y": np.arange(time_steps)}, |
138 |
| - "left2": {"x": -1, "y": np.arange(time_steps)}, |
139 |
| - "right2": {"x": 1, "y": np.arange(time_steps)}, |
| 72 | + "left1": {"x": -1, "y": np.arange(4)}, |
| 73 | + "right1": {"x": 1, "y": np.arange(4)}, |
| 74 | + "left2": {"x": -2, "y": np.arange(4)}, |
| 75 | + "right2": {"x": 2, "y": np.arange(4)}, |
| 76 | + "centre0": {"x": 0, "y": np.arange(4)}, |
| 77 | + "centre1": {"x": 0, "y": np.arange(4) + 1}, |
140 | 78 | }
|
141 |
| - |
142 |
| - time = np.arange(time_steps) |
143 |
| - position_data = np.zeros( |
144 |
| - (time_steps, len(space), len(keypoints), len(individuals)) |
145 |
| - ) |
146 |
| - |
147 |
| - # Create x and y coordinates arrays |
148 |
| - x_coords = np.array([positions[key]["x"] for key in keypoints]) |
149 |
| - y_coords = np.array([positions[key]["y"] for key in keypoints]) |
150 |
| - |
151 |
| - for i, _ in enumerate(keypoints): |
152 |
| - position_data[:, 0, i, 0] = x_coords[i] # x-coordinates |
153 |
| - position_data[:, 1, i, 0] = y_coords[i] # y-coordinates |
154 |
| - |
155 |
| - confidence_data = np.full( |
156 |
| - (time_steps, len(keypoints), len(individuals)), 0.90 |
157 |
| - ) |
158 |
| - |
159 |
| - ds = xr.Dataset( |
160 |
| - { |
161 |
| - "position": ( |
162 |
| - ["time", "space", "keypoints", "individuals"], |
163 |
| - position_data, |
164 |
| - ), |
165 |
| - "confidence": ( |
166 |
| - ["time", "keypoints", "individuals"], |
167 |
| - confidence_data, |
168 |
| - ), |
169 |
| - }, |
170 |
| - coords={ |
171 |
| - "time": time, |
172 |
| - "space": space, |
173 |
| - "keypoints": keypoints, |
174 |
| - "individuals": individuals, |
175 |
| - }, |
176 |
| - ) |
177 |
| - return ds |
178 |
| - |
179 |
| - |
180 |
| -def test_vector_no_quiver(sample_data): |
181 |
| - """Test midpoint between left and right keypoints.""" |
| 79 | + return create_sample_data(keypoints, positions) |
| 80 | + |
| 81 | + |
| 82 | +@pytest.mark.parametrize( |
| 83 | + ["vector_point", "expected_u", "expected_v"], |
| 84 | + [ |
| 85 | + pytest.param( |
| 86 | + "centre0", |
| 87 | + [0.0, 0.0, 0.0, 0.0], |
| 88 | + [0.0, 0.0, 0.0, 0.0], |
| 89 | + id="u = 0, v = 0", |
| 90 | + ), |
| 91 | + pytest.param( |
| 92 | + "centre1", |
| 93 | + [0.0, 0.0, 0.0, 0.0], |
| 94 | + [1.0, 1.0, 1.0, 1.0], |
| 95 | + id="u = 0, v = 1", |
| 96 | + ), |
| 97 | + pytest.param( |
| 98 | + "right2", |
| 99 | + [2.0, 2.0, 2.0, 2.0], |
| 100 | + [0.0, 0.0, 0.0, 0.0], |
| 101 | + id="u = 2, v = 0", |
| 102 | + ), |
| 103 | + pytest.param( |
| 104 | + "left2", |
| 105 | + [-2.0, -2.0, -2.0, -2.0], |
| 106 | + [0.0, 0.0, 0.0, 0.0], |
| 107 | + id="u = -2, v = 0", |
| 108 | + ), |
| 109 | + ], |
| 110 | +) |
| 111 | +def test_vector(sample_data, vector_point, expected_u, expected_v): |
| 112 | + """Test vector plot. |
| 113 | +
|
| 114 | + Test the vector plot for different vector points. The U and V values |
| 115 | + represent the horizontal (x) and vertical (y) displacement of the vector |
| 116 | +
|
| 117 | + The reference points are "left1" and "right1". |
| 118 | + """ |
182 | 119 | vector_fig = vector(
|
183 | 120 | sample_data,
|
184 |
| - reference_points=["left", "right"], |
185 |
| - vector_point="centre", |
186 |
| - ) |
187 |
| - |
188 |
| - quiver = vector_fig.axes[0].collections[-1] |
189 |
| - |
190 |
| - # Extract the X, Y, U, V data |
191 |
| - x = quiver.X |
192 |
| - y = quiver.Y |
193 |
| - u = quiver.U |
194 |
| - v = quiver.V |
195 |
| - |
196 |
| - expected_x = np.array([0.0, 0.0, 0.0, 0.0]) |
197 |
| - expected_y = np.array([0.0, 1.0, 2.0, 3.0]) |
198 |
| - expected_u = np.array([0.0, 0.0, 0.0, 0.0]) |
199 |
| - expected_v = np.array([0.0, 0.0, 0.0, 0.0]) |
200 |
| - |
201 |
| - assert np.allclose(x, expected_x) |
202 |
| - assert np.allclose(y, expected_y) |
203 |
| - assert np.allclose(u, expected_u) |
204 |
| - assert np.allclose(v, expected_v) |
205 |
| - |
206 |
| - |
207 |
| -def test_vector_quiver2(sample_data_quiver2): |
208 |
| - """Test midpoint between left and right keypoints.""" |
209 |
| - vector_fig = vector( |
210 |
| - sample_data_quiver2, |
211 | 121 | reference_points=["left1", "right1"],
|
212 |
| - vector_point="right2", |
213 |
| - ) |
214 |
| - |
215 |
| - quiver = vector_fig.axes[0].collections[-1] |
216 |
| - |
217 |
| - # Extract the X, Y, U, V data |
218 |
| - x = quiver.X |
219 |
| - y = quiver.Y |
220 |
| - u = quiver.U |
221 |
| - v = quiver.V |
222 |
| - |
223 |
| - expected_x = np.array([0.0, 0.0, 0.0, 0.0]) |
224 |
| - expected_y = np.array([0.0, 1.0, 2.0, 3.0]) |
225 |
| - expected_u = np.array([1.0, 1.0, 1.0, 1.0]) |
226 |
| - expected_v = np.array([0.0, 0.0, 0.0, 0.0]) |
227 |
| - |
228 |
| - assert np.allclose(x, expected_x) |
229 |
| - assert np.allclose(y, expected_y) |
230 |
| - assert np.allclose(u, expected_u) |
231 |
| - assert np.allclose(v, expected_v) |
232 |
| - |
233 |
| - |
234 |
| -def test_vector_quiver1(sample_data_quiver1): |
235 |
| - """Test midpoint between left and right keypoints.""" |
236 |
| - vector_fig = vector( |
237 |
| - sample_data_quiver1, |
238 |
| - reference_points=["left", "right"], |
239 |
| - vector_point="centre", |
| 122 | + vector_point=vector_point, |
240 | 123 | )
|
241 | 124 |
|
242 | 125 | quiver = vector_fig.axes[0].collections[-1]
|
243 | 126 |
|
244 |
| - # Extract the X, Y, U, V data |
245 | 127 | x = quiver.X
|
246 | 128 | y = quiver.Y
|
247 | 129 | u = quiver.U
|
248 | 130 | v = quiver.V
|
249 | 131 |
|
250 | 132 | expected_x = np.array([0.0, 0.0, 0.0, 0.0])
|
251 | 133 | expected_y = np.array([0.0, 1.0, 2.0, 3.0])
|
252 |
| - expected_u = np.array([0.0, 0.0, 0.0, 0.0]) |
253 |
| - expected_v = np.array([1.0, 1.0, 1.0, 1.0]) |
254 | 134 |
|
255 | 135 | assert np.allclose(x, expected_x)
|
256 | 136 | assert np.allclose(y, expected_y)
|
|
0 commit comments