forked from matplotlib/napari-matplotlib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscatter.py
112 lines (88 loc) · 3.14 KB
/
scatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Any, Optional, Tuple
import napari
import numpy.typing as npt
from qtpy.QtWidgets import QWidget
from .base import SingleAxesWidget
from .features import FeaturesMixin
from .util import Interval
__all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"]
class ScatterBaseWidget(SingleAxesWidget):
"""
Base class for widgets that scatter two datasets against each other.
"""
# if the number of points is greater than this value,
# the scatter is plotted as a 2D histogram
_threshold_to_switch_to_histogram = 500
def draw(self) -> None:
"""
Scatter the currently selected layers.
"""
x, y, x_axis_name, y_axis_name = self._get_data()
if x.size > self._threshold_to_switch_to_histogram:
self.axes.hist2d(
x.ravel(),
y.ravel(),
bins=100,
)
else:
self.axes.scatter(x, y, alpha=0.5)
self.axes.set_xlabel(x_axis_name)
self.axes.set_ylabel(y_axis_name)
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
Get the plot data.
This must be implemented on the subclass.
Returns
-------
x, y : np.ndarray
x and y values of plot data.
x_axis_name, y_axis_name : str
Label to display on the x/y axis
"""
raise NotImplementedError
class ScatterWidget(ScatterBaseWidget):
"""
Scatter data in two similarly shaped layers.
If there are more than 500 data points, a 2D histogram is displayed instead
of a scatter plot, to avoid too many scatter points.
"""
n_layers_input = Interval(2, 2)
input_layer_types = (napari.layers.Image,)
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
Get the plot data.
Returns
-------
data : List[np.ndarray]
List contains the in view slice of X and Y axis images.
x_axis_name : str
The title to display on the x axis
y_axis_name: str
The title to display on the y axis
"""
x = self.layers[0].data[self.current_z]
y = self.layers[1].data[self.current_z]
x_axis_name = self.layers[0].name
y_axis_name = self.layers[1].name
return x, y, x_axis_name, y_axis_name
class FeaturesScatterWidget(ScatterBaseWidget, FeaturesMixin):
"""
Widget to scatter data stored in two layer feature attributes.
"""
def __init__(
self,
napari_viewer: napari.viewer.Viewer,
parent: Optional[QWidget] = None,
):
ScatterBaseWidget.__init__(self, napari_viewer, parent=parent)
FeaturesMixin.__init__(self, ndim=2)
self._update_layers(None)
def draw(self) -> None:
"""
Scatter two features from the currently selected layer.
"""
if self._ready_to_plot():
super().draw()
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
data, names = self._get_data_names()
return data[0], data[1], names[0], names[1]