Skip to content
This repository was archived by the owner on Jun 3, 2024. It is now read-only.

Commit 1c2c780

Browse files
expose OLS fit results
1 parent 0565a5e commit 1c2c780

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

plotly_express/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
density_contour,
3232
)
3333

34-
from ._core import ExpressFigure, set_mapbox_access_token, defaults # noqa: F401
34+
from ._core import ( # noqa: F401
35+
ExpressFigure,
36+
set_mapbox_access_token,
37+
defaults,
38+
get_trendline_results,
39+
)
3540

3641
from . import data, colors # noqa: F401
3742

@@ -61,5 +66,6 @@
6166
"data",
6267
"colors",
6368
"set_mapbox_access_token",
69+
"get_trendline_results",
6470
"ExpressFigure",
6571
]

plotly_express/_core.py

+39-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from plotly.offline import init_notebook_mode, iplot
33
from collections import namedtuple, OrderedDict
44
from .colors import qualitative, sequential
5-
import math
5+
import math, pandas
66

77

88
class PxDefaults(object):
@@ -55,6 +55,21 @@ def _ipython_display_(self):
5555
iplot(self, show_link=False, auto_play=False)
5656

5757

58+
def get_trendline_results(fig):
59+
"""
60+
Extracts fit statistics for trendlines (when applied to figures generated with
61+
the `trendline` argument set to `"ols"`).
62+
63+
Arguments:
64+
fig: the output of a `plotly_express` charting call
65+
Returns:
66+
A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
67+
results objects, along with columns identifying the subset of the data the
68+
trendline was fit on.
69+
"""
70+
return fig._px_trendlines
71+
72+
5873
Mapping = namedtuple(
5974
"Mapping",
6075
["show_in_trace_name", "grouper", "val_map", "sequence", "updater", "variable"],
@@ -128,6 +143,7 @@ def make_trace_kwargs(
128143
if "line_close" in args and args["line_close"]:
129144
g = g.append(g.iloc[0])
130145
result = trace_spec.trace_patch.copy() or {}
146+
fit_results = None
131147
hover_header = ""
132148
for k in trace_spec.attrs:
133149
v = args[k]
@@ -185,16 +201,18 @@ def make_trace_kwargs(
185201
result["y"] = trendline[:, 1]
186202
hover_header = "<b>LOWESS trendline</b><br><br>"
187203
elif v == "ols":
188-
fitted = sm.OLS(y, sm.add_constant(x)).fit()
189-
result["y"] = fitted.predict()
204+
fit_results = sm.OLS(y, sm.add_constant(x)).fit()
205+
result["y"] = fit_results.predict()
190206
hover_header = "<b>OLS trendline</b><br>"
191207
hover_header += "%s = %f * %s + %f<br>" % (
192208
args["y"],
193-
fitted.params[1],
209+
fit_results.params[1],
194210
args["x"],
195-
fitted.params[0],
211+
fit_results.params[0],
212+
)
213+
hover_header += (
214+
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
196215
)
197-
hover_header += "R<sup>2</sup>=%f<br><br>" % fitted.rsquared
198216
mapping_labels[get_label(args, args["x"])] = "%{x}"
199217
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
200218

@@ -253,7 +271,7 @@ def make_trace_kwargs(
253271
if trace_spec.constructor not in [go.Histogram2dContour, go.Parcoords, go.Parcats]:
254272
hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
255273
result["hovertemplate"] = hover_header + "<br>".join(hover_lines)
256-
return result
274+
return result, fit_results
257275

258276

259277
def configure_axes(args, constructor, fig, axes, orders):
@@ -728,6 +746,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
728746

729747
trace_names_by_frame = {}
730748
frames = OrderedDict()
749+
trendline_rows = []
731750
for group_name in group_names:
732751
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
733752
mapping_labels = OrderedDict()
@@ -803,17 +822,19 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
803822
):
804823
trace.update(marker=dict(color=trace.line.color))
805824

806-
trace.update(
807-
make_trace_kwargs(
808-
args,
809-
trace_spec,
810-
group,
811-
mapping_labels.copy(),
812-
sizeref,
813-
color_range=color_range,
814-
show_colorbar=(frame_name not in frames),
815-
)
825+
patch, fit_results = make_trace_kwargs(
826+
args,
827+
trace_spec,
828+
group,
829+
mapping_labels.copy(),
830+
sizeref,
831+
color_range=color_range,
832+
show_colorbar=(frame_name not in frames),
816833
)
834+
trace.update(patch)
835+
if fit_results is not None:
836+
trendline_rows.append(mapping_labels.copy())
837+
trendline_rows[-1]["px_fit_results"] = fit_results
817838
if frame_name not in frames:
818839
frames[frame_name] = dict(data=[], name=frame_name)
819840
frames[frame_name]["data"].append(trace)
@@ -832,6 +853,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
832853
layout=layout_patch,
833854
frames=frame_list if len(frames) > 1 else [],
834855
)
856+
fig._px_trendlines = pandas.DataFrame(trendline_rows)
835857
axes = {m.variable: m.val_map for m in grouped_mappings}
836858
configure_axes(args, constructor, fig, axes, orders)
837859
configure_animation_controls(args, constructor, fig)

0 commit comments

Comments
 (0)