Skip to content
This repository was archived by the owner on Jun 3, 2024. It is now read-only.

Commit 9ddbb16

Browse files
Merge pull request #69 from plotly/trendline_results
expose OLS fit results
2 parents 9ec0e03 + 1c2c780 commit 9ddbb16

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

Diff for: plotly_express/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
density_contour,
3232
)
3333

34-
from ._core import ExpressFigure, set_mapbox_access_token, defaults # noqa: F401
34+
from ._core import ( # noqa: F401
35+
ExpressFigure,
36+
set_mapbox_access_token,
37+
defaults,
38+
get_trendline_results,
39+
)
3540

3641
from . import data, colors # noqa: F401
3742

@@ -61,5 +66,6 @@
6166
"data",
6267
"colors",
6368
"set_mapbox_access_token",
69+
"get_trendline_results",
6470
"ExpressFigure",
6571
]

Diff for: plotly_express/_core.py

+39-17
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import plotly.io as pio
44
from collections import namedtuple, OrderedDict
55
from .colors import qualitative, sequential
6-
import math
6+
import math, pandas
77

88

99
class PxDefaults(object):
@@ -56,6 +56,21 @@ def _ipython_display_(self):
5656
iplot(self, show_link=False, auto_play=False)
5757

5858

59+
def get_trendline_results(fig):
60+
"""
61+
Extracts fit statistics for trendlines (when applied to figures generated with
62+
the `trendline` argument set to `"ols"`).
63+
64+
Arguments:
65+
fig: the output of a `plotly_express` charting call
66+
Returns:
67+
A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
68+
results objects, along with columns identifying the subset of the data the
69+
trendline was fit on.
70+
"""
71+
return fig._px_trendlines
72+
73+
5974
Mapping = namedtuple(
6075
"Mapping",
6176
["show_in_trace_name", "grouper", "val_map", "sequence", "updater", "variable"],
@@ -129,6 +144,7 @@ def make_trace_kwargs(
129144
if "line_close" in args and args["line_close"]:
130145
g = g.append(g.iloc[0])
131146
result = trace_spec.trace_patch.copy() or {}
147+
fit_results = None
132148
hover_header = ""
133149
for k in trace_spec.attrs:
134150
v = args[k]
@@ -186,16 +202,18 @@ def make_trace_kwargs(
186202
result["y"] = trendline[:, 1]
187203
hover_header = "<b>LOWESS trendline</b><br><br>"
188204
elif v == "ols":
189-
fitted = sm.OLS(y, sm.add_constant(x)).fit()
190-
result["y"] = fitted.predict()
205+
fit_results = sm.OLS(y, sm.add_constant(x)).fit()
206+
result["y"] = fit_results.predict()
191207
hover_header = "<b>OLS trendline</b><br>"
192208
hover_header += "%s = %f * %s + %f<br>" % (
193209
args["y"],
194-
fitted.params[1],
210+
fit_results.params[1],
195211
args["x"],
196-
fitted.params[0],
212+
fit_results.params[0],
213+
)
214+
hover_header += (
215+
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
197216
)
198-
hover_header += "R<sup>2</sup>=%f<br><br>" % fitted.rsquared
199217
mapping_labels[get_label(args, args["x"])] = "%{x}"
200218
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
201219

@@ -254,7 +272,7 @@ def make_trace_kwargs(
254272
if trace_spec.constructor not in [go.Histogram2dContour, go.Parcoords, go.Parcats]:
255273
hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
256274
result["hovertemplate"] = hover_header + "<br>".join(hover_lines)
257-
return result
275+
return result, fit_results
258276

259277

260278
def configure_axes(args, constructor, fig, axes, orders):
@@ -766,6 +784,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
766784

767785
trace_names_by_frame = {}
768786
frames = OrderedDict()
787+
trendline_rows = []
769788
for group_name in group_names:
770789
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
771790
mapping_labels = OrderedDict()
@@ -841,17 +860,19 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
841860
):
842861
trace.update(marker=dict(color=trace.line.color))
843862

844-
trace.update(
845-
make_trace_kwargs(
846-
args,
847-
trace_spec,
848-
group,
849-
mapping_labels.copy(),
850-
sizeref,
851-
color_range=color_range,
852-
show_colorbar=(frame_name not in frames),
853-
)
863+
patch, fit_results = make_trace_kwargs(
864+
args,
865+
trace_spec,
866+
group,
867+
mapping_labels.copy(),
868+
sizeref,
869+
color_range=color_range,
870+
show_colorbar=(frame_name not in frames),
854871
)
872+
trace.update(patch)
873+
if fit_results is not None:
874+
trendline_rows.append(mapping_labels.copy())
875+
trendline_rows[-1]["px_fit_results"] = fit_results
855876
if frame_name not in frames:
856877
frames[frame_name] = dict(data=[], name=frame_name)
857878
frames[frame_name]["data"].append(trace)
@@ -870,6 +891,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
870891
layout=layout_patch,
871892
frames=frame_list if len(frames) > 1 else [],
872893
)
894+
fig._px_trendlines = pandas.DataFrame(trendline_rows)
873895
axes = {m.variable: m.val_map for m in grouped_mappings}
874896
configure_axes(args, constructor, fig, axes, orders)
875897
configure_animation_controls(args, constructor, fig)

0 commit comments

Comments
 (0)