diff --git a/doc/python/sunburst-charts.md b/doc/python/sunburst-charts.md index c0b01c61895..db659f99092 100644 --- a/doc/python/sunburst-charts.md +++ b/doc/python/sunburst-charts.md @@ -62,6 +62,53 @@ fig =px.sunburst( fig.show() ``` +### Sunburst of a rectangular DataFrame with plotly.express + +Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.sunburst` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given. + +```python +import plotly.express as px +df = px.data.tips() +fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill') +fig.show() +``` + +### Sunburst of a rectangular DataFrame with continuous color argument in px.sunburst + +If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values. + +```python +import plotly.express as px +import numpy as np +df = px.data.gapminder().query("year == 2007") +fig = px.sunburst(df, path=['continent', 'country'], values='pop', + color='lifeExp', hover_data=['iso_alpha'], + color_continuous_scale='RdBu', + color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop'])) +fig.show() +``` + +### Rectangular data with missing values + +If the dataset is not fully rectangular, missing values should be supplied as `None`. Note that the parents of `None` entries must be a leaf, i.e. it cannot have other children than `None` (otherwise a `ValueError` is raised). + +```python +import plotly.express as px +import pandas as pd +vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] +sectors = ["Tech", "Tech", "Finance", "Finance", "Other", + "Tech", "Tech", "Finance", "Finance", "Other"] +regions = ["North", "North", "North", "North", "North", + "South", "South", "South", "South", "South"] +sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] +df = pd.DataFrame( + dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales) +) +print(df) +fig = px.sunburst(df, path=['regions', 'sectors', 'vendors'], values='sales') +fig.show() +``` + ### Basic Sunburst Plot with go.Sunburst If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Sunburst` function from `plotly.graph_objects`. diff --git a/doc/python/treemaps.md b/doc/python/treemaps.md index 02167e3cdb8..8b1cc97d1f1 100644 --- a/doc/python/treemaps.md +++ b/doc/python/treemaps.md @@ -51,6 +51,52 @@ fig = px.treemap( fig.show() ``` +### Treemap of a rectangular DataFrame with plotly.express + +Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.treemap` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given. + +```python +import plotly.express as px +df = px.data.tips() +fig = px.treemap(df, path=['day', 'time', 'sex'], values='total_bill') +fig.show() +``` + +### Treemap of a rectangular DataFrame with continuous color argument in px.treemap + +If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values. + +```python +import plotly.express as px +import numpy as np +df = px.data.gapminder().query("year == 2007") +fig = px.treemap(df, path=['continent', 'country'], values='pop', + color='lifeExp', hover_data=['iso_alpha'], + color_continuous_scale='RdBu', + color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop'])) +fig.show() +``` + +### Rectangular data with missing values + +If the dataset is not fully rectangular, missing values should be supplied as `None`. + +```python +import plotly.express as px +import pandas as pd +vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] +sectors = ["Tech", "Tech", "Finance", "Finance", "Other", + "Tech", "Tech", "Finance", "Finance", "Other"] +regions = ["North", "North", "North", "North", "North", + "South", "South", "South", "South", "South"] +sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] +df = pd.DataFrame( + dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales) +) +print(df) +fig = px.treemap(df, path=['regions', 'sectors', 'vendors'], values='sales') +fig.show() +``` ### Basic Treemap with go.Treemap If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Treemap` function from `plotly.graph_objects`. diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 12cf439e2b3..7d4dd7e0df8 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -1269,6 +1269,7 @@ def sunburst( names=None, values=None, parents=None, + path=None, ids=None, color=None, color_continuous_scale=None, @@ -1295,6 +1296,13 @@ def sunburst( layout_patch = {"sunburstcolorway": color_discrete_sequence} else: layout_patch = {} + if path is not None and (ids is not None or parents is not None): + raise ValueError( + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) + if path is not None and branchvalues is None: + branchvalues = "total" return make_figure( args=locals(), constructor=go.Sunburst, @@ -1312,6 +1320,7 @@ def treemap( values=None, parents=None, ids=None, + path=None, color=None, color_continuous_scale=None, range_color=None, @@ -1337,6 +1346,13 @@ def treemap( layout_patch = {"treemapcolorway": color_discrete_sequence} else: layout_patch = {} + if path is not None and (ids is not None or parents is not None): + raise ValueError( + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) + if path is not None and branchvalues is None: + branchvalues = "total" return make_figure( args=locals(), constructor=go.Treemap, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 21c0ca03cc1..e43b4beb76b 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1007,6 +1007,147 @@ def build_dataframe(args, attrables, array_attrables): return args +def _check_dataframe_all_leaves(df): + df_sorted = df.sort_values(by=list(df.columns)) + null_mask = df_sorted.isnull() + null_indices = np.nonzero(null_mask.any(axis=1).values)[0] + for null_row_index in null_indices: + row = null_mask.iloc[null_row_index] + indices = np.nonzero(row.values)[0] + if not row[indices[0] :].all(): + raise ValueError( + "None entries cannot have not-None children", + df_sorted.iloc[null_row_index], + ) + df_sorted[null_mask] = "" + row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1)) + for i, row in enumerate(row_strings[:-1]): + if row_strings[i + 1] in row and (i + 1) in null_indices: + raise ValueError( + "Non-leaves rows are not permitted in the dataframe \n", + df_sorted.iloc[i + 1], + "is not a leaf.", + ) + + +def process_dataframe_hierarchy(args): + """ + Build dataframe for sunburst or treemap when the path argument is provided. + """ + df = args["data_frame"] + path = args["path"][::-1] + _check_dataframe_all_leaves(df[path[::-1]]) + discrete_color = False + + if args["color"] and args["color"] in path: + series_to_copy = df[args["color"]] + args["color"] = str(args["color"]) + "additional_col_for_px" + df[args["color"]] = series_to_copy + if args["hover_data"]: + for col_name in args["hover_data"]: + if col_name == args["color"]: + series_to_copy = df[col_name] + new_col_name = str(args["color"]) + "additional_col_for_hover" + df[new_col_name] = series_to_copy + args["color"] = new_col_name + elif col_name in path: + series_to_copy = df[col_name] + new_col_name = col_name + "additional_col_for_hover" + path = [new_col_name if x == col_name else x for x in path] + df[new_col_name] = series_to_copy + # ------------ Define aggregation functions -------------------------------- + def aggfunc_discrete(x): + uniques = x.unique() + if len(uniques) == 1: + return uniques[0] + else: + return "(?)" + + agg_f = {} + aggfunc_color = None + if args["values"]: + try: + df[args["values"]] = pd.to_numeric(df[args["values"]]) + except ValueError: + raise ValueError( + "Column `%s` of `df` could not be converted to a numerical data type." + % args["values"] + ) + + if args["color"]: + if args["color"] == args["values"]: + aggfunc_color = "sum" + count_colname = args["values"] + else: + # we need a count column for the first groupby and the weighted mean of color + # trick to be sure the col name is unused: take the sum of existing names + count_colname = ( + "count" + if "count" not in df.columns + else "".join([str(el) for el in list(df.columns)]) + ) + # we can modify df because it's a copy of the px argument + df[count_colname] = 1 + args["values"] = count_colname + agg_f[count_colname] = "sum" + + if args["color"]: + if df[args["color"]].dtype.kind not in "bifc": + aggfunc_color = aggfunc_discrete + discrete_color = True + elif not aggfunc_color: + + def aggfunc_continuous(x): + return np.average(x, weights=df.loc[x.index, count_colname]) + + aggfunc_color = aggfunc_continuous + agg_f[args["color"]] = aggfunc_color + + # Other columns (for color, hover_data, custom_data etc.) + cols = list(set(df.columns).difference(path)) + for col in cols: # for hover_data, custom_data etc. + if col not in agg_f: + agg_f[col] = aggfunc_discrete + # ---------------------------------------------------------------------------- + + df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) + # Set column type here (useful for continuous vs discrete colorscale) + for col in cols: + df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) + for i, level in enumerate(path): + df_tree = pd.DataFrame(columns=df_all_trees.columns) + dfg = df.groupby(path[i:]).agg(agg_f) + dfg = dfg.reset_index() + # Path label massaging + df_tree["labels"] = dfg[level].copy().astype(str) + df_tree["parent"] = "" + df_tree["id"] = dfg[level].copy().astype(str) + if i < len(path) - 1: + j = i + 1 + while j < len(path): + df_tree["parent"] = ( + dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"] + ) + df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"] + j += 1 + + df_tree["parent"] = df_tree["parent"].str.rstrip("/") + if cols: + df_tree[cols] = dfg[cols] + df_all_trees = df_all_trees.append(df_tree, ignore_index=True) + + if args["color"] and discrete_color: + df_all_trees = df_all_trees.sort_values(by=args["color"]) + + # Now modify arguments + args["data_frame"] = df_all_trees + args["path"] = None + args["ids"] = "id" + args["names"] = "labels" + args["parents"] = "parent" + return args + + def infer_config(args, constructor, trace_patch): # Declare all supported attributes, across all plot types attrables = ( @@ -1015,9 +1156,9 @@ def infer_config(args, constructor, trace_patch): + ["names", "values", "parents", "ids"] + ["error_x", "error_x_minus"] + ["error_y", "error_y_minus", "error_z", "error_z_minus"] - + ["lat", "lon", "locations", "animation_group"] + + ["lat", "lon", "locations", "animation_group", "path"] ) - array_attrables = ["dimensions", "custom_data", "hover_data"] + array_attrables = ["dimensions", "custom_data", "hover_data", "path"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] all_attrables = attrables + group_attrables + ["color"] group_attrs = ["symbol", "line_dash"] @@ -1026,6 +1167,8 @@ def infer_config(args, constructor, trace_patch): all_attrables += [group_attr] args = build_dataframe(args, all_attrables, array_attrables) + if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None: + args = process_dataframe_hierarchy(args) attrs = [k for k in attrables if k in args] grouped_attrs = [] diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 3a5b9344e19..b3c6d39dc7a 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -86,6 +86,12 @@ colref_desc, "Values from this column or array_like are used to set ids of sectors", ], + path=[ + colref_list_type, + colref_list_desc, + "List of columns names or columns of a rectangular dataframe defining the hierarchy of sectors, from root to leaves.", + "An error is raised if path AND ids or parents is passed", + ], lat=[ colref_type, colref_desc, diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 339accf9d57..8ae1e9ea791 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -2,6 +2,8 @@ import plotly.graph_objects as go from numpy.testing import assert_array_equal import numpy as np +import pandas as pd +import pytest def _compare_figures(go_trace, px_fig): @@ -111,6 +113,157 @@ def test_sunburst_treemap_colorscales(): assert list(fig.layout[colorway]) == color_seq +def test_sunburst_treemap_with_path(): + vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] + sectors = [ + "Tech", + "Tech", + "Finance", + "Finance", + "Tech", + "Tech", + "Finance", + "Finance", + ] + regions = ["North", "North", "North", "North", "South", "South", "South", "South"] + values = [1, 3, 2, 4, 2, 2, 1, 4] + total = ["total",] * 8 + df = pd.DataFrame( + dict( + vendors=vendors, + sectors=sectors, + regions=regions, + values=values, + total=total, + ) + ) + path = ["total", "regions", "sectors", "vendors"] + # No values + fig = px.sunburst(df, path=path) + assert fig.data[0].branchvalues == "total" + # Values passed + fig = px.sunburst(df, path=path, values="values") + assert fig.data[0].branchvalues == "total" + assert fig.data[0].values[-1] == np.sum(values) + # Values passed + fig = px.sunburst(df, path=path, values="values") + assert fig.data[0].branchvalues == "total" + assert fig.data[0].values[-1] == np.sum(values) + # Continuous colorscale + fig = px.sunburst(df, path=path, values="values", color="values") + assert "coloraxis" in fig.data[0].marker + assert np.all(np.array(fig.data[0].marker.colors) == np.array(fig.data[0].values)) + # Error when values cannot be converted to numerical data type + df["values"] = ["1 000", "3 000", "2", "4", "2", "2", "1 000", "4 000"] + msg = "Column `values` of `df` could not be converted to a numerical data type." + with pytest.raises(ValueError, match=msg): + fig = px.sunburst(df, path=path, values="values") + # path is a mixture of column names and array-like + path = [df.total, "regions", df.sectors, "vendors"] + fig = px.sunburst(df, path=path) + assert fig.data[0].branchvalues == "total" + + +def test_sunburst_treemap_with_path_and_hover(): + df = px.data.tips() + fig = px.sunburst( + df, path=["sex", "day", "time", "smoker"], color="smoker", hover_data=["smoker"] + ) + assert "smoker" in fig.data[0].hovertemplate + + +def test_sunburst_treemap_with_path_color(): + vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] + sectors = [ + "Tech", + "Tech", + "Finance", + "Finance", + "Tech", + "Tech", + "Finance", + "Finance", + ] + regions = ["North", "North", "North", "North", "South", "South", "South", "South"] + values = [1, 3, 2, 4, 2, 2, 1, 4] + calls = [8, 2, 1, 3, 2, 2, 4, 1] + total = ["total",] * 8 + df = pd.DataFrame( + dict( + vendors=vendors, + sectors=sectors, + regions=regions, + values=values, + total=total, + calls=calls, + ) + ) + path = ["total", "regions", "sectors", "vendors"] + fig = px.sunburst(df, path=path, values="values", color="calls") + colors = fig.data[0].marker.colors + assert np.all(np.array(colors[:8]) == np.array(calls)) + fig = px.sunburst(df, path=path, color="calls") + colors = fig.data[0].marker.colors + assert np.all(np.array(colors[:8]) == np.array(calls)) + + # Hover info + df["hover"] = [el.lower() for el in vendors] + fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"]) + custom = fig.data[0].customdata.ravel() + assert np.all(custom[:8] == df["hover"]) + assert np.all(custom[8:] == "(?)") + + # Discrete color + fig = px.sunburst(df, path=path, color="vendors") + assert len(np.unique(fig.data[0].marker.colors)) == 9 + + +def test_sunburst_treemap_with_path_non_rectangular(): + vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] + sectors = [ + "Tech", + "Tech", + "Finance", + "Finance", + None, + "Tech", + "Tech", + "Finance", + "Finance", + "Finance", + ] + regions = [ + "North", + "North", + "North", + "North", + "North", + "South", + "South", + "South", + "South", + "South", + ] + values = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] + total = ["total",] * 10 + df = pd.DataFrame( + dict( + vendors=vendors, + sectors=sectors, + regions=regions, + values=values, + total=total, + ) + ) + path = ["total", "regions", "sectors", "vendors"] + msg = "Non-leaves rows are not permitted in the dataframe" + with pytest.raises(ValueError, match=msg): + fig = px.sunburst(df, path=path, values="values") + df.loc[df["vendors"].isnull(), "sectors"] = "Other" + fig = px.sunburst(df, path=path, values="values") + assert fig.data[0].values[-1] == np.sum(values) + + def test_pie_funnelarea_colorscale(): labels = ["A", "B", "C", "D"] values = [3, 2, 1, 4]