Skip to content

Commit 3e35a6b

Browse files
Merge pull request #1909 from plotly/px-pie-etc
adding pie, sunburst, funnel and funnelarea to px
2 parents 3ca829c + 0dd2f39 commit 3e35a6b

File tree

5 files changed

+474
-5
lines changed

5 files changed

+474
-5
lines changed

packages/python/plotly/plotly/express/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
choropleth,
4040
density_contour,
4141
density_heatmap,
42+
pie,
43+
sunburst,
44+
treemap,
45+
funnel,
46+
funnel_area,
4247
)
4348

4449
from ._imshow import imshow
@@ -77,6 +82,11 @@
7782
"strip",
7883
"histogram",
7984
"choropleth",
85+
"pie",
86+
"sunburst",
87+
"treemap",
88+
"funnel",
89+
"funnel_area",
8090
"imshow",
8191
"data",
8292
"colors",

packages/python/plotly/plotly/express/_chart_types.py

+205
Original file line numberDiff line numberDiff line change
@@ -1115,3 +1115,208 @@ def parallel_categories(
11151115

11161116

11171117
parallel_categories.__doc__ = make_docstring(parallel_categories)
1118+
1119+
1120+
def pie(
1121+
data_frame=None,
1122+
names=None,
1123+
values=None,
1124+
color=None,
1125+
color_discrete_sequence=None,
1126+
color_discrete_map={},
1127+
hover_name=None,
1128+
hover_data=None,
1129+
custom_data=None,
1130+
labels={},
1131+
title=None,
1132+
template=None,
1133+
width=None,
1134+
height=None,
1135+
opacity=None,
1136+
hole=None,
1137+
):
1138+
"""
1139+
In a pie plot, each row of `data_frame` is represented as a sector of a pie.
1140+
"""
1141+
if color_discrete_sequence is not None:
1142+
layout_patch = {"piecolorway": color_discrete_sequence}
1143+
else:
1144+
layout_patch = {}
1145+
return make_figure(
1146+
args=locals(),
1147+
constructor=go.Pie,
1148+
trace_patch=dict(showlegend=(names is not None), hole=hole),
1149+
layout_patch=layout_patch,
1150+
)
1151+
1152+
1153+
pie.__doc__ = make_docstring(
1154+
pie,
1155+
override_dict=dict(
1156+
hole=[
1157+
"float",
1158+
"Sets the fraction of the radius to cut out of the pie."
1159+
"Use this to make a donut chart.",
1160+
],
1161+
),
1162+
)
1163+
1164+
1165+
def sunburst(
1166+
data_frame=None,
1167+
names=None,
1168+
values=None,
1169+
parents=None,
1170+
ids=None,
1171+
color=None,
1172+
color_continuous_scale=None,
1173+
range_color=None,
1174+
color_continuous_midpoint=None,
1175+
color_discrete_sequence=None,
1176+
color_discrete_map={},
1177+
hover_name=None,
1178+
hover_data=None,
1179+
custom_data=None,
1180+
labels={},
1181+
title=None,
1182+
template=None,
1183+
width=None,
1184+
height=None,
1185+
branchvalues=None,
1186+
maxdepth=None,
1187+
):
1188+
"""
1189+
A sunburst plot represents hierarchial data as sectors laid out over
1190+
several levels of concentric rings.
1191+
"""
1192+
if color_discrete_sequence is not None:
1193+
layout_patch = {"sunburstcolorway": color_discrete_sequence}
1194+
else:
1195+
layout_patch = {}
1196+
return make_figure(
1197+
args=locals(),
1198+
constructor=go.Sunburst,
1199+
trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth),
1200+
layout_patch=layout_patch,
1201+
)
1202+
1203+
1204+
sunburst.__doc__ = make_docstring(sunburst)
1205+
1206+
1207+
def treemap(
1208+
data_frame=None,
1209+
names=None,
1210+
values=None,
1211+
parents=None,
1212+
ids=None,
1213+
color=None,
1214+
color_continuous_scale=None,
1215+
range_color=None,
1216+
color_continuous_midpoint=None,
1217+
color_discrete_sequence=None,
1218+
color_discrete_map={},
1219+
hover_name=None,
1220+
hover_data=None,
1221+
custom_data=None,
1222+
labels={},
1223+
title=None,
1224+
template=None,
1225+
width=None,
1226+
height=None,
1227+
branchvalues=None,
1228+
maxdepth=None,
1229+
):
1230+
"""
1231+
A treemap plot represents hierarchial data as nested rectangular sectors.
1232+
"""
1233+
if color_discrete_sequence is not None:
1234+
layout_patch = {"treemapcolorway": color_discrete_sequence}
1235+
else:
1236+
layout_patch = {}
1237+
return make_figure(
1238+
args=locals(),
1239+
constructor=go.Treemap,
1240+
trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth),
1241+
layout_patch=layout_patch,
1242+
)
1243+
1244+
1245+
treemap.__doc__ = make_docstring(treemap)
1246+
1247+
1248+
def funnel(
1249+
data_frame=None,
1250+
x=None,
1251+
y=None,
1252+
color=None,
1253+
facet_row=None,
1254+
facet_col=None,
1255+
facet_col_wrap=0,
1256+
hover_name=None,
1257+
hover_data=None,
1258+
custom_data=None,
1259+
text=None,
1260+
animation_frame=None,
1261+
animation_group=None,
1262+
category_orders={},
1263+
labels={},
1264+
color_discrete_sequence=None,
1265+
color_discrete_map={},
1266+
opacity=None,
1267+
orientation="h",
1268+
log_x=False,
1269+
log_y=False,
1270+
range_x=None,
1271+
range_y=None,
1272+
title=None,
1273+
template=None,
1274+
width=None,
1275+
height=None,
1276+
):
1277+
"""
1278+
In a funnel plot, each row of `data_frame` is represented as a rectangular sector of a funnel.
1279+
"""
1280+
return make_figure(
1281+
args=locals(),
1282+
constructor=go.Funnel,
1283+
trace_patch=dict(opacity=opacity, orientation=orientation),
1284+
)
1285+
1286+
1287+
funnel.__doc__ = make_docstring(funnel)
1288+
1289+
1290+
def funnel_area(
1291+
data_frame=None,
1292+
names=None,
1293+
values=None,
1294+
color=None,
1295+
color_discrete_sequence=None,
1296+
color_discrete_map={},
1297+
hover_name=None,
1298+
hover_data=None,
1299+
custom_data=None,
1300+
labels={},
1301+
title=None,
1302+
template=None,
1303+
width=None,
1304+
height=None,
1305+
opacity=None,
1306+
):
1307+
"""
1308+
In a funnel area plot, each row of `data_frame` is represented as a trapezoidal sector of a funnel.
1309+
"""
1310+
if color_discrete_sequence is not None:
1311+
layout_patch = {"funnelareacolorway": color_discrete_sequence}
1312+
else:
1313+
layout_patch = {}
1314+
return make_figure(
1315+
args=locals(),
1316+
constructor=go.Funnelarea,
1317+
trace_patch=dict(showlegend=(names is not None)),
1318+
layout_patch=layout_patch,
1319+
)
1320+
1321+
1322+
funnel_area.__doc__ = make_docstring(funnel_area)

packages/python/plotly/plotly/express/_core.py

+75-2
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,28 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
291291
result["z"] = g[v]
292292
result["coloraxis"] = "coloraxis1"
293293
mapping_labels[v_label] = "%{z}"
294+
elif trace_spec.constructor in [
295+
go.Sunburst,
296+
go.Treemap,
297+
go.Pie,
298+
go.Funnelarea,
299+
]:
300+
if "marker" not in result:
301+
result["marker"] = dict()
302+
303+
if args.get("color_is_continuous"):
304+
result["marker"]["colors"] = g[v]
305+
result["marker"]["coloraxis"] = "coloraxis1"
306+
mapping_labels[v_label] = "%{color}"
307+
else:
308+
result["marker"]["colors"] = []
309+
mapping = {}
310+
for cat in g[v]:
311+
if mapping.get(cat) is None:
312+
mapping[cat] = args["color_discrete_sequence"][
313+
len(mapping) % len(args["color_discrete_sequence"])
314+
]
315+
result["marker"]["colors"].append(mapping[cat])
294316
else:
295317
colorable = "marker"
296318
if trace_spec.constructor in [go.Parcats, go.Parcoords]:
@@ -305,11 +327,38 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
305327
elif k == "locations":
306328
result[k] = g[v]
307329
mapping_labels[v_label] = "%{location}"
330+
elif k == "values":
331+
result[k] = g[v]
332+
_label = "value" if v_label == "values" else v_label
333+
mapping_labels[_label] = "%{value}"
334+
elif k == "parents":
335+
result[k] = g[v]
336+
_label = "parent" if v_label == "parents" else v_label
337+
mapping_labels[_label] = "%{parent}"
338+
elif k == "ids":
339+
result[k] = g[v]
340+
_label = "id" if v_label == "ids" else v_label
341+
mapping_labels[_label] = "%{id}"
342+
elif k == "names":
343+
if trace_spec.constructor in [
344+
go.Sunburst,
345+
go.Treemap,
346+
go.Pie,
347+
go.Funnelarea,
348+
]:
349+
result["labels"] = g[v]
350+
_label = "label" if v_label == "names" else v_label
351+
mapping_labels[_label] = "%{label}"
352+
else:
353+
result[k] = g[v]
308354
else:
309355
if v:
310356
result[k] = g[v]
311357
mapping_labels[v_label] = "%%{%s}" % k
312-
if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
358+
if trace_spec.constructor not in [
359+
go.Parcoords,
360+
go.Parcats,
361+
]:
313362
hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
314363
result["hovertemplate"] = hover_header + "<br>".join(hover_lines)
315364
return result, fit_results
@@ -674,6 +723,7 @@ def one_group(x):
674723

675724
def apply_default_cascade(args):
676725
# first we apply px.defaults to unspecified args
726+
677727
for param in (
678728
["color_discrete_sequence", "color_continuous_scale"]
679729
+ ["symbol_sequence", "line_dash_sequence", "template"]
@@ -956,6 +1006,7 @@ def infer_config(args, constructor, trace_patch):
9561006
attrables = (
9571007
["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
9581008
+ ["custom_data", "hover_name", "hover_data", "text"]
1009+
+ ["names", "values", "parents", "ids"]
9591010
+ ["error_x", "error_x_minus"]
9601011
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
9611012
+ ["lat", "lon", "locations", "animation_group"]
@@ -989,14 +1040,34 @@ def infer_config(args, constructor, trace_patch):
9891040
and args["data_frame"][args["color"]].dtype.kind in "bifc"
9901041
):
9911042
attrs.append("color")
1043+
args["color_is_continuous"] = True
1044+
elif constructor in [go.Sunburst, go.Treemap]:
1045+
attrs.append("color")
1046+
args["color_is_continuous"] = False
9921047
else:
9931048
grouped_attrs.append("marker.color")
9941049
elif "line_group" in args or constructor == go.Histogram2dContour:
9951050
grouped_attrs.append("line.color")
1051+
elif constructor in [go.Pie, go.Funnelarea]:
1052+
attrs.append("color")
1053+
if args["color"]:
1054+
if args["hover_data"] is None:
1055+
args["hover_data"] = []
1056+
args["hover_data"].append(args["color"])
9961057
else:
9971058
grouped_attrs.append("marker.color")
9981059

999-
show_colorbar = bool("color" in attrs and args["color"])
1060+
show_colorbar = bool(
1061+
"color" in attrs
1062+
and args["color"]
1063+
and constructor not in [go.Pie, go.Funnelarea]
1064+
and (
1065+
constructor not in [go.Treemap, go.Sunburst]
1066+
or args.get("color_is_continuous")
1067+
)
1068+
)
1069+
else:
1070+
show_colorbar = False
10001071

10011072
# Compute line_dash grouping attribute
10021073
if "line_dash" in args:
@@ -1148,6 +1219,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
11481219
go.Parcoords,
11491220
go.Choropleth,
11501221
go.Histogram2d,
1222+
go.Sunburst,
1223+
go.Treemap,
11511224
]:
11521225
trace.update(
11531226
legendgroup=trace_name,

0 commit comments

Comments
 (0)