From 44d437d73791756f5adc6c30c3c6b845167ed0e6 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 26 Oct 2020 15:16:58 -0400 Subject: [PATCH 01/24] ma and ewm trendlines --- .../python/plotly/plotly/express/_core.py | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f8e391053b9..6cf453ef430 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -313,7 +313,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): mapping_labels["count"] = "%{x}" elif attr_name == "trendline": if ( - attr_value in ["ols", "lowess"] + attr_value[0] in ["ols", "lowess", "ma", "ewm"] and args["x"] and args["y"] and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 @@ -345,19 +345,36 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): ) # preserve original values of "x" in case they're dates - trace_patch["x"] = sorted_trace_data[args["x"]][ - np.logical_not(np.logical_or(np.isnan(y), np.isnan(x))) - ] + non_missing = np.logical_not( + np.logical_or(np.isnan(y), np.isnan(x)) + ) + trace_patch["x"] = sorted_trace_data[args["x"]][non_missing] - if attr_value == "lowess": + if attr_value[0] == "lowess": + alpha = attr_value[1] or 0.6666666 # missing ='drop' is the default value for lowess but not for OLS (None) # we force it here in case statsmodels change their defaults - trendline = sm.nonparametric.lowess(y, x, missing="drop") + trendline = sm.nonparametric.lowess( + y, x, missing="drop", frac=alpha + ) trace_patch["y"] = trendline[:, 1] hover_header = "LOWESS trendline

" - elif attr_value == "ols": + elif attr_value[0] == "ma": + trace_patch["y"] = ( + pd.Series(y[non_missing]) + .rolling(window=attr_value[1] or 3) + .mean() + ) + elif attr_value[0] == "ewm": + trace_patch["y"] = ( + pd.Series(y[non_missing]) + .ewm(alpha=attr_value[1] or 0.5) + .mean() + ) + elif attr_value[0] == "ols": + add_constant = attr_value[1] is not False fit_results = sm.OLS( - y, sm.add_constant(x), missing="drop" + y, sm.add_constant(x) if add_constant else x, missing="drop" ).fit() trace_patch["y"] = fit_results.predict() hover_header = "OLS trendline
" @@ -368,6 +385,12 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): args["x"], fit_results.params[0], ) + elif not add_constant: + hover_header += "%s = %g* %s
" % ( + args["y"], + fit_results.params[0], + args["x"], + ) else: hover_header += "%s = %g
" % ( args["y"], @@ -1827,6 +1850,10 @@ def infer_config(args, constructor, trace_patch, layout_patch): ): args["facet_col_wrap"] = 0 + if args.get("trendline", None) is not None: + if isinstance(args["trendline"], str): + args["trendline"] = (args["trendline"], None) + # Compute applicable grouping attributes for k in group_attrables: if k in args: From 10199859135384a50b4c62988b8cd0630d45162f Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Fri, 25 Dec 2020 10:43:13 -0500 Subject: [PATCH 02/24] extract trendline function API --- .../plotly/plotly/express/_chart_types.py | 2 + .../python/plotly/plotly/express/_core.py | 121 ++++++++++-------- packages/python/plotly/plotly/express/_doc.py | 4 + 3 files changed, 72 insertions(+), 55 deletions(-) diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 6cfb6a90367..5051c8a5367 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -46,6 +46,7 @@ def scatter( marginal_x=None, marginal_y=None, trendline=None, + trendline_options=None, trendline_color_override=None, log_x=False, log_y=False, @@ -90,6 +91,7 @@ def density_contour( marginal_x=None, marginal_y=None, trendline=None, + trendline_options=None, trendline_color_override=None, log_x=False, log_y=False, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 6cf453ef430..8567afb909e 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -239,6 +239,56 @@ def make_mapping(args, variable): ) +def lowess(options, x, y, x_label, y_label, non_missing): + import statsmodels.api as sm + + frac = options.get("frac", 0.6666666) + # missing ='drop' is the default value for lowess but not for OLS (None) + # we force it here in case statsmodels change their defaults + y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] + hover_header = "LOWESS trendline

" + return y_out, hover_header, None + + +def ma(options, x, y, x_label, y_label, non_missing): + y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] + hover_header = "Moving Average trendline

" + return y_out, hover_header, None + + +def ewm(options, x, y, x_label, y_label, non_missing): + y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] + hover_header = "EWM trendline

" + return y_out, hover_header, None + + +def ols(options, x, y, x_label, y_label, non_missing): + import statsmodels.api as sm + + add_constant = options.get("add_constant", True) + fit_results = sm.OLS( + y, sm.add_constant(x) if add_constant else x, missing="drop" + ).fit() + y_out = fit_results.predict() + hover_header = "OLS trendline
" + if len(fit_results.params) == 2: + hover_header += "%s = %g * %s + %g
" % ( + y_label, + fit_results.params[1], + x_label, + fit_results.params[0], + ) + elif not add_constant: + hover_header += "%s = %g* %s
" % (y_label, fit_results.params[0], x_label,) + else: + hover_header += "%s = %g
" % (y_label, fit_results.params[0],) + hover_header += "R2=%f

" % fit_results.rsquared + return y_out, hover_header, fit_results + + +trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) + + def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): """Populates a dict with arguments to update trace @@ -313,12 +363,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): mapping_labels["count"] = "%{x}" elif attr_name == "trendline": if ( - attr_value[0] in ["ols", "lowess", "ma", "ewm"] + attr_value in trendline_functions and args["x"] and args["y"] and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 ): - import statsmodels.api as sm # sorting is bad but trace_specs with "trendline" have no other attrs sorted_trace_data = trace_data.sort_values(by=args["x"]) @@ -349,56 +398,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): np.logical_or(np.isnan(y), np.isnan(x)) ) trace_patch["x"] = sorted_trace_data[args["x"]][non_missing] - - if attr_value[0] == "lowess": - alpha = attr_value[1] or 0.6666666 - # missing ='drop' is the default value for lowess but not for OLS (None) - # we force it here in case statsmodels change their defaults - trendline = sm.nonparametric.lowess( - y, x, missing="drop", frac=alpha - ) - trace_patch["y"] = trendline[:, 1] - hover_header = "LOWESS trendline

" - elif attr_value[0] == "ma": - trace_patch["y"] = ( - pd.Series(y[non_missing]) - .rolling(window=attr_value[1] or 3) - .mean() - ) - elif attr_value[0] == "ewm": - trace_patch["y"] = ( - pd.Series(y[non_missing]) - .ewm(alpha=attr_value[1] or 0.5) - .mean() - ) - elif attr_value[0] == "ols": - add_constant = attr_value[1] is not False - fit_results = sm.OLS( - y, sm.add_constant(x) if add_constant else x, missing="drop" - ).fit() - trace_patch["y"] = fit_results.predict() - hover_header = "OLS trendline
" - if len(fit_results.params) == 2: - hover_header += "%s = %g * %s + %g
" % ( - args["y"], - fit_results.params[1], - args["x"], - fit_results.params[0], - ) - elif not add_constant: - hover_header += "%s = %g* %s
" % ( - args["y"], - fit_results.params[0], - args["x"], - ) - else: - hover_header += "%s = %g
" % ( - args["y"], - fit_results.params[0], - ) - hover_header += ( - "R2=%f

" % fit_results.rsquared - ) + trendline_function = trendline_functions[attr_value] + y_out, hover_header, fit_results = trendline_function( + args["trendline_options"], + x, + y, + args["x"], + args["y"], + non_missing, + ) + assert len(y_out) == len( + trace_patch["x"] + ), "missing-data-handling failure in trendline code" + trace_patch["y"] = y_out mapping_labels[get_label(args, args["x"])] = "%{x}" mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" elif attr_name.startswith("error"): @@ -1850,9 +1862,8 @@ def infer_config(args, constructor, trace_patch, layout_patch): ): args["facet_col_wrap"] = 0 - if args.get("trendline", None) is not None: - if isinstance(args["trendline"], str): - args["trendline"] = (args["trendline"], None) + if "trendline_options" in args and args["trendline_options"] is None: + args["trendline_options"] = dict() # Compute applicable grouping attributes for k in group_attrables: diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index f2f2ab0544d..1a218a9edbb 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -406,6 +406,10 @@ "If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.", "If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.", ], + trendline_options=[ + "dict", + "Options passed to the function named in the `trendline` argument.", + ], trendline_color_override=[ "str", "Valid CSS color.", From 560fad72594a101f0b9bc3e07ea21895aa541853 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Fri, 25 Dec 2020 10:43:28 -0500 Subject: [PATCH 03/24] ols log options checkpoint --- packages/python/plotly/plotly/express/_core.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 8567afb909e..d46c0354d4b 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -266,10 +266,19 @@ def ols(options, x, y, x_label, y_label, non_missing): import statsmodels.api as sm add_constant = options.get("add_constant", True) - fit_results = sm.OLS( - y, sm.add_constant(x) if add_constant else x, missing="drop" - ).fit() + log_x = options.get("log_x", False) + log_y = options.get("log_y", False) + + if log_y: + y = np.log(y) + if log_x: + x = np.log(x) + if add_constant: + x = sm.add_constant(x) + fit_results = sm.OLS(y, x, missing="drop").fit() y_out = fit_results.predict() + if log_y: + y_out = np.exp(y_out) hover_header = "OLS trendline
" if len(fit_results.params) == 2: hover_header += "%s = %g * %s + %g
" % ( From 50572d51a176a4cde2dd591654e748cab3e7f5ad Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 29 Dec 2020 13:30:32 -0500 Subject: [PATCH 04/24] move trendline code to own module --- doc/apidoc/plotly.express.rst | 1 + .../python/plotly/plotly/express/__init__.py | 2 +- .../python/plotly/plotly/express/_core.py | 61 +------------------ .../plotly/express/trendline_functions.py | 58 ++++++++++++++++++ 4 files changed, 62 insertions(+), 60 deletions(-) create mode 100644 packages/python/plotly/plotly/express/trendline_functions.py diff --git a/doc/apidoc/plotly.express.rst b/doc/apidoc/plotly.express.rst index cd252158cb4..36550a0a0fb 100644 --- a/doc/apidoc/plotly.express.rst +++ b/doc/apidoc/plotly.express.rst @@ -60,3 +60,4 @@ plotly's high-level API for rapid figure generation. :: generated/plotly.express.data.rst generated/plotly.express.colors.rst + generated/plotly.express.trendline_functions.rst diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index 140a0fbe814..f0c2d4e154c 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -60,7 +60,7 @@ from ._special_inputs import IdentityMap, Constant, Range # noqa: F401 -from . import data, colors # noqa: F401 +from . import data, colors, trendline_functions # noqa: F401 __all__ = [ "scatter", diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d46c0354d4b..d0b00b5437d 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2,6 +2,7 @@ import plotly.io as pio from collections import namedtuple, OrderedDict from ._special_inputs import IdentityMap, Constant, Range +from .trendline_functions import ols, lowess, ma, ewm from _plotly_utils.basevalidators import ColorscaleValidator from plotly.colors import qualitative, sequential @@ -239,65 +240,6 @@ def make_mapping(args, variable): ) -def lowess(options, x, y, x_label, y_label, non_missing): - import statsmodels.api as sm - - frac = options.get("frac", 0.6666666) - # missing ='drop' is the default value for lowess but not for OLS (None) - # we force it here in case statsmodels change their defaults - y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] - hover_header = "LOWESS trendline

" - return y_out, hover_header, None - - -def ma(options, x, y, x_label, y_label, non_missing): - y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] - hover_header = "Moving Average trendline

" - return y_out, hover_header, None - - -def ewm(options, x, y, x_label, y_label, non_missing): - y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] - hover_header = "EWM trendline

" - return y_out, hover_header, None - - -def ols(options, x, y, x_label, y_label, non_missing): - import statsmodels.api as sm - - add_constant = options.get("add_constant", True) - log_x = options.get("log_x", False) - log_y = options.get("log_y", False) - - if log_y: - y = np.log(y) - if log_x: - x = np.log(x) - if add_constant: - x = sm.add_constant(x) - fit_results = sm.OLS(y, x, missing="drop").fit() - y_out = fit_results.predict() - if log_y: - y_out = np.exp(y_out) - hover_header = "OLS trendline
" - if len(fit_results.params) == 2: - hover_header += "%s = %g * %s + %g
" % ( - y_label, - fit_results.params[1], - x_label, - fit_results.params[0], - ) - elif not add_constant: - hover_header += "%s = %g* %s
" % (y_label, fit_results.params[0], x_label,) - else: - hover_header += "%s = %g
" % (y_label, fit_results.params[0],) - hover_header += "R2=%f

" % fit_results.rsquared - return y_out, hover_header, fit_results - - -trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) - - def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): """Populates a dict with arguments to update trace @@ -371,6 +313,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if trace_spec.constructor == go.Histogram: mapping_labels["count"] = "%{x}" elif attr_name == "trendline": + trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) if ( attr_value in trendline_functions and args["x"] diff --git a/packages/python/plotly/plotly/express/trendline_functions.py b/packages/python/plotly/plotly/express/trendline_functions.py new file mode 100644 index 00000000000..611067ce8a8 --- /dev/null +++ b/packages/python/plotly/plotly/express/trendline_functions.py @@ -0,0 +1,58 @@ +import pandas as pd +import numpy as np + + +def ols(options, x, y, x_label, y_label, non_missing): + import statsmodels.api as sm + + add_constant = options.get("add_constant", True) + log_x = options.get("log_x", False) + log_y = options.get("log_y", False) + + if log_y: + y = np.log(y) + y_label = "log(%s)" % y_label + if log_x: + x = np.log(x) + x_label = "log(%s)" % x_label + if add_constant: + x = sm.add_constant(x) + fit_results = sm.OLS(y, x, missing="drop").fit() + y_out = fit_results.predict() + if log_y: + y_out = np.exp(y_out) + hover_header = "OLS trendline
" + if len(fit_results.params) == 2: + hover_header += "%s = %g * %s + %g
" % ( + y_label, + fit_results.params[1], + x_label, + fit_results.params[0], + ) + elif not add_constant: + hover_header += "%s = %g* %s
" % (y_label, fit_results.params[0], x_label,) + else: + hover_header += "%s = %g
" % (y_label, fit_results.params[0],) + hover_header += "R2=%f

" % fit_results.rsquared + return y_out, hover_header, fit_results + + +def lowess(options, x, y, x_label, y_label, non_missing): + import statsmodels.api as sm + + frac = options.get("frac", 0.6666666) + y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] + hover_header = "LOWESS trendline

" + return y_out, hover_header, None + + +def ma(options, x, y, x_label, y_label, non_missing): + y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] + hover_header = "Moving Average trendline

" + return y_out, hover_header, None + + +def ewm(options, x, y, x_label, y_label, non_missing): + y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] + hover_header = "EWM trendline

" + return y_out, hover_header, None From d1001c6d1d3a2a41cf9fbe3c80cc2559b0d3f730 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 29 Dec 2020 19:35:09 -0500 Subject: [PATCH 05/24] get trendline_functions into apidoc --- packages/python/plotly/plotly/express/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index f0c2d4e154c..8bc5da53910 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -100,6 +100,7 @@ "imshow", "data", "colors", + "trendline_functions", "set_mapbox_access_token", "get_trendline_results", "IdentityMap", From 2d3e8b0b9aec836b9c45b2adbfebaeea97933a54 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 29 Dec 2020 21:20:10 -0500 Subject: [PATCH 06/24] tests for new trendlines --- .../python/plotly/plotly/express/_core.py | 5 +- .../plotly/express/trendline_functions.py | 18 +-- .../test_optional/test_px/test_trendline.py | 142 ++++++++++++++---- 3 files changed, 127 insertions(+), 38 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d0b00b5437d..fd395d739ee 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2,7 +2,7 @@ import plotly.io as pio from collections import namedtuple, OrderedDict from ._special_inputs import IdentityMap, Constant, Range -from .trendline_functions import ols, lowess, ma, ewm +from .trendline_functions import ols, lowess, ma, ewma from _plotly_utils.basevalidators import ColorscaleValidator from plotly.colors import qualitative, sequential @@ -313,7 +313,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if trace_spec.constructor == go.Histogram: mapping_labels["count"] = "%{x}" elif attr_name == "trendline": - trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) + trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols) if ( attr_value in trendline_functions and args["x"] @@ -353,6 +353,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): trendline_function = trendline_functions[attr_value] y_out, hover_header, fit_results = trendline_function( args["trendline_options"], + sorted_trace_data[args["x"]], x, y, args["x"], diff --git a/packages/python/plotly/plotly/express/trendline_functions.py b/packages/python/plotly/plotly/express/trendline_functions.py index 611067ce8a8..78d447353d6 100644 --- a/packages/python/plotly/plotly/express/trendline_functions.py +++ b/packages/python/plotly/plotly/express/trendline_functions.py @@ -2,7 +2,7 @@ import numpy as np -def ols(options, x, y, x_label, y_label, non_missing): +def ols(options, x_raw, x, y, x_label, y_label, non_missing): import statsmodels.api as sm add_constant = options.get("add_constant", True) @@ -30,14 +30,14 @@ def ols(options, x, y, x_label, y_label, non_missing): fit_results.params[0], ) elif not add_constant: - hover_header += "%s = %g* %s
" % (y_label, fit_results.params[0], x_label,) + hover_header += "%s = %g * %s
" % (y_label, fit_results.params[0], x_label,) else: hover_header += "%s = %g
" % (y_label, fit_results.params[0],) hover_header += "R2=%f

" % fit_results.rsquared return y_out, hover_header, fit_results -def lowess(options, x, y, x_label, y_label, non_missing): +def lowess(options, x_raw, x, y, x_label, y_label, non_missing): import statsmodels.api as sm frac = options.get("frac", 0.6666666) @@ -46,13 +46,13 @@ def lowess(options, x, y, x_label, y_label, non_missing): return y_out, hover_header, None -def ma(options, x, y, x_label, y_label, non_missing): - y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] - hover_header = "Moving Average trendline

" +def ma(options, x_raw, x, y, x_label, y_label, non_missing): + y_out = pd.Series(y, index=x_raw).rolling(**options).mean()[non_missing] + hover_header = "MA trendline

" return y_out, hover_header, None -def ewm(options, x, y, x_label, y_label, non_missing): - y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] - hover_header = "EWM trendline

" +def ewma(options, x_raw, x, y, x_label, y_label, non_missing): + y_out = pd.Series(y, index=x_raw).ewm(**options).mean()[non_missing] + hover_header = "EWMA trendline

" return y_out, hover_header, None diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py index 41064bd19df..19d4db4efdf 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py @@ -5,10 +5,27 @@ from datetime import datetime -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_results_passthrough(mode): +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("ols", dict(log_x=True, log_y=True)), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("ma", dict(window=2)), + ("ewma", dict(alpha=0.5)), + ], +) +def test_trendline_results_passthrough(mode, options): df = px.data.gapminder().query("continent == 'Oceania'") - fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + fig = px.scatter( + df, + x="year", + y="pop", + color="country", + trendline=mode, + trendline_options=options, + ) assert len(fig.data) == 4 for trace in fig["data"][0::2]: assert "trendline" not in trace.hovertemplate @@ -20,90 +37,161 @@ def test_trendline_results_passthrough(mode): if mode == "ols": assert len(results) == 2 assert results["country"].values[0] == "Australia" - assert results["country"].values[0] == "Australia" au_result = results["px_fit_results"].values[0] assert len(au_result.params) == 2 else: assert len(results) == 0 -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_enough_values(mode): - fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode) +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("ols", dict(add_constant=False, log_x=True, log_y=True)), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("ma", dict(window=2)), + ("ewma", dict(alpha=0.5)), + ], +) +def test_trendline_enough_values(mode, options): + fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert len(fig.data[1].x) == 2 - fig = px.scatter(x=[0], y=[0], trendline=mode) + fig = px.scatter(x=[0], y=[0], trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode) + fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode) + fig = px.scatter( + x=[0, 1], y=np.array([0, np.nan]), trendline=mode, trendline_options=options + ) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode) + fig = px.scatter( + x=[0, 1, None], y=[0, None, 1], trendline=mode, trendline_options=options + ) assert len(fig.data) == 2 assert fig.data[1].x is None fig = px.scatter( - x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode + x=np.array([0, 1, np.nan]), + y=np.array([0, np.nan, 1]), + trendline=mode, + trendline_options=options, ) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode) + fig = px.scatter( + x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode, trendline_options=options + ) assert len(fig.data) == 2 assert len(fig.data[1].x) == 2 fig = px.scatter( - x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode + x=np.array([0, 1, np.nan, 2]), + y=np.array([1, np.nan, 1, 2]), + trendline=mode, + trendline_options=options, ) assert len(fig.data) == 2 assert len(fig.data[1].x) == 2 -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_nan_values(mode): +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("ols", dict(add_constant=False, log_x=True, log_y=True)), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("ma", dict(window=2)), + ("ewma", dict(alpha=0.5)), + ], +) +def test_trendline_nan_values(mode, options): df = px.data.gapminder().query("continent == 'Oceania'") start_date = 1970 df["pop"][df["year"] < start_date] = np.nan - fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + fig = px.scatter( + df, + x="year", + y="pop", + color="country", + trendline=mode, + trendline_options=options, + ) for trendline in fig["data"][1::2]: assert trendline.x[0] >= start_date assert len(trendline.x) == len(trendline.y) -def test_no_slope_ols_trendline(): +def test_ols_trendline_slopes(): fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols") - assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number) + assert "y = 1 * x + 0
" in fig.data[1].hovertemplate results = px.get_trendline_results(fig) params = results["px_fit_results"].iloc[0].params assert np.all(np.isclose(params, [0, 1])) + fig = px.scatter(x=[0, 1], y=[1, 2], trendline="ols") + assert "y = 1 * x + 1
" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [1, 1])) + + fig = px.scatter( + x=[0, 1], y=[1, 2], trendline="ols", trendline_options=dict(add_constant=False) + ) + assert "y = 2 * x
" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [2])) + + fig = px.scatter( + x=[1, 1], y=[0, 0], trendline="ols", trendline_options=dict(add_constant=False) + ) + assert "y = 0 * x
" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [0])) + fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols") - assert "y = 0" in fig.data[1].hovertemplate + assert "y = 0
" in fig.data[1].hovertemplate results = px.get_trendline_results(fig) params = results["px_fit_results"].iloc[0].params assert np.all(np.isclose(params, [0])) fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols") - assert "y = 0" in fig.data[1].hovertemplate + assert "y = 0 * x + 0
" in fig.data[1].hovertemplate fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols") - assert "y = 0 * x + 1" in fig.data[1].hovertemplate + assert "y = 0 * x + 1
" in fig.data[1].hovertemplate fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols") - assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate + assert "y = 0 * x + 1.5
" in fig.data[1].hovertemplate -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_on_timeseries(mode): +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("ols", dict(add_constant=False, log_x=True, log_y=True)), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("ma", dict(window=2)), + ("ma", dict(window="10d")), + ("ewma", dict(alpha=0.5)), + ], +) +def test_trendline_on_timeseries(mode, options): df = px.data.stocks() with pytest.raises(ValueError) as err_msg: - px.scatter(df, x="date", y="GOOG", trendline=mode) + px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options) assert "Could not convert value of 'x' ('date') into a numeric type." in str( err_msg.value ) df["date"] = pd.to_datetime(df["date"]) df["date"] = df["date"].dt.tz_localize("CET") # force a timezone - fig = px.scatter(df, x="date", y="GOOG", trendline=mode) + fig = px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert len(fig.data[0].x) == len(fig.data[1].x) assert type(fig.data[0].x[0]) == datetime From 81a76c7a6859804c832773e86f65bb13c7cb4861 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 29 Dec 2020 22:00:20 -0500 Subject: [PATCH 07/24] apidoc and tests --- .../__init__.py} | 0 .../plotly/tests/test_optional/test_px/test_trendline.py | 4 ---- 2 files changed, 4 deletions(-) rename packages/python/plotly/plotly/express/{trendline_functions.py => trendline_functions/__init__.py} (100%) diff --git a/packages/python/plotly/plotly/express/trendline_functions.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py similarity index 100% rename from packages/python/plotly/plotly/express/trendline_functions.py rename to packages/python/plotly/plotly/express/trendline_functions/__init__.py diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py index 19d4db4efdf..661f46bba2a 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py @@ -9,7 +9,6 @@ "mode,options", [ ("ols", None), - ("ols", dict(log_x=True, log_y=True)), ("lowess", None), ("lowess", dict(frac=0.3)), ("ma", dict(window=2)), @@ -47,7 +46,6 @@ def test_trendline_results_passthrough(mode, options): "mode,options", [ ("ols", None), - ("ols", dict(add_constant=False, log_x=True, log_y=True)), ("lowess", None), ("lowess", dict(frac=0.3)), ("ma", dict(window=2)), @@ -101,7 +99,6 @@ def test_trendline_enough_values(mode, options): "mode,options", [ ("ols", None), - ("ols", dict(add_constant=False, log_x=True, log_y=True)), ("lowess", None), ("lowess", dict(frac=0.3)), ("ma", dict(window=2)), @@ -172,7 +169,6 @@ def test_ols_trendline_slopes(): "mode,options", [ ("ols", None), - ("ols", dict(add_constant=False, log_x=True, log_y=True)), ("lowess", None), ("lowess", dict(frac=0.3)), ("ma", dict(window=2)), From 16a9c63d3baf30019a23ea50b2e382d5e0dd12ee Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 30 Dec 2020 20:36:06 -0500 Subject: [PATCH 08/24] fix up tests --- packages/python/plotly/plotly/express/_core.py | 12 +++++++++--- .../plotly/express/trendline_functions/__init__.py | 8 ++++++++ .../tests/test_optional/test_px/test_trendline.py | 4 +++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index fd395d739ee..bd5befafb58 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -17,6 +17,7 @@ ) NO_COLOR = "px_no_color_constant" +trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols) # Declare all supported attributes, across all plot types direct_attrables = ( @@ -313,10 +314,8 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if trace_spec.constructor == go.Histogram: mapping_labels["count"] = "%{x}" elif attr_name == "trendline": - trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols) if ( - attr_value in trendline_functions - and args["x"] + args["x"] and args["y"] and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 ): @@ -1815,6 +1814,13 @@ def infer_config(args, constructor, trace_patch, layout_patch): ): args["facet_col_wrap"] = 0 + if "trendline" in args and args["trendline"] is not None: + if args["trendline"] not in trendline_functions: + raise ValueError( + "Value '%s' for `trendline` must be one of %s" + % (args["trendline"], trendline_functions.keys()) + ) + if "trendline_options" in args and args["trendline_options"] is None: args["trendline_options"] = dict() diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index 78d447353d6..79bb2de125f 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -10,9 +10,17 @@ def ols(options, x_raw, x, y, x_label, y_label, non_missing): log_y = options.get("log_y", False) if log_y: + if np.any(y == 0): + raise ValueError( + "Can't do OLS trendline with `log_y=True` when `y` contains zeros." + ) y = np.log(y) y_label = "log(%s)" % y_label if log_x: + if np.any(x == 0): + raise ValueError( + "Can't do OLS trendline with `log_x=True` when `x` contains zeros." + ) x = np.log(x) x_label = "log(%s)" % x_label if add_constant: diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py index 661f46bba2a..848e8f68003 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py @@ -99,6 +99,7 @@ def test_trendline_enough_values(mode, options): "mode,options", [ ("ols", None), + ("ols", dict(add_constant=False, log_x=True, log_y=True)), ("lowess", None), ("lowess", dict(frac=0.3)), ("ma", dict(window=2)), @@ -124,7 +125,8 @@ def test_trendline_nan_values(mode, options): def test_ols_trendline_slopes(): fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols") - assert "y = 1 * x + 0
" in fig.data[1].hovertemplate + # should be "y = 1 * x + 0" but sometimes is some tiny number instead + assert "y = 1 * x + " in fig.data[1].hovertemplate results = px.get_trendline_results(fig) params = results["px_fit_results"].iloc[0].params assert np.all(np.isclose(params, [0, 1])) From c413a31796c48f145aced35ce11ea1d9db9a2210 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Thu, 31 Dec 2020 21:35:14 -0500 Subject: [PATCH 09/24] docstrings --- packages/python/plotly/plotly/express/_doc.py | 11 +++- .../express/trendline_functions/__init__.py | 60 ++++++++++++++----- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 1a218a9edbb..ded3b039d55 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -402,18 +402,23 @@ ], trendline=[ "str", - "One of `'ols'` or `'lowess'`.", + "One of `'ols'`, `'lowess'`, `'ma'` or `'ewma'`.", "If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.", "If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.", + "If `'ma`', a Moving Average line will be drawn for each discrete-color/symbol group.", + "If `'ewma`', an Exponentially Weighted Moving Average line will be drawn for each discrete-color/symbol group.", + "See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how", + "to configure them with the `trendline_options` argument.", ], trendline_options=[ "dict", - "Options passed to the function named in the `trendline` argument.", + "Options passed as the first argument to the function from `plotly.express.trendline_functions` ", + "named in the `trendline` argument.", ], trendline_color_override=[ "str", "Valid CSS color.", - "If provided, and if `trendline` is set, all trendlines will be drawn in this color.", + "If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.", ], render_mode=[ "str", diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index 79bb2de125f..c36482ec8e7 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -2,33 +2,47 @@ import numpy as np -def ols(options, x_raw, x, y, x_label, y_label, non_missing): +def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Ordinary Least Squares trendline function + + Requires `statsmodels` to be installed. + + Valid keys for the `trendline_options` dict are: + + `add_constant` (`bool`, default `True`): if `False`, the trendline passes through + the origin but if `True` a y-intercept is fitted. + + `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with + respect to the base 10 logarithm of the input. Note that this means no zeros can + be present in the input. + """ + import statsmodels.api as sm - add_constant = options.get("add_constant", True) - log_x = options.get("log_x", False) - log_y = options.get("log_y", False) + add_constant = trendline_options.get("add_constant", True) + log_x = trendline_options.get("log_x", False) + log_y = trendline_options.get("log_y", False) if log_y: if np.any(y == 0): raise ValueError( "Can't do OLS trendline with `log_y=True` when `y` contains zeros." ) - y = np.log(y) - y_label = "log(%s)" % y_label + y = np.log10(y) + y_label = "log10(%s)" % y_label if log_x: if np.any(x == 0): raise ValueError( "Can't do OLS trendline with `log_x=True` when `x` contains zeros." ) - x = np.log(x) - x_label = "log(%s)" % x_label + x = np.log10(x) + x_label = "log10(%s)" % x_label if add_constant: x = sm.add_constant(x) fit_results = sm.OLS(y, x, missing="drop").fit() y_out = fit_results.predict() if log_y: - y_out = np.exp(y_out) + y_out = np.power(10, y_out) hover_header = "OLS trendline
" if len(fit_results.params) == 2: hover_header += "%s = %g * %s + %g
" % ( @@ -45,22 +59,38 @@ def ols(options, x_raw, x, y, x_label, y_label, non_missing): return y_out, hover_header, fit_results -def lowess(options, x_raw, x, y, x_label, y_label, non_missing): +def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Locally Weighted Scatterplot Smoothing trendline function + + Requires `statsmodels` to be installed. + + Valid keys for the `trendline_options` dict are: + + `frac` (`float`, default `0.6666666`): the `frac` parameter from `statsmodels.api.nonparametric.lowess` + """ import statsmodels.api as sm - frac = options.get("frac", 0.6666666) + frac = trendline_options.get("frac", 0.6666666) y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] hover_header = "LOWESS trendline

" return y_out, hover_header, None -def ma(options, x_raw, x, y, x_label, y_label, non_missing): - y_out = pd.Series(y, index=x_raw).rolling(**options).mean()[non_missing] +def ma(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Moving Average trendline function + + The `trendline_options` dict is passed as keyword arguments into the `pandas.Series.rolling` function. + """ + y_out = pd.Series(y, index=x_raw).rolling(**trendline_options).mean()[non_missing] hover_header = "MA trendline

" return y_out, hover_header, None -def ewma(options, x_raw, x, y, x_label, y_label, non_missing): - y_out = pd.Series(y, index=x_raw).ewm(**options).mean()[non_missing] +def ewma(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Exponentially Weighted Moving Average trendline function + + The `trendline_options` dict is passed as keyword arguments into the `pandas.Series.ewma` function. + """ + y_out = pd.Series(y, index=x_raw).ewm(**trendline_options).mean()[non_missing] hover_header = "EWMA trendline

" return y_out, hover_header, None From aad79b441fd405f339224da2ea1e683c77be38f9 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Thu, 31 Dec 2020 21:46:30 -0500 Subject: [PATCH 10/24] docstrings --- .../express/trendline_functions/__init__.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index c36482ec8e7..f5175351cbf 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -3,16 +3,20 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - """Ordinary Least Squares trendline function + """Ordinary Least Squares (OLS) trendline function Requires `statsmodels` to be installed. + This trendline function causes fit results to be stored within the figure, + accessible via the `plotly.express.get_trendline_results` function. The fit results + are the output of the `statsmodels.api.OLS` function. + Valid keys for the `trendline_options` dict are: - `add_constant` (`bool`, default `True`): if `False`, the trendline passes through + - `add_constant` (`bool`, default `True`): if `False`, the trendline passes through the origin but if `True` a y-intercept is fitted. - `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with + - `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with respect to the base 10 logarithm of the input. Note that this means no zeros can be present in the input. """ @@ -60,13 +64,14 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - """Locally Weighted Scatterplot Smoothing trendline function + """LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function Requires `statsmodels` to be installed. Valid keys for the `trendline_options` dict are: - `frac` (`float`, default `0.6666666`): the `frac` parameter from `statsmodels.api.nonparametric.lowess` + - `frac` (`float`, default `0.6666666`): the `frac` parameter from the + `statsmodels.api.nonparametric.lowess` function """ import statsmodels.api as sm @@ -77,9 +82,12 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): def ma(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - """Moving Average trendline function + """Moving Average (MA) trendline function + + Requires `pandas` to be installed. - The `trendline_options` dict is passed as keyword arguments into the `pandas.Series.rolling` function. + The `trendline_options` dict is passed as keyword arguments into the + `pandas.Series.rolling` function. """ y_out = pd.Series(y, index=x_raw).rolling(**trendline_options).mean()[non_missing] hover_header = "MA trendline

" @@ -87,9 +95,12 @@ def ma(trendline_options, x_raw, x, y, x_label, y_label, non_missing): def ewma(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - """Exponentially Weighted Moving Average trendline function + """Exponentially Weighted Moving Average (EWMA) trendline function + + Requires `pandas` to be installed. - The `trendline_options` dict is passed as keyword arguments into the `pandas.Series.ewma` function. + The `trendline_options` dict is passed as keyword arguments into the + `pandas.Series.ewma` function. """ y_out = pd.Series(y, index=x_raw).ewm(**trendline_options).mean()[non_missing] hover_header = "EWMA trendline

" From ce73f31fa5a38ad0478cb5021059eda4d2f2a977 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Fri, 1 Jan 2021 19:05:53 -0500 Subject: [PATCH 11/24] docstrings --- .../plotly/express/trendline_functions/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index f5175351cbf..3872d3f81a3 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -1,3 +1,13 @@ +""" +The `trendline_functions` module contains functions which are called by Plotly Express +when the `trendline` argument is used. Valid values for `trendline` are the names of the +functions in this module, and the value of the `trendline_options` argument to PX +functions is passed in as the first argument to these functions when called. + +Note that the functions in this module are not meant to be called directly, and are +exposed as part of the public API for documentation purposes. +""" + import pandas as pd import numpy as np From 1fa9fac50a189b1dba8c7377ae9a59edaa5cb3d9 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 4 Jan 2021 15:17:14 -0500 Subject: [PATCH 12/24] add PX utility functions to apidoc --- doc/apidoc/plotly.express.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/apidoc/plotly.express.rst b/doc/apidoc/plotly.express.rst index 36550a0a0fb..bff238d6684 100644 --- a/doc/apidoc/plotly.express.rst +++ b/doc/apidoc/plotly.express.rst @@ -49,6 +49,8 @@ plotly's high-level API for rapid figure generation. :: density_heatmap density_mapbox imshow + set_mapbox_access_token + get_trendline_results `plotly.express` subpackages From d0ab458536412ca45745ca0ce982f901299e07d5 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Sun, 1 Aug 2021 20:39:12 -0400 Subject: [PATCH 13/24] refactor new trendlines --- .../python/plotly/plotly/express/_core.py | 6 +- .../express/trendline_functions/__init__.py | 57 +++++++++++++------ .../test_optional/test_px/test_trendline.py | 22 ++++--- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index bd5befafb58..e55949a64bc 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2,7 +2,7 @@ import plotly.io as pio from collections import namedtuple, OrderedDict from ._special_inputs import IdentityMap, Constant, Range -from .trendline_functions import ols, lowess, ma, ewma +from .trendline_functions import ols, lowess, rolling, expanding, ewm from _plotly_utils.basevalidators import ColorscaleValidator from plotly.colors import qualitative, sequential @@ -17,7 +17,9 @@ ) NO_COLOR = "px_no_color_constant" -trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols) +trendline_functions = dict( + lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols +) # Declare all supported attributes, across all plot types direct_attrables = ( diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index 3872d3f81a3..442280178c7 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -66,9 +66,9 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): fit_results.params[0], ) elif not add_constant: - hover_header += "%s = %g * %s
" % (y_label, fit_results.params[0], x_label,) + hover_header += "%s = %g * %s
" % (y_label, fit_results.params[0], x_label) else: - hover_header += "%s = %g
" % (y_label, fit_results.params[0],) + hover_header += "%s = %g
" % (y_label, fit_results.params[0]) hover_header += "R2=%f

" % fit_results.rsquared return y_out, hover_header, fit_results @@ -91,27 +91,48 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): return y_out, hover_header, None -def ma(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - """Moving Average (MA) trendline function +def _pandas(mode, trendline_options, x_raw, y, non_missing): + modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding") + function_name = trendline_options.pop("function", "mean") + function_args = trendline_options.pop("function_args", dict()) + series = pd.Series(y, index=x_raw) + agg = getattr(series, mode) # e.g. series.rolling + agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts) + function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean + y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts) + y_out = y_out[non_missing] + hover_header = "%s %s trendline

" % (modes[mode], function_name) + return y_out, hover_header, None + - Requires `pandas` to be installed. +def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Rolling trendline function - The `trendline_options` dict is passed as keyword arguments into the - `pandas.Series.rolling` function. + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.rolling` function. """ - y_out = pd.Series(y, index=x_raw).rolling(**trendline_options).mean()[non_missing] - hover_header = "MA trendline

" - return y_out, hover_header, None + return _pandas("rolling", trendline_options, x_raw, y, non_missing) -def ewma(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - """Exponentially Weighted Moving Average (EWMA) trendline function +def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Expanding trendline function - Requires `pandas` to be installed. + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.expanding` function. + """ + return _pandas("expanding", trendline_options, x_raw, y, non_missing) - The `trendline_options` dict is passed as keyword arguments into the - `pandas.Series.ewma` function. + +def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Exponentially weighted trendline function + + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.ewm` function. """ - y_out = pd.Series(y, index=x_raw).ewm(**trendline_options).mean()[non_missing] - hover_header = "EWMA trendline

" - return y_out, hover_header, None + return _pandas("ewm", trendline_options, x_raw, y, non_missing) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py index 848e8f68003..5f47cad2f6f 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py @@ -11,8 +11,9 @@ ("ols", None), ("lowess", None), ("lowess", dict(frac=0.3)), - ("ma", dict(window=2)), - ("ewma", dict(alpha=0.5)), + ("rolling", dict(window=2)), + ("expanding", None), + ("ewm", dict(alpha=0.5)), ], ) def test_trendline_results_passthrough(mode, options): @@ -48,8 +49,9 @@ def test_trendline_results_passthrough(mode, options): ("ols", None), ("lowess", None), ("lowess", dict(frac=0.3)), - ("ma", dict(window=2)), - ("ewma", dict(alpha=0.5)), + ("rolling", dict(window=2)), + ("expanding", None), + ("ewm", dict(alpha=0.5)), ], ) def test_trendline_enough_values(mode, options): @@ -102,8 +104,9 @@ def test_trendline_enough_values(mode, options): ("ols", dict(add_constant=False, log_x=True, log_y=True)), ("lowess", None), ("lowess", dict(frac=0.3)), - ("ma", dict(window=2)), - ("ewma", dict(alpha=0.5)), + ("rolling", dict(window=2)), + ("expanding", None), + ("ewm", dict(alpha=0.5)), ], ) def test_trendline_nan_values(mode, options): @@ -173,9 +176,10 @@ def test_ols_trendline_slopes(): ("ols", None), ("lowess", None), ("lowess", dict(frac=0.3)), - ("ma", dict(window=2)), - ("ma", dict(window="10d")), - ("ewma", dict(alpha=0.5)), + ("rolling", dict(window=2)), + ("rolling", dict(window="10d")), + ("expanding", None), + ("ewm", dict(alpha=0.5)), ], ) def test_trendline_on_timeseries(mode, options): From 619c6938ec25bf441692e0f4c69f072d5eb947ac Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 2 Aug 2021 21:49:33 -0400 Subject: [PATCH 14/24] clarify --- packages/python/plotly/plotly/express/_core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index e55949a64bc..975ae24e248 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -869,7 +869,9 @@ def make_trace_spec(args, constructor, attrs, trace_patch): # Add trendline trace specifications if "trendline" in args and args["trendline"]: trace_spec = TraceSpec( - constructor=go.Scattergl if constructor == go.Scattergl else go.Scatter, + constructor=go.Scattergl + if constructor == go.Scattergl # could be contour + else go.Scatter, attrs=["trendline"], trace_patch=dict(mode="lines"), marginal=None, From 744afb98e0ebf83f6e416591348a38bce9696c17 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Mon, 2 Aug 2021 22:57:28 -0400 Subject: [PATCH 15/24] docstring --- .../plotly/plotly/express/trendline_functions/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index 442280178c7..ace0c15bc61 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -128,7 +128,7 @@ def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing): def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - """Exponentially weighted trendline function + """Exponentially Weighted Moment (EWM) trendline function The value of the `function` key of the `trendline_options` dict is the function to use (defaults to `mean`) and the value of the `function_args` key are taken to be From e7a2fbcb9c8768e58b9e973ee8f283939a73617e Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Tue, 3 Aug 2021 14:14:51 -0400 Subject: [PATCH 16/24] bugfix --- .../python/plotly/plotly/express/trendline_functions/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index ace0c15bc61..67c0c540558 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -93,6 +93,7 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): def _pandas(mode, trendline_options, x_raw, y, non_missing): modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding") + trendline_options = trendline_options.copy() function_name = trendline_options.pop("function", "mean") function_args = trendline_options.pop("function_args", dict()) series = pd.Series(y, index=x_raw) From b6806ffc8137dcbe6399d4fcdc7282cf8f8cd266 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 4 Aug 2021 23:12:37 -0400 Subject: [PATCH 17/24] trendline_scope --- .../plotly/plotly/express/_chart_types.py | 16 ++++-- .../python/plotly/plotly/express/_core.py | 54 ++++++++++++++----- packages/python/plotly/plotly/express/_doc.py | 9 +++- .../test_optional/test_px/test_trendline.py | 39 ++++++++++++++ 4 files changed, 100 insertions(+), 18 deletions(-) diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 5051c8a5367..25e5991c9bc 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -48,6 +48,7 @@ def scatter( trendline=None, trendline_options=None, trendline_color_override=None, + trendline_scope="trace", log_x=False, log_y=False, range_x=None, @@ -93,6 +94,7 @@ def density_contour( trendline=None, trendline_options=None, trendline_color_override=None, + trendline_scope="trace", log_x=False, log_y=False, range_x=None, @@ -202,7 +204,9 @@ def density_heatmap( z=[ "For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.", ], - histfunc=["The arguments to this function are the values of `z`.",], + histfunc=[ + "The arguments to this function are the values of `z`.", + ], ), ) @@ -467,7 +471,9 @@ def histogram( args=locals(), constructor=go.Histogram, trace_patch=dict( - histnorm=histnorm, histfunc=histfunc, cumulative=dict(enabled=cumulative), + histnorm=histnorm, + histfunc=histfunc, + cumulative=dict(enabled=cumulative), ), layout_patch=dict(barmode=barmode, barnorm=barnorm), ) @@ -527,7 +533,11 @@ def violin( args=locals(), constructor=go.Violin, trace_patch=dict( - points=points, box=dict(visible=box), scalegroup=True, x0=" ", y0=" ", + points=points, + box=dict(visible=box), + scalegroup=True, + x0=" ", + y0=" ", ), layout_patch=dict(violinmode=violinmode), ) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 975ae24e248..cc0e98375b2 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -347,6 +347,9 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): ) # preserve original values of "x" in case they're dates + # otherwise numpy/pandas can mess with the timezones + # NB this means trendline functions must output one-to-one with the input series + # i.e. we can't do resampling, because then the X values might not line up! non_missing = np.logical_not( np.logical_or(np.isnan(y), np.isnan(x)) ) @@ -867,23 +870,25 @@ def make_trace_spec(args, constructor, attrs, trace_patch): result.append(trace_spec) # Add trendline trace specifications - if "trendline" in args and args["trendline"]: - trace_spec = TraceSpec( - constructor=go.Scattergl - if constructor == go.Scattergl # could be contour - else go.Scatter, - attrs=["trendline"], - trace_patch=dict(mode="lines"), - marginal=None, - ) - if args["trendline_color_override"]: - trace_spec.trace_patch["line"] = dict( - color=args["trendline_color_override"] - ) - result.append(trace_spec) + if args.get("trendline") and args.get("trendline_scope", "trace") == "trace": + result.append(make_trendline_spec(args, constructor)) return result +def make_trendline_spec(args, constructor): + trace_spec = TraceSpec( + constructor=go.Scattergl + if constructor == go.Scattergl # could be contour + else go.Scatter, + attrs=["trendline"], + trace_patch=dict(mode="lines"), + marginal=None, + ) + if args["trendline_color_override"]: + trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"]) + return trace_spec + + def one_group(x): return "" @@ -2127,6 +2132,27 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): fig.update_layout(template=args["template"], overwrite=True) fig.frames = frame_list if len(frames) > 1 else [] + if args.get("trendline") and args.get("trendline_scope", "trace") == "overall": + trendline_spec = make_trendline_spec(args, constructor) + trendline_trace = trendline_spec.constructor( + name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False + ) + if "line" not in trendline_spec.trace_patch: # no color override + for m in grouped_mappings: + if m.variable == "color": + next_color = m.sequence[len(m.val_map) % len(m.sequence)] + trendline_spec.trace_patch["line"] = dict(color=next_color) + patch, fit_results = make_trace_kwargs( + args, trendline_spec, args["data_frame"], {}, sizeref + ) + trendline_trace.update(patch) + fig.add_trace( + trendline_trace, row="all", col="all", exclude_empty_subplots=True + ) + fig.update_traces(selector=-1, showlegend=True) + if fit_results is not None: + trendline_rows.append(dict(px_fit_results=fit_results)) + fig._px_trendlines = pd.DataFrame(trendline_rows) configure_axes(args, constructor, fig, orders) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index ded3b039d55..60e474e7adb 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -325,7 +325,10 @@ "Setting this value is recommended when using `plotly.express.colors.diverging` color scales as the inputs to `color_continuous_scale`.", ], size_max=["int (default `20`)", "Set the maximum mark size when using `size`."], - markers=["boolean (default `False`)", "If `True`, markers are shown on lines.",], + markers=[ + "boolean (default `False`)", + "If `True`, markers are shown on lines.", + ], log_x=[ "boolean (default `False`)", "If `True`, the x-axis is log-scaled in cartesian coordinates.", @@ -420,6 +423,10 @@ "Valid CSS color.", "If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.", ], + trendline_scope=[ + "str (one of `'trace'` or `'overall'`, default `'trace'`)", + "If `'trace'`, then one trendline is drawn per trace (i.e. per color, symbol, facet, animation frame etc) and if `'overall'` then one trendline is computed for the entire dataset, and replicated across all facets.", + ], render_mode=[ "str", "One of `'auto'`, `'svg'` or `'webgl'`, default `'auto'`", diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py index 5f47cad2f6f..66046981eff 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py @@ -200,3 +200,42 @@ def test_trendline_on_timeseries(mode, options): assert type(fig.data[1].x[0]) == datetime assert np.all(fig.data[0].x == fig.data[1].x) assert str(fig.data[0].x[0]) == str(fig.data[1].x[0]) + + +def test_overall_trendline(): + df = px.data.tips() + fig1 = px.scatter(df, x="total_bill", y="tip", trendline="ols") + assert len(fig1.data) == 2 + assert "trendline" in fig1.data[1].hovertemplate + results1 = px.get_trendline_results(fig1) + params1 = results1["px_fit_results"].iloc[0].params + + fig2 = px.scatter( + df, + x="total_bill", + y="tip", + color="sex", + trendline="ols", + trendline_scope="overall", + ) + assert len(fig2.data) == 3 + assert "trendline" in fig2.data[2].hovertemplate + results2 = px.get_trendline_results(fig2) + params2 = results2["px_fit_results"].iloc[0].params + + assert np.all(np.array_equal(params1, params2)) + + fig3 = px.scatter( + df, + x="total_bill", + y="tip", + facet_row="sex", + trendline="ols", + trendline_scope="overall", + ) + assert len(fig3.data) == 4 + assert "trendline" in fig3.data[3].hovertemplate + results3 = px.get_trendline_results(fig3) + params3 = results3["px_fit_results"].iloc[0].params + + assert np.all(np.array_equal(params1, params3)) From 6ef46bfa2a7ea6613af77b8655f6ff36d7f9212d Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 4 Aug 2021 23:29:35 -0400 Subject: [PATCH 18/24] black --- .../python/plotly/plotly/express/_chart_types.py | 14 +++----------- packages/python/plotly/plotly/express/_doc.py | 5 +---- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 25e5991c9bc..f335e78de34 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -204,9 +204,7 @@ def density_heatmap( z=[ "For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.", ], - histfunc=[ - "The arguments to this function are the values of `z`.", - ], + histfunc=["The arguments to this function are the values of `z`.",], ), ) @@ -471,9 +469,7 @@ def histogram( args=locals(), constructor=go.Histogram, trace_patch=dict( - histnorm=histnorm, - histfunc=histfunc, - cumulative=dict(enabled=cumulative), + histnorm=histnorm, histfunc=histfunc, cumulative=dict(enabled=cumulative), ), layout_patch=dict(barmode=barmode, barnorm=barnorm), ) @@ -533,11 +529,7 @@ def violin( args=locals(), constructor=go.Violin, trace_patch=dict( - points=points, - box=dict(visible=box), - scalegroup=True, - x0=" ", - y0=" ", + points=points, box=dict(visible=box), scalegroup=True, x0=" ", y0=" ", ), layout_patch=dict(violinmode=violinmode), ) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 60e474e7adb..37d16cd1e11 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -325,10 +325,7 @@ "Setting this value is recommended when using `plotly.express.colors.diverging` color scales as the inputs to `color_continuous_scale`.", ], size_max=["int (default `20`)", "Set the maximum mark size when using `size`."], - markers=[ - "boolean (default `False`)", - "If `True`, markers are shown on lines.", - ], + markers=["boolean (default `False`)", "If `True`, markers are shown on lines.",], log_x=[ "boolean (default `False`)", "If `True`, the x-axis is log-scaled in cartesian coordinates.", From f37eb06303e5873780b57e20cb6e46097cf0daad Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 4 Aug 2021 23:33:33 -0400 Subject: [PATCH 19/24] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8985395f77b..1fd1096655d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added - Extra flags were added to the `gapminder` and `stocks` dataset to facilitate testing, documentation and demos [#3305](https://github.com/plotly/plotly.py/issues/3305) - All line-like Plotly Express functions now accept `markers` argument to display markers, and all but `line_mapbox` accept `symbol` to map a field to the symbol attribute, similar to scatter-like functions [#3326](https://github.com/plotly/plotly.py/issues/3326) + - `px.scatter` and `px.density_contours` now support new `trendline` types `"rolling"`, `"expanding"` and `"ewm` + - `px.scatter` and `px.density_contours` now support new `trendline_options` argument to parameterize trendlines, with support for constant control and log-scaling in `'ols'` and specification of the fraction used for `'lowess'` + - `px.scatter` and `px.density_contours` now support new `trendline_scope` argument that accepts the value `'overall'` to request a single trendline for all traces, including across facets and animation frames ### Fixed - Fixed regression introduced in version 5.0.0 where pandas/numpy arrays with `dtype` of Object were being converted to `list` values when added to a Figure ([#3292](https://github.com/plotly/plotly.py/issues/3292), [#3293](https://github.com/plotly/plotly.py/pull/3293)) From 03c146cd52a922c043302fed4aa6c726a43a9c1a Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 4 Aug 2021 23:34:09 -0400 Subject: [PATCH 20/24] changelog --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fd1096655d..9ddd0bd04c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,9 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added - Extra flags were added to the `gapminder` and `stocks` dataset to facilitate testing, documentation and demos [#3305](https://github.com/plotly/plotly.py/issues/3305) - All line-like Plotly Express functions now accept `markers` argument to display markers, and all but `line_mapbox` accept `symbol` to map a field to the symbol attribute, similar to scatter-like functions [#3326](https://github.com/plotly/plotly.py/issues/3326) - - `px.scatter` and `px.density_contours` now support new `trendline` types `"rolling"`, `"expanding"` and `"ewm` - - `px.scatter` and `px.density_contours` now support new `trendline_options` argument to parameterize trendlines, with support for constant control and log-scaling in `'ols'` and specification of the fraction used for `'lowess'` - - `px.scatter` and `px.density_contours` now support new `trendline_scope` argument that accepts the value `'overall'` to request a single trendline for all traces, including across facets and animation frames + - `px.scatter` and `px.density_contours` now support new `trendline` types `'rolling'`, `'expanding'` and `'ewm'` [#2997](https://github.com/plotly/plotly.py/pull/2997) + - `px.scatter` and `px.density_contours` now support new `trendline_options` argument to parameterize trendlines, with support for constant control and log-scaling in `'ols'` and specification of the fraction used for `'lowess'`, as well as pass-through to Pandas for `'rolling'`, `'expanding'` and `'ewm'` [#2997](https://github.com/plotly/plotly.py/pull/2997) + - `px.scatter` and `px.density_contours` now support new `trendline_scope` argument that accepts the value `'overall'` to request a single trendline for all traces, including across facets and animation frames [#2997](https://github.com/plotly/plotly.py/pull/2997) ### Fixed - Fixed regression introduced in version 5.0.0 where pandas/numpy arrays with `dtype` of Object were being converted to `list` values when added to a Figure ([#3292](https://github.com/plotly/plotly.py/issues/3292), [#3293](https://github.com/plotly/plotly.py/pull/3293)) From bd0d368bea5f110027c41510c6bdcf8313930e5b Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 4 Aug 2021 23:40:05 -0400 Subject: [PATCH 21/24] validate trendline_options --- .../express/trendline_functions/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index 67c0c540558..9d1c5f5e534 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -30,6 +30,13 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): respect to the base 10 logarithm of the input. Note that this means no zeros can be present in the input. """ + valid_options = ["add_constant", "log_x", "log_y"] + for k in trendline_options.keys(): + if k not in valid_options: + raise ValueError( + "OLS trendline_options keys must be one of [%s] but got '%s'" + % (", ".join(valid_options), k) + ) import statsmodels.api as sm @@ -83,6 +90,15 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): - `frac` (`float`, default `0.6666666`): the `frac` parameter from the `statsmodels.api.nonparametric.lowess` function """ + + valid_options = ["frac"] + for k in trendline_options.keys(): + if k not in valid_options: + raise ValueError( + "LOWESS trendline_options keys must be one of [%s] but got '%s'" + % (", ".join(valid_options), k) + ) + import statsmodels.api as sm frac = trendline_options.get("frac", 0.6666666) From 860a321aa6416cf06d0a539e9a6a3bb9d256b637 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Thu, 5 Aug 2021 11:39:01 -0400 Subject: [PATCH 22/24] tweak docstrings --- packages/python/plotly/plotly/express/_doc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 37d16cd1e11..65d9f0588ff 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -402,11 +402,12 @@ ], trendline=[ "str", - "One of `'ols'`, `'lowess'`, `'ma'` or `'ewma'`.", + "One of `'ols'`, `'lowess'`, `'rolling'`, `'expanding'` or `'ewm'`.", "If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.", "If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.", - "If `'ma`', a Moving Average line will be drawn for each discrete-color/symbol group.", - "If `'ewma`', an Exponentially Weighted Moving Average line will be drawn for each discrete-color/symbol group.", + "If `'rolling`', a Rolling (e.g. rolling average, rolling median) line will be drawn for each discrete-color/symbol group.", + "If `'expanding`', an Expanding (e.g. expanding average, expanding sum) line will be drawn for each discrete-color/symbol group.", + "If `'ewm`', an Exponentially Weighted Moment (e.g. exponentially-weighted moving average) line will be drawn for each discrete-color/symbol group.", "See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how", "to configure them with the `trendline_options` argument.", ], From ab0e0aba3b809567b2a77855427aa1d0bfb15034 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 11 Aug 2021 12:11:23 -0400 Subject: [PATCH 23/24] positive values only for logs --- .../plotly/plotly/express/trendline_functions/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index 9d1c5f5e534..f0fc29cee4b 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -45,16 +45,16 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): log_y = trendline_options.get("log_y", False) if log_y: - if np.any(y == 0): + if np.any(y <= 0): raise ValueError( - "Can't do OLS trendline with `log_y=True` when `y` contains zeros." + "Can't do OLS trendline with `log_y=True` when `y` contains non-positive values." ) y = np.log10(y) y_label = "log10(%s)" % y_label if log_x: - if np.any(x == 0): + if np.any(x <= 0): raise ValueError( - "Can't do OLS trendline with `log_x=True` when `x` contains zeros." + "Can't do OLS trendline with `log_x=True` when `x` contains non-positive values." ) x = np.log10(x) x_label = "log10(%s)" % x_label From f42c5afe5153f2394ed997df0a8e34d2dbd3521a Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Wed, 11 Aug 2021 20:32:43 -0400 Subject: [PATCH 24/24] docs for trendlines --- doc/python/linear-fits.md | 165 +++++++++++++++++++++++++++++++++++--- 1 file changed, 156 insertions(+), 9 deletions(-) diff --git a/doc/python/linear-fits.md b/doc/python/linear-fits.md index 7f1f1a2971f..0029be6c4be 100644 --- a/doc/python/linear-fits.md +++ b/doc/python/linear-fits.md @@ -5,8 +5,8 @@ jupyter: text_representation: extension: .md format_name: markdown - format_version: '1.1' - jupytext_version: 1.1.1 + format_version: '1.2' + jupytext_version: 1.4.2 kernelspec: display_name: Python 3 language: python @@ -20,11 +20,12 @@ jupyter: name: python nbconvert_exporter: python pygments_lexer: ipython3 - version: 3.6.8 + version: 3.7.7 plotly: description: Add linear Ordinary Least Squares (OLS) regression trendlines or non-linear Locally Weighted Scatterplot Smoothing (LOWESS) trendlines to scatterplots - in Python. + in Python. Options for moving averages (rolling means) as well as exponentially-weighted + and expanding functions. display_as: statistical language: python layout: base @@ -39,7 +40,7 @@ jupyter: [Plotly Express](/python/plotly-express/) is the easy-to-use, high-level interface to Plotly, which [operates on a variety of types of data](/python/px-arguments/) and produces [easy-to-style figures](/python/styling-plotly-express/). -Plotly Express allows you to add [Ordinary Least](https://en.wikipedia.org/wiki/Ordinary_least_squares) Squares regression trendline to scatterplots with the `trendline` argument. In order to do so, you will need to install `statsmodels` and its dependencies. Hovering over the trendline will show the equation of the line and its R-squared value. +Plotly Express allows you to add [Ordinary Least Squares](https://en.wikipedia.org/wiki/Ordinary_least_squares) regression trendline to scatterplots with the `trendline` argument. In order to do so, you will need to [install `statsmodels` and its dependencies](https://www.statsmodels.org/stable/install.html). Hovering over the trendline will show the equation of the line and its R-squared value. ```python import plotly.express as px @@ -66,14 +67,160 @@ print(results) results.query("sex == 'Male' and smoker == 'Yes'").px_fit_results.iloc[0].summary() ``` -### Non-Linear Trendlines +### Displaying a single trendline with multiple traces -Plotly Express also supports non-linear [LOWESS](https://en.wikipedia.org/wiki/Local_regression) trendlines. +_new in v5.2_ + +To display a single trendline using the entire dataset, set the `trendline_scope` argument to `"overall"`. The same trendline will be overlaid on all facets and animation frames. The trendline color can be overridden with `trendline_color_override`. + +```python +import plotly.express as px + +df = px.data.tips() +fig = px.scatter(df, x="total_bill", y="tip", symbol="smoker", color="sex", trendline="ols", trendline_scope="overall") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.tips() +fig = px.scatter(df, x="total_bill", y="tip", facet_col="smoker", color="sex", + trendline="ols", trendline_scope="overall", trendline_color_override="black") +fig.show() +``` + +### OLS Parameters + +_new in v5.2_ + +OLS trendlines can be fit with log transformations to both X or Y data using the `trendline_options` argument, independently of whether or not the plot has [logarithmic axes](https://plotly.com/python/log-plot/). + +```python +import plotly.express as px + +df = px.data.gapminder(year=2007) +fig = px.scatter(df, x="gdpPercap", y="lifeExp", + trendline="ols", trendline_options=dict(log_x=True), + title="Log-transformed fit on linear axes") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.gapminder(year=2007) +fig = px.scatter(df, x="gdpPercap", y="lifeExp", log_x=True, + trendline="ols", trendline_options=dict(log_x=True), + title="Log-scaled X axis and log-transformed fit") +fig.show() +``` + +### Locally WEighted Scatterplot Smoothing (LOWESS) + +Plotly Express also supports non-linear [LOWESS](https://en.wikipedia.org/wiki/Local_regression) trendlines. In order use this feature, you will need to [install `statsmodels` and its dependencies](https://www.statsmodels.org/stable/install.html). + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="lowess") +fig.show() +``` + +_new in v5.2_ + +The level of smoothing can be controlled via the `frac` trendline option, which indicates the fraction of the data that the LOWESS smoother should include. The default is a fairly smooth line with `frac=0.6666` and lowering this fraction will give a line that more closely follows the data. + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="lowess", trendline_options=dict(frac=0.1)) +fig.show() +``` + +### Moving Averages + +_new in v5.2_ + +Plotly Express can leverage Pandas' [`rolling`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rolling.html), [`ewm`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.ewm.html) and [`expanding`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.expanding.html) functions in trendlines as well, for example to display moving averages. Values passed to `trendline_options` are passed directly to the underlying Pandas function (with the exception of the `function` and `function_options` keys, see below). + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", trendline_options=dict(window=5), + title="5-point moving average") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="ewm", trendline_options=dict(halflife=2), + title="Exponentially-weighted moving average (halflife of 2 points)") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="expanding", title="Expanding mean") +fig.show() +``` + +### Other Functions + +The `rolling`, `expanding` and `ewm` trendlines support other functions than the default `mean`, enabling, for example, a moving-median trendline, or an expanding-max trendline. ```python import plotly.express as px -df = px.data.gapminder().query("year == 2007") -fig = px.scatter(df, x="gdpPercap", y="lifeExp", color="continent", trendline="lowess") +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", trendline_options=dict(function="median", window=5), + title="Rolling Median") fig.show() ``` + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="expanding", trendline_options=dict(function="max"), + title="Expanding Maximum") +fig.show() +``` + +In some cases, it is necessary to pass options into the underying Pandas function, for example the `std` parameter must be provided if the `win_type` argument to `rolling` is `"gaussian"`. This is possible with the `function_args` trendline option. + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", + trendline_options=dict(window=5, win_type="gaussian", function_args=dict(std=2)), + title="Rolling Mean with Gaussian Window") +fig.show() +``` + +### Displaying only the trendlines + +In some cases, it may be desirable to show only the trendlines, by removing the scatter points. + +```python +import plotly.express as px + +df = px.data.stocks(indexed=True, datetimes=True) +fig = px.scatter(df, trendline="rolling", trendline_options=dict(window=5), + title="5-point moving average") +fig.data = [t for t in fig.data if t.mode == "lines"] +fig.update_traces(showlegend=True) #trendlines have showlegend=False by default +fig.show() +``` + +```python + +```