Skip to content

Commit 043ffb3

Browse files
committed
Refactor callback handling
1 parent 234227d commit 043ffb3

File tree

3 files changed

+74
-55
lines changed

3 files changed

+74
-55
lines changed

Diff for: src/napari_matplotlib/base.py

+61-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
FigureCanvas,
55
NavigationToolbar2QT,
66
)
7-
from matplotlib.figure import Figure
87
from qtpy.QtWidgets import QVBoxLayout, QWidget
98

109
mpl.rc("axes", edgecolor="white")
@@ -15,6 +14,7 @@
1514

1615
mpl.rc("xtick", color="white")
1716
mpl.rc("ytick", color="white")
17+
1818
__all__ = ["NapariMPLWidget"]
1919

2020

@@ -23,8 +23,12 @@ class NapariMPLWidget(QWidget):
2323
Base widget that can be embedded as a napari widget and contains a
2424
Matplotlib canvas.
2525
26-
This creates a single Figure, and sub-classes should implement logic for
27-
drawing on that Figure.
26+
This creates a single FigureCanvas, which contains a single Figure.
27+
28+
This class also handles callbacks to automatically update figures when
29+
the layer selection or z-step is changed in the napari viewer. To take
30+
advantage of this sub-classes should implement the ``clear()`` and
31+
``draw()`` methods.
2832
2933
Attributes
3034
----------
@@ -34,13 +38,14 @@ class NapariMPLWidget(QWidget):
3438
Matplotlib figure.
3539
canvas : matplotlib.backends.backend_qt5agg.FigureCanvas
3640
Matplotlib canvas.
41+
layers : `list`
42+
List of currently selected napari layers.
3743
"""
3844

3945
def __init__(self, napari_viewer: napari.viewer.Viewer):
4046
super().__init__()
4147

4248
self.viewer = napari_viewer
43-
self.figure = Figure(figsize=(5, 3), tight_layout=True)
4449
self.canvas = FigureCanvas()
4550
self.canvas.figure.patch.set_facecolor("#262930")
4651
self.toolbar = NavigationToolbar2QT(self.canvas, self)
@@ -49,9 +54,61 @@ def __init__(self, napari_viewer: napari.viewer.Viewer):
4954
self.layout().addWidget(self.toolbar)
5055
self.layout().addWidget(self.canvas)
5156

57+
self.setup_callbacks()
58+
59+
@property
60+
def n_selected_layers(self) -> int:
61+
"""
62+
Number of currently selected layers.
63+
"""
64+
return len(self.layers)
65+
5266
@property
5367
def current_z(self) -> int:
5468
"""
5569
Current z-step of the viewer.
5670
"""
5771
return self.viewer.dims.current_step[0]
72+
73+
def setup_callbacks(self) -> None:
74+
"""
75+
Setup callbacks for:
76+
- Layer selection changing
77+
- z-step changing
78+
"""
79+
# z-step changed in viewer
80+
self.viewer.dims.events.current_step.connect(self._draw)
81+
# Layer selection changed in viewer
82+
self.viewer.layers.selection.events.active.connect(self.update_layers)
83+
84+
def update_layers(self, event: napari.utils.events.Event) -> None:
85+
"""
86+
Update the currently selected layers and re-draw.
87+
"""
88+
self.layers = list(self.viewer.layers.selection)
89+
self._draw()
90+
91+
def _draw(self) -> None:
92+
"""
93+
Clear current figure, check selected layers are correct, and draw new
94+
figure if so.
95+
"""
96+
self.clear()
97+
if self.n_selected_layers != self.n_layers_input:
98+
return
99+
self.draw()
100+
self.canvas.draw()
101+
102+
def clear(self) -> None:
103+
"""
104+
Clear any previously drawn figures.
105+
106+
This is a no-op, and is intended for derived classes to override.
107+
"""
108+
109+
def draw(self) -> None:
110+
"""
111+
Re-draw any figures.
112+
113+
This is a no-op, and is intended for derived classes to override.
114+
"""

Diff for: src/napari_matplotlib/histogram.py

+9-25
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,34 @@
1-
import napari
21
import numpy as np
32

43
from .base import NapariMPLWidget
54

65
__all__ = ["HistogramWidget"]
76

7+
import napari
88

99
_COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"}
1010

1111

1212
class HistogramWidget(NapariMPLWidget):
1313
"""
14-
Widget to display a histogram of the currently selected layer.
15-
16-
Attributes
17-
----------
18-
layer : `napari.layers.Layer`
19-
Current layer being histogrammed.
14+
Display a histogram of the currently selected layer.
2015
"""
2116

17+
n_layers_input = 1
18+
2219
def __init__(self, napari_viewer: napari.viewer.Viewer):
2320
super().__init__(napari_viewer)
2421
self.axes = self.canvas.figure.subplots()
25-
self.layer = self.viewer.layers[-1]
26-
27-
self.viewer.dims.events.current_step.connect(self.hist_current_layer)
28-
self.viewer.layers.selection.events.active.connect(self.update_layer)
22+
self.update_layers(None)
2923

30-
self.hist_current_layer()
31-
32-
def update_layer(self, event: napari.utils.events.Event) -> None:
33-
"""
34-
Update the currently selected layer.
35-
"""
36-
# Update current layer when selection changed in viewer
37-
if event.value:
38-
self.layer = event.value
39-
self.hist_current_layer()
24+
def clear(self) -> None:
25+
self.axes.clear()
4026

41-
def hist_current_layer(self) -> None:
27+
def draw(self) -> None:
4228
"""
4329
Clear the axes and histogram the currently selected layer/slice.
4430
"""
45-
self.axes.clear()
46-
layer = self.layer
31+
layer = self.layers[0]
4732
bins = np.linspace(np.min(layer.data), np.max(layer.data), 100)
4833

4934
if layer.data.ndim - layer.rgb == 3:
@@ -67,4 +52,3 @@ def hist_current_layer(self) -> None:
6752
self.axes.hist(data.ravel(), bins=bins, label=layer.name)
6853

6954
self.axes.legend()
70-
self.canvas.draw()

Diff for: src/napari_matplotlib/scatter.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,19 @@ class ScatterWidget(NapariMPLWidget):
1212
1313
If there are more than 500 data points, a 2D histogram is displayed instead
1414
of a scatter plot, to avoid too many scatter points.
15-
16-
Attributes
17-
----------
18-
layers : list[`napari.layers.Layer`]
19-
Current two layers being scattered.
2015
"""
2116

17+
n_layers_input = 2
18+
2219
def __init__(self, napari_viewer: napari.viewer.Viewer):
2320
super().__init__(napari_viewer)
2421
self.axes = self.canvas.figure.subplots()
25-
self.layers = self.viewer.layers[-2:]
26-
27-
self.viewer.dims.events.current_step.connect(
28-
self.scatter_current_layers
29-
)
30-
self.viewer.layers.selection.events.changed.connect(self.update_layers)
31-
32-
self.scatter_current_layers()
33-
34-
def update_layers(self, event: napari.utils.events.Event) -> None:
35-
"""
36-
Update the currently selected layers.
37-
"""
38-
# Update current layer when selection changed in viewer
39-
layers = self.viewer.layers.selection
40-
if len(layers) == 2:
41-
self.layers = list(layers)
42-
self.scatter_current_layers()
22+
self.update_layers(None)
4323

44-
def scatter_current_layers(self) -> None:
24+
def draw(self) -> None:
4525
"""
4626
Clear the axes and scatter the currently selected layers.
4727
"""
48-
self.axes.clear()
4928
data = [layer.data[self.current_z] for layer in self.layers]
5029
if data[0].size < 500:
5130
self.axes.scatter(data[0], data[1], alpha=0.5)
@@ -58,4 +37,3 @@ def scatter_current_layers(self) -> None:
5837
)
5938
self.axes.set_xlabel(self.layers[0].name)
6039
self.axes.set_ylabel(self.layers[1].name)
61-
self.canvas.draw()

0 commit comments

Comments
 (0)