@@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders):
375
375
376
376
# Configure axis ticks on marginal subplots
377
377
if args ["marginal_x" ]:
378
- fig .update_yaxes (
379
- showticklabels = False , showgrid = args ["marginal_x" ] == "histogram" , row = nrows
380
- )
381
- fig .update_xaxes (showgrid = True , row = nrows )
378
+ fig .update_yaxes (showticklabels = False , row = nrows )
379
+ if args ["template" ].layout .yaxis .showgrid is None :
380
+ fig .update_yaxes (showgrid = args ["marginal_x" ] == "histogram" , row = nrows )
381
+ if args ["template" ].layout .xaxis .showgrid is None :
382
+ fig .update_xaxes (showgrid = True , row = nrows )
382
383
383
384
if args ["marginal_y" ]:
384
- fig .update_xaxes (
385
- showticklabels = False , showgrid = args ["marginal_y" ] == "histogram" , col = ncols
386
- )
387
- fig .update_yaxes (showgrid = True , col = ncols )
385
+ fig .update_xaxes (showticklabels = False , col = ncols )
386
+ if args ["template" ].layout .xaxis .showgrid is None :
387
+ fig .update_xaxes (showgrid = args ["marginal_y" ] == "histogram" , col = ncols )
388
+ if args ["template" ].layout .yaxis .showgrid is None :
389
+ fig .update_yaxes (showgrid = True , col = ncols )
388
390
389
391
# Add axis titles to non-marginal subplots
390
392
y_title = get_decorated_label (args , args ["y" ], "y" )
@@ -687,55 +689,47 @@ def apply_default_cascade(args):
687
689
else :
688
690
args ["template" ] = "plotly"
689
691
690
- # retrieve the actual template if we were given a name
691
692
try :
692
- template = pio .templates [args ["template" ]]
693
+ # retrieve the actual template if we were given a name
694
+ args ["template" ] = pio .templates [args ["template" ]]
693
695
except Exception :
694
- template = args ["template" ]
696
+ # otherwise try to build a real template
697
+ args ["template" ] = go .layout .Template (args ["template" ])
695
698
696
699
# if colors not set explicitly or in px.defaults, defer to a template
697
700
# if the template doesn't have one, we set some final fallback defaults
698
701
if "color_continuous_scale" in args :
699
- if args [ "color_continuous_scale" ] is None :
700
- try :
701
- args ["color_continuous_scale" ] = [
702
- x [ 1 ] for x in template . layout . colorscale . sequential
703
- ]
704
- except ( AttributeError , TypeError ):
705
- pass
702
+ if (
703
+ args [ "color_continuous_scale" ] is None
704
+ and args ["template" ]. layout . colorscale . sequential
705
+ ):
706
+ args [ "color_continuous_scale" ] = [
707
+ x [ 1 ] for x in args [ "template" ]. layout . colorscale . sequential
708
+ ]
706
709
if args ["color_continuous_scale" ] is None :
707
710
args ["color_continuous_scale" ] = sequential .Viridis
708
711
709
712
if "color_discrete_sequence" in args :
710
- if args ["color_discrete_sequence" ] is None :
711
- try :
712
- args ["color_discrete_sequence" ] = template .layout .colorway
713
- except (AttributeError , TypeError ):
714
- pass
713
+ if args ["color_discrete_sequence" ] is None and args ["template" ].layout .colorway :
714
+ args ["color_discrete_sequence" ] = args ["template" ].layout .colorway
715
715
if args ["color_discrete_sequence" ] is None :
716
716
args ["color_discrete_sequence" ] = qualitative .D3
717
717
718
718
# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
719
719
# see if we can defer to template. If not, set reasonable defaults
720
720
if "symbol_sequence" in args :
721
- if args ["symbol_sequence" ] is None :
722
- try :
723
- args ["symbol_sequence" ] = [
724
- scatter .marker .symbol for scatter in template .data .scatter
725
- ]
726
- except (AttributeError , TypeError ):
727
- pass
721
+ if args ["symbol_sequence" ] is None and args ["template" ].data .scatter :
722
+ args ["symbol_sequence" ] = [
723
+ scatter .marker .symbol for scatter in args ["template" ].data .scatter
724
+ ]
728
725
if not args ["symbol_sequence" ] or not any (args ["symbol_sequence" ]):
729
726
args ["symbol_sequence" ] = ["circle" , "diamond" , "square" , "x" , "cross" ]
730
727
731
728
if "line_dash_sequence" in args :
732
- if args ["line_dash_sequence" ] is None :
733
- try :
734
- args ["line_dash_sequence" ] = [
735
- scatter .line .dash for scatter in template .data .scatter
736
- ]
737
- except (AttributeError , TypeError ):
738
- pass
729
+ if args ["line_dash_sequence" ] is None and args ["template" ].data .scatter :
730
+ args ["line_dash_sequence" ] = [
731
+ scatter .line .dash for scatter in args ["template" ].data .scatter
732
+ ]
739
733
if not args ["line_dash_sequence" ] or not any (args ["line_dash_sequence" ]):
740
734
args ["line_dash_sequence" ] = [
741
735
"solid" ,
@@ -1264,13 +1258,17 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1264
1258
cmax = range_color [1 ],
1265
1259
colorbar = dict (title = get_decorated_label (args , args [colorvar ], colorvar )),
1266
1260
)
1267
- for v in ["title" , "height" , "width" , "template" ]:
1261
+ for v in ["title" , "height" , "width" ]:
1268
1262
if args [v ]:
1269
1263
layout_patch [v ] = args [v ]
1270
1264
layout_patch ["legend" ] = {"tracegroupgap" : 0 }
1271
- if "title" not in layout_patch :
1265
+ if "title" not in layout_patch and args [ "template" ]. layout . margin . t is None :
1272
1266
layout_patch ["margin" ] = {"t" : 60 }
1273
- if "size" in args and args ["size" ]:
1267
+ if (
1268
+ "size" in args
1269
+ and args ["size" ]
1270
+ and args ["template" ].layout .legend .itemsizing is None
1271
+ ):
1274
1272
layout_patch ["legend" ]["itemsizing" ] = "constant"
1275
1273
1276
1274
fig = init_figure (
@@ -1295,6 +1293,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1295
1293
# Add traces, layout and frames to figure
1296
1294
fig .add_traces (frame_list [0 ]["data" ] if len (frame_list ) > 0 else [])
1297
1295
fig .layout .update (layout_patch )
1296
+ if "template" in args and args ["template" ] is not None :
1297
+ fig .update_layout (template = args ["template" ], overwrite = True )
1298
1298
fig .frames = frame_list if len (frames ) > 1 else []
1299
1299
1300
1300
fig ._px_trendlines = pd .DataFrame (trendline_rows )
0 commit comments