Skip to content

Commit 029cdb0

Browse files
committed
Move key setting/getting
1 parent 5bd30bd commit 029cdb0

File tree

3 files changed

+33
-51
lines changed

3 files changed

+33
-51
lines changed

Diff for: src/napari_matplotlib/features.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List
1+
from typing import Dict, List, Optional
22

33
import napari
44
import napari.layers
@@ -12,12 +12,6 @@ class FeaturesMixin(NapariMPLWidget):
1212
"""
1313
Mixin to help with widgets that plot data from a features table stored
1414
on a single layer.
15-
16-
Notes
17-
-----
18-
This currently only works for widgets that plot two quatities against each other
19-
e.g., scatter plots. It is intended to be generalised in the future for widgets
20-
that plot one quantity e.g., histograms.
2115
"""
2216

2317
n_layers_input = Interval(1, 1)
@@ -30,19 +24,37 @@ class FeaturesMixin(NapariMPLWidget):
3024
napari.layers.Vectors,
3125
)
3226

33-
def __init__(self) -> None:
27+
def __init__(self, *, ndim: int) -> None:
28+
assert ndim in [1, 2]
29+
self.dims = ["x", "y"][:ndim]
3430
# Set up selection boxes
3531
self.layout().addLayout(QVBoxLayout())
3632

3733
self._selectors: Dict[str, QComboBox] = {}
38-
for dim in ["x", "y"]:
34+
for dim in self.dims:
3935
self._selectors[dim] = QComboBox()
4036
# Re-draw when combo boxes are updated
4137
self._selectors[dim].currentTextChanged.connect(self._draw)
4238

4339
self.layout().addWidget(QLabel(f"{dim}-axis:"))
4440
self.layout().addWidget(self._selectors[dim])
4541

42+
def get_key(self, dim: str) -> Optional[str]:
43+
"""
44+
Get key for a given dimension.
45+
"""
46+
if self._selectors[dim].count() == 0:
47+
return None
48+
else:
49+
return self._selectors[dim].currentText()
50+
51+
def set_key(self, dim: str, value: str) -> None:
52+
"""
53+
Set key for a given dimension.
54+
"""
55+
self._selectors[dim].setCurrentText(value)
56+
self._draw()
57+
4658
def _get_valid_axis_keys(self) -> List[str]:
4759
"""
4860
Get the valid axis keys from the layer FeatureTable.

Diff for: src/napari_matplotlib/scatter.py

+8-38
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Tuple, Union
1+
from typing import Any, Optional, Tuple
22

33
import napari
44
import numpy.typing as npt
@@ -97,39 +97,9 @@ def __init__(
9797
parent: Optional[QWidget] = None,
9898
):
9999
ScatterBaseWidget.__init__(self, napari_viewer, parent=parent)
100-
FeaturesMixin.__init__(self)
100+
FeaturesMixin.__init__(self, ndim=2)
101101
self._update_layers(None)
102102

103-
@property
104-
def x_axis_key(self) -> Union[str, None]:
105-
"""
106-
Key for the x-axis data.
107-
"""
108-
if self._selectors["x"].count() == 0:
109-
return None
110-
else:
111-
return self._selectors["x"].currentText()
112-
113-
@x_axis_key.setter
114-
def x_axis_key(self, key: str) -> None:
115-
self._selectors["x"].setCurrentText(key)
116-
self._draw()
117-
118-
@property
119-
def y_axis_key(self) -> Union[str, None]:
120-
"""
121-
Key for the y-axis data.
122-
"""
123-
if self._selectors["y"].count() == 0:
124-
return None
125-
else:
126-
return self._selectors["y"].currentText()
127-
128-
@y_axis_key.setter
129-
def y_axis_key(self, key: str) -> None:
130-
self._selectors["y"].setCurrentText(key)
131-
self._draw()
132-
133103
def _ready_to_scatter(self) -> bool:
134104
"""
135105
Return True if selected layer has a feature table we can scatter with,
@@ -143,8 +113,8 @@ def _ready_to_scatter(self) -> bool:
143113
return (
144114
feature_table is not None
145115
and len(feature_table) > 0
146-
and self.x_axis_key in valid_keys
147-
and self.y_axis_key in valid_keys
116+
and self.get_key("x") in valid_keys
117+
and self.get_key("y") in valid_keys
148118
)
149119

150120
def draw(self) -> None:
@@ -173,11 +143,11 @@ def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
173143
"""
174144
feature_table = self.layers[0].features
175145

176-
x = feature_table[self.x_axis_key]
177-
y = feature_table[self.y_axis_key]
146+
x = feature_table[self.get_key("x")]
147+
y = feature_table[self.get_key("y")]
178148

179-
x_axis_name = str(self.x_axis_key)
180-
y_axis_name = str(self.y_axis_key)
149+
x_axis_name = str(self.get_key("x"))
150+
y_axis_name = str(self.get_key("y"))
181151

182152
return x, y, x_axis_name, y_axis_name
183153

Diff for: src/napari_matplotlib/tests/scatter/test_scatter_features.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def test_features_scatter_widget_2D(
2525

2626
# Select points data and chosen features
2727
viewer.layers.selection.add(viewer.layers[0]) # images need to be selected
28-
widget.x_axis_key = "feature_0"
29-
widget.y_axis_key = "feature_1"
28+
widget.set_key("x", "feature_0")
29+
widget.set_key("y", "feature_1")
3030

3131
fig = widget.figure
3232

@@ -64,9 +64,9 @@ def test_features_scatter_get_data(make_napari_viewer):
6464
viewer.layers.selection = [labels_layer]
6565

6666
x_column = "feature_0"
67-
scatter_widget.x_axis_key = x_column
6867
y_column = "feature_2"
69-
scatter_widget.y_axis_key = y_column
68+
scatter_widget.set_key("x", x_column)
69+
scatter_widget.set_key("y", y_column)
7070

7171
x, y, x_axis_name, y_axis_name = scatter_widget._get_data()
7272
np.testing.assert_allclose(x, feature_table[x_column])

0 commit comments

Comments
 (0)