@@ -95,38 +95,46 @@ def update_lim(self, axes):
95
95
delta2 = _api .deprecated ("3.6" )(
96
96
property (lambda self : 0.00001 , lambda self , value : None ))
97
97
98
+ def _to_xy (self , values , const ):
99
+ """
100
+ Create a (*values.shape, 2)-shape array representing (x, y) pairs.
101
+
102
+ *values* go into the coordinate determined by ``self.nth_coord``.
103
+ The other coordinate is filled with the constant *const*.
104
+
105
+ Example::
106
+
107
+ >>> self.nth_coord = 0
108
+ >>> self._to_xy([1, 2, 3], const=0)
109
+ array([[1, 0],
110
+ [2, 0],
111
+ [3, 0]])
112
+ """
113
+ if self .nth_coord == 0 :
114
+ return np .stack (np .broadcast_arrays (values , const ), axis = - 1 )
115
+ elif self .nth_coord == 1 :
116
+ return np .stack (np .broadcast_arrays (const , values ), axis = - 1 )
117
+ else :
118
+ raise ValueError ("Unexpected nth_coord" )
119
+
98
120
class Fixed (_Base ):
99
121
"""Helper class for a fixed (in the axes coordinate) axis."""
100
122
101
- _default_passthru_pt = dict (left = (0 , 0 ),
102
- right = (1 , 0 ),
103
- bottom = (0 , 0 ),
104
- top = (0 , 1 ))
123
+ passthru_pt = _api .deprecated ("3.7" )(property (
124
+ lambda self : {"left" : (0 , 0 ), "right" : (1 , 0 ),
125
+ "bottom" : (0 , 0 ), "top" : (0 , 1 )}[self ._loc ]))
105
126
106
127
def __init__ (self , loc , nth_coord = None ):
107
128
"""``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
108
129
_api .check_in_list (["left" , "right" , "bottom" , "top" ], loc = loc )
109
130
self ._loc = loc
110
-
111
- if nth_coord is None :
112
- if loc in ["left" , "right" ]:
113
- nth_coord = 1
114
- else : # "bottom", "top"
115
- nth_coord = 0
116
-
117
- self .nth_coord = nth_coord
118
-
131
+ self ._pos = {"bottom" : 0 , "top" : 1 , "left" : 0 , "right" : 1 }[loc ]
132
+ self .nth_coord = (
133
+ nth_coord if nth_coord is not None else
134
+ {"bottom" : 0 , "top" : 0 , "left" : 1 , "right" : 1 }[loc ])
119
135
super ().__init__ ()
120
-
121
- self .passthru_pt = self ._default_passthru_pt [loc ]
122
-
123
- _verts = np .array ([[0. , 0. ],
124
- [1. , 1. ]])
125
- fixed_coord = 1 - nth_coord
126
- _verts [:, fixed_coord ] = self .passthru_pt [fixed_coord ]
127
-
128
136
# axis line in transAxes
129
- self ._path = Path (_verts )
137
+ self ._path = Path (self . _to_xy (( 0 , 1 ), const = self . _pos ) )
130
138
131
139
def get_nth_coord (self ):
132
140
return self .nth_coord
@@ -208,14 +216,13 @@ def get_tick_iterators(self, axes):
208
216
tick_to_axes = self .get_tick_transform (axes ) - axes .transAxes
209
217
210
218
def _f (locs , labels ):
211
- for x , l in zip (locs , labels ):
212
- c = list (self .passthru_pt ) # copy
213
- c [self .nth_coord ] = x
219
+ for loc , label in zip (locs , labels ):
220
+ c = self ._to_xy (loc , const = self ._pos )
214
221
# check if the tick point is inside axes
215
222
c2 = tick_to_axes .transform (c )
216
223
if mpl .transforms ._interval_contains_close (
217
224
(0 , 1 ), c2 [self .nth_coord ]):
218
- yield c , angle_normal , angle_tangent , l
225
+ yield c , angle_normal , angle_tangent , label
219
226
220
227
return _f (major_locs , major_labels ), _f (minor_locs , minor_labels )
221
228
@@ -227,15 +234,10 @@ def __init__(self, axes, nth_coord,
227
234
self .axis = [axes .xaxis , axes .yaxis ][self .nth_coord ]
228
235
229
236
def get_line (self , axes ):
230
- _verts = np .array ([[0. , 0. ],
231
- [1. , 1. ]])
232
-
233
237
fixed_coord = 1 - self .nth_coord
234
238
data_to_axes = axes .transData - axes .transAxes
235
239
p = data_to_axes .transform ([self ._value , self ._value ])
236
- _verts [:, fixed_coord ] = p [fixed_coord ]
237
-
238
- return Path (_verts )
240
+ return Path (self ._to_xy ((0 , 1 ), const = p [fixed_coord ]))
239
241
240
242
def get_line_transform (self , axes ):
241
243
return axes .transAxes
@@ -250,13 +252,12 @@ def get_axislabel_pos_angle(self, axes):
250
252
get_label_transform() returns a transform of (transAxes+offset)
251
253
"""
252
254
angle = [0 , 90 ][self .nth_coord ]
253
- _verts = [0.5 , 0.5 ]
254
255
fixed_coord = 1 - self .nth_coord
255
256
data_to_axes = axes .transData - axes .transAxes
256
257
p = data_to_axes .transform ([self ._value , self ._value ])
257
- _verts [ fixed_coord ] = p [fixed_coord ]
258
- if 0 <= _verts [fixed_coord ] <= 1 :
259
- return _verts , angle
258
+ verts = self . _to_xy ( 0.5 , const = p [fixed_coord ])
259
+ if 0 <= verts [fixed_coord ] <= 1 :
260
+ return verts , angle
260
261
else :
261
262
return None , None
262
263
@@ -281,12 +282,11 @@ def get_tick_iterators(self, axes):
281
282
data_to_axes = axes .transData - axes .transAxes
282
283
283
284
def _f (locs , labels ):
284
- for x , l in zip (locs , labels ):
285
- c = [self ._value , self ._value ]
286
- c [self .nth_coord ] = x
285
+ for loc , label in zip (locs , labels ):
286
+ c = self ._to_xy (loc , const = self ._value )
287
287
c1 , c2 = data_to_axes .transform (c )
288
288
if 0 <= c1 <= 1 and 0 <= c2 <= 1 :
289
- yield c , angle_normal , angle_tangent , l
289
+ yield c , angle_normal , angle_tangent , label
290
290
291
291
return _f (major_locs , major_labels ), _f (minor_locs , minor_labels )
292
292
0 commit comments