-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathcore.py
272 lines (241 loc) · 9.04 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import warnings
from collections.abc import Iterable
from typing import List, Literal, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.container import ErrorbarContainer
from matplotlib.dates import DateConverter, num2date
from more_itertools import always_iterable
from .line_label import CurvedLineLabel, LineLabel
from .utils import ensure_float, maximum_bipartite_matching
# Label line with line2D label data
def labelLine(
line: plt.Line2D,
x: float,
curved_text: bool = False,
label: Optional[str] = None,
align: bool = True,
drop_label: bool = False,
yoffset: float = 0,
yoffset_logspace: bool = False,
outline_color: Union[Literal["auto"], None, "str"] = "auto",
outline_width: float = 8,
**kwargs,
):
"""
Label a single matplotlib line at position x
Parameters
----------
line : matplotlib.lines.Line
The line holding the label
x : number
The location in data unit of the label
curved_text : bool, optional
If True, the label will be curved to follow the line.
label : string, optional
The label to set. This is inferred from the line by default
drop_label : bool, optional
If True, the label is consumed by the function so that subsequent
calls to e.g. legend do not use it anymore.
yoffset : double, optional
Space to add to label's y position
yoffset_logspace : bool, optional
If True, then yoffset will be added to the label's y position in
log10 space
outline_color : None | "auto" | color
Colour of the outline. If set to "auto", use the background color.
If set to None, do not draw an outline.
outline_width : number
Width of the outline
kwargs : dict, optional
Optional arguments passed to ax.text
"""
label = label or line.get_label()
try:
if curved_text:
txt = CurvedLineLabel(
line,
label=label,
axes=line.axes,
yoffset=yoffset,
yoffset_logspace=yoffset_logspace,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
)
else:
txt = LineLabel(
line,
x,
label=label,
align=align,
yoffset=yoffset,
yoffset_logspace=yoffset_logspace,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
)
except ValueError as err:
if "does not have a well defined value" in str(err):
warnings.warn(
(
"%s could not be annotated due to `nans` values. "
"Consider using another location via the `x` argument."
)
% line,
UserWarning,
stacklevel=1,
)
return
raise err
if drop_label:
line.set_label(None)
return txt
def labelLines(
lines: Optional[List[plt.Line2D]] = None,
align: bool = True,
xvals: Union[None, Tuple[float, float], Iterable[float]] = None,
curved_text: bool = False,
drop_label: bool = False,
shrink_factor: float = 0.05,
yoffsets: Union[float, Iterable[float]] = 0,
outline_color: Union[Literal["auto"], None, "str"] = "auto",
outline_width: float = 5,
**kwargs,
):
"""Label all lines with their respective legends.
Parameters
----------
lines : list of matplotlib lines, optional.
Lines to label. If empty, label all lines that have a label.
align : boolean, optional
If True, the label will be aligned with the slope of the line
at the location of the label. If False, they will be horizontal.
xvals : (xfirst, xlast) or array of float, optional
The location of the labels. If a tuple, the labels will be
evenly spaced between xfirst and xlast (in the axis units).
curved_text : bool, optional
If True, the labels will be curved to follow the line.
drop_label : bool, optional
If True, the label is consumed by the function so that subsequent
calls to e.g. legend do not use it anymore.
shrink_factor : double, optional
Relative distance from the edges to place closest labels. Defaults to 0.05.
yoffsets : number or list, optional.
Distance relative to the line when positioning the labels. If given a number,
the same value is used for all lines.
outline_color : None | "auto" | color
Colour of the outline. If set to "auto", use the background color.
If set to None, do not draw an outline.
outline_width : number
Width of the outline
kwargs : dict, optional
Optional arguments passed to ax.text
"""
if lines:
ax = lines[0].axes
else:
ax = plt.gca()
handles, all_labels = ax.get_legend_handles_labels()
all_lines = []
for h in handles:
if isinstance(h, ErrorbarContainer):
line = h.lines[0]
else:
line = h
# If the user provided a list of lines to label, only label those
if (lines is not None) and (line not in lines):
continue
all_lines.append(line)
# Check that the lines passed to the function have all a label
if lines is not None:
for line in lines:
if line in all_lines:
continue
warnings.warn(
"Tried to label line %s, but could not find a label for it." % line,
UserWarning,
stacklevel=1,
)
# In case no x location was provided, we need to use some heuristics
# to generate them.
if xvals is None:
xvals = ax.get_xlim()
xvals_rng = xvals[1] - xvals[0] # type: ignore
shrinkage = xvals_rng * shrink_factor
xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage) # type: ignore
if isinstance(xvals, tuple) and len(xvals) == 2:
xmin, xmax = xvals
xscale = ax.get_xscale()
if xscale == "log":
xvals = np.logspace(np.log10(xmin), np.log10(xmax), len(all_lines) + 2)[
1:-1
]
else:
xvals = np.linspace(xmin, xmax, len(all_lines) + 2)[1:-1]
# Build matrix line -> xvalue
ok_matrix = np.zeros((len(all_lines), len(all_lines)), dtype=bool)
for i, line in enumerate(all_lines):
xdata = ensure_float(line.get_xdata())
minx, maxx = min(xdata), max(xdata)
for j, xv in enumerate(xvals): # type: ignore
ok_matrix[i, j] = minx < xv < maxx
# If some xvals do not fall in their corresponding line,
# find a better matching using maximum bipartite matching.
if not np.all(np.diag(ok_matrix)):
order = maximum_bipartite_matching(ok_matrix)
# The maximum match may miss a few points, let's add them back
order[order < 0] = np.setdiff1d(np.arange(len(order)), order[order >= 0])
# Now reorder the xvalues
old_xvals = xvals.copy() # type: ignore
xvals[order] = old_xvals # type: ignore
else:
xvals = list(always_iterable(xvals)) # force the creation of a copy
lab_lines, labels = [], []
# Take only the lines which have labels other than the default ones
for i, (line, xv) in enumerate(zip(all_lines, xvals)): # type: ignore
label = all_labels[all_lines.index(line)]
lab_lines.append(line)
labels.append(label)
# Move xlabel if it is outside valid range
xdata = ensure_float(line.get_xdata())
if not (min(xdata) <= xv <= max(xdata)):
warnings.warn(
(
"The value at position %s in `xvals` is outside the range of its "
"associated line (xmin=%s, xmax=%s, xval=%s). Clipping it "
"into the allowed range."
)
% (i, min(xdata), max(xdata), xv),
UserWarning,
stacklevel=1,
)
new_xv = min(xdata) + (max(xdata) - min(xdata)) * 0.9
xvals[i] = new_xv # type: ignore
# Convert float values back to datetime in case of datetime axis
if isinstance(ax.xaxis.converter, DateConverter):
tz = ax.xaxis.get_units()
xvals = [num2date(x).replace(tzinfo=tz) for x in xvals] # type: ignore
txts = []
if not isinstance(yoffsets, Iterable):
yoffsets = [float(yoffsets)] * len(all_lines)
for line, x, yoffset, label in zip(
lab_lines,
xvals, # type: ignore
yoffsets,
labels,
):
txts.append(
labelLine(
line,
x,
label=label,
align=align,
drop_label=drop_label,
yoffset=yoffset,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
)
)
return txts