|
| 1 | +"""Unit tests for the plot module.""" |
| 2 | + |
1 | 3 | import numpy as np |
2 | 4 | import pytest |
3 | 5 | import xarray as xr |
4 | 6 |
|
5 | 7 | from movement.plot import vector |
6 | 8 |
|
7 | 9 |
|
8 | | -@pytest.fixture |
9 | | -def sample_data(): |
10 | | - """Sample data for plot testing. |
11 | | -
|
12 | | - Data has three keypoints (left, centre, right) for one |
13 | | - individual that moves in a straight line along the y-axis with a |
14 | | - constant x-coordinate. |
15 | | -
|
16 | | - """ |
| 10 | +def create_sample_data(keypoints, positions): |
| 11 | + """Create sample data for testing.""" |
17 | 12 | time_steps = 4 |
18 | 13 | individuals = ["individual_0"] |
19 | | - keypoints = ["left", "centre", "right"] |
20 | 14 | space = ["x", "y"] |
21 | | - positions = { |
22 | | - "left": {"x": -1, "y": np.arange(time_steps)}, |
23 | | - "centre": {"x": 0, "y": np.arange(time_steps)}, |
24 | | - "right": {"x": 1, "y": np.arange(time_steps)}, |
25 | | - } |
26 | 15 |
|
27 | 16 | time = np.arange(time_steps) |
28 | 17 | position_data = np.zeros( |
29 | 18 | (time_steps, len(space), len(keypoints), len(individuals)) |
30 | 19 | ) |
31 | 20 |
|
32 | | - # Create x and y coordinates arrays |
33 | 21 | x_coords = np.array([positions[key]["x"] for key in keypoints]) |
34 | 22 | y_coords = np.array([positions[key]["y"] for key in keypoints]) |
35 | 23 |
|
@@ -63,194 +51,86 @@ def sample_data(): |
63 | 51 |
|
64 | 52 |
|
65 | 53 | @pytest.fixture |
66 | | -def sample_data_quiver1(): |
| 54 | +def sample_data(): |
67 | 55 | """Sample data for plot testing. |
68 | 56 |
|
69 | | - Data has three keypoints (left, centre, right) for one |
70 | | - individual that moves in a straight line along the y-axis with a |
71 | | - constant x-coordinate. |
| 57 | + Data has six keypoints for one individual that moves in a straight line |
| 58 | + along the y-axis. All keypoints have a constant x-coordinate and move in |
| 59 | + steps of 1 along the y-axis. |
72 | 60 |
|
73 | | - """ |
74 | | - time_steps = 4 |
75 | | - individuals = ["individual_0"] |
76 | | - keypoints = ["left", "centre", "right"] |
77 | | - space = ["x", "y"] |
78 | | - positions = { |
79 | | - "left": {"x": -1, "y": np.arange(time_steps)}, |
80 | | - "centre": {"x": 0, "y": np.arange(time_steps) + 1}, |
81 | | - "right": {"x": 1, "y": np.arange(time_steps)}, |
82 | | - } |
83 | | - |
84 | | - time = np.arange(time_steps) |
85 | | - position_data = np.zeros( |
86 | | - (time_steps, len(space), len(keypoints), len(individuals)) |
87 | | - ) |
88 | | - |
89 | | - # Create x and y coordinates arrays |
90 | | - x_coords = np.array([positions[key]["x"] for key in keypoints]) |
91 | | - y_coords = np.array([positions[key]["y"] for key in keypoints]) |
92 | | - |
93 | | - for i, _ in enumerate(keypoints): |
94 | | - position_data[:, 0, i, 0] = x_coords[i] # x-coordinates |
95 | | - position_data[:, 1, i, 0] = y_coords[i] # y-coordinates |
96 | | - |
97 | | - confidence_data = np.full( |
98 | | - (time_steps, len(keypoints), len(individuals)), 0.90 |
99 | | - ) |
100 | | - |
101 | | - ds = xr.Dataset( |
102 | | - { |
103 | | - "position": ( |
104 | | - ["time", "space", "keypoints", "individuals"], |
105 | | - position_data, |
106 | | - ), |
107 | | - "confidence": ( |
108 | | - ["time", "keypoints", "individuals"], |
109 | | - confidence_data, |
110 | | - ), |
111 | | - }, |
112 | | - coords={ |
113 | | - "time": time, |
114 | | - "space": space, |
115 | | - "keypoints": keypoints, |
116 | | - "individuals": individuals, |
117 | | - }, |
118 | | - ) |
119 | | - return ds |
120 | | - |
121 | | - |
122 | | -@pytest.fixture |
123 | | -def sample_data_quiver2(): |
124 | | - """Sample data for plot testing. |
125 | | -
|
126 | | - Data has three keypoints (left, centre, right) for one |
127 | | - individual that moves in a straight line along the y-axis with a |
128 | | - constant x-coordinate. |
| 61 | + Keypoint starting positions: |
| 62 | + - left1: (-1, 0) |
| 63 | + - right1: (1, 0) |
| 64 | + - left2: (-2, 0) |
| 65 | + - right2: (2, 0) |
| 66 | + - centre0: (0, 0) |
| 67 | + - centre1: (0, 1) |
129 | 68 |
|
130 | 69 | """ |
131 | | - time_steps = 4 |
132 | | - individuals = ["individual_0"] |
133 | | - keypoints = ["left1", "right1", "left2", "right2"] |
134 | | - space = ["x", "y"] |
| 70 | + keypoints = ["left1", "right1", "left2", "right2", "centre0", "centre1"] |
135 | 71 | positions = { |
136 | | - "left1": {"x": -1, "y": np.arange(time_steps)}, |
137 | | - "right1": {"x": 1, "y": np.arange(time_steps)}, |
138 | | - "left2": {"x": -1, "y": np.arange(time_steps)}, |
139 | | - "right2": {"x": 1, "y": np.arange(time_steps)}, |
| 72 | + "left1": {"x": -1, "y": np.arange(4)}, |
| 73 | + "right1": {"x": 1, "y": np.arange(4)}, |
| 74 | + "left2": {"x": -2, "y": np.arange(4)}, |
| 75 | + "right2": {"x": 2, "y": np.arange(4)}, |
| 76 | + "centre0": {"x": 0, "y": np.arange(4)}, |
| 77 | + "centre1": {"x": 0, "y": np.arange(4) + 1}, |
140 | 78 | } |
141 | | - |
142 | | - time = np.arange(time_steps) |
143 | | - position_data = np.zeros( |
144 | | - (time_steps, len(space), len(keypoints), len(individuals)) |
145 | | - ) |
146 | | - |
147 | | - # Create x and y coordinates arrays |
148 | | - x_coords = np.array([positions[key]["x"] for key in keypoints]) |
149 | | - y_coords = np.array([positions[key]["y"] for key in keypoints]) |
150 | | - |
151 | | - for i, _ in enumerate(keypoints): |
152 | | - position_data[:, 0, i, 0] = x_coords[i] # x-coordinates |
153 | | - position_data[:, 1, i, 0] = y_coords[i] # y-coordinates |
154 | | - |
155 | | - confidence_data = np.full( |
156 | | - (time_steps, len(keypoints), len(individuals)), 0.90 |
157 | | - ) |
158 | | - |
159 | | - ds = xr.Dataset( |
160 | | - { |
161 | | - "position": ( |
162 | | - ["time", "space", "keypoints", "individuals"], |
163 | | - position_data, |
164 | | - ), |
165 | | - "confidence": ( |
166 | | - ["time", "keypoints", "individuals"], |
167 | | - confidence_data, |
168 | | - ), |
169 | | - }, |
170 | | - coords={ |
171 | | - "time": time, |
172 | | - "space": space, |
173 | | - "keypoints": keypoints, |
174 | | - "individuals": individuals, |
175 | | - }, |
176 | | - ) |
177 | | - return ds |
178 | | - |
179 | | - |
180 | | -def test_vector_no_quiver(sample_data): |
181 | | - """Test midpoint between left and right keypoints.""" |
| 79 | + return create_sample_data(keypoints, positions) |
| 80 | + |
| 81 | + |
| 82 | +@pytest.mark.parametrize( |
| 83 | + ["vector_point", "expected_u", "expected_v"], |
| 84 | + [ |
| 85 | + pytest.param( |
| 86 | + "centre0", |
| 87 | + [0.0, 0.0, 0.0, 0.0], |
| 88 | + [0.0, 0.0, 0.0, 0.0], |
| 89 | + id="u = 0, v = 0", |
| 90 | + ), |
| 91 | + pytest.param( |
| 92 | + "centre1", |
| 93 | + [0.0, 0.0, 0.0, 0.0], |
| 94 | + [1.0, 1.0, 1.0, 1.0], |
| 95 | + id="u = 0, v = 1", |
| 96 | + ), |
| 97 | + pytest.param( |
| 98 | + "right2", |
| 99 | + [2.0, 2.0, 2.0, 2.0], |
| 100 | + [0.0, 0.0, 0.0, 0.0], |
| 101 | + id="u = 2, v = 0", |
| 102 | + ), |
| 103 | + pytest.param( |
| 104 | + "left2", |
| 105 | + [-2.0, -2.0, -2.0, -2.0], |
| 106 | + [0.0, 0.0, 0.0, 0.0], |
| 107 | + id="u = -2, v = 0", |
| 108 | + ), |
| 109 | + ], |
| 110 | +) |
| 111 | +def test_vector(sample_data, vector_point, expected_u, expected_v): |
| 112 | + """Test vector plot. |
| 113 | +
|
| 114 | + Test the vector plot for different vector points. The U and V values |
| 115 | + represent the horizontal (x) and vertical (y) displacement of the vector |
| 116 | +
|
| 117 | + The reference points are "left1" and "right1". |
| 118 | + """ |
182 | 119 | vector_fig = vector( |
183 | 120 | sample_data, |
184 | | - reference_points=["left", "right"], |
185 | | - vector_point="centre", |
186 | | - ) |
187 | | - |
188 | | - quiver = vector_fig.axes[0].collections[-1] |
189 | | - |
190 | | - # Extract the X, Y, U, V data |
191 | | - x = quiver.X |
192 | | - y = quiver.Y |
193 | | - u = quiver.U |
194 | | - v = quiver.V |
195 | | - |
196 | | - expected_x = np.array([0.0, 0.0, 0.0, 0.0]) |
197 | | - expected_y = np.array([0.0, 1.0, 2.0, 3.0]) |
198 | | - expected_u = np.array([0.0, 0.0, 0.0, 0.0]) |
199 | | - expected_v = np.array([0.0, 0.0, 0.0, 0.0]) |
200 | | - |
201 | | - assert np.allclose(x, expected_x) |
202 | | - assert np.allclose(y, expected_y) |
203 | | - assert np.allclose(u, expected_u) |
204 | | - assert np.allclose(v, expected_v) |
205 | | - |
206 | | - |
207 | | -def test_vector_quiver2(sample_data_quiver2): |
208 | | - """Test midpoint between left and right keypoints.""" |
209 | | - vector_fig = vector( |
210 | | - sample_data_quiver2, |
211 | 121 | reference_points=["left1", "right1"], |
212 | | - vector_point="right2", |
213 | | - ) |
214 | | - |
215 | | - quiver = vector_fig.axes[0].collections[-1] |
216 | | - |
217 | | - # Extract the X, Y, U, V data |
218 | | - x = quiver.X |
219 | | - y = quiver.Y |
220 | | - u = quiver.U |
221 | | - v = quiver.V |
222 | | - |
223 | | - expected_x = np.array([0.0, 0.0, 0.0, 0.0]) |
224 | | - expected_y = np.array([0.0, 1.0, 2.0, 3.0]) |
225 | | - expected_u = np.array([1.0, 1.0, 1.0, 1.0]) |
226 | | - expected_v = np.array([0.0, 0.0, 0.0, 0.0]) |
227 | | - |
228 | | - assert np.allclose(x, expected_x) |
229 | | - assert np.allclose(y, expected_y) |
230 | | - assert np.allclose(u, expected_u) |
231 | | - assert np.allclose(v, expected_v) |
232 | | - |
233 | | - |
234 | | -def test_vector_quiver1(sample_data_quiver1): |
235 | | - """Test midpoint between left and right keypoints.""" |
236 | | - vector_fig = vector( |
237 | | - sample_data_quiver1, |
238 | | - reference_points=["left", "right"], |
239 | | - vector_point="centre", |
| 122 | + vector_point=vector_point, |
240 | 123 | ) |
241 | 124 |
|
242 | 125 | quiver = vector_fig.axes[0].collections[-1] |
243 | 126 |
|
244 | | - # Extract the X, Y, U, V data |
245 | 127 | x = quiver.X |
246 | 128 | y = quiver.Y |
247 | 129 | u = quiver.U |
248 | 130 | v = quiver.V |
249 | 131 |
|
250 | 132 | expected_x = np.array([0.0, 0.0, 0.0, 0.0]) |
251 | 133 | expected_y = np.array([0.0, 1.0, 2.0, 3.0]) |
252 | | - expected_u = np.array([0.0, 0.0, 0.0, 0.0]) |
253 | | - expected_v = np.array([1.0, 1.0, 1.0, 1.0]) |
254 | 134 |
|
255 | 135 | assert np.allclose(x, expected_x) |
256 | 136 | assert np.allclose(y, expected_y) |
|
0 commit comments