3
3
import plotly .io as pio
4
4
from collections import namedtuple , OrderedDict
5
5
from .colors import qualitative , sequential
6
- import math
6
+ import math , pandas
7
7
8
8
9
9
class PxDefaults (object ):
@@ -56,6 +56,21 @@ def _ipython_display_(self):
56
56
iplot (self , show_link = False , auto_play = False )
57
57
58
58
59
+ def get_trendline_results (fig ):
60
+ """
61
+ Extracts fit statistics for trendlines (when applied to figures generated with
62
+ the `trendline` argument set to `"ols"`).
63
+
64
+ Arguments:
65
+ fig: the output of a `plotly_express` charting call
66
+ Returns:
67
+ A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
68
+ results objects, along with columns identifying the subset of the data the
69
+ trendline was fit on.
70
+ """
71
+ return fig ._px_trendlines
72
+
73
+
59
74
Mapping = namedtuple (
60
75
"Mapping" ,
61
76
["show_in_trace_name" , "grouper" , "val_map" , "sequence" , "updater" , "variable" ],
@@ -129,6 +144,7 @@ def make_trace_kwargs(
129
144
if "line_close" in args and args ["line_close" ]:
130
145
g = g .append (g .iloc [0 ])
131
146
result = trace_spec .trace_patch .copy () or {}
147
+ fit_results = None
132
148
hover_header = ""
133
149
for k in trace_spec .attrs :
134
150
v = args [k ]
@@ -186,16 +202,18 @@ def make_trace_kwargs(
186
202
result ["y" ] = trendline [:, 1 ]
187
203
hover_header = "<b>LOWESS trendline</b><br><br>"
188
204
elif v == "ols" :
189
- fitted = sm .OLS (y , sm .add_constant (x )).fit ()
190
- result ["y" ] = fitted .predict ()
205
+ fit_results = sm .OLS (y , sm .add_constant (x )).fit ()
206
+ result ["y" ] = fit_results .predict ()
191
207
hover_header = "<b>OLS trendline</b><br>"
192
208
hover_header += "%s = %f * %s + %f<br>" % (
193
209
args ["y" ],
194
- fitted .params [1 ],
210
+ fit_results .params [1 ],
195
211
args ["x" ],
196
- fitted .params [0 ],
212
+ fit_results .params [0 ],
213
+ )
214
+ hover_header += (
215
+ "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
197
216
)
198
- hover_header += "R<sup>2</sup>=%f<br><br>" % fitted .rsquared
199
217
mapping_labels [get_label (args , args ["x" ])] = "%{x}"
200
218
mapping_labels [get_label (args , args ["y" ])] = "%{y} <b>(trend)</b>"
201
219
@@ -254,7 +272,7 @@ def make_trace_kwargs(
254
272
if trace_spec .constructor not in [go .Histogram2dContour , go .Parcoords , go .Parcats ]:
255
273
hover_lines = [k + "=" + v for k , v in mapping_labels .items ()]
256
274
result ["hovertemplate" ] = hover_header + "<br>" .join (hover_lines )
257
- return result
275
+ return result , fit_results
258
276
259
277
260
278
def configure_axes (args , constructor , fig , axes , orders ):
@@ -766,6 +784,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
766
784
767
785
trace_names_by_frame = {}
768
786
frames = OrderedDict ()
787
+ trendline_rows = []
769
788
for group_name in group_names :
770
789
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
771
790
mapping_labels = OrderedDict ()
@@ -841,17 +860,19 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
841
860
):
842
861
trace .update (marker = dict (color = trace .line .color ))
843
862
844
- trace .update (
845
- make_trace_kwargs (
846
- args ,
847
- trace_spec ,
848
- group ,
849
- mapping_labels .copy (),
850
- sizeref ,
851
- color_range = color_range ,
852
- show_colorbar = (frame_name not in frames ),
853
- )
863
+ patch , fit_results = make_trace_kwargs (
864
+ args ,
865
+ trace_spec ,
866
+ group ,
867
+ mapping_labels .copy (),
868
+ sizeref ,
869
+ color_range = color_range ,
870
+ show_colorbar = (frame_name not in frames ),
854
871
)
872
+ trace .update (patch )
873
+ if fit_results is not None :
874
+ trendline_rows .append (mapping_labels .copy ())
875
+ trendline_rows [- 1 ]["px_fit_results" ] = fit_results
855
876
if frame_name not in frames :
856
877
frames [frame_name ] = dict (data = [], name = frame_name )
857
878
frames [frame_name ]["data" ].append (trace )
@@ -870,6 +891,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
870
891
layout = layout_patch ,
871
892
frames = frame_list if len (frames ) > 1 else [],
872
893
)
894
+ fig ._px_trendlines = pandas .DataFrame (trendline_rows )
873
895
axes = {m .variable : m .val_map for m in grouped_mappings }
874
896
configure_axes (args , constructor , fig , axes , orders )
875
897
configure_animation_controls (args , constructor , fig )
0 commit comments