|
9 | 9 | from napari_matplotlib import (
|
10 | 10 | FeaturesScatterWidget,
|
11 | 11 | HistogramWidget,
|
| 12 | + ScatterWidget, |
12 | 13 | SliceWidget,
|
13 | 14 | )
|
14 | 15 | from napari_matplotlib.base import NapariMPLWidget
|
|
18 | 19 | )
|
19 | 20 |
|
20 | 21 |
|
21 |
| -@pytest.mark.parametrize("widget_cls", [HistogramWidget, SliceWidget]) |
| 22 | +@pytest.mark.parametrize( |
| 23 | + "widget_cls, n_layers", |
| 24 | + [(HistogramWidget, 1), (SliceWidget, 1), (ScatterWidget, 2)], |
| 25 | +) |
22 | 26 | def test_change_one_layer(
|
23 |
| - make_napari_viewer, brain_data, astronaut_data, widget_cls |
| 27 | + make_napari_viewer, |
| 28 | + brain_data, |
| 29 | + astronaut_data, |
| 30 | + widget_cls, |
| 31 | + n_layers, |
24 | 32 | ):
|
25 | 33 | """
|
26 | 34 | Test all widgets that take one layer as input to make sure the plot changes
|
27 | 35 | when the napari layer selection changes.
|
28 | 36 | """
|
29 | 37 | viewer = make_napari_viewer()
|
30 |
| - assert_one_layer_plot_changes( |
31 |
| - viewer, widget_cls, brain_data, astronaut_data |
32 |
| - ) |
33 |
| - |
34 | 38 |
|
35 |
| -def assert_one_layer_plot_changes( |
36 |
| - viewer: Viewer, |
37 |
| - widget_cls: Type[NapariMPLWidget], |
38 |
| - data1: Tuple[npt.NDArray[np.generic], Dict[str, Any]], |
39 |
| - data2: Tuple[npt.NDArray[np.generic], Dict[str, Any]], |
40 |
| -) -> None: |
41 |
| - """ |
42 |
| - When the selected layer is changed, make sure the plot generated |
43 |
| - by `widget_cls` also changes. |
44 |
| - """ |
45 | 39 | widget = widget_cls(viewer)
|
46 |
| - viewer.add_image(data1[0], **data1[1]) |
47 |
| - viewer.add_image(data2[0], **data2[1]) |
48 |
| - assert_plot_changes(viewer, widget) |
| 40 | + # Add n copies of two different datasets |
| 41 | + for _ in range(n_layers): |
| 42 | + viewer.add_image(brain_data[0], **brain_data[1]) |
| 43 | + for _ in range(n_layers): |
| 44 | + viewer.add_image(astronaut_data[0], **astronaut_data[1]) |
| 45 | + |
| 46 | + assert len(viewer.layers) == 2 * n_layers |
| 47 | + assert_plot_changes(viewer, widget, n_layers=n_layers) |
49 | 48 |
|
50 | 49 |
|
51 | 50 | @pytest.mark.parametrize("widget_cls", [FeaturesScatterWidget])
|
@@ -76,26 +75,35 @@ def assert_features_plot_changes(
|
76 | 75 | name: data + 1 for name, data in data[1]["features"].items()
|
77 | 76 | }
|
78 | 77 | viewer.add_points(data[0], **data[1])
|
79 |
| - assert_plot_changes(viewer, widget) |
| 78 | + assert_plot_changes(viewer, widget, n_layers=1) |
80 | 79 |
|
81 | 80 |
|
82 |
| -def assert_plot_changes(viewer: Viewer, widget: NapariMPLWidget) -> None: |
| 81 | +def assert_plot_changes( |
| 82 | + viewer: Viewer, widget: NapariMPLWidget, *, n_layers: int |
| 83 | +) -> None: |
83 | 84 | """
|
84 | 85 | Assert that a widget plot changes when the layer selection
|
85 |
| - is changed. The passed viewer must already have two layers |
| 86 | + is changed. The passed viewer must already have (2 * n_layers) layers |
86 | 87 | loaded.
|
87 | 88 | """
|
88 |
| - # Select first layer |
| 89 | + # Select first layer(s) |
89 | 90 | viewer.layers.selection.clear()
|
90 |
| - viewer.layers.selection.add(viewer.layers[0]) |
| 91 | + |
| 92 | + for i in range(n_layers): |
| 93 | + viewer.layers.selection.add(viewer.layers[i]) |
| 94 | + assert len(viewer.layers.selection) == n_layers |
91 | 95 | fig1 = deepcopy(widget.figure)
|
92 | 96 |
|
93 |
| - # Re-selecting first layer should produce identical plot |
| 97 | + # Re-selecting first layer(s) should produce identical plot |
94 | 98 | viewer.layers.selection.clear()
|
95 |
| - viewer.layers.selection.add(viewer.layers[0]) |
| 99 | + for i in range(n_layers): |
| 100 | + viewer.layers.selection.add(viewer.layers[i]) |
| 101 | + assert len(viewer.layers.selection) == n_layers |
96 | 102 | assert_figures_equal(widget.figure, fig1)
|
97 | 103 |
|
98 |
| - # Plotting the second layer should produce a different plot |
| 104 | + # Plotting the second layer(s) should produce a different plot |
99 | 105 | viewer.layers.selection.clear()
|
100 |
| - viewer.layers.selection.add(viewer.layers[1]) |
| 106 | + for i in range(n_layers): |
| 107 | + viewer.layers.selection.add(viewer.layers[n_layers + i]) |
| 108 | + assert len(viewer.layers.selection) == n_layers |
101 | 109 | assert_figures_not_equal(widget.figure, fig1)
|
0 commit comments