Skip to content

API: Don't add extra attributes to matplotlib axes #54485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mroeschke opened this issue Aug 10, 2023 · 5 comments
Open

API: Don't add extra attributes to matplotlib axes #54485

mroeschke opened this issue Aug 10, 2023 · 5 comments

Comments

@mroeschke
Copy link
Member

Currently there's a few areas where pandas adds extra attributes to a matplotlib axis

def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):

orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax

Since axes are stateful and there no way to clear these attributes via matplotlib public APIs, these attributes can cause issues when they are reused (discovered by running tests via pytest-randomly)

@rsm-23
Copy link
Contributor

rsm-23 commented Aug 17, 2023

@mroeschke for clarification, do we just stop adding this ?
orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
I am trying to understand what all changes are we looking for. Thanks in advance.

@mroeschke
Copy link
Member Author

Ideally all existing functionality should still work without adding extra attributes to a matplotlib axis, so these attributes still need to be passed along somehow (IMO this is probably a nontrival change)

@jbrockmendel
Copy link
Member

I've been looking at this and am currently skeptical we can get rid of all of this ugly pattern. We have some tests that seem to pretty directly rely on storing state in ax, e.g. from test_from_resampling_area_line_mixed_high_to_low

kind1 = "line"
kind2 = "area"

idxh = date_range("1/1/1999", periods=52, freq="W")
idxl = date_range("1/1/1999", periods=12, freq="ME")
high = DataFrame(
            np.random.default_rng(2).random((len(idxh), 3)),
            index=idxh,
            columns=[0, 1, 2],
        )
low = DataFrame(
            np.random.default_rng(2).random((len(idxl), 3)),
            index=idxl,
            columns=[0, 1, 2],
        )
_, ax = mpl.pyplot.subplots()
high.plot(kind=kind1, stacked=True, ax=ax)
low.plot(kind=kind2, stacked=True, ax=ax)

In the last line here the ax obj is the only thing that can be storing the state. IIUC it is detecting that something is already plotted on ax and resampling low to match the freq of the existing x-axis. Or something. Honestly I know resampling is happening but im still trying to figure out why.

Some ways that come to mind to avoid this:

  1. Just don't support this multiple-call usage
  2. Try to back out the relevant state from whatever state variables matplotlib is using to store the information
  3. Provide some other API to plot low and high in a single call
  4. Make our own object to hold the state, maybe return it from plot (xref API: decide what to return from plotting functions #4020)

@jbrockmendel
Copy link
Member

Disabling the place where we pin "right_ax" and "left_ax" breaks 30 tests. Of those 16 has the test itself directly trying to access one of these attributes. it tentatively looks like many of these have the pattern where we call multiple .plot calls with a re-used ax.

@williambdean
Copy link

I believe that this example is related:

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

n_dates = 52 * 3
dates = pd.date_range("2022-01-01", periods=n_dates, freq="W-MON")

seed = sum(map(ord, "Order matters"))
rng = np.random.default_rng(seed)
data = rng.normal(size=n_dates).cumsum()

ser = pd.Series(data, index=dates)
padding = 15

fig, axes = plt.subplots(nrows=2, ncols=2, sharex=False, sharey=True)

def plot_time_series(pandas: bool, ax: plt.Axes): 
    if pandas: 
        ser.plot(ax=ax)
    else: 
        ax.plot(dates, data)

def plot_fill_between(ax: plt.Axes): 
    ax.fill_between(dates, data - padding, data + padding, alpha=0.25)

ax = axes[0, 0]
plot_time_series(pandas=True, ax=ax)
plot_fill_between(ax)
ax.set(title="time-series first", ylabel="pandas.Series.plot")

ax = axes[0, 1]
plot_fill_between(ax)
plot_time_series(pandas=True, ax=ax)
ax.set(title="time-series second")

ax = axes[1, 0]
plot_time_series(pandas=False, ax=ax)
plot_fill_between(ax)
ax.set(title="", ylabel="plt.plot")

ax = axes[1, 1]
plot_fill_between(ax)
plot_time_series(pandas=False, ax=ax)
ax.set(title="")

plt.show()

order-matters

Coming from here:

# clear current axes and data
# TODO #54485
ax._plot_data = [] # type: ignore[attr-defined]
ax.clear()

and brought from here: matplotlib/matplotlib#28505

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants