diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8985395f77b..9ddd0bd04c8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Added
- Extra flags were added to the `gapminder` and `stocks` dataset to facilitate testing, documentation and demos [#3305](https://github.com/plotly/plotly.py/issues/3305)
- All line-like Plotly Express functions now accept `markers` argument to display markers, and all but `line_mapbox` accept `symbol` to map a field to the symbol attribute, similar to scatter-like functions [#3326](https://github.com/plotly/plotly.py/issues/3326)
+ - `px.scatter` and `px.density_contours` now support new `trendline` types `'rolling'`, `'expanding'` and `'ewm'` [#2997](https://github.com/plotly/plotly.py/pull/2997)
+ - `px.scatter` and `px.density_contours` now support new `trendline_options` argument to parameterize trendlines, with support for constant control and log-scaling in `'ols'` and specification of the fraction used for `'lowess'`, as well as pass-through to Pandas for `'rolling'`, `'expanding'` and `'ewm'` [#2997](https://github.com/plotly/plotly.py/pull/2997)
+ - `px.scatter` and `px.density_contours` now support new `trendline_scope` argument that accepts the value `'overall'` to request a single trendline for all traces, including across facets and animation frames [#2997](https://github.com/plotly/plotly.py/pull/2997)
### Fixed
- Fixed regression introduced in version 5.0.0 where pandas/numpy arrays with `dtype` of Object were being converted to `list` values when added to a Figure ([#3292](https://github.com/plotly/plotly.py/issues/3292), [#3293](https://github.com/plotly/plotly.py/pull/3293))
diff --git a/doc/apidoc/plotly.express.rst b/doc/apidoc/plotly.express.rst
index cd252158cb4..bff238d6684 100644
--- a/doc/apidoc/plotly.express.rst
+++ b/doc/apidoc/plotly.express.rst
@@ -49,6 +49,8 @@ plotly's high-level API for rapid figure generation. ::
density_heatmap
density_mapbox
imshow
+ set_mapbox_access_token
+ get_trendline_results
`plotly.express` subpackages
@@ -60,3 +62,4 @@ plotly's high-level API for rapid figure generation. ::
generated/plotly.express.data.rst
generated/plotly.express.colors.rst
+ generated/plotly.express.trendline_functions.rst
diff --git a/doc/python/linear-fits.md b/doc/python/linear-fits.md
index 7f1f1a2971f..0029be6c4be 100644
--- a/doc/python/linear-fits.md
+++ b/doc/python/linear-fits.md
@@ -5,8 +5,8 @@ jupyter:
text_representation:
extension: .md
format_name: markdown
- format_version: '1.1'
- jupytext_version: 1.1.1
+ format_version: '1.2'
+ jupytext_version: 1.4.2
kernelspec:
display_name: Python 3
language: python
@@ -20,11 +20,12 @@ jupyter:
name: python
nbconvert_exporter: python
pygments_lexer: ipython3
- version: 3.6.8
+ version: 3.7.7
plotly:
description: Add linear Ordinary Least Squares (OLS) regression trendlines or
non-linear Locally Weighted Scatterplot Smoothing (LOWESS) trendlines to scatterplots
- in Python.
+ in Python. Options for moving averages (rolling means) as well as exponentially-weighted
+ and expanding functions.
display_as: statistical
language: python
layout: base
@@ -39,7 +40,7 @@ jupyter:
[Plotly Express](/python/plotly-express/) is the easy-to-use, high-level interface to Plotly, which [operates on a variety of types of data](/python/px-arguments/) and produces [easy-to-style figures](/python/styling-plotly-express/).
-Plotly Express allows you to add [Ordinary Least](https://en.wikipedia.org/wiki/Ordinary_least_squares) Squares regression trendline to scatterplots with the `trendline` argument. In order to do so, you will need to install `statsmodels` and its dependencies. Hovering over the trendline will show the equation of the line and its R-squared value.
+Plotly Express allows you to add [Ordinary Least Squares](https://en.wikipedia.org/wiki/Ordinary_least_squares) regression trendline to scatterplots with the `trendline` argument. In order to do so, you will need to [install `statsmodels` and its dependencies](https://www.statsmodels.org/stable/install.html). Hovering over the trendline will show the equation of the line and its R-squared value.
```python
import plotly.express as px
@@ -66,14 +67,160 @@ print(results)
results.query("sex == 'Male' and smoker == 'Yes'").px_fit_results.iloc[0].summary()
```
-### Non-Linear Trendlines
+### Displaying a single trendline with multiple traces
-Plotly Express also supports non-linear [LOWESS](https://en.wikipedia.org/wiki/Local_regression) trendlines.
+_new in v5.2_
+
+To display a single trendline using the entire dataset, set the `trendline_scope` argument to `"overall"`. The same trendline will be overlaid on all facets and animation frames. The trendline color can be overridden with `trendline_color_override`.
+
+```python
+import plotly.express as px
+
+df = px.data.tips()
+fig = px.scatter(df, x="total_bill", y="tip", symbol="smoker", color="sex", trendline="ols", trendline_scope="overall")
+fig.show()
+```
+
+```python
+import plotly.express as px
+
+df = px.data.tips()
+fig = px.scatter(df, x="total_bill", y="tip", facet_col="smoker", color="sex",
+ trendline="ols", trendline_scope="overall", trendline_color_override="black")
+fig.show()
+```
+
+### OLS Parameters
+
+_new in v5.2_
+
+OLS trendlines can be fit with log transformations to both X or Y data using the `trendline_options` argument, independently of whether or not the plot has [logarithmic axes](https://plotly.com/python/log-plot/).
+
+```python
+import plotly.express as px
+
+df = px.data.gapminder(year=2007)
+fig = px.scatter(df, x="gdpPercap", y="lifeExp",
+ trendline="ols", trendline_options=dict(log_x=True),
+ title="Log-transformed fit on linear axes")
+fig.show()
+```
+
+```python
+import plotly.express as px
+
+df = px.data.gapminder(year=2007)
+fig = px.scatter(df, x="gdpPercap", y="lifeExp", log_x=True,
+ trendline="ols", trendline_options=dict(log_x=True),
+ title="Log-scaled X axis and log-transformed fit")
+fig.show()
+```
+
+### Locally WEighted Scatterplot Smoothing (LOWESS)
+
+Plotly Express also supports non-linear [LOWESS](https://en.wikipedia.org/wiki/Local_regression) trendlines. In order use this feature, you will need to [install `statsmodels` and its dependencies](https://www.statsmodels.org/stable/install.html).
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="lowess")
+fig.show()
+```
+
+_new in v5.2_
+
+The level of smoothing can be controlled via the `frac` trendline option, which indicates the fraction of the data that the LOWESS smoother should include. The default is a fairly smooth line with `frac=0.6666` and lowering this fraction will give a line that more closely follows the data.
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="lowess", trendline_options=dict(frac=0.1))
+fig.show()
+```
+
+### Moving Averages
+
+_new in v5.2_
+
+Plotly Express can leverage Pandas' [`rolling`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rolling.html), [`ewm`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.ewm.html) and [`expanding`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.expanding.html) functions in trendlines as well, for example to display moving averages. Values passed to `trendline_options` are passed directly to the underlying Pandas function (with the exception of the `function` and `function_options` keys, see below).
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", trendline_options=dict(window=5),
+ title="5-point moving average")
+fig.show()
+```
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="ewm", trendline_options=dict(halflife=2),
+ title="Exponentially-weighted moving average (halflife of 2 points)")
+fig.show()
+```
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="expanding", title="Expanding mean")
+fig.show()
+```
+
+### Other Functions
+
+The `rolling`, `expanding` and `ewm` trendlines support other functions than the default `mean`, enabling, for example, a moving-median trendline, or an expanding-max trendline.
```python
import plotly.express as px
-df = px.data.gapminder().query("year == 2007")
-fig = px.scatter(df, x="gdpPercap", y="lifeExp", color="continent", trendline="lowess")
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", trendline_options=dict(function="median", window=5),
+ title="Rolling Median")
fig.show()
```
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="expanding", trendline_options=dict(function="max"),
+ title="Expanding Maximum")
+fig.show()
+```
+
+In some cases, it is necessary to pass options into the underying Pandas function, for example the `std` parameter must be provided if the `win_type` argument to `rolling` is `"gaussian"`. This is possible with the `function_args` trendline option.
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(datetimes=True)
+fig = px.scatter(df, x="date", y="GOOG", trendline="rolling",
+ trendline_options=dict(window=5, win_type="gaussian", function_args=dict(std=2)),
+ title="Rolling Mean with Gaussian Window")
+fig.show()
+```
+
+### Displaying only the trendlines
+
+In some cases, it may be desirable to show only the trendlines, by removing the scatter points.
+
+```python
+import plotly.express as px
+
+df = px.data.stocks(indexed=True, datetimes=True)
+fig = px.scatter(df, trendline="rolling", trendline_options=dict(window=5),
+ title="5-point moving average")
+fig.data = [t for t in fig.data if t.mode == "lines"]
+fig.update_traces(showlegend=True) #trendlines have showlegend=False by default
+fig.show()
+```
+
+```python
+
+```
diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py
index 140a0fbe814..8bc5da53910 100644
--- a/packages/python/plotly/plotly/express/__init__.py
+++ b/packages/python/plotly/plotly/express/__init__.py
@@ -60,7 +60,7 @@
from ._special_inputs import IdentityMap, Constant, Range # noqa: F401
-from . import data, colors # noqa: F401
+from . import data, colors, trendline_functions # noqa: F401
__all__ = [
"scatter",
@@ -100,6 +100,7 @@
"imshow",
"data",
"colors",
+ "trendline_functions",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py
index 6cfb6a90367..f335e78de34 100644
--- a/packages/python/plotly/plotly/express/_chart_types.py
+++ b/packages/python/plotly/plotly/express/_chart_types.py
@@ -46,7 +46,9 @@ def scatter(
marginal_x=None,
marginal_y=None,
trendline=None,
+ trendline_options=None,
trendline_color_override=None,
+ trendline_scope="trace",
log_x=False,
log_y=False,
range_x=None,
@@ -90,7 +92,9 @@ def density_contour(
marginal_x=None,
marginal_y=None,
trendline=None,
+ trendline_options=None,
trendline_color_override=None,
+ trendline_scope="trace",
log_x=False,
log_y=False,
range_x=None,
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index f8e391053b9..cc0e98375b2 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -2,6 +2,7 @@
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant, Range
+from .trendline_functions import ols, lowess, rolling, expanding, ewm
from _plotly_utils.basevalidators import ColorscaleValidator
from plotly.colors import qualitative, sequential
@@ -16,6 +17,9 @@
)
NO_COLOR = "px_no_color_constant"
+trendline_functions = dict(
+ lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols
+)
# Declare all supported attributes, across all plot types
direct_attrables = (
@@ -313,12 +317,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
mapping_labels["count"] = "%{x}"
elif attr_name == "trendline":
if (
- attr_value in ["ols", "lowess"]
- and args["x"]
+ args["x"]
and args["y"]
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
):
- import statsmodels.api as sm
# sorting is bad but trace_specs with "trendline" have no other attrs
sorted_trace_data = trace_data.sort_values(by=args["x"])
@@ -345,37 +347,27 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
)
# preserve original values of "x" in case they're dates
- trace_patch["x"] = sorted_trace_data[args["x"]][
- np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
- ]
-
- if attr_value == "lowess":
- # missing ='drop' is the default value for lowess but not for OLS (None)
- # we force it here in case statsmodels change their defaults
- trendline = sm.nonparametric.lowess(y, x, missing="drop")
- trace_patch["y"] = trendline[:, 1]
- hover_header = "LOWESS trendline
"
- elif attr_value == "ols":
- fit_results = sm.OLS(
- y, sm.add_constant(x), missing="drop"
- ).fit()
- trace_patch["y"] = fit_results.predict()
- hover_header = "OLS trendline
"
- if len(fit_results.params) == 2:
- hover_header += "%s = %g * %s + %g
" % (
- args["y"],
- fit_results.params[1],
- args["x"],
- fit_results.params[0],
- )
- else:
- hover_header += "%s = %g
" % (
- args["y"],
- fit_results.params[0],
- )
- hover_header += (
- "R2=%f
" % fit_results.rsquared
- )
+ # otherwise numpy/pandas can mess with the timezones
+ # NB this means trendline functions must output one-to-one with the input series
+ # i.e. we can't do resampling, because then the X values might not line up!
+ non_missing = np.logical_not(
+ np.logical_or(np.isnan(y), np.isnan(x))
+ )
+ trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
+ trendline_function = trendline_functions[attr_value]
+ y_out, hover_header, fit_results = trendline_function(
+ args["trendline_options"],
+ sorted_trace_data[args["x"]],
+ x,
+ y,
+ args["x"],
+ args["y"],
+ non_missing,
+ )
+ assert len(y_out) == len(
+ trace_patch["x"]
+ ), "missing-data-handling failure in trendline code"
+ trace_patch["y"] = y_out
mapping_labels[get_label(args, args["x"])] = "%{x}"
mapping_labels[get_label(args, args["y"])] = "%{y} (trend)"
elif attr_name.startswith("error"):
@@ -878,21 +870,25 @@ def make_trace_spec(args, constructor, attrs, trace_patch):
result.append(trace_spec)
# Add trendline trace specifications
- if "trendline" in args and args["trendline"]:
- trace_spec = TraceSpec(
- constructor=go.Scattergl if constructor == go.Scattergl else go.Scatter,
- attrs=["trendline"],
- trace_patch=dict(mode="lines"),
- marginal=None,
- )
- if args["trendline_color_override"]:
- trace_spec.trace_patch["line"] = dict(
- color=args["trendline_color_override"]
- )
- result.append(trace_spec)
+ if args.get("trendline") and args.get("trendline_scope", "trace") == "trace":
+ result.append(make_trendline_spec(args, constructor))
return result
+def make_trendline_spec(args, constructor):
+ trace_spec = TraceSpec(
+ constructor=go.Scattergl
+ if constructor == go.Scattergl # could be contour
+ else go.Scatter,
+ attrs=["trendline"],
+ trace_patch=dict(mode="lines"),
+ marginal=None,
+ )
+ if args["trendline_color_override"]:
+ trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"])
+ return trace_spec
+
+
def one_group(x):
return ""
@@ -1827,6 +1823,16 @@ def infer_config(args, constructor, trace_patch, layout_patch):
):
args["facet_col_wrap"] = 0
+ if "trendline" in args and args["trendline"] is not None:
+ if args["trendline"] not in trendline_functions:
+ raise ValueError(
+ "Value '%s' for `trendline` must be one of %s"
+ % (args["trendline"], trendline_functions.keys())
+ )
+
+ if "trendline_options" in args and args["trendline_options"] is None:
+ args["trendline_options"] = dict()
+
# Compute applicable grouping attributes
for k in group_attrables:
if k in args:
@@ -2126,6 +2132,27 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
fig.update_layout(template=args["template"], overwrite=True)
fig.frames = frame_list if len(frames) > 1 else []
+ if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":
+ trendline_spec = make_trendline_spec(args, constructor)
+ trendline_trace = trendline_spec.constructor(
+ name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False
+ )
+ if "line" not in trendline_spec.trace_patch: # no color override
+ for m in grouped_mappings:
+ if m.variable == "color":
+ next_color = m.sequence[len(m.val_map) % len(m.sequence)]
+ trendline_spec.trace_patch["line"] = dict(color=next_color)
+ patch, fit_results = make_trace_kwargs(
+ args, trendline_spec, args["data_frame"], {}, sizeref
+ )
+ trendline_trace.update(patch)
+ fig.add_trace(
+ trendline_trace, row="all", col="all", exclude_empty_subplots=True
+ )
+ fig.update_traces(selector=-1, showlegend=True)
+ if fit_results is not None:
+ trendline_rows.append(dict(px_fit_results=fit_results))
+
fig._px_trendlines = pd.DataFrame(trendline_rows)
configure_axes(args, constructor, fig, orders)
diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py
index f2f2ab0544d..65d9f0588ff 100644
--- a/packages/python/plotly/plotly/express/_doc.py
+++ b/packages/python/plotly/plotly/express/_doc.py
@@ -402,14 +402,28 @@
],
trendline=[
"str",
- "One of `'ols'` or `'lowess'`.",
+ "One of `'ols'`, `'lowess'`, `'rolling'`, `'expanding'` or `'ewm'`.",
"If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.",
"If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.",
+ "If `'rolling`', a Rolling (e.g. rolling average, rolling median) line will be drawn for each discrete-color/symbol group.",
+ "If `'expanding`', an Expanding (e.g. expanding average, expanding sum) line will be drawn for each discrete-color/symbol group.",
+ "If `'ewm`', an Exponentially Weighted Moment (e.g. exponentially-weighted moving average) line will be drawn for each discrete-color/symbol group.",
+ "See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how",
+ "to configure them with the `trendline_options` argument.",
+ ],
+ trendline_options=[
+ "dict",
+ "Options passed as the first argument to the function from `plotly.express.trendline_functions` ",
+ "named in the `trendline` argument.",
],
trendline_color_override=[
"str",
"Valid CSS color.",
- "If provided, and if `trendline` is set, all trendlines will be drawn in this color.",
+ "If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.",
+ ],
+ trendline_scope=[
+ "str (one of `'trace'` or `'overall'`, default `'trace'`)",
+ "If `'trace'`, then one trendline is drawn per trace (i.e. per color, symbol, facet, animation frame etc) and if `'overall'` then one trendline is computed for the entire dataset, and replicated across all facets.",
],
render_mode=[
"str",
diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py
new file mode 100644
index 00000000000..f0fc29cee4b
--- /dev/null
+++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py
@@ -0,0 +1,155 @@
+"""
+The `trendline_functions` module contains functions which are called by Plotly Express
+when the `trendline` argument is used. Valid values for `trendline` are the names of the
+functions in this module, and the value of the `trendline_options` argument to PX
+functions is passed in as the first argument to these functions when called.
+
+Note that the functions in this module are not meant to be called directly, and are
+exposed as part of the public API for documentation purposes.
+"""
+
+import pandas as pd
+import numpy as np
+
+
+def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Ordinary Least Squares (OLS) trendline function
+
+ Requires `statsmodels` to be installed.
+
+ This trendline function causes fit results to be stored within the figure,
+ accessible via the `plotly.express.get_trendline_results` function. The fit results
+ are the output of the `statsmodels.api.OLS` function.
+
+ Valid keys for the `trendline_options` dict are:
+
+ - `add_constant` (`bool`, default `True`): if `False`, the trendline passes through
+ the origin but if `True` a y-intercept is fitted.
+
+ - `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with
+ respect to the base 10 logarithm of the input. Note that this means no zeros can
+ be present in the input.
+ """
+ valid_options = ["add_constant", "log_x", "log_y"]
+ for k in trendline_options.keys():
+ if k not in valid_options:
+ raise ValueError(
+ "OLS trendline_options keys must be one of [%s] but got '%s'"
+ % (", ".join(valid_options), k)
+ )
+
+ import statsmodels.api as sm
+
+ add_constant = trendline_options.get("add_constant", True)
+ log_x = trendline_options.get("log_x", False)
+ log_y = trendline_options.get("log_y", False)
+
+ if log_y:
+ if np.any(y <= 0):
+ raise ValueError(
+ "Can't do OLS trendline with `log_y=True` when `y` contains non-positive values."
+ )
+ y = np.log10(y)
+ y_label = "log10(%s)" % y_label
+ if log_x:
+ if np.any(x <= 0):
+ raise ValueError(
+ "Can't do OLS trendline with `log_x=True` when `x` contains non-positive values."
+ )
+ x = np.log10(x)
+ x_label = "log10(%s)" % x_label
+ if add_constant:
+ x = sm.add_constant(x)
+ fit_results = sm.OLS(y, x, missing="drop").fit()
+ y_out = fit_results.predict()
+ if log_y:
+ y_out = np.power(10, y_out)
+ hover_header = "OLS trendline
"
+ if len(fit_results.params) == 2:
+ hover_header += "%s = %g * %s + %g
" % (
+ y_label,
+ fit_results.params[1],
+ x_label,
+ fit_results.params[0],
+ )
+ elif not add_constant:
+ hover_header += "%s = %g * %s
" % (y_label, fit_results.params[0], x_label)
+ else:
+ hover_header += "%s = %g
" % (y_label, fit_results.params[0])
+ hover_header += "R2=%f
" % fit_results.rsquared
+ return y_out, hover_header, fit_results
+
+
+def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function
+
+ Requires `statsmodels` to be installed.
+
+ Valid keys for the `trendline_options` dict are:
+
+ - `frac` (`float`, default `0.6666666`): the `frac` parameter from the
+ `statsmodels.api.nonparametric.lowess` function
+ """
+
+ valid_options = ["frac"]
+ for k in trendline_options.keys():
+ if k not in valid_options:
+ raise ValueError(
+ "LOWESS trendline_options keys must be one of [%s] but got '%s'"
+ % (", ".join(valid_options), k)
+ )
+
+ import statsmodels.api as sm
+
+ frac = trendline_options.get("frac", 0.6666666)
+ y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
+ hover_header = "LOWESS trendline
"
+ return y_out, hover_header, None
+
+
+def _pandas(mode, trendline_options, x_raw, y, non_missing):
+ modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
+ trendline_options = trendline_options.copy()
+ function_name = trendline_options.pop("function", "mean")
+ function_args = trendline_options.pop("function_args", dict())
+ series = pd.Series(y, index=x_raw)
+ agg = getattr(series, mode) # e.g. series.rolling
+ agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
+ function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
+ y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts)
+ y_out = y_out[non_missing]
+ hover_header = "%s %s trendline
" % (modes[mode], function_name)
+ return y_out, hover_header, None
+
+
+def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Rolling trendline function
+
+ The value of the `function` key of the `trendline_options` dict is the function to
+ use (defaults to `mean`) and the value of the `function_args` key are taken to be
+ its arguments as a dict. The remainder of the `trendline_options` dict is passed as
+ keyword arguments into the `pandas.Series.rolling` function.
+ """
+ return _pandas("rolling", trendline_options, x_raw, y, non_missing)
+
+
+def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Expanding trendline function
+
+ The value of the `function` key of the `trendline_options` dict is the function to
+ use (defaults to `mean`) and the value of the `function_args` key are taken to be
+ its arguments as a dict. The remainder of the `trendline_options` dict is passed as
+ keyword arguments into the `pandas.Series.expanding` function.
+ """
+ return _pandas("expanding", trendline_options, x_raw, y, non_missing)
+
+
+def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Exponentially Weighted Moment (EWM) trendline function
+
+ The value of the `function` key of the `trendline_options` dict is the function to
+ use (defaults to `mean`) and the value of the `function_args` key are taken to be
+ its arguments as a dict. The remainder of the `trendline_options` dict is passed as
+ keyword arguments into the `pandas.Series.ewm` function.
+ """
+ return _pandas("ewm", trendline_options, x_raw, y, non_missing)
diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py
index 41064bd19df..66046981eff 100644
--- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py
+++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py
@@ -5,10 +5,27 @@
from datetime import datetime
-@pytest.mark.parametrize("mode", ["ols", "lowess"])
-def test_trendline_results_passthrough(mode):
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_results_passthrough(mode, options):
df = px.data.gapminder().query("continent == 'Oceania'")
- fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
+ fig = px.scatter(
+ df,
+ x="year",
+ y="pop",
+ color="country",
+ trendline=mode,
+ trendline_options=options,
+ )
assert len(fig.data) == 4
for trace in fig["data"][0::2]:
assert "trendline" not in trace.hovertemplate
@@ -20,93 +37,205 @@ def test_trendline_results_passthrough(mode):
if mode == "ols":
assert len(results) == 2
assert results["country"].values[0] == "Australia"
- assert results["country"].values[0] == "Australia"
au_result = results["px_fit_results"].values[0]
assert len(au_result.params) == 2
else:
assert len(results) == 0
-@pytest.mark.parametrize("mode", ["ols", "lowess"])
-def test_trendline_enough_values(mode):
- fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_enough_values(mode, options):
+ fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode, trendline_options=options)
assert len(fig.data) == 2
assert len(fig.data[1].x) == 2
- fig = px.scatter(x=[0], y=[0], trendline=mode)
+ fig = px.scatter(x=[0], y=[0], trendline=mode, trendline_options=options)
assert len(fig.data) == 2
assert fig.data[1].x is None
- fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
+ fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode, trendline_options=options)
assert len(fig.data) == 2
assert fig.data[1].x is None
- fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode)
+ fig = px.scatter(
+ x=[0, 1], y=np.array([0, np.nan]), trendline=mode, trendline_options=options
+ )
assert len(fig.data) == 2
assert fig.data[1].x is None
- fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
+ fig = px.scatter(
+ x=[0, 1, None], y=[0, None, 1], trendline=mode, trendline_options=options
+ )
assert len(fig.data) == 2
assert fig.data[1].x is None
fig = px.scatter(
- x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode
+ x=np.array([0, 1, np.nan]),
+ y=np.array([0, np.nan, 1]),
+ trendline=mode,
+ trendline_options=options,
)
assert len(fig.data) == 2
assert fig.data[1].x is None
- fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
+ fig = px.scatter(
+ x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode, trendline_options=options
+ )
assert len(fig.data) == 2
assert len(fig.data[1].x) == 2
fig = px.scatter(
- x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode
+ x=np.array([0, 1, np.nan, 2]),
+ y=np.array([1, np.nan, 1, 2]),
+ trendline=mode,
+ trendline_options=options,
)
assert len(fig.data) == 2
assert len(fig.data[1].x) == 2
-@pytest.mark.parametrize("mode", ["ols", "lowess"])
-def test_trendline_nan_values(mode):
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("ols", dict(add_constant=False, log_x=True, log_y=True)),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_nan_values(mode, options):
df = px.data.gapminder().query("continent == 'Oceania'")
start_date = 1970
df["pop"][df["year"] < start_date] = np.nan
- fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
+ fig = px.scatter(
+ df,
+ x="year",
+ y="pop",
+ color="country",
+ trendline=mode,
+ trendline_options=options,
+ )
for trendline in fig["data"][1::2]:
assert trendline.x[0] >= start_date
assert len(trendline.x) == len(trendline.y)
-def test_no_slope_ols_trendline():
+def test_ols_trendline_slopes():
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
- assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
+ # should be "y = 1 * x + 0" but sometimes is some tiny number instead
+ assert "y = 1 * x + " in fig.data[1].hovertemplate
results = px.get_trendline_results(fig)
params = results["px_fit_results"].iloc[0].params
assert np.all(np.isclose(params, [0, 1]))
+ fig = px.scatter(x=[0, 1], y=[1, 2], trendline="ols")
+ assert "y = 1 * x + 1
" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [1, 1]))
+
+ fig = px.scatter(
+ x=[0, 1], y=[1, 2], trendline="ols", trendline_options=dict(add_constant=False)
+ )
+ assert "y = 2 * x
" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [2]))
+
+ fig = px.scatter(
+ x=[1, 1], y=[0, 0], trendline="ols", trendline_options=dict(add_constant=False)
+ )
+ assert "y = 0 * x
" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [0]))
+
fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
- assert "y = 0" in fig.data[1].hovertemplate
+ assert "y = 0
" in fig.data[1].hovertemplate
results = px.get_trendline_results(fig)
params = results["px_fit_results"].iloc[0].params
assert np.all(np.isclose(params, [0]))
fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
- assert "y = 0" in fig.data[1].hovertemplate
+ assert "y = 0 * x + 0
" in fig.data[1].hovertemplate
fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
- assert "y = 0 * x + 1" in fig.data[1].hovertemplate
+ assert "y = 0 * x + 1
" in fig.data[1].hovertemplate
fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
- assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate
+ assert "y = 0 * x + 1.5
" in fig.data[1].hovertemplate
-@pytest.mark.parametrize("mode", ["ols", "lowess"])
-def test_trendline_on_timeseries(mode):
+@pytest.mark.parametrize(
+ "mode,options",
+ [
+ ("ols", None),
+ ("lowess", None),
+ ("lowess", dict(frac=0.3)),
+ ("rolling", dict(window=2)),
+ ("rolling", dict(window="10d")),
+ ("expanding", None),
+ ("ewm", dict(alpha=0.5)),
+ ],
+)
+def test_trendline_on_timeseries(mode, options):
df = px.data.stocks()
with pytest.raises(ValueError) as err_msg:
- px.scatter(df, x="date", y="GOOG", trendline=mode)
+ px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options)
assert "Could not convert value of 'x' ('date') into a numeric type." in str(
err_msg.value
)
df["date"] = pd.to_datetime(df["date"])
df["date"] = df["date"].dt.tz_localize("CET") # force a timezone
- fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
+ fig = px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options)
assert len(fig.data) == 2
assert len(fig.data[0].x) == len(fig.data[1].x)
assert type(fig.data[0].x[0]) == datetime
assert type(fig.data[1].x[0]) == datetime
assert np.all(fig.data[0].x == fig.data[1].x)
assert str(fig.data[0].x[0]) == str(fig.data[1].x[0])
+
+
+def test_overall_trendline():
+ df = px.data.tips()
+ fig1 = px.scatter(df, x="total_bill", y="tip", trendline="ols")
+ assert len(fig1.data) == 2
+ assert "trendline" in fig1.data[1].hovertemplate
+ results1 = px.get_trendline_results(fig1)
+ params1 = results1["px_fit_results"].iloc[0].params
+
+ fig2 = px.scatter(
+ df,
+ x="total_bill",
+ y="tip",
+ color="sex",
+ trendline="ols",
+ trendline_scope="overall",
+ )
+ assert len(fig2.data) == 3
+ assert "trendline" in fig2.data[2].hovertemplate
+ results2 = px.get_trendline_results(fig2)
+ params2 = results2["px_fit_results"].iloc[0].params
+
+ assert np.all(np.array_equal(params1, params2))
+
+ fig3 = px.scatter(
+ df,
+ x="total_bill",
+ y="tip",
+ facet_row="sex",
+ trendline="ols",
+ trendline_scope="overall",
+ )
+ assert len(fig3.data) == 4
+ assert "trendline" in fig3.data[3].hovertemplate
+ results3 = px.get_trendline_results(fig3)
+ params3 = results3["px_fit_results"].iloc[0].params
+
+ assert np.all(np.array_equal(params1, params3))