Skip to content

Commit a9d41db

Browse files
committed
Fix slicing selection
1 parent d7e88a9 commit a9d41db

File tree

2 files changed

+70
-39
lines changed

2 files changed

+70
-39
lines changed

Diff for: src/napari_matplotlib/slice.py

+46-39
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1-
from typing import Any, Dict, List, Optional, Tuple
1+
from typing import Any, List, Optional, Tuple
22

33
import matplotlib.ticker as mticker
44
import napari
55
import numpy as np
66
import numpy.typing as npt
7-
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget
7+
from qtpy.QtCore import Qt
8+
from qtpy.QtWidgets import (
9+
QComboBox,
10+
QLabel,
11+
QSlider,
12+
QVBoxLayout,
13+
QWidget,
14+
)
815

916
from .base import SingleAxesWidget
1017
from .util import Interval
1118

1219
__all__ = ["SliceWidget"]
1320

14-
_dims_sel = ["x", "y"]
15-
1621

1722
class SliceWidget(SingleAxesWidget):
1823
"""
@@ -30,28 +35,44 @@ def __init__(
3035
# Setup figure/axes
3136
super().__init__(napari_viewer, parent=parent)
3237

33-
button_layout = QHBoxLayout()
34-
self.layout().addLayout(button_layout)
35-
3638
self.dim_selector = QComboBox()
39+
self.dim_selector.addItems(["x", "y"])
40+
41+
self.slice_selector = QSlider(orientation=Qt.Orientation.Horizontal)
42+
43+
# Create widget layout
44+
button_layout = QVBoxLayout()
3745
button_layout.addWidget(QLabel("Slice axis:"))
3846
button_layout.addWidget(self.dim_selector)
39-
self.dim_selector.addItems(["x", "y", "z"])
40-
41-
self.slice_selectors = {}
42-
for d in _dims_sel:
43-
self.slice_selectors[d] = QSpinBox()
44-
button_layout.addWidget(QLabel(f"{d}:"))
45-
button_layout.addWidget(self.slice_selectors[d])
47+
button_layout.addWidget(self.slice_selector)
48+
self.layout().addLayout(button_layout)
4649

4750
# Setup callbacks
48-
# Re-draw when any of the combon/spin boxes are updated
51+
# Re-draw when any of the combo/slider is updated
4952
self.dim_selector.currentTextChanged.connect(self._draw)
50-
for d in _dims_sel:
51-
self.slice_selectors[d].textChanged.connect(self._draw)
53+
self.slice_selector.valueChanged.connect(self._draw)
5254

5355
self._update_layers(None)
5456

57+
def on_update_layers(self) -> None:
58+
"""
59+
Called when layer selection is updated.
60+
"""
61+
if self.current_dim_name == "x":
62+
max = self._layer.data.shape[-2]
63+
elif self.current_dim_name == "y":
64+
max = self._layer.data.shape[-1]
65+
else:
66+
raise RuntimeError("dim name must be x or y")
67+
self.slice_selector.setRange(0, max)
68+
69+
@property
70+
def _slice_width(self) -> int:
71+
"""
72+
Width of the slice being plotted.
73+
"""
74+
return self._layer.data.shape[self.current_dim_index] - 1
75+
5576
@property
5677
def _layer(self) -> napari.layers.Layer:
5778
"""
@@ -73,7 +94,7 @@ def current_dim_index(self) -> int:
7394
"""
7495
# Note the reversed list because in napari the z-axis is the first
7596
# numpy axis
76-
return self._dim_names[::-1].index(self.current_dim_name)
97+
return self._dim_names.index(self.current_dim_name)
7798

7899
@property
79100
def _dim_names(self) -> List[str]:
@@ -82,45 +103,31 @@ def _dim_names(self) -> List[str]:
82103
dimensionality of the currently selected data.
83104
"""
84105
if self._layer.data.ndim == 2:
85-
return ["x", "y"]
106+
return ["y", "x"]
86107
elif self._layer.data.ndim == 3:
87-
return ["x", "y", "z"]
108+
return ["z", "y", "x"]
88109
else:
89110
raise RuntimeError("Don't know how to handle ndim != 2 or 3")
90111

91-
@property
92-
def _selector_values(self) -> Dict[str, int]:
93-
"""
94-
Values of the slice selectors.
95-
96-
Mapping from dimension name to value.
97-
"""
98-
return {d: self.slice_selectors[d].value() for d in _dims_sel}
99-
100112
def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]:
101113
"""
102114
Get data for plotting.
103115
"""
104-
dim_index = self.current_dim_index
105-
if self._layer.data.ndim == 2:
106-
dim_index -= 1
107-
x = np.arange(self._layer.data.shape[dim_index])
108-
109-
vals = self._selector_values
110-
vals.update({"z": self.current_z})
116+
val = self.slice_selector.value()
111117

112118
slices = []
113119
for dim_name in self._dim_names:
114120
if dim_name == self.current_dim_name:
115121
# Select all data along this axis
116122
slices.append(slice(None))
123+
elif dim_name == "z":
124+
# Only select the currently viewed z-index
125+
slices.append(slice(self.current_z, self.current_z + 1))
117126
else:
118127
# Select specific index
119-
val = vals[dim_name]
120128
slices.append(slice(val, val + 1))
121129

122-
# Reverse since z is the first axis in napari
123-
slices = slices[::-1]
130+
x = np.arange(self._slice_width)
124131
y = self._layer.data[tuple(slices)].ravel()
125132

126133
return x, y

Diff for: src/napari_matplotlib/tests/test_slice.py

+24
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,27 @@ def test_slice_2D(make_napari_viewer, astronaut_data):
3737
# Need to return a copy, as original figure is too eagerley garbage
3838
# collected by the widget
3939
return deepcopy(fig)
40+
41+
42+
def test_slice_axes(make_napari_viewer, astronaut_data):
43+
viewer = make_napari_viewer()
44+
viewer.theme = "light"
45+
46+
# Take first RGB channel
47+
data = astronaut_data[0][:256, :, 0]
48+
# Shape:
49+
# x: 0 > 512
50+
# y: 0 > 256
51+
assert data.ndim == 2, data.shape
52+
# Make sure data isn't square for later tests
53+
assert data.shape[0] != data.shape[1]
54+
viewer.add_image(data)
55+
56+
widget = SliceWidget(viewer)
57+
assert widget._dim_names == ["y", "x"]
58+
assert widget.current_dim_name == "x"
59+
assert widget.slice_selector.value() == 0
60+
assert widget.slice_selector.minimum() == 0
61+
assert widget.slice_selector.maximum() == data.shape[0]
62+
# x/y are flipped in napari
63+
assert widget._slice_width == data.shape[1]

0 commit comments

Comments
 (0)