diff --git a/labellines/baseline/test_dateaxis_advanced.png b/labellines/baseline/test_dateaxis_advanced.png index aebf763..879f157 100644 Binary files a/labellines/baseline/test_dateaxis_advanced.png and b/labellines/baseline/test_dateaxis_advanced.png differ diff --git a/labellines/core.py b/labellines/core.py index 0fa4588..c9940fb 100644 --- a/labellines/core.py +++ b/labellines/core.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt import numpy as np +from datetime import datetime from matplotlib.container import ErrorbarContainer from matplotlib.dates import DateConverter, num2date from matplotlib.lines import Line2D @@ -186,11 +187,26 @@ def labelLines( if isinstance(xvals, tuple) and len(xvals) == 2: xmin, xmax = xvals xscale = ax.get_xscale() + + # Convert datetime objects to numeric values for linspace/geomspace + x_is_datetime = isinstance(xmin, datetime) or isinstance(xmax, datetime) + if x_is_datetime: + if not isinstance(xmin, datetime) or not isinstance(xmax, datetime): + raise ValueError( + f"Cannot mix datetime and numeric values in xvals: {xvals}" + ) + xmin = plt.matplotlib.dates.date2num(xmin) + xmax = plt.matplotlib.dates.date2num(xmax) + if xscale == "log": xvals = np.geomspace(xmin, xmax, len(all_lines) + 2)[1:-1] else: xvals = np.linspace(xmin, xmax, len(all_lines) + 2)[1:-1] + # Convert numeric values back to datetime objects + if x_is_datetime: + xvals = plt.matplotlib.dates.num2date(xvals) + # Build matrix line -> xvalue ok_matrix = np.zeros((len(all_lines), len(all_lines)), dtype=bool) @@ -198,6 +214,7 @@ def labelLines( xdata, _ = normalize_xydata(line) minx, maxx = min(xdata), max(xdata) for j, xv in enumerate(xvals): # type: ignore + xv = line.convert_xunits(xv) ok_matrix[i, j] = minx < xv < maxx # If some xvals do not fall in their corresponding line, @@ -224,6 +241,8 @@ def labelLines( # Move xlabel if it is outside valid range xdata, _ = normalize_xydata(line) xmin, xmax = min(xdata), max(xdata) + xv = line.convert_xunits(xv) + if not (xmin <= xv <= xmax): warnings.warn( ( diff --git a/labellines/test.py b/labellines/test.py index e317148..70f10cf 100644 --- a/labellines/test.py +++ b/labellines/test.py @@ -164,7 +164,7 @@ def test_dateaxis_advanced(setup_mpl): ax.xaxis.set_major_locator(DayLocator()) ax.xaxis.set_major_formatter(DateFormatter("%Y-%m-%d")) - labelLines(ax.get_lines()) + labelLines(ax.get_lines(), xvals=(dates[0], dates[-1])) return plt.gcf()