2
2
from plotly .offline import init_notebook_mode , iplot
3
3
from collections import namedtuple , OrderedDict
4
4
from .colors import qualitative , sequential
5
- import math
5
+ import math , pandas
6
6
7
7
8
8
class PxDefaults (object ):
@@ -55,6 +55,21 @@ def _ipython_display_(self):
55
55
iplot (self , show_link = False , auto_play = False )
56
56
57
57
58
+ def get_trendline_results (fig ):
59
+ """
60
+ Extracts fit statistics for trendlines (when applied to figures generated with
61
+ the `trendline` argument set to `"ols"`).
62
+
63
+ Arguments:
64
+ fig: the output of a `plotly_express` charting call
65
+ Returns:
66
+ A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
67
+ results objects, along with columns identifying the subset of the data the
68
+ trendline was fit on.
69
+ """
70
+ return fig ._px_trendlines
71
+
72
+
58
73
Mapping = namedtuple (
59
74
"Mapping" ,
60
75
["show_in_trace_name" , "grouper" , "val_map" , "sequence" , "updater" , "variable" ],
@@ -128,6 +143,7 @@ def make_trace_kwargs(
128
143
if "line_close" in args and args ["line_close" ]:
129
144
g = g .append (g .iloc [0 ])
130
145
result = trace_spec .trace_patch .copy () or {}
146
+ fit_results = None
131
147
hover_header = ""
132
148
for k in trace_spec .attrs :
133
149
v = args [k ]
@@ -185,16 +201,18 @@ def make_trace_kwargs(
185
201
result ["y" ] = trendline [:, 1 ]
186
202
hover_header = "<b>LOWESS trendline</b><br><br>"
187
203
elif v == "ols" :
188
- fitted = sm .OLS (y , sm .add_constant (x )).fit ()
189
- result ["y" ] = fitted .predict ()
204
+ fit_results = sm .OLS (y , sm .add_constant (x )).fit ()
205
+ result ["y" ] = fit_results .predict ()
190
206
hover_header = "<b>OLS trendline</b><br>"
191
207
hover_header += "%s = %f * %s + %f<br>" % (
192
208
args ["y" ],
193
- fitted .params [1 ],
209
+ fit_results .params [1 ],
194
210
args ["x" ],
195
- fitted .params [0 ],
211
+ fit_results .params [0 ],
212
+ )
213
+ hover_header += (
214
+ "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
196
215
)
197
- hover_header += "R<sup>2</sup>=%f<br><br>" % fitted .rsquared
198
216
mapping_labels [get_label (args , args ["x" ])] = "%{x}"
199
217
mapping_labels [get_label (args , args ["y" ])] = "%{y} <b>(trend)</b>"
200
218
@@ -253,7 +271,7 @@ def make_trace_kwargs(
253
271
if trace_spec .constructor not in [go .Histogram2dContour , go .Parcoords , go .Parcats ]:
254
272
hover_lines = [k + "=" + v for k , v in mapping_labels .items ()]
255
273
result ["hovertemplate" ] = hover_header + "<br>" .join (hover_lines )
256
- return result
274
+ return result , fit_results
257
275
258
276
259
277
def configure_axes (args , constructor , fig , axes , orders ):
@@ -728,6 +746,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
728
746
729
747
trace_names_by_frame = {}
730
748
frames = OrderedDict ()
749
+ trendline_rows = []
731
750
for group_name in group_names :
732
751
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
733
752
mapping_labels = OrderedDict ()
@@ -803,17 +822,19 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
803
822
):
804
823
trace .update (marker = dict (color = trace .line .color ))
805
824
806
- trace .update (
807
- make_trace_kwargs (
808
- args ,
809
- trace_spec ,
810
- group ,
811
- mapping_labels .copy (),
812
- sizeref ,
813
- color_range = color_range ,
814
- show_colorbar = (frame_name not in frames ),
815
- )
825
+ patch , fit_results = make_trace_kwargs (
826
+ args ,
827
+ trace_spec ,
828
+ group ,
829
+ mapping_labels .copy (),
830
+ sizeref ,
831
+ color_range = color_range ,
832
+ show_colorbar = (frame_name not in frames ),
816
833
)
834
+ trace .update (patch )
835
+ if fit_results is not None :
836
+ trendline_rows .append (mapping_labels .copy ())
837
+ trendline_rows [- 1 ]["px_fit_results" ] = fit_results
817
838
if frame_name not in frames :
818
839
frames [frame_name ] = dict (data = [], name = frame_name )
819
840
frames [frame_name ]["data" ].append (trace )
@@ -832,6 +853,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
832
853
layout = layout_patch ,
833
854
frames = frame_list if len (frames ) > 1 else [],
834
855
)
856
+ fig ._px_trendlines = pandas .DataFrame (trendline_rows )
835
857
axes = {m .variable : m .val_map for m in grouped_mappings }
836
858
configure_axes (args , constructor , fig , axes , orders )
837
859
configure_animation_controls (args , constructor , fig )
0 commit comments