Skip to content

Commit 99dcf17

Browse files
Fix xvals for datetime axis
1 parent 7860fb3 commit 99dcf17

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

labellines/core.py

+15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import matplotlib.pyplot as plt
55
import numpy as np
6+
from datetime import datetime
67
from matplotlib.container import ErrorbarContainer
78
from matplotlib.dates import DateConverter, num2date
89
from matplotlib.lines import Line2D
@@ -186,18 +187,30 @@ def labelLines(
186187
if isinstance(xvals, tuple) and len(xvals) == 2:
187188
xmin, xmax = xvals
188189
xscale = ax.get_xscale()
190+
191+
# Convert datetime objects to numeric values for linspace/geomspace
192+
x_is_datetime = isinstance(xmin, datetime) or isinstance(xmax, datetime)
193+
if x_is_datetime:
194+
xmin = plt.matplotlib.dates.date2num(xmin)
195+
xmax = plt.matplotlib.dates.date2num(xmax)
196+
189197
if xscale == "log":
190198
xvals = np.geomspace(xmin, xmax, len(all_lines) + 2)[1:-1]
191199
else:
192200
xvals = np.linspace(xmin, xmax, len(all_lines) + 2)[1:-1]
193201

202+
# Convert numeric values back to datetime objects
203+
if x_is_datetime:
204+
xvals = plt.matplotlib.dates.num2date(xvals)
205+
194206
# Build matrix line -> xvalue
195207
ok_matrix = np.zeros((len(all_lines), len(all_lines)), dtype=bool)
196208

197209
for i, line in enumerate(all_lines):
198210
xdata, _ = normalize_xydata(line)
199211
minx, maxx = min(xdata), max(xdata)
200212
for j, xv in enumerate(xvals): # type: ignore
213+
xv = line.convert_xunits(xv)
201214
ok_matrix[i, j] = minx < xv < maxx
202215

203216
# If some xvals do not fall in their corresponding line,
@@ -224,6 +237,8 @@ def labelLines(
224237
# Move xlabel if it is outside valid range
225238
xdata, _ = normalize_xydata(line)
226239
xmin, xmax = min(xdata), max(xdata)
240+
xv = line.convert_xunits(xv)
241+
227242
if not (xmin <= xv <= xmax):
228243
warnings.warn(
229244
(

0 commit comments

Comments
 (0)