Skip to content

Commit 8dd1058

Browse files
authored
Merge pull request matplotlib#22314 from anntzer/aatoxy
Add a helper to generate xy coordinates for AxisArtistHelper.
2 parents 7a0ea31 + 758649e commit 8dd1058

File tree

2 files changed

+43
-40
lines changed

2 files changed

+43
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
``passthru_pt``
2+
~~~~~~~~~~~~~~~
3+
This attribute of ``AxisArtistHelper``\s is deprecated.

lib/mpl_toolkits/axisartist/axislines.py

+40-40
Original file line numberDiff line numberDiff line change
@@ -95,38 +95,46 @@ def update_lim(self, axes):
9595
delta2 = _api.deprecated("3.6")(
9696
property(lambda self: 0.00001, lambda self, value: None))
9797

98+
def _to_xy(self, values, const):
99+
"""
100+
Create a (*values.shape, 2)-shape array representing (x, y) pairs.
101+
102+
*values* go into the coordinate determined by ``self.nth_coord``.
103+
The other coordinate is filled with the constant *const*.
104+
105+
Example::
106+
107+
>>> self.nth_coord = 0
108+
>>> self._to_xy([1, 2, 3], const=0)
109+
array([[1, 0],
110+
[2, 0],
111+
[3, 0]])
112+
"""
113+
if self.nth_coord == 0:
114+
return np.stack(np.broadcast_arrays(values, const), axis=-1)
115+
elif self.nth_coord == 1:
116+
return np.stack(np.broadcast_arrays(const, values), axis=-1)
117+
else:
118+
raise ValueError("Unexpected nth_coord")
119+
98120
class Fixed(_Base):
99121
"""Helper class for a fixed (in the axes coordinate) axis."""
100122

101-
_default_passthru_pt = dict(left=(0, 0),
102-
right=(1, 0),
103-
bottom=(0, 0),
104-
top=(0, 1))
123+
passthru_pt = _api.deprecated("3.7")(property(
124+
lambda self: {"left": (0, 0), "right": (1, 0),
125+
"bottom": (0, 0), "top": (0, 1)}[self._loc]))
105126

106127
def __init__(self, loc, nth_coord=None):
107128
"""``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
108129
_api.check_in_list(["left", "right", "bottom", "top"], loc=loc)
109130
self._loc = loc
110-
111-
if nth_coord is None:
112-
if loc in ["left", "right"]:
113-
nth_coord = 1
114-
else: # "bottom", "top"
115-
nth_coord = 0
116-
117-
self.nth_coord = nth_coord
118-
131+
self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc]
132+
self.nth_coord = (
133+
nth_coord if nth_coord is not None else
134+
{"bottom": 0, "top": 0, "left": 1, "right": 1}[loc])
119135
super().__init__()
120-
121-
self.passthru_pt = self._default_passthru_pt[loc]
122-
123-
_verts = np.array([[0., 0.],
124-
[1., 1.]])
125-
fixed_coord = 1 - nth_coord
126-
_verts[:, fixed_coord] = self.passthru_pt[fixed_coord]
127-
128136
# axis line in transAxes
129-
self._path = Path(_verts)
137+
self._path = Path(self._to_xy((0, 1), const=self._pos))
130138

131139
def get_nth_coord(self):
132140
return self.nth_coord
@@ -208,14 +216,13 @@ def get_tick_iterators(self, axes):
208216
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
209217

210218
def _f(locs, labels):
211-
for x, l in zip(locs, labels):
212-
c = list(self.passthru_pt) # copy
213-
c[self.nth_coord] = x
219+
for loc, label in zip(locs, labels):
220+
c = self._to_xy(loc, const=self._pos)
214221
# check if the tick point is inside axes
215222
c2 = tick_to_axes.transform(c)
216223
if mpl.transforms._interval_contains_close(
217224
(0, 1), c2[self.nth_coord]):
218-
yield c, angle_normal, angle_tangent, l
225+
yield c, angle_normal, angle_tangent, label
219226

220227
return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
221228

@@ -227,15 +234,10 @@ def __init__(self, axes, nth_coord,
227234
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
228235

229236
def get_line(self, axes):
230-
_verts = np.array([[0., 0.],
231-
[1., 1.]])
232-
233237
fixed_coord = 1 - self.nth_coord
234238
data_to_axes = axes.transData - axes.transAxes
235239
p = data_to_axes.transform([self._value, self._value])
236-
_verts[:, fixed_coord] = p[fixed_coord]
237-
238-
return Path(_verts)
240+
return Path(self._to_xy((0, 1), const=p[fixed_coord]))
239241

240242
def get_line_transform(self, axes):
241243
return axes.transAxes
@@ -250,13 +252,12 @@ def get_axislabel_pos_angle(self, axes):
250252
get_label_transform() returns a transform of (transAxes+offset)
251253
"""
252254
angle = [0, 90][self.nth_coord]
253-
_verts = [0.5, 0.5]
254255
fixed_coord = 1 - self.nth_coord
255256
data_to_axes = axes.transData - axes.transAxes
256257
p = data_to_axes.transform([self._value, self._value])
257-
_verts[fixed_coord] = p[fixed_coord]
258-
if 0 <= _verts[fixed_coord] <= 1:
259-
return _verts, angle
258+
verts = self._to_xy(0.5, const=p[fixed_coord])
259+
if 0 <= verts[fixed_coord] <= 1:
260+
return verts, angle
260261
else:
261262
return None, None
262263

@@ -281,12 +282,11 @@ def get_tick_iterators(self, axes):
281282
data_to_axes = axes.transData - axes.transAxes
282283

283284
def _f(locs, labels):
284-
for x, l in zip(locs, labels):
285-
c = [self._value, self._value]
286-
c[self.nth_coord] = x
285+
for loc, label in zip(locs, labels):
286+
c = self._to_xy(loc, const=self._value)
287287
c1, c2 = data_to_axes.transform(c)
288288
if 0 <= c1 <= 1 and 0 <= c2 <= 1:
289-
yield c, angle_normal, angle_tangent, l
289+
yield c, angle_normal, angle_tangent, label
290290

291291
return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
292292

0 commit comments

Comments
 (0)