Skip to content

Commit 0a31b25

Browse files
authored
Merge pull request #41 from ksunden/picking
Picking and Contains for new artists
2 parents a33566d + f7c8a07 commit 0a31b25

File tree

9 files changed

+243
-38
lines changed

9 files changed

+243
-38
lines changed

data_prototype/artist.py

+94-12
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
import numpy as np
55

6+
from matplotlib.backend_bases import PickEvent
7+
import matplotlib.artist as martist
8+
69
from .containers import DataContainer, ArrayContainer, DataUnion
710
from .description import Desc, desc_like
8-
from .conversion_edge import Edge, Graph, TransformEdge
11+
from .conversion_edge import Edge, FuncEdge, Graph, TransformEdge
912

1013

1114
class Artist:
@@ -18,6 +21,9 @@ def __init__(
1821
kwargs_cont = ArrayContainer(**kwargs)
1922
self._container = DataUnion(container, kwargs_cont)
2023

24+
self._children: list[tuple[float, Artist]] = []
25+
self._picker = None
26+
2127
edges = edges or []
2228
self._visible = True
2329
self._graph = Graph(edges)
@@ -41,6 +47,77 @@ def get_visible(self):
4147
def set_visible(self, visible):
4248
self._visible = visible
4349

50+
def pickable(self) -> bool:
51+
return self._picker is not None
52+
53+
def get_picker(self):
54+
return self._picker
55+
56+
def set_picker(self, picker):
57+
self._picker = picker
58+
59+
def contains(self, mouseevent, graph=None):
60+
"""
61+
Test whether the artist contains the mouse event.
62+
63+
Parameters
64+
----------
65+
mouseevent : `~matplotlib.backend_bases.MouseEvent`
66+
67+
Returns
68+
-------
69+
contains : bool
70+
Whether any values are within the radius.
71+
details : dict
72+
An artist-specific dictionary of details of the event context,
73+
such as which points are contained in the pick radius. See the
74+
individual Artist subclasses for details.
75+
"""
76+
return False, {}
77+
78+
def get_children(self):
79+
return [a[1] for a in self._children]
80+
81+
def pick(self, mouseevent, graph: Graph | None = None):
82+
"""
83+
Process a pick event.
84+
85+
Each child artist will fire a pick event if *mouseevent* is over
86+
the artist and the artist has picker set.
87+
88+
See Also
89+
--------
90+
set_picker, get_picker, pickable
91+
"""
92+
if graph is None:
93+
graph = self._graph
94+
else:
95+
graph = graph + self._graph
96+
# Pick self
97+
if self.pickable():
98+
picker = self.get_picker()
99+
if callable(picker):
100+
inside, prop = picker(self, mouseevent)
101+
else:
102+
inside, prop = self.contains(mouseevent, graph)
103+
if inside:
104+
PickEvent(
105+
"pick_event", mouseevent.canvas, mouseevent, self, **prop
106+
)._process()
107+
108+
# Pick children
109+
for a in self.get_children():
110+
# make sure the event happened in the same Axes
111+
ax = getattr(a, "axes", None)
112+
if mouseevent.inaxes is None or ax is None or mouseevent.inaxes == ax:
113+
# we need to check if mouseevent.inaxes is None
114+
# because some objects associated with an Axes (e.g., a
115+
# tick label) can be outside the bounding box of the
116+
# Axes and inaxes will be None
117+
# also check that ax is None so that it traverse objects
118+
# which do not have an axes property but children might
119+
a.pick(mouseevent, graph)
120+
44121

45122
class CompatibilityArtist:
46123
"""A compatibility shim to ducktype as a classic Matplotlib Artist.
@@ -59,7 +136,7 @@ class CompatibilityArtist:
59136
useful for avoiding accidental dependency.
60137
"""
61138

62-
def __init__(self, artist: Artist):
139+
def __init__(self, artist: martist.Artist):
63140
self._artist = artist
64141

65142
self._axes = None
@@ -134,7 +211,7 @@ def draw(self, renderer, graph=None):
134211
self._artist.draw(renderer, graph + self._graph)
135212

136213

137-
class CompatibilityAxes:
214+
class CompatibilityAxes(Artist):
138215
"""A compatibility shim to add to traditional matplotlib axes.
139216
140217
At this time features are implemented on an "as needed" basis, and many
@@ -152,12 +229,11 @@ class CompatibilityAxes:
152229
"""
153230

154231
def __init__(self, axes):
232+
super().__init__(ArrayContainer())
155233
self._axes = axes
156234
self.figure = None
157235
self._clippath = None
158-
self._visible = True
159236
self.zorder = 2
160-
self._children: list[tuple[float, Artist]] = []
161237

162238
@property
163239
def axes(self):
@@ -187,6 +263,18 @@ def axes(self, ax):
187263
desc_like(xy, coordinates="display"),
188264
transform=self._axes.transAxes,
189265
),
266+
FuncEdge.from_func(
267+
"xunits",
268+
lambda: self._axes.xaxis.units,
269+
{},
270+
{"xunits": Desc((), "units")},
271+
),
272+
FuncEdge.from_func(
273+
"yunits",
274+
lambda: self._axes.yaxis.units,
275+
{},
276+
{"yunits": Desc((), "units")},
277+
),
190278
],
191279
aliases=(("parent", "axes"),),
192280
)
@@ -210,7 +298,7 @@ def get_animated(self):
210298
return False
211299

212300
def draw(self, renderer, graph=None):
213-
if not self.visible:
301+
if not self.get_visible():
214302
return
215303
if graph is None:
216304
graph = Graph([])
@@ -228,9 +316,3 @@ def set_xlim(self, min_=None, max_=None):
228316

229317
def set_ylim(self, min_=None, max_=None):
230318
self.axes.set_ylim(min_, max_)
231-
232-
def get_visible(self):
233-
return self._visible
234-
235-
def set_visible(self, visible):
236-
self._visible = visible

data_prototype/containers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def __init__(self, coordinates: dict[str, str] | None = None, /, **data):
8989
self._desc = {
9090
k: (
9191
Desc(v.shape, coordinates.get(k, "auto"))
92-
if isinstance(v, np.ndarray)
93-
else Desc(())
92+
if hasattr(v, "shape")
93+
else Desc((), coordinates.get(k, "auto"))
9494
)
9595
for k, v in data.items()
9696
}

data_prototype/conversion_edge.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ def __ge__(self, other):
295295
def __gt__(self, other):
296296
return self.weight > other.weight
297297

298+
@property
299+
def edges(self):
300+
if self.prev_node is None:
301+
return [self.edge]
302+
return self.prev_node.edges + [self.edge]
303+
298304
q: PriorityQueue[Node] = PriorityQueue()
299305
q.put(Node(0, input))
300306

@@ -308,6 +314,8 @@ def __gt__(self, other):
308314
best = n
309315
continue
310316
for e in sub_edges:
317+
if e in n.edges:
318+
continue
311319
if Desc.compatible(n.desc, e.input, aliases=self._aliases):
312320
d = n.desc | e.output
313321
w = n.weight + e.weight
@@ -397,7 +405,7 @@ def node_format(x):
397405
)
398406

399407
try:
400-
pos = nx.planar_layout(G)
408+
pos = nx.shell_layout(G)
401409
except Exception:
402410
pos = nx.circular_layout(G)
403411
plt.figure()

data_prototype/image.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ def _interpolate_nearest(image, x, y):
1414
l, r = x
1515
width = int(((round(r) + 0.5) - (round(l) - 0.5)) * magnification)
1616

17-
xpix = np.digitize(np.arange(width), np.linspace(0, r - l, image.shape[1] + 1))
17+
xpix = np.digitize(np.arange(width), np.linspace(0, r - l, image.shape[1]))
1818

1919
b, t = y
2020
height = int(((round(t) + 0.5) - (round(b) - 0.5)) * magnification)
21-
ypix = np.digitize(np.arange(height), np.linspace(0, t - b, image.shape[0] + 1))
21+
ypix = np.digitize(np.arange(height), np.linspace(0, t - b, image.shape[0]))
2222

2323
out = np.empty((height, width, 4))
2424

@@ -53,7 +53,7 @@ def __init__(self, container, edges=None, norm=None, cmap=None, **kwargs):
5353
{"image": Desc(("O", "P", 4), coordinates="rgba_resampled")},
5454
)
5555

56-
self._edges += [
56+
edges = [
5757
CoordinateEdge.from_coords("xycoords", {"x": "auto", "y": "auto"}, "data"),
5858
CoordinateEdge.from_coords(
5959
"image_coords", {"image": Desc(("M", "N"), "auto")}, "data"
@@ -79,7 +79,7 @@ def __init__(self, container, edges=None, norm=None, cmap=None, **kwargs):
7979
self._interpolation_edge,
8080
]
8181

82-
self._graph = Graph(self._edges, (("data", "data_resampled"),))
82+
self._graph = self._graph + Graph(edges, (("data", "data_resampled"),))
8383

8484
def draw(self, renderer, graph: Graph) -> None:
8585
if not self.get_visible():
@@ -111,3 +111,29 @@ def draw(self, renderer, graph: Graph) -> None:
111111
mtransforms.Bbox.from_extents(clipx[0], clipy[0], clipx[1], clipy[1])
112112
)
113113
renderer.draw_image(gc, x[0], y[0], image) # TODO vector backend transforms
114+
115+
def contains(self, mouseevent, graph=None):
116+
if graph is None:
117+
return False, {}
118+
g = graph + self._graph
119+
conv = g.evaluator(
120+
self._container.describe(),
121+
{
122+
"x": Desc(("X",), "display"),
123+
"y": Desc(("Y",), "display"),
124+
},
125+
).inverse
126+
query, _ = self._container.query(g)
127+
xmin, xmax = query["x"]
128+
ymin, ymax = query["y"]
129+
x, y = conv.evaluate({"x": mouseevent.x, "y": mouseevent.y}).values()
130+
131+
# This checks xmin <= x <= xmax *or* xmax <= x <= xmin.
132+
inside = (
133+
x is not None
134+
and (x - xmin) * (x - xmax) <= 0
135+
and y is not None
136+
and (y - ymin) * (y - ymax) <= 0
137+
)
138+
139+
return inside, {}

data_prototype/line.py

+64
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from .description import Desc
1010
from .conversion_edge import Graph, CoordinateEdge, DefaultEdge
1111

12+
segment_hits = mlines.segment_hits
13+
1214

1315
class Line(Artist):
1416
def __init__(self, container, edges=None, **kwargs):
@@ -57,6 +59,68 @@ def __init__(self, container, edges=None, **kwargs):
5759
# - non-str markers
5860
# Each individually pretty easy, but relatively rare features, focusing on common cases
5961

62+
def contains(self, mouseevent, graph=None):
63+
"""
64+
Test whether *mouseevent* occurred on the line.
65+
66+
An event is deemed to have occurred "on" the line if it is less
67+
than ``self.pickradius`` (default: 5 points) away from it. Use
68+
`~.Line2D.get_pickradius` or `~.Line2D.set_pickradius` to get or set
69+
the pick radius.
70+
71+
Parameters
72+
----------
73+
mouseevent : `~matplotlib.backend_bases.MouseEvent`
74+
75+
Returns
76+
-------
77+
contains : bool
78+
Whether any values are within the radius.
79+
details : dict
80+
A dictionary ``{'ind': pointlist}``, where *pointlist* is a
81+
list of points of the line that are within the pickradius around
82+
the event position.
83+
84+
TODO: sort returned indices by distance
85+
"""
86+
if graph is None:
87+
return False, {}
88+
89+
g = graph + self._graph
90+
desc = Desc(("N",), "display")
91+
scalar = Desc((), "display") # ... this needs thinking...
92+
# Convert points to pixels
93+
require = {
94+
"x": desc,
95+
"y": desc,
96+
"linestyle": scalar,
97+
}
98+
conv = g.evaluator(self._container.describe(), require)
99+
query, _ = self._container.query(g)
100+
xt, yt, linestyle = conv.evaluate(query).values()
101+
102+
# Convert pick radius from points to pixels
103+
pixels = 5 # self._pickradius # TODO
104+
105+
# The math involved in checking for containment (here and inside of
106+
# segment_hits) assumes that it is OK to overflow, so temporarily set
107+
# the error flags accordingly.
108+
with np.errstate(all="ignore"):
109+
# Check for collision
110+
if linestyle in ["None", None]:
111+
# If no line, return the nearby point(s)
112+
(ind,) = np.nonzero(
113+
(xt - mouseevent.x) ** 2 + (yt - mouseevent.y) ** 2 <= pixels**2
114+
)
115+
else:
116+
# If line, return the nearby segment(s)
117+
ind = segment_hits(mouseevent.x, mouseevent.y, xt, yt, pixels)
118+
# if self._drawstyle.startswith("steps"):
119+
# ind //= 2
120+
121+
# Return the point(s) within radius
122+
return len(ind) > 0, dict(ind=ind)
123+
60124
def draw(self, renderer, graph: Graph) -> None:
61125
if not self.get_visible():
62126
return

data_prototype/tests/test_containers.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def _verify_describe(container):
2222
assert set(data) == set(desc)
2323
for k, v in data.items():
2424
assert v.shape == desc[k].shape
25-
assert v.dtype == desc[k].dtype
2625

2726

2827
def test_array_describe(ac):

examples/first.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@
3838
ax.add_artist(lw2, 2)
3939
ax.set_xlim(0, np.pi * 4)
4040
ax.set_ylim(-1.1, 1.1)
41+
4142
plt.show()

0 commit comments

Comments
 (0)