Skip to content

Commit 6c18705

Browse files
committed
parametrize test_vector, refactor sample_data fixture
1 parent 8820be8 commit 6c18705

File tree

2 files changed

+63
-234
lines changed

2 files changed

+63
-234
lines changed

movement/plot.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -174,54 +174,3 @@ def vector(
174174
)
175175

176176
return fig
177-
178-
179-
# # # FOR TESTING
180-
# from movement import sample_data
181-
# from movement.io import load_poses
182-
183-
# ds_path = sample_data.fetch_dataset_paths(
184-
# "SLEAP_single-mouse_EPM.analysis.h5"
185-
# )["poses"]
186-
# ds = load_poses.from_sleap_file(ds_path, fps=None)
187-
# # force time_unit = frames
188-
# frame_path = sample_data.fetch_dataset_paths(
189-
# "SLEAP_single-mouse_EPM.analysis.h5"
190-
# )["frame"]
191-
192-
# # head_trajectory = vector(
193-
# # ds,
194-
# # reference_points=("left_ear", "right_ear"),
195-
# # vector_point="snout",
196-
# # time_points=None,
197-
# # x_lim=None,
198-
# # y_lim=None,
199-
# # individual=0,
200-
# # )
201-
202-
# # plt.ion()
203-
# # head_trajectory.show()
204-
# # # user input to close window
205-
# # input("Press Enter to continue...")
206-
207-
208-
# # area of interest
209-
# xmin, ymin = 600.0, 665.0 # pixels
210-
# x_delta, y_delta = 125.0, 100.0 # pixels
211-
212-
# # time window
213-
# time_window = tuple(range(1650, 1671)) # tuple of frame numbers
214-
215-
# fig_head_vector = vector(
216-
# ds,
217-
# reference_points=["left_ear", "right_ear"],
218-
# vector_point="snout",
219-
# individual="individual_0",
220-
# x_lim=(xmin, xmin + x_delta),
221-
# y_lim=(ymin, ymin + y_delta),
222-
# time_points=time_window,
223-
# title="Zoomed in head vector (individual_0)",
224-
# )
225-
226-
# fig_head_vector.show()
227-
# input("Press Enter to continue...")

tests/test_unit/test_plot.py

Lines changed: 63 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,23 @@
1+
"""Unit tests for the plot module."""
2+
13
import numpy as np
24
import pytest
35
import xarray as xr
46

57
from movement.plot import vector
68

79

8-
@pytest.fixture
9-
def sample_data():
10-
"""Sample data for plot testing.
11-
12-
Data has three keypoints (left, centre, right) for one
13-
individual that moves in a straight line along the y-axis with a
14-
constant x-coordinate.
15-
16-
"""
10+
def create_sample_data(keypoints, positions):
11+
"""Create sample data for testing."""
1712
time_steps = 4
1813
individuals = ["individual_0"]
19-
keypoints = ["left", "centre", "right"]
2014
space = ["x", "y"]
21-
positions = {
22-
"left": {"x": -1, "y": np.arange(time_steps)},
23-
"centre": {"x": 0, "y": np.arange(time_steps)},
24-
"right": {"x": 1, "y": np.arange(time_steps)},
25-
}
2615

2716
time = np.arange(time_steps)
2817
position_data = np.zeros(
2918
(time_steps, len(space), len(keypoints), len(individuals))
3019
)
3120

32-
# Create x and y coordinates arrays
3321
x_coords = np.array([positions[key]["x"] for key in keypoints])
3422
y_coords = np.array([positions[key]["y"] for key in keypoints])
3523

@@ -63,194 +51,86 @@ def sample_data():
6351

6452

6553
@pytest.fixture
66-
def sample_data_quiver1():
54+
def sample_data():
6755
"""Sample data for plot testing.
6856
69-
Data has three keypoints (left, centre, right) for one
70-
individual that moves in a straight line along the y-axis with a
71-
constant x-coordinate.
57+
Data has six keypoints for one individual that moves in a straight line
58+
along the y-axis. All keypoints have a constant x-coordinate and move in
59+
steps of 1 along the y-axis.
7260
73-
"""
74-
time_steps = 4
75-
individuals = ["individual_0"]
76-
keypoints = ["left", "centre", "right"]
77-
space = ["x", "y"]
78-
positions = {
79-
"left": {"x": -1, "y": np.arange(time_steps)},
80-
"centre": {"x": 0, "y": np.arange(time_steps) + 1},
81-
"right": {"x": 1, "y": np.arange(time_steps)},
82-
}
83-
84-
time = np.arange(time_steps)
85-
position_data = np.zeros(
86-
(time_steps, len(space), len(keypoints), len(individuals))
87-
)
88-
89-
# Create x and y coordinates arrays
90-
x_coords = np.array([positions[key]["x"] for key in keypoints])
91-
y_coords = np.array([positions[key]["y"] for key in keypoints])
92-
93-
for i, _ in enumerate(keypoints):
94-
position_data[:, 0, i, 0] = x_coords[i] # x-coordinates
95-
position_data[:, 1, i, 0] = y_coords[i] # y-coordinates
96-
97-
confidence_data = np.full(
98-
(time_steps, len(keypoints), len(individuals)), 0.90
99-
)
100-
101-
ds = xr.Dataset(
102-
{
103-
"position": (
104-
["time", "space", "keypoints", "individuals"],
105-
position_data,
106-
),
107-
"confidence": (
108-
["time", "keypoints", "individuals"],
109-
confidence_data,
110-
),
111-
},
112-
coords={
113-
"time": time,
114-
"space": space,
115-
"keypoints": keypoints,
116-
"individuals": individuals,
117-
},
118-
)
119-
return ds
120-
121-
122-
@pytest.fixture
123-
def sample_data_quiver2():
124-
"""Sample data for plot testing.
125-
126-
Data has three keypoints (left, centre, right) for one
127-
individual that moves in a straight line along the y-axis with a
128-
constant x-coordinate.
61+
Keypoint starting positions:
62+
- left1: (-1, 0)
63+
- right1: (1, 0)
64+
- left2: (-2, 0)
65+
- right2: (2, 0)
66+
- centre0: (0, 0)
67+
- centre1: (0, 1)
12968
13069
"""
131-
time_steps = 4
132-
individuals = ["individual_0"]
133-
keypoints = ["left1", "right1", "left2", "right2"]
134-
space = ["x", "y"]
70+
keypoints = ["left1", "right1", "left2", "right2", "centre0", "centre1"]
13571
positions = {
136-
"left1": {"x": -1, "y": np.arange(time_steps)},
137-
"right1": {"x": 1, "y": np.arange(time_steps)},
138-
"left2": {"x": -1, "y": np.arange(time_steps)},
139-
"right2": {"x": 1, "y": np.arange(time_steps)},
72+
"left1": {"x": -1, "y": np.arange(4)},
73+
"right1": {"x": 1, "y": np.arange(4)},
74+
"left2": {"x": -2, "y": np.arange(4)},
75+
"right2": {"x": 2, "y": np.arange(4)},
76+
"centre0": {"x": 0, "y": np.arange(4)},
77+
"centre1": {"x": 0, "y": np.arange(4) + 1},
14078
}
141-
142-
time = np.arange(time_steps)
143-
position_data = np.zeros(
144-
(time_steps, len(space), len(keypoints), len(individuals))
145-
)
146-
147-
# Create x and y coordinates arrays
148-
x_coords = np.array([positions[key]["x"] for key in keypoints])
149-
y_coords = np.array([positions[key]["y"] for key in keypoints])
150-
151-
for i, _ in enumerate(keypoints):
152-
position_data[:, 0, i, 0] = x_coords[i] # x-coordinates
153-
position_data[:, 1, i, 0] = y_coords[i] # y-coordinates
154-
155-
confidence_data = np.full(
156-
(time_steps, len(keypoints), len(individuals)), 0.90
157-
)
158-
159-
ds = xr.Dataset(
160-
{
161-
"position": (
162-
["time", "space", "keypoints", "individuals"],
163-
position_data,
164-
),
165-
"confidence": (
166-
["time", "keypoints", "individuals"],
167-
confidence_data,
168-
),
169-
},
170-
coords={
171-
"time": time,
172-
"space": space,
173-
"keypoints": keypoints,
174-
"individuals": individuals,
175-
},
176-
)
177-
return ds
178-
179-
180-
def test_vector_no_quiver(sample_data):
181-
"""Test midpoint between left and right keypoints."""
79+
return create_sample_data(keypoints, positions)
80+
81+
82+
@pytest.mark.parametrize(
83+
["vector_point", "expected_u", "expected_v"],
84+
[
85+
pytest.param(
86+
"centre0",
87+
[0.0, 0.0, 0.0, 0.0],
88+
[0.0, 0.0, 0.0, 0.0],
89+
id="u = 0, v = 0",
90+
),
91+
pytest.param(
92+
"centre1",
93+
[0.0, 0.0, 0.0, 0.0],
94+
[1.0, 1.0, 1.0, 1.0],
95+
id="u = 0, v = 1",
96+
),
97+
pytest.param(
98+
"right2",
99+
[2.0, 2.0, 2.0, 2.0],
100+
[0.0, 0.0, 0.0, 0.0],
101+
id="u = 2, v = 0",
102+
),
103+
pytest.param(
104+
"left2",
105+
[-2.0, -2.0, -2.0, -2.0],
106+
[0.0, 0.0, 0.0, 0.0],
107+
id="u = -2, v = 0",
108+
),
109+
],
110+
)
111+
def test_vector(sample_data, vector_point, expected_u, expected_v):
112+
"""Test vector plot.
113+
114+
Test the vector plot for different vector points. The U and V values
115+
represent the horizontal (x) and vertical (y) displacement of the vector
116+
117+
The reference points are "left1" and "right1".
118+
"""
182119
vector_fig = vector(
183120
sample_data,
184-
reference_points=["left", "right"],
185-
vector_point="centre",
186-
)
187-
188-
quiver = vector_fig.axes[0].collections[-1]
189-
190-
# Extract the X, Y, U, V data
191-
x = quiver.X
192-
y = quiver.Y
193-
u = quiver.U
194-
v = quiver.V
195-
196-
expected_x = np.array([0.0, 0.0, 0.0, 0.0])
197-
expected_y = np.array([0.0, 1.0, 2.0, 3.0])
198-
expected_u = np.array([0.0, 0.0, 0.0, 0.0])
199-
expected_v = np.array([0.0, 0.0, 0.0, 0.0])
200-
201-
assert np.allclose(x, expected_x)
202-
assert np.allclose(y, expected_y)
203-
assert np.allclose(u, expected_u)
204-
assert np.allclose(v, expected_v)
205-
206-
207-
def test_vector_quiver2(sample_data_quiver2):
208-
"""Test midpoint between left and right keypoints."""
209-
vector_fig = vector(
210-
sample_data_quiver2,
211121
reference_points=["left1", "right1"],
212-
vector_point="right2",
213-
)
214-
215-
quiver = vector_fig.axes[0].collections[-1]
216-
217-
# Extract the X, Y, U, V data
218-
x = quiver.X
219-
y = quiver.Y
220-
u = quiver.U
221-
v = quiver.V
222-
223-
expected_x = np.array([0.0, 0.0, 0.0, 0.0])
224-
expected_y = np.array([0.0, 1.0, 2.0, 3.0])
225-
expected_u = np.array([1.0, 1.0, 1.0, 1.0])
226-
expected_v = np.array([0.0, 0.0, 0.0, 0.0])
227-
228-
assert np.allclose(x, expected_x)
229-
assert np.allclose(y, expected_y)
230-
assert np.allclose(u, expected_u)
231-
assert np.allclose(v, expected_v)
232-
233-
234-
def test_vector_quiver1(sample_data_quiver1):
235-
"""Test midpoint between left and right keypoints."""
236-
vector_fig = vector(
237-
sample_data_quiver1,
238-
reference_points=["left", "right"],
239-
vector_point="centre",
122+
vector_point=vector_point,
240123
)
241124

242125
quiver = vector_fig.axes[0].collections[-1]
243126

244-
# Extract the X, Y, U, V data
245127
x = quiver.X
246128
y = quiver.Y
247129
u = quiver.U
248130
v = quiver.V
249131

250132
expected_x = np.array([0.0, 0.0, 0.0, 0.0])
251133
expected_y = np.array([0.0, 1.0, 2.0, 3.0])
252-
expected_u = np.array([0.0, 0.0, 0.0, 0.0])
253-
expected_v = np.array([1.0, 1.0, 1.0, 1.0])
254134

255135
assert np.allclose(x, expected_x)
256136
assert np.allclose(y, expected_y)

0 commit comments

Comments
 (0)