@@ -756,24 +756,27 @@ def infer_config(args, constructor, trace_patch):
756
756
return trace_specs , grouped_mappings , sizeref , color_range
757
757
758
758
759
- def make_figure (args , constructor , trace_patch = {}, layout_patch = {}):
760
- apply_default_cascade (args )
761
- trace_specs , grouped_mappings , sizeref , color_range = infer_config (
762
- args , constructor , trace_patch
763
- )
764
- grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
765
- grouped = args ["data_frame" ].groupby (grouper , sort = False )
759
+ def get_orderings (args , grouper , grouped ):
760
+ """
761
+ `orders` is the user-supplied ordering (with the remaining data-frame-supplied
762
+ ordering appended if the column is used for grouping)
763
+ `group_names` is the set of groups, ordered by the order above
764
+ """
766
765
orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
767
766
group_names = []
768
767
for group_name in grouped .groups :
769
768
if len (grouper ) == 1 :
770
769
group_name = (group_name ,)
771
770
group_names .append (group_name )
772
- for col , val in zip (grouper , group_name ):
773
- if col not in orders :
774
- orders [col ] = []
775
- if val not in orders [col ]:
776
- orders [col ].append (val )
771
+ for col in grouper :
772
+ if col != one_group :
773
+ uniques = args ["data_frame" ][col ].unique ()
774
+ if col not in orders :
775
+ orders [col ] = list (uniques )
776
+ else :
777
+ for val in uniques :
778
+ if val not in orders [col ]:
779
+ orders [col ].append (val )
777
780
778
781
for i , col in reversed (list (enumerate (grouper ))):
779
782
if col != one_group :
@@ -782,10 +785,23 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
782
785
key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
783
786
)
784
787
788
+ return orders , group_names
789
+
790
+
791
+ def make_figure (args , constructor , trace_patch = {}, layout_patch = {}):
792
+ apply_default_cascade (args )
793
+ trace_specs , grouped_mappings , sizeref , color_range = infer_config (
794
+ args , constructor , trace_patch
795
+ )
796
+ grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
797
+ grouped = args ["data_frame" ].groupby (grouper , sort = False )
798
+
799
+ orders , sorted_group_names = get_orderings (args , grouper , grouped )
800
+
785
801
trace_names_by_frame = {}
786
802
frames = OrderedDict ()
787
803
trendline_rows = []
788
- for group_name in group_names :
804
+ for group_name in sorted_group_names :
789
805
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
790
806
mapping_labels = OrderedDict ()
791
807
trace_name_labels = OrderedDict ()
0 commit comments