Skip to content

adding pie, treemap, sunburst, funnel and funnelarea to px #1909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 28, 2019
10 changes: 10 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
choropleth,
density_contour,
density_heatmap,
pie,
sunburst,
treemap,
funnel,
funnel_area,
)

from ._imshow import imshow
Expand Down Expand Up @@ -77,6 +82,11 @@
"strip",
"histogram",
"choropleth",
"pie",
"sunburst",
"treemap",
"funnel",
"funnel_area",
"imshow",
"data",
"colors",
Expand Down
169 changes: 169 additions & 0 deletions packages/python/plotly/plotly/express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,3 +1115,172 @@ def parallel_categories(


parallel_categories.__doc__ = make_docstring(parallel_categories)


def pie(
data_frame=None,
names=None,
values=None,
textinfo=None,
hover_name=None,
hover_data=None,
custom_data=None,
labels={},
title=None,
template=None,
width=None,
height=None,
opacity=None,
hole=None,
):
"""
In a pie plot, each row of `data_frame` is represented as a sector of a pie.
"""
return make_figure(
args=locals(), constructor=go.Pie, trace_patch=dict(showlegend=True, hole=hole)
)


pie.__doc__ = make_docstring(
pie,
override_dict=dict(
textinfo=["str", "Determines which trace information appear on the graph.",],
hole=[
"float",
"Sets the fraction of the radius to cut out of the pie."
"Use this to make a donut chart.",
],
),
)


def sunburst(
data_frame=None,
names=None,
values=None,
parents=None,
ids=None,
hover_name=None,
hover_data=None,
custom_data=None,
labels={},
title=None,
template=None,
width=None,
height=None,
branchvalues=None,
maxdepth=None,
):
"""
A sunburst plot represents hierarchial data as sectors laid out over
several levels of concentric rings.
"""
return make_figure(
args=locals(),
constructor=go.Sunburst,
trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth),
)


sunburst.__doc__ = make_docstring(sunburst)


def treemap(
data_frame=None,
names=None,
values=None,
parents=None,
ids=None,
hover_name=None,
hover_data=None,
custom_data=None,
labels={},
title=None,
template=None,
width=None,
height=None,
branchvalues=None,
maxdepth=None,
):
"""
A treemap plot represents hierarchial data as nested rectangular sectors.
"""
return make_figure(
args=locals(),
constructor=go.Treemap,
trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth),
)


treemap.__doc__ = make_docstring(treemap)


def funnel(
data_frame=None,
x=None,
y=None,
color=None,
color_discrete_sequence=None,
color_discrete_map={},
orientation=None,
textinfo=None,
hover_name=None,
hover_data=None,
custom_data=None,
labels={},
title=None,
template=None,
width=None,
height=None,
opacity=None,
):
"""
In a funnel plot, each row of `data_frame` is represented as a rectangular sector of a funnel.
"""
return make_figure(
args=locals(),
constructor=go.Funnel,
trace_patch=dict(opacity=opacity, orientation=orientation, textinfo=textinfo),
)


funnel.__doc__ = make_docstring(
funnel,
override_dict=dict(
textinfo=[
"str",
"Determines which trace information appear on the graph. In the case of having multiple funnels, percentages & totals are computed separately (per trace).",
]
),
)


def funnel_area(
data_frame=None,
values=None,
names=None,
textinfo=None,
hover_name=None,
hover_data=None,
custom_data=None,
labels={},
title=None,
template=None,
width=None,
height=None,
):
"""
In a funnel area plot, each row of `data_frame` is represented as a trapezoidal sector of a funnel.
"""
return make_figure(args=locals(), constructor=go.Funnelarea,)


funnel_area.__doc__ = make_docstring(
funnel_area,
override_dict=dict(
textinfo=[
"str",
"Determines which trace information appear on the graph. In the case of having multiple funnels, percentages & totals are computed separately (per trace).",
]
),
)
34 changes: 33 additions & 1 deletion packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,38 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
elif k == "locations":
result[k] = g[v]
mapping_labels[v_label] = "%{location}"
elif k == "values":
result[k] = g[v]
_label = "value" if v_label == "values" else v_label
mapping_labels[_label] = "%{value}"
elif k == "parents":
result[k] = g[v]
_label = "parent" if v_label == "parents" else v_label
mapping_labels[_label] = "%{parent}"
elif k == "ids":
result[k] = g[v]
_label = "id" if v_label == "ids" else v_label
mapping_labels[_label] = "%{id}"
elif k == "names":
if trace_spec.constructor in [
go.Sunburst,
go.Treemap,
go.Pie,
go.Funnelarea,
]:
result["labels"] = g[v]
_label = "label" if v_label == "names" else v_label
mapping_labels[_label] = "%{label}"
else:
result[k] = g[v]
else:
if v:
result[k] = g[v]
mapping_labels[v_label] = "%%{%s}" % k
if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
if trace_spec.constructor not in [
go.Parcoords,
go.Parcats,
]:
hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
result["hovertemplate"] = hover_header + "<br>".join(hover_lines)
return result, fit_results
Expand Down Expand Up @@ -956,6 +983,7 @@ def infer_config(args, constructor, trace_patch):
attrables = (
["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
+ ["custom_data", "hover_name", "hover_data", "text"]
+ ["names", "values", "parents", "ids"]
+ ["error_x", "error_x_minus"]
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
+ ["lat", "lon", "locations", "animation_group"]
Expand Down Expand Up @@ -997,6 +1025,8 @@ def infer_config(args, constructor, trace_patch):
grouped_attrs.append("marker.color")

show_colorbar = bool("color" in attrs and args["color"])
else:
show_colorbar = False

# Compute line_dash grouping attribute
if "line_dash" in args:
Expand Down Expand Up @@ -1148,6 +1178,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
go.Parcoords,
go.Choropleth,
go.Histogram2d,
go.Sunburst,
go.Treemap,
]:
trace.update(
legendgroup=trace_name,
Expand Down
46 changes: 43 additions & 3 deletions packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@
colref_desc,
"Values from this column or array_like are used to position marks along the angular axis in polar coordinates.",
],
values=[
colref_type,
colref_desc,
"Values from this column or array_like are used to set values associated to sectors.",
],
parents=[
colref_type,
colref_desc,
"Values from this column or array_like are used to set values associated to sectors.",
],
ids=[
colref_type,
colref_desc,
"Values from this column or array_like are used to set values associated to sectors.",
],
lat=[
colref_type,
colref_desc,
Expand Down Expand Up @@ -168,6 +183,11 @@
colref_desc,
"Values from this column or array_like appear in the figure as text labels.",
],
names=[
colref_type,
colref_desc,
"Values from this column or array_like are used as labels for sectors.",
],
locationmode=[
"str",
"One of 'ISO-3', 'USA-states', or 'country names'",
Expand Down Expand Up @@ -442,21 +462,41 @@
nbins=["int", "Positive integer.", "Sets the number of bins."],
nbinsx=["int", "Positive integer.", "Sets the number of bins along the x axis."],
nbinsy=["int", "Positive integer.", "Sets the number of bins along the y axis."],
branchvalues=[
"str",
"'total' or 'remainder'",
"Determines how the items in `values` are summed. When"
"set to 'total', items in `values` are taken to be value"
"of all its descendants. When set to 'remainder', items"
"in `values` corresponding to the root and the branches"
":sectors are taken to be the extra part not part of the"
"sum of the values at their leaves.",
],
maxdepth=[
"int",
"Positive integer",
"Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the"
"levels in the hierarchy.",
],
)


def make_docstring(fn):
def make_docstring(fn, override_dict={}):
tw = TextWrapper(width=77, initial_indent=" ", subsequent_indent=" ")
result = (fn.__doc__ or "") + "\nParameters\n----------\n"
for param in inspect.getargspec(fn)[0]:
param_desc_list = docs[param][1:]
if override_dict.get(param):
param_doc = override_dict[param]
else:
param_doc = docs[param]
param_desc_list = param_doc[1:]
param_desc = (
tw.fill(" ".join(param_desc_list or ""))
if param in docs
else "(documentation missing from map)"
)

param_type = docs[param][0]
param_type = param_doc[0]
result += "%s: %s\n%s\n" % (param, param_type, param_desc)
result += "\nReturns\n-------\n"
result += " A `Figure` object."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import plotly.express as px
import plotly.graph_objects as go
from numpy.testing import assert_array_equal

def _compare_figures(go_trace, px_fig):
"""Compare a figure created with a go trace and a figure created with
a px function call. Check that all values inside the go Figure are the
same in the px figure (which sets more parameters).
"""
go_fig = go.Figure(go_trace)
go_fig = go_fig.to_plotly_json()
px_fig = px_fig.to_plotly_json()
del go_fig["layout"]["template"]
del px_fig["layout"]["template"]
for key in go_fig['data'][0]:
assert_array_equal(go_fig['data'][0][key], px_fig['data'][0][key])
for key in go_fig['layout']:
assert go_fig['layout'][key] == px_fig['layout'][key]


def test_pie_like_px():
# Pie
labels = ['Oxygen','Hydrogen','Carbon_Dioxide','Nitrogen']
values = [4500, 2500, 1053, 500]

fig = px.pie(names=labels, values=values)
trace = go.Pie(labels=labels, values=values)
_compare_figures(trace, fig)

labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"]
parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve" ]
values = [10, 14, 12, 10, 2, 6, 6, 4, 4]
# Sunburst
fig = px.sunburst(names=labels, parents=parents, values=values)
trace = go.Sunburst(labels=labels, parents=parents, values=values)
_compare_figures(trace, fig)
# Treemap
fig = px.treemap(names=labels, parents=parents, values=values)
trace = go.Treemap(labels=labels, parents=parents, values=values)
_compare_figures(trace, fig)

# Funnel
x = ['A', 'B', 'C']
y = [3, 2, 1]
fig = px.funnel(y=y, x=x)
trace = go.Funnel(y=y, x=x)
_compare_figures(trace, fig)
# Funnelarea
fig = px.funnel_area(values=y, names=x)
trace = go.Funnelarea(values=y, labels=x)
_compare_figures(trace, fig)