Skip to content

Commit e853d34

Browse files
committed
ENH: Implemented y-axis-based labeling
Should close cphyc#121
1 parent 11a8466 commit e853d34

File tree

4 files changed

+306
-195
lines changed

4 files changed

+306
-195
lines changed

example/matplotlib_label_lines.ipynb

+172-104
Large diffs are not rendered by default.

labellines/core.py

+75-51
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# Label line with line2D label data
1414
def labelLine(
1515
line,
16-
x,
16+
val,
17+
axis="x",
1718
label=None,
1819
align=True,
1920
drop_label=False,
20-
yoffset=0,
21-
yoffset_logspace=False,
21+
offset=0,
22+
offset_logspace=False,
2223
outline_color="auto",
2324
outline_width=8,
2425
**kwargs,
@@ -30,17 +31,19 @@ def labelLine(
3031
----------
3132
line : matplotlib.lines.Line
3233
The line holding the label
33-
x : number
34+
val : number
3435
The location in data unit of the label
36+
axis : "x" | "y"
37+
Reference axis for `val`.
3538
label : string, optional
3639
The label to set. This is inferred from the line by default
3740
drop_label : bool, optional
3841
If True, the label is consumed by the function so that subsequent
3942
calls to e.g. legend do not use it anymore.
40-
yoffset : double, optional
43+
offset : double, optional
4144
Space to add to label's y position
42-
yoffset_logspace : bool, optional
43-
If True, then yoffset will be added to the label's y position in
45+
offset_logspace : bool, optional
46+
If True, then offset will be added to the label's y position in
4447
log10 space
4548
outline_color : None | "auto" | color
4649
Colour of the outline. If set to "auto", use the background color.
@@ -54,11 +57,12 @@ def labelLine(
5457
try:
5558
txt = LineLabel(
5659
line,
57-
x,
60+
val,
61+
axis,
5862
label=label,
5963
align=align,
60-
yoffset=yoffset,
61-
yoffset_logspace=yoffset_logspace,
64+
offset=offset,
65+
offset_logspace=offset_logspace,
6266
outline_color=outline_color,
6367
outline_width=outline_width,
6468
**kwargs,
@@ -86,10 +90,11 @@ def labelLine(
8690
def labelLines(
8791
lines=None,
8892
align=True,
89-
xvals=None,
93+
vals=None,
94+
axis=None,
9095
drop_label=False,
9196
shrink_factor=0.05,
92-
yoffsets=0,
97+
offsets=0,
9398
outline_color="auto",
9499
outline_width=5,
95100
**kwargs,
@@ -103,17 +108,19 @@ def labelLines(
103108
align : boolean, optional
104109
If True, the label will be aligned with the slope of the line
105110
at the location of the label. If False, they will be horizontal.
106-
xvals : (xfirst, xlast) or array of float, optional
111+
vals : (first, last) or array of float, optional
107112
The location of the labels. If a tuple, the labels will be
108-
evenly spaced between xfirst and xlast (in the axis units).
113+
evenly spaced between first and last (in the axis units).
114+
axis : None | "x" | "y", optional
115+
Reference axis for the `vals`.
109116
drop_label : bool, optional
110117
If True, the label is consumed by the function so that subsequent
111118
calls to e.g. legend do not use it anymore.
112119
shrink_factor : double, optional
113120
Relative distance from the edges to place closest labels. Defaults to 0.05.
114-
yoffsets : number or list, optional.
121+
offsets : number or list, optional.
115122
Distance relative to the line when positioning the labels. If given a number,
116-
the same value is used for all lines.
123+
the same value is used for all lines. It refers to the *other* axis (i.e. to y if axis=="x")
117124
outline_color : None | "auto" | color
118125
Colour of the outline. If set to "auto", use the background color.
119126
If set to None, do not draw an outline.
@@ -122,11 +129,18 @@ def labelLines(
122129
kwargs : dict, optional
123130
Optional arguments passed to ax.text
124131
"""
132+
125133
if lines:
126134
ax = lines[0].axes
127135
else:
128136
ax = plt.gca()
129137

138+
if axis == "y":
139+
yaxis = True
140+
else:
141+
axis = "x"
142+
yaxis = False
143+
130144
handles, labels_of_handles = ax.get_legend_handles_labels()
131145

132146
all_lines, all_labels = [], []
@@ -156,32 +170,38 @@ def labelLines(
156170

157171
# In case no x location was provided, we need to use some heuristics
158172
# to generate them.
159-
if xvals is None:
160-
xvals = ax.get_xlim()
161-
xvals_rng = xvals[1] - xvals[0]
162-
shrinkage = xvals_rng * shrink_factor
163-
xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage)
164-
165-
if isinstance(xvals, tuple) and len(xvals) == 2:
166-
xmin, xmax = xvals
173+
if vals is None:
174+
if yaxis:
175+
vals = ax.get_ylim()
176+
else:
177+
vals = ax.get_xlim()
178+
vals_rng = vals[1] - vals[0]
179+
shrinkage = vals_rng * shrink_factor
180+
vals = (vals[0] + shrinkage, vals[1] - shrinkage)
181+
182+
if isinstance(vals, tuple) and len(vals) == 2:
183+
vmin, vmax = vals
167184
xscale = ax.get_xscale()
168185
if xscale == "log":
169-
xvals = np.logspace(np.log10(xmin), np.log10(xmax), len(all_lines) + 2)[
170-
1:-1
171-
]
186+
vals = np.logspace(np.log10(vmin), np.log10(vmax), len(all_lines) + 2)[
187+
1:-1
188+
]
172189
else:
173-
xvals = np.linspace(xmin, xmax, len(all_lines) + 2)[1:-1]
190+
vals = np.linspace(vmin, vmax, len(all_lines) + 2)[1:-1]
174191

175-
# Build matrix line -> xvalue
192+
# Build matrix line -> value
176193
ok_matrix = np.zeros((len(all_lines), len(all_lines)), dtype=bool)
177194

178195
for i, line in enumerate(all_lines):
179-
xdata, _ = normalize_xydata(line)
180-
minx, maxx = min(xdata), max(xdata)
181-
for j, xv in enumerate(xvals):
182-
ok_matrix[i, j] = minx < xv < maxx
196+
if yaxis:
197+
_, data = normalize_xydata(line)
198+
else:
199+
data, _ = normalize_xydata(line)
200+
minv, maxv = min(data), max(data)
201+
for j, val in enumerate(vals):
202+
ok_matrix[i, j] = minv < val < maxv
183203

184-
# If some xvals do not fall in their corresponding line,
204+
# If some vals do not fall in their corresponding line,
185205
# find a better matching using maximum bipartite matching.
186206
if not np.all(np.diag(ok_matrix)):
187207
order = maximum_bipartite_matching(ok_matrix)
@@ -190,51 +210,55 @@ def labelLines(
190210
order[order < 0] = np.setdiff1d(np.arange(len(order)), order[order >= 0])
191211

192212
# Now reorder the xvalues
193-
old_xvals = xvals.copy()
194-
xvals[order] = old_xvals
213+
old_xvals = vals.copy()
214+
vals[order] = old_xvals
195215
else:
196-
xvals = list(always_iterable(xvals)) # force the creation of a copy
216+
vals = list(always_iterable(vals)) # force the creation of a copy
197217

198218
lab_lines, labels = [], []
199219
# Take only the lines which have labels other than the default ones
200-
for i, (line, xv) in enumerate(zip(all_lines, xvals)):
220+
for i, (line, val) in enumerate(zip(all_lines, vals)):
201221
label = all_labels[all_lines.index(line)]
202222
lab_lines.append(line)
203223
labels.append(label)
204224

205-
# Move xlabel if it is outside valid range
206-
xdata, _ = normalize_xydata(line)
207-
if not (min(xdata) <= xv <= max(xdata)):
225+
# Move xlabel/ylabel if it is outside valid range
226+
if yaxis:
227+
_, data = normalize_xydata(line)
228+
else:
229+
data, _ = normalize_xydata(line)
230+
if not (min(data) <= val <= max(data)):
208231
warnings.warn(
209232
(
210-
"The value at position {} in `xvals` is outside the range of its "
211-
"associated line (xmin={}, xmax={}, xval={}). Clipping it "
233+
"The value at position {} in `vals` is outside the range of its "
234+
"associated line (vmin={}, vmax={}, xval={}). Clipping it "
212235
"into the allowed range."
213-
).format(i, min(xdata), max(xdata), xv),
236+
).format(i, min(data), max(data), val),
214237
UserWarning,
215238
stacklevel=1,
216239
)
217-
new_xv = min(xdata) + (max(xdata) - min(xdata)) * 0.9
218-
xvals[i] = new_xv
240+
new_val = min(data) + (max(data) - min(data)) * 0.9
241+
vals[i] = new_val
219242

220243
# Convert float values back to datetime in case of datetime axis
221244
if isinstance(ax.xaxis.converter, DateConverter):
222-
xvals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in xvals]
245+
vals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in vals]
223246

224247
txts = []
225248
try:
226-
yoffsets = [float(yoffsets)] * len(all_lines)
249+
offsets = [float(offsets)] * len(all_lines)
227250
except TypeError:
228251
pass
229-
for line, x, yoffset, label in zip(lab_lines, xvals, yoffsets, labels):
252+
for line, val, offset, label in zip(lab_lines, vals, offsets, labels):
230253
txts.append(
231254
labelLine(
232255
line,
233-
x,
256+
val,
257+
axis,
234258
label=label,
235259
align=align,
236260
drop_label=drop_label,
237-
yoffset=yoffset,
261+
offset=offset,
238262
outline_color=outline_color,
239263
outline_width=outline_width,
240264
**kwargs,

0 commit comments

Comments
 (0)