diff --git a/labellines/core.py b/labellines/core.py index 3fb3555..9f222fb 100644 --- a/labellines/core.py +++ b/labellines/core.py @@ -1,4 +1,6 @@ import warnings +from collections.abc import Iterable +from typing import List, Literal, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -6,21 +8,22 @@ from matplotlib.dates import DateConverter, num2date from more_itertools import always_iterable -from .line_label import LineLabel +from .line_label import CurvedLineLabel, LineLabel from .utils import ensure_float, maximum_bipartite_matching # Label line with line2D label data def labelLine( - line, - x, - label=None, - align=True, - drop_label=False, - yoffset=0, - yoffset_logspace=False, - outline_color="auto", - outline_width=8, + line: plt.Line2D, + x: float, + curved_text: bool = False, + label: Optional[str] = None, + align: bool = True, + drop_label: bool = False, + yoffset: float = 0, + yoffset_logspace: bool = False, + outline_color: Union[Literal["auto"], None, "str"] = "auto", + outline_width: float = 8, **kwargs, ): """ @@ -32,6 +35,8 @@ def labelLine( The line holding the label x : number The location in data unit of the label + curved_text : bool, optional + If True, the label will be curved to follow the line. label : string, optional The label to set. This is inferred from the line by default drop_label : bool, optional @@ -51,18 +56,32 @@ def labelLine( Optional arguments passed to ax.text """ + label = label or line.get_label() + try: - txt = LineLabel( - line, - x, - label=label, - align=align, - yoffset=yoffset, - yoffset_logspace=yoffset_logspace, - outline_color=outline_color, - outline_width=outline_width, - **kwargs, - ) + if curved_text: + txt = CurvedLineLabel( + line, + label=label, + axes=line.axes, + yoffset=yoffset, + yoffset_logspace=yoffset_logspace, + outline_color=outline_color, + outline_width=outline_width, + **kwargs, + ) + else: + txt = LineLabel( + line, + x, + label=label, + align=align, + yoffset=yoffset, + yoffset_logspace=yoffset_logspace, + outline_color=outline_color, + outline_width=outline_width, + **kwargs, + ) except ValueError as err: if "does not have a well defined value" in str(err): warnings.warn( @@ -84,14 +103,15 @@ def labelLine( def labelLines( - lines=None, - align=True, - xvals=None, - drop_label=False, - shrink_factor=0.05, - yoffsets=0, - outline_color="auto", - outline_width=5, + lines: Optional[List[plt.Line2D]] = None, + align: bool = True, + xvals: Union[None, Tuple[float, float], Iterable[float]] = None, + curved_text: bool = False, + drop_label: bool = False, + shrink_factor: float = 0.05, + yoffsets: Union[float, Iterable[float]] = 0, + outline_color: Union[Literal["auto"], None, "str"] = "auto", + outline_width: float = 5, **kwargs, ): """Label all lines with their respective legends. @@ -106,6 +126,8 @@ def labelLines( xvals : (xfirst, xlast) or array of float, optional The location of the labels. If a tuple, the labels will be evenly spaced between xfirst and xlast (in the axis units). + curved_text : bool, optional + If True, the labels will be curved to follow the line. drop_label : bool, optional If True, the label is consumed by the function so that subsequent calls to e.g. legend do not use it anymore. @@ -157,9 +179,9 @@ def labelLines( # to generate them. if xvals is None: xvals = ax.get_xlim() - xvals_rng = xvals[1] - xvals[0] + xvals_rng = xvals[1] - xvals[0] # type: ignore shrinkage = xvals_rng * shrink_factor - xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage) + xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage) # type: ignore if isinstance(xvals, tuple) and len(xvals) == 2: xmin, xmax = xvals @@ -177,7 +199,7 @@ def labelLines( for i, line in enumerate(all_lines): xdata = ensure_float(line.get_xdata()) minx, maxx = min(xdata), max(xdata) - for j, xv in enumerate(xvals): + for j, xv in enumerate(xvals): # type: ignore ok_matrix[i, j] = minx < xv < maxx # If some xvals do not fall in their corresponding line, @@ -189,14 +211,14 @@ def labelLines( order[order < 0] = np.setdiff1d(np.arange(len(order)), order[order >= 0]) # Now reorder the xvalues - old_xvals = xvals.copy() - xvals[order] = old_xvals + old_xvals = xvals.copy() # type: ignore + xvals[order] = old_xvals # type: ignore else: xvals = list(always_iterable(xvals)) # force the creation of a copy lab_lines, labels = [], [] # Take only the lines which have labels other than the default ones - for i, (line, xv) in enumerate(zip(all_lines, xvals)): + for i, (line, xv) in enumerate(zip(all_lines, xvals)): # type: ignore label = all_labels[all_lines.index(line)] lab_lines.append(line) labels.append(label) @@ -215,18 +237,24 @@ def labelLines( stacklevel=1, ) new_xv = min(xdata) + (max(xdata) - min(xdata)) * 0.9 - xvals[i] = new_xv + xvals[i] = new_xv # type: ignore # Convert float values back to datetime in case of datetime axis if isinstance(ax.xaxis.converter, DateConverter): - xvals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in xvals] + tz = ax.xaxis.get_units() + xvals = [num2date(x).replace(tzinfo=tz) for x in xvals] # type: ignore txts = [] - try: + + if not isinstance(yoffsets, Iterable): yoffsets = [float(yoffsets)] * len(all_lines) - except TypeError: - pass - for line, x, yoffset, label in zip(lab_lines, xvals, yoffsets, labels): + + for line, x, yoffset, label in zip( + lab_lines, + xvals, # type: ignore + yoffsets, + labels, + ): txts.append( labelLine( line, diff --git a/labellines/line_label.py b/labellines/line_label.py index 79b1d12..d244d2e 100644 --- a/labellines/line_label.py +++ b/labellines/line_label.py @@ -1,5 +1,7 @@ from __future__ import annotations +import re +from itertools import repeat from typing import TYPE_CHECKING import matplotlib.patheffects as patheffects @@ -17,6 +19,9 @@ ColorLike = Any # mpl has no type annotations so this is just a crutch AutoLiteral = Literal["auto"] +# This matches a dollar sign that is not preceded by a backslash +VALID_MATH_RE = re.compile(r"(? l_fig[-1]: + for t in both_t: + t.set_alpha(0.0) + rel_pos += w + continue + + elif c != " ": + for t in both_t: + t.set_alpha(1.0) + + # finding the two data points between which the horizontal + # center point of the character will be situated + # left and right indices: + il = np.where(rel_pos + w / 2 >= l_fig)[0][-1] + ir = np.where(rel_pos + w / 2 <= l_fig)[0][0] + + # if we exactly hit a data point: + if ir == il: + ir += 1 + + # how much of the letter width was needed to find il: + used = l_fig[il] - rel_pos + rel_pos = l_fig[il] + + # relative distance between il and ir where the center + # of the character will be + fraction = (w / 2 - used) / r_fig_dist[il] + + ## setting the character position in data coordinates: + ## interpolate between the two points: + x = self._x_data[il] + fraction * (self._x_data[ir] - self._x_data[il]) + y = self._y_data[il] + fraction * (self._y_data[ir] - self._y_data[il]) + + # getting the offset when setting correct vertical alignment + # in data coordinates + for t in both_t: + t.set_va(self.get_va()) + bbox2 = t_char.get_window_extent(renderer=renderer) + + bbox1d = self._ax.transData.inverted().transform(bbox1) + bbox2d = self._ax.transData.inverted().transform(bbox2) + dr = np.array(bbox2d[0] - bbox1d[0]) + + # the rotation/stretch matrix + rad = rads[il] + rot_mat = np.array( + [ + [np.math.cos(rad), np.math.sin(rad) * aspect], + [-np.math.sin(rad) / aspect, np.math.cos(rad)], + ] + ) + + ## computing the offset vector of the rotated character + drp = np.dot(dr, rot_mat) + + # setting final position and rotation: + for t in both_t: + t.set_position(np.array([x, y]) + drp) + t.set_rotation(degs[il]) + + t.set_va("center") + t.set_ha("center") + + # updating rel_pos to right edge of character + rel_pos += w - used + + @staticmethod + def tokenize_string(text: str) -> list[str]: + # Make sure the string has only valid math (i.e. there is an even number of `$`) + valid_math = len(re.findall(VALID_MATH_RE, text)) % 2 == 0 + + if not valid_math: + return list(text) + + math_mode = False + tokens = [] + prev_c = None + current_token: str = "" + for _i, c in enumerate(text): + if c == "$" and prev_c != "\\": + if math_mode: + tokens.append("$" + current_token + "$") + math_mode = False + else: + math_mode = True + current_token = "" + elif math_mode: + current_token += c + else: + tokens.append(c) + prev_c = c + return tokens