Skip to content
This repository was archived by the owner on Jun 3, 2024. It is now read-only.

Commit 2d327d9

Browse files
fix groupby ordering when multi-grouping, should address #23
1 parent 9a64419 commit 2d327d9

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

plotly_express/_core.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -756,24 +756,27 @@ def infer_config(args, constructor, trace_patch):
756756
return trace_specs, grouped_mappings, sizeref, color_range
757757

758758

759-
def make_figure(args, constructor, trace_patch={}, layout_patch={}):
760-
apply_default_cascade(args)
761-
trace_specs, grouped_mappings, sizeref, color_range = infer_config(
762-
args, constructor, trace_patch
763-
)
764-
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
765-
grouped = args["data_frame"].groupby(grouper, sort=False)
759+
def get_orderings(args, grouper, grouped):
760+
"""
761+
`orders` is the user-supplied ordering (with the remaining data-frame-supplied
762+
ordering appended if the column is used for grouping)
763+
`group_names` is the set of groups, ordered by the order above
764+
"""
766765
orders = {} if "category_orders" not in args else args["category_orders"].copy()
767766
group_names = []
768767
for group_name in grouped.groups:
769768
if len(grouper) == 1:
770769
group_name = (group_name,)
771770
group_names.append(group_name)
772-
for col, val in zip(grouper, group_name):
773-
if col not in orders:
774-
orders[col] = []
775-
if val not in orders[col]:
776-
orders[col].append(val)
771+
for col in grouper:
772+
if col != one_group:
773+
uniques = args["data_frame"][col].unique()
774+
if col not in orders:
775+
orders[col] = list(uniques)
776+
else:
777+
for val in uniques:
778+
if val not in orders[col]:
779+
orders[col].append(val)
777780

778781
for i, col in reversed(list(enumerate(grouper))):
779782
if col != one_group:
@@ -782,10 +785,23 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
782785
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
783786
)
784787

788+
return orders, group_names
789+
790+
791+
def make_figure(args, constructor, trace_patch={}, layout_patch={}):
792+
apply_default_cascade(args)
793+
trace_specs, grouped_mappings, sizeref, color_range = infer_config(
794+
args, constructor, trace_patch
795+
)
796+
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
797+
grouped = args["data_frame"].groupby(grouper, sort=False)
798+
799+
orders, sorted_group_names = get_orderings(args, grouper, grouped)
800+
785801
trace_names_by_frame = {}
786802
frames = OrderedDict()
787803
trendline_rows = []
788-
for group_name in group_names:
804+
for group_name in sorted_group_names:
789805
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
790806
mapping_labels = OrderedDict()
791807
trace_name_labels = OrderedDict()

0 commit comments

Comments
 (0)