From 440774a248a2c1beb18e85132e4ebc8985703c93 Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Fri, 30 Oct 2020 17:41:38 -0400 Subject: [PATCH 1/8] Prototype demonstrating possible px.combine --- proto/px_combine_proto/px_combine.py | 154 +++++++++++++++++++++++++++ proto/px_combine_proto/test_data.py | 17 +++ 2 files changed, 171 insertions(+) create mode 100644 proto/px_combine_proto/px_combine.py create mode 100644 proto/px_combine_proto/test_data.py diff --git a/proto/px_combine_proto/px_combine.py b/proto/px_combine_proto/px_combine.py new file mode 100644 index 00000000000..c0166334a98 --- /dev/null +++ b/proto/px_combine_proto/px_combine.py @@ -0,0 +1,154 @@ +# Prototype for px.combine +# Combine 2 figures containing subplots +# Run as +# python px_combine.py + +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import test_data +import json +from itertools import product + + +def multi_index(*kwargs): + return product(*[range(k) for k in kwargs]) + + +def extract_axes(layout): + ret = dict() + for k in dir(layout): + if k[1 : 1 + len("axis")] == "axis": + ret[k] = layout[k] + return ret + + +def fig_grid_ref_shape(fig): + grid_ref = fig._validate_get_grid_ref() + return (len(grid_ref), len(grid_ref[0])) + + +def fig_subplot_axes(fig, r, c): + grid_ref = fig._validate_get_grid_ref() + return [fig.layout[k] for k in grid_ref[r - 1][c - 1][0].layout_keys] + + +def extract_axis_titles(fig): + """ + Given figure created using make_subplots, with r rows and c columns, return + r titles from the x axes and y titles from the y axes. + """ + grid_ref_shape = fig_grid_ref_shape(fig) + r_titles = [ + fig_subplot_axes(fig, r + 1, 1)[1]["title"] for r in range(grid_ref_shape[0]) + ] + c_titles = [ + fig_subplot_axes(fig, 1, c + 1)[0]["title"] for c in range(grid_ref_shape[1]) + ] + return (r_titles, c_titles) + + +def px_simple_combine(fig0, fig1): + """ + Combines two figures by just using the layout of the first figure and + appending the data of the second figure. + """ + grid_ref_shape = fig_grid_ref_shape(fig0) + if grid_ref_shape != fig_grid_ref_shape(fig1): + raise ValueError( + "Only two figures with the same subplot geometry can be combined." + ) + if fig0.layout.annotations != fig1.layout.annotations: + raise ValueError( + "Only two figures created with Plotly Express with " + "identical faceting can be combined." + ) + fig = go.Figure(data=fig0.data + fig1.data, layout=fig0.layout) + return fig + + +def px_combine_secondary_y(fig0, fig1): + """ + Combines two figures that have the same faceting but whose y axes refer + to different data by referencing the second figure's y-data to secondary + y-axes. + """ + grid_ref_shape = fig_grid_ref_shape(fig0) + if grid_ref_shape != fig_grid_ref_shape(fig1): + raise ValueError( + "Only two figures with the same subplot geometry can be combined." + ) + if fig0.layout.annotations != fig1.layout.annotations: + raise ValueError( + "Only two figures created with Plotly Express with " + "identical faceting can be combined." + ) + specs = [ + [dict(secondary_y=True) for __ in range(grid_ref_shape[1])] + for _ in range(grid_ref_shape[0]) + ] + fig = make_subplots(*grid_ref_shape, specs=specs, start_cell="bottom-left") + fig0_ax_titles = extract_axis_titles(fig0) + fig1_ax_titles = extract_axis_titles(fig1) + # set primary y + for r, c in multi_index(*grid_ref_shape): + for tr in fig0.select_traces(row=r + 1, col=c + 1): + fig.add_trace(tr, row=r + 1, col=c + 1) + if r == 0: + t = fig0_ax_titles[1][c] + fig.update_xaxes(title=t, row=r + 1, col=c + 1) + if c == 0: + t = fig0_ax_titles[0][r] + fig.update_yaxes( + selector=lambda ax: ax["side"] != "right", title=t, row=r + 1, col=c + 1 + ) + # set secondary y + for r, c in multi_index(*grid_ref_shape): + for tr in fig1.select_traces(row=r + 1, col=c + 1): + # TODO: How to set meaningful color regardless of trace type? + tr["marker_color"] = "red" + fig.add_trace(tr, row=r + 1, col=c + 1, secondary_y=True) + t = fig1_ax_titles[0][r] + # TODO: How to best set the secondary y's title standoff? + t["standoff"] = 0 + fig.update_yaxes( + selector=lambda ax: ax["side"] == "right", title=t, row=r + 1, col=c + 1 + ) + fig.update_layout(annotations=fig0.layout.annotations) + return fig + + +df = test_data.aug_tips() + + +def simple_combine_example(): + fig0 = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker") + fig1 = px.histogram( + df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker" + ) + fig1.update_traces(marker_color="red") + fig = px_simple_combine(fig0, fig1) + fig.update_layout(title="Simple figure combination") + return fig + + +def secondary_y_combine_example(): + fig0 = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker") + fig1 = px.scatter( + df, + x="total_bill", + y="calories_consumed", + facet_row="sex", + facet_col="smoker", + trendline="ols", + ) + fig1.update_traces(marker_size=3) + fig = px_combine_secondary_y(fig0, fig1) + fig.update_layout(title="Figure combination with secondary y-axis") + return fig + + +fig_simple = simple_combine_example() +fig_secondary_y = secondary_y_combine_example() +fig_simple.show() +fig_secondary_y.show() diff --git a/proto/px_combine_proto/test_data.py b/proto/px_combine_proto/test_data.py new file mode 100644 index 00000000000..a2320c4c6b7 --- /dev/null +++ b/proto/px_combine_proto/test_data.py @@ -0,0 +1,17 @@ +import numpy as np +import plotly.express as px +import pandas as pd + +# some made up data for demos + + +def aug_tips(): + """ The tips data buf with "calories consumed". """ + tips = px.data.tips() + calories = np.clip( + tips["total_bill"] * 30 + np.random.standard_normal(tips.shape[0]) * 100, + 100, + None, + ) + tips["calories_consumed"] = calories + return tips From fe866519a302f17cefb0c6bc1bbb72690269079e Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Tue, 3 Nov 2020 17:23:39 -0500 Subject: [PATCH 2/8] px_combine prototype now smarter reflows colors, changes adds figure titles to legend names to differentiate the source figures in the final plot, includes all annotation-like objects in the final plot. --- proto/px_combine/find_field.py | 43 ++++++++ proto/px_combine/multilayered_data_test.py | 36 +++++++ .../px_combine.py | 97 ++++++++++++++++--- proto/px_combine/run_px_simple_combine_demo | 2 + proto/px_combine/test_data.py | 66 +++++++++++++ proto/px_combine_proto/test_data.py | 17 ---- 6 files changed, 233 insertions(+), 28 deletions(-) create mode 100644 proto/px_combine/find_field.py create mode 100644 proto/px_combine/multilayered_data_test.py rename proto/{px_combine_proto => px_combine}/px_combine.py (61%) create mode 100755 proto/px_combine/run_px_simple_combine_demo create mode 100644 proto/px_combine/test_data.py delete mode 100644 proto/px_combine_proto/test_data.py diff --git a/proto/px_combine/find_field.py b/proto/px_combine/find_field.py new file mode 100644 index 00000000000..c173ae05107 --- /dev/null +++ b/proto/px_combine/find_field.py @@ -0,0 +1,43 @@ +import plotly.graph_objects as go +from plotly import basedatatypes + +# Search down an object's composition tree and find fields with a given name + + +def find_field(obj, field, basepath="", max_path_len=80, forbidden=["parent"]): + if obj is not None and len(basepath) < max_path_len: + for f in dir(obj): + joined_path = ".".join([basepath, f]) + if f == field: + print(joined_path) + if ( + (f not in forbidden) + and (not f.startswith("_")) + and (not f.endswith("_")) + ): + find_field(eval("obj.%s" % (f,)), field, joined_path) + + +def find_all_xy_traces(): + for field in dir(go): + call_str = "go.%s" % (field,) + call = eval(call_str) + try: + if issubclass(call, basedatatypes.BaseTraceType): + obj = call() + if "xaxis" in obj and "yaxis" in obj: + yield (call_str) + except TypeError: + pass + + +# s=go.Scatter() +# s=go.Bar() +# find_field(s,"color",basepath="scatter") +# print() +# find_field(s,"color",basepath="bar") + +for call_str in find_all_xy_traces(): + call = eval(call_str) + find_field(call(), "color", basepath=call_str) + print() diff --git a/proto/px_combine/multilayered_data_test.py b/proto/px_combine/multilayered_data_test.py new file mode 100644 index 00000000000..59d063aae96 --- /dev/null +++ b/proto/px_combine/multilayered_data_test.py @@ -0,0 +1,36 @@ +import test_data +import numpy as np +import plotly.express as px +from px_combine import px_combine_secondary_y, px_simple_combine + +df = test_data.multilayered_data(d_divs=[2, 3, 4, 2], rwalk=0.1) +print(df) +last_cat = df.columns[3] +figs = [] +for px_call, last_cat_0 in zip([px.line, px.bar], list(set(df[last_cat]))): + df_slice = df.loc[df[last_cat] == last_cat_0] + fig = px_call( + df_slice, + x="x", + y="y", + facet_row=df.columns[0], + facet_col=df.columns[1], + color=df.columns[2], + ) + fig.update_layout(title="%s=%s" % (last_cat, last_cat_0,)) + figs.append(fig) + +figs[0].add_hline(y=1, row=1, col="all") +figs[1].add_vline(x=10, row="all", col=2) +figs[0].add_annotation( + x=0.25, y=0.5, xref="x domain", yref="y domain", row=2, col=3, text="yo" +) +figs[1].add_annotation( + x=0.5, y=0.35, xref="x domain", yref="y domain", row=1, col=2, text="budday" +) +figs[0].layout.barmode = "group" +figs[1].layout.barmode = "relative" +final_fig = px_simple_combine(*figs) +for fig in figs: + fig.show() +final_fig.show() diff --git a/proto/px_combine_proto/px_combine.py b/proto/px_combine/px_combine.py similarity index 61% rename from proto/px_combine_proto/px_combine.py rename to proto/px_combine/px_combine.py index c0166334a98..b015e2742e1 100644 --- a/proto/px_combine_proto/px_combine.py +++ b/proto/px_combine/px_combine.py @@ -8,7 +8,8 @@ from plotly.subplots import make_subplots import test_data import json -from itertools import product +from itertools import product, cycle, chain +from functools import reduce def multi_index(*kwargs): @@ -58,21 +59,93 @@ def px_simple_combine(fig0, fig1): raise ValueError( "Only two figures with the same subplot geometry can be combined." ) - if fig0.layout.annotations != fig1.layout.annotations: - raise ValueError( - "Only two figures created with Plotly Express with " - "identical faceting can be combined." - ) - fig = go.Figure(data=fig0.data + fig1.data, layout=fig0.layout) + # reflow the colors + colorway = fig0.layout.template.layout.colorway + fig = make_subplots(*fig_grid_ref_shape(fig0)) + for r, c in multi_index(*fig_grid_ref_shape(fig)): + for (tr, title), color in zip( + chain( + *[ + zip( + f.select_traces(row=r + 1, col=c + 1), + cycle([f.layout.title.text]), + ) + for f in [fig0, fig1] + ] + ), + cycle(colorway), + ): + set_main_trace_color(tr, color) + # use figure title to differentiate the legend items + tr["name"] = "%s %s" % (title, tr["name"]) + # TODO: argument to group legend items? + tr["legendgroup"] = None + fig.add_trace(tr, row=r + 1, col=c + 1) + fig.update_layout(fig0.layout) + # title will be wrong + fig.layout.title = None + # preserve bar mode + # if both figures have barmode set, the first is taken, otherwise the set one is taken + # TODO argument to force barmode? or the user can just update it after + fig.layout.barmode = get_first_set_barmode([fig0, fig1]) + # also include annotations, shapes and layout images from fig1 + for kw in ["annotations", "shapes", "images"]: + fig.layout[kw] += fig1.layout[kw] return fig +def select_all_traces(figs): + traces = list( + reduce( + lambda a, b: a + b, + map(lambda t: list(go.Figure.select_traces(t)), figs), + [], + ) + ) + return traces + + +def check_trace_type_xy(tr): + return ("xaxis" in tr) and ("yaxis" in tr) + + +def check_figs_trace_types_xy(figs): + traces = select_all_traces(figs) + xy_traces = list(map(check_trace_type_xy, traces)) + return xy_traces + + +def set_main_trace_color(tr, color): + # Set the main color of a trace + if type(tr) == type(go.Scatter()): + if tr["mode"] == "lines": + tr["line_color"] = color + else: + tr["marker_color"] = color + elif type(tr) == type(go.Bar()): + tr["marker_color"] = color + + +def get_first_set_barmode(figs): + barmode = None + try: + barmode = list( + filter(lambda x: x is not None, [f.layout.barmode for f in figs]) + )[0] + except IndexError: + # if no figure sets barmode, then it is not set + pass + return barmode + + def px_combine_secondary_y(fig0, fig1): """ Combines two figures that have the same faceting but whose y axes refer to different data by referencing the second figure's y-data to secondary y-axes. """ + if not all(check_figs_trace_types_xy([fig0, fig1])): + raise ValueError('Only subplots containing "xy" trace types may be combined') grid_ref_shape = fig_grid_ref_shape(fig0) if grid_ref_shape != fig_grid_ref_shape(fig1): raise ValueError( @@ -148,7 +221,9 @@ def secondary_y_combine_example(): return fig -fig_simple = simple_combine_example() -fig_secondary_y = secondary_y_combine_example() -fig_simple.show() -fig_secondary_y.show() +if __name__ == "__main__": + fig_simple = simple_combine_example() + fig_secondary_y = secondary_y_combine_example() + fig_simple.show() + fig_secondary_y.show() + fig_secondary_y.write_json("/tmp/fig.json") diff --git a/proto/px_combine/run_px_simple_combine_demo b/proto/px_combine/run_px_simple_combine_demo new file mode 100755 index 00000000000..ea7891adde6 --- /dev/null +++ b/proto/px_combine/run_px_simple_combine_demo @@ -0,0 +1,2 @@ +#/bin/bash +PYTHONPATH=proto/px_combine python3 proto/px_combine/multilayered_data_test.py diff --git a/proto/px_combine/test_data.py b/proto/px_combine/test_data.py new file mode 100644 index 00000000000..d45b4428e1b --- /dev/null +++ b/proto/px_combine/test_data.py @@ -0,0 +1,66 @@ +import numpy as np +import plotly.express as px +import pandas as pd +from random import sample +from itertools import product +from functools import reduce + +# some made up data for demos + + +def words(remove_non_letters=True): + with open("/usr/share/dict/british-english", "r") as fd: + ws = fd.readlines() + return [w.strip().replace("'s", "") for w in ws] + + +def aug_tips(): + """ The tips data buf with "calories consumed". """ + tips = px.data.tips() + calories = np.clip( + tips["total_bill"] * 30 + np.random.standard_normal(tips.shape[0]) * 100, + 100, + None, + ) + tips["calories_consumed"] = calories + return tips + + +def take(it, N): + return [next(it) for n in range(N)] + + +def multilayered_data( + N=20, d_divs=[2, 3, 4], rseed=np.random.RandomState(seed=2), rwalk=0.1 +): + """ + Generate data that can be faceted in len(d_divs) ways (e.g., row, col and + trace color/linestyle. etc.) + """ + ws = words() + tot_divs = np.cumprod(d_divs)[-1] + sample_i = np.arange(len(ws), dtype="int") + rseed.shuffle(sample_i) + names = iter(ws[i] for i in sample_i[: tot_divs + len(d_divs)]) + x = np.arange(N) + cat_div_names = [] + for div in d_divs: + # generate category names + div_names = [next(names) for _ in range(div)] + cat_div_names.append(div_names) + cat_names = [next(names) for _ in d_divs] + dfs = [] + for cat_combo in product(*cat_div_names): + d = dict() + for cat_name, c in zip(cat_names, cat_combo): + d[cat_name] = c + d["x"] = x + if rwalk is not None: + y = np.cumsum(rseed.standard_normal(N)) * rwalk + else: + y = rseed.standard_normal(N) + d["y"] = y + dfs.append(pd.DataFrame(d)) + # combine all the dicts + df = reduce(lambda a, b: pd.concat([a, b]), dfs, pd.DataFrame()) + return df diff --git a/proto/px_combine_proto/test_data.py b/proto/px_combine_proto/test_data.py deleted file mode 100644 index a2320c4c6b7..00000000000 --- a/proto/px_combine_proto/test_data.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np -import plotly.express as px -import pandas as pd - -# some made up data for demos - - -def aug_tips(): - """ The tips data buf with "calories consumed". """ - tips = px.data.tips() - calories = np.clip( - tips["total_bill"] * 30 + np.random.standard_normal(tips.shape[0]) * 100, - 100, - None, - ) - tips["calories_consumed"] = calories - return tips From 29b2c8bccec41fcf000b1877b2ef6a31b9cff0de Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Tue, 3 Nov 2020 18:09:25 -0500 Subject: [PATCH 3/8] px_combine with secondary-y Argument and logic is there, but not yet tested. It doesn't break the case where px_combine was called with fig1_secondary_y=False. --- .../python/plotly/plotly/basedatatypes.py | 3 ++ .../python/plotly/plotly/express/_core.py | 2 ++ proto/px_combine/px_combine.py | 29 ++++++++++++++----- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index f2033cf359a..623e976dc2d 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -391,6 +391,9 @@ class is a subclass of both BaseFigure and widgets.DOMWidget. self._animation_duration_validator = animation.DurationValidator() self._animation_easing_validator = animation.EasingValidator() + # Space for auxiliary data + self._aux = dict() + # Template # -------- # ### Check for default template ### diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 7ad2fb4eb01..727ecd687a8 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2057,6 +2057,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) + # store args in figure metadata + fig._aux["px"] = dict(args=args) return fig diff --git a/proto/px_combine/px_combine.py b/proto/px_combine/px_combine.py index b015e2742e1..4fba14f3a06 100644 --- a/proto/px_combine/px_combine.py +++ b/proto/px_combine/px_combine.py @@ -49,11 +49,18 @@ def extract_axis_titles(fig): return (r_titles, c_titles) -def px_simple_combine(fig0, fig1): +def px_simple_combine(fig0, fig1, fig1_secondary_y=False): """ Combines two figures by just using the layout of the first figure and appending the data of the second figure. """ + if fig1_secondary_y and ( + ("px" not in fig0._aux.keys()) or ("px" not in fig0._aux.keys()) + ): + raise ValueError( + "To place fig1's traces on secondary y-axes, both figures must have " + "been made with Plotly Express." + ) grid_ref_shape = fig_grid_ref_shape(fig0) if grid_ref_shape != fig_grid_ref_shape(fig1): raise ValueError( @@ -61,26 +68,32 @@ def px_simple_combine(fig0, fig1): ) # reflow the colors colorway = fig0.layout.template.layout.colorway - fig = make_subplots(*fig_grid_ref_shape(fig0)) + specs = None + if fig1_secondary_y: + specs = [ + [dict(secondary_y=True) for __ in range(grid_ref_shape[1])] + for _ in range(grid_ref_shape[0]) + ] + fig = make_subplots(*fig_grid_ref_shape(fig0), specs=specs) for r, c in multi_index(*fig_grid_ref_shape(fig)): - for (tr, title), color in zip( + for (tr, f), color in zip( chain( *[ - zip( - f.select_traces(row=r + 1, col=c + 1), - cycle([f.layout.title.text]), - ) + zip(f.select_traces(row=r + 1, col=c + 1), cycle([f]),) for f in [fig0, fig1] ] ), cycle(colorway), ): + title = f.layout.title.text set_main_trace_color(tr, color) # use figure title to differentiate the legend items tr["name"] = "%s %s" % (title, tr["name"]) # TODO: argument to group legend items? tr["legendgroup"] = None - fig.add_trace(tr, row=r + 1, col=c + 1) + fig.add_trace( + tr, row=r + 1, col=c + 1, secondary_y=(fig1_secondary_y and (f == fig1)) + ) fig.update_layout(fig0.layout) # title will be wrong fig.layout.title = None From 9f4dff421146ce311310e34fc741a2a11fe6cb76 Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Wed, 4 Nov 2020 12:49:02 -0500 Subject: [PATCH 4/8] Secondary y working with px_combine, but Cannot yet copy annotation-like things to the final plot as we need to find out which row and column in the subplots they are referring to. --- proto/px_overlay/facet_col_wrap_test.py | 27 +++++++ .../{px_combine => px_overlay}/find_field.py | 0 .../multilayered_data_test.py | 0 .../{px_combine => px_overlay}/px_combine.py | 77 +------------------ .../run_px_simple_combine_demo | 0 proto/px_overlay/secondary_y_test.py | 1 + proto/{px_combine => px_overlay}/test_data.py | 0 7 files changed, 32 insertions(+), 73 deletions(-) create mode 100644 proto/px_overlay/facet_col_wrap_test.py rename proto/{px_combine => px_overlay}/find_field.py (100%) rename proto/{px_combine => px_overlay}/multilayered_data_test.py (100%) rename proto/{px_combine => px_overlay}/px_combine.py (64%) rename proto/{px_combine => px_overlay}/run_px_simple_combine_demo (100%) create mode 100644 proto/px_overlay/secondary_y_test.py rename proto/{px_combine => px_overlay}/test_data.py (100%) diff --git a/proto/px_overlay/facet_col_wrap_test.py b/proto/px_overlay/facet_col_wrap_test.py new file mode 100644 index 00000000000..ae41ae67c61 --- /dev/null +++ b/proto/px_overlay/facet_col_wrap_test.py @@ -0,0 +1,27 @@ +import plotly.express as px +import test_data +from px_combine import px_simple_combine + +df = test_data.multilayered_data(d_divs=[6, 3, 2], rwalk=0.1) +last_cat = df.columns[2] +last_cat_types = list(set(df[last_cat])) +fig0 = px.line( + df.loc[df[last_cat] == last_cat_types[0]], + x="x", + y="y", + facet_col=df.columns[0], + facet_col_wrap=3, + color=df.columns[1], +).update_layout(title="%s=%s" % (last_cat, last_cat_types[0])) +fig1 = px.line( + df.loc[df[last_cat] == last_cat_types[1]], + x="x", + y="y", + facet_col=df.columns[0], + facet_col_wrap=3, + color=df.columns[1], +).update_layout(title="%s=%s" % (last_cat, last_cat_types[1])) +fig = px_simple_combine(fig0, fig1, fig1_secondary_y=True) +fig0.show() +fig1.show() +fig.show() diff --git a/proto/px_combine/find_field.py b/proto/px_overlay/find_field.py similarity index 100% rename from proto/px_combine/find_field.py rename to proto/px_overlay/find_field.py diff --git a/proto/px_combine/multilayered_data_test.py b/proto/px_overlay/multilayered_data_test.py similarity index 100% rename from proto/px_combine/multilayered_data_test.py rename to proto/px_overlay/multilayered_data_test.py diff --git a/proto/px_combine/px_combine.py b/proto/px_overlay/px_combine.py similarity index 64% rename from proto/px_combine/px_combine.py rename to proto/px_overlay/px_combine.py index 4fba14f3a06..41955eac2e1 100644 --- a/proto/px_combine/px_combine.py +++ b/proto/px_overlay/px_combine.py @@ -76,6 +76,7 @@ def px_simple_combine(fig0, fig1, fig1_secondary_y=False): ] fig = make_subplots(*fig_grid_ref_shape(fig0), specs=specs) for r, c in multi_index(*fig_grid_ref_shape(fig)): + print("row,col", r + 1, c + 1) for (tr, f), color in zip( chain( *[ @@ -94,7 +95,9 @@ def px_simple_combine(fig0, fig1, fig1_secondary_y=False): fig.add_trace( tr, row=r + 1, col=c + 1, secondary_y=(fig1_secondary_y and (f == fig1)) ) - fig.update_layout(fig0.layout) + # TODO: How to preserve axis sizes when adding secondary y? + # TODO: How to put annotations on the correct subplot when using secondary y? + # fig.update_layout(fig0.layout) # title will be wrong fig.layout.title = None # preserve bar mode @@ -151,59 +154,6 @@ def get_first_set_barmode(figs): return barmode -def px_combine_secondary_y(fig0, fig1): - """ - Combines two figures that have the same faceting but whose y axes refer - to different data by referencing the second figure's y-data to secondary - y-axes. - """ - if not all(check_figs_trace_types_xy([fig0, fig1])): - raise ValueError('Only subplots containing "xy" trace types may be combined') - grid_ref_shape = fig_grid_ref_shape(fig0) - if grid_ref_shape != fig_grid_ref_shape(fig1): - raise ValueError( - "Only two figures with the same subplot geometry can be combined." - ) - if fig0.layout.annotations != fig1.layout.annotations: - raise ValueError( - "Only two figures created with Plotly Express with " - "identical faceting can be combined." - ) - specs = [ - [dict(secondary_y=True) for __ in range(grid_ref_shape[1])] - for _ in range(grid_ref_shape[0]) - ] - fig = make_subplots(*grid_ref_shape, specs=specs, start_cell="bottom-left") - fig0_ax_titles = extract_axis_titles(fig0) - fig1_ax_titles = extract_axis_titles(fig1) - # set primary y - for r, c in multi_index(*grid_ref_shape): - for tr in fig0.select_traces(row=r + 1, col=c + 1): - fig.add_trace(tr, row=r + 1, col=c + 1) - if r == 0: - t = fig0_ax_titles[1][c] - fig.update_xaxes(title=t, row=r + 1, col=c + 1) - if c == 0: - t = fig0_ax_titles[0][r] - fig.update_yaxes( - selector=lambda ax: ax["side"] != "right", title=t, row=r + 1, col=c + 1 - ) - # set secondary y - for r, c in multi_index(*grid_ref_shape): - for tr in fig1.select_traces(row=r + 1, col=c + 1): - # TODO: How to set meaningful color regardless of trace type? - tr["marker_color"] = "red" - fig.add_trace(tr, row=r + 1, col=c + 1, secondary_y=True) - t = fig1_ax_titles[0][r] - # TODO: How to best set the secondary y's title standoff? - t["standoff"] = 0 - fig.update_yaxes( - selector=lambda ax: ax["side"] == "right", title=t, row=r + 1, col=c + 1 - ) - fig.update_layout(annotations=fig0.layout.annotations) - return fig - - df = test_data.aug_tips() @@ -218,25 +168,6 @@ def simple_combine_example(): return fig -def secondary_y_combine_example(): - fig0 = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker") - fig1 = px.scatter( - df, - x="total_bill", - y="calories_consumed", - facet_row="sex", - facet_col="smoker", - trendline="ols", - ) - fig1.update_traces(marker_size=3) - fig = px_combine_secondary_y(fig0, fig1) - fig.update_layout(title="Figure combination with secondary y-axis") - return fig - - if __name__ == "__main__": fig_simple = simple_combine_example() - fig_secondary_y = secondary_y_combine_example() fig_simple.show() - fig_secondary_y.show() - fig_secondary_y.write_json("/tmp/fig.json") diff --git a/proto/px_combine/run_px_simple_combine_demo b/proto/px_overlay/run_px_simple_combine_demo similarity index 100% rename from proto/px_combine/run_px_simple_combine_demo rename to proto/px_overlay/run_px_simple_combine_demo diff --git a/proto/px_overlay/secondary_y_test.py b/proto/px_overlay/secondary_y_test.py new file mode 100644 index 00000000000..ae7b49edc36 --- /dev/null +++ b/proto/px_overlay/secondary_y_test.py @@ -0,0 +1 @@ +# Put the second plot's y data on a secondary y diff --git a/proto/px_combine/test_data.py b/proto/px_overlay/test_data.py similarity index 100% rename from proto/px_combine/test_data.py rename to proto/px_overlay/test_data.py From ec2332f44aa35653851996a3984ca933bfbec11a Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Wed, 4 Nov 2020 12:52:58 -0500 Subject: [PATCH 5/8] Renamed combine to overlay --- proto/px_overlay/{px_combine.py => px_overlay.py} | 0 .../{run_px_simple_combine_demo => run_px_simple_overlay_demo} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename proto/px_overlay/{px_combine.py => px_overlay.py} (100%) rename proto/px_overlay/{run_px_simple_combine_demo => run_px_simple_overlay_demo} (100%) diff --git a/proto/px_overlay/px_combine.py b/proto/px_overlay/px_overlay.py similarity index 100% rename from proto/px_overlay/px_combine.py rename to proto/px_overlay/px_overlay.py diff --git a/proto/px_overlay/run_px_simple_combine_demo b/proto/px_overlay/run_px_simple_overlay_demo similarity index 100% rename from proto/px_overlay/run_px_simple_combine_demo rename to proto/px_overlay/run_px_simple_overlay_demo From bf7cf9b5ee8a34b4e8b0afc8b687dc75369ae25a Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Fri, 6 Nov 2020 17:54:30 -0500 Subject: [PATCH 6/8] annotation-like objects axis reference mapping to a subplot on a new figure. This will allow us to copy the annotations to a new figure with more axes as part of a px.overlay command. This just needs to be generalized to work with shapes and images but the workings can be seen in proto/px_overlay/map_axis_pair_example.py. --- proto/px_overlay/map_axis_pair_example.py | 29 ++++++ proto/px_overlay/px_overlay.py | 113 +++++++++++++++++++++ proto/px_overlay/test_find_subplot_axes.py | 40 ++++++++ 3 files changed, 182 insertions(+) create mode 100644 proto/px_overlay/map_axis_pair_example.py create mode 100644 proto/px_overlay/test_find_subplot_axes.py diff --git a/proto/px_overlay/map_axis_pair_example.py b/proto/px_overlay/map_axis_pair_example.py new file mode 100644 index 00000000000..a22c26136dd --- /dev/null +++ b/proto/px_overlay/map_axis_pair_example.py @@ -0,0 +1,29 @@ +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import px_overlay +import pytest + +fig0 = px_overlay.make_subplots_all_secondary_y(3, 4) +fig1 = px_overlay.make_subplots_all_secondary_y(4, 5) + +for dims, f in zip([(3, 4), (4, 5)], [fig0, fig1]): + for r, c in px_overlay.multi_index(*dims): + for sy in [False, True]: + f.add_trace(go.Scatter(x=[], y=[]), row=r + 1, col=c + 1, secondary_y=sy) + +fig0.add_annotation(row=2, col=3, text="hi", x=0.25, xref="x domain", y=3) +fig0.add_annotation( + row=3, col=4, text="hi", x=0.25, xref="x domain", y=2, secondary_y=True +) + +for an in fig0.layout.annotations: + oldaxpair = tuple([an[ref] for ref in ["xref", "yref"]]) + newaxpair = px_overlay.map_axis_pair(fig0, fig1, oldaxpair) + newan = go.layout.Annotation(an) + print(oldaxpair) + print(newaxpair) + newan["xref"], newan["yref"] = newaxpair + fig1.add_annotation(newan) + +fig0.show() +fig1.show() diff --git a/proto/px_overlay/px_overlay.py b/proto/px_overlay/px_overlay.py index 41955eac2e1..08b0593401d 100644 --- a/proto/px_overlay/px_overlay.py +++ b/proto/px_overlay/px_overlay.py @@ -10,6 +10,7 @@ import json from itertools import product, cycle, chain from functools import reduce +import re def multi_index(*kwargs): @@ -49,6 +50,118 @@ def extract_axis_titles(fig): return (r_titles, c_titles) +def make_subplots_all_secondary_y(rows, cols): + """ + Get subplots like make_subplots but all also have secondary y-axes. + """ + grid_ref_shape = [rows, cols] + specs = [ + [dict(secondary_y=True) for __ in range(grid_ref_shape[1])] + for _ in range(grid_ref_shape[0]) + ] + fig = make_subplots(*grid_ref_shape, specs=specs) + return fig + + +def parse_axis_ref(ax): + """ Find the axis letter, optional number, and domain of axis. """ + # TODO: can this be obtained via codegen? + pat = re.compile("([xy])(axis)?([0-9]*)( domain)?") + matches = pat.match(ax) + if matches is None: + raise ValueError('Axis "%s" cannot be parsed.' % (ax,)) + return (matches[1], matches[3], matches[4]) + + +def norm_axis_ref(ax): + """ normalize ax so it is in the format: yaxis, yaxis2, xaxis7 etc. """ + al, an, _ = parse_axis_ref(ax) + return al + "axis" + an + + +def axis_pair_to_row_col(fig, axpair): + """ + returns the row and column of the subplot having the axis pair and whether it is a + secondary y + """ + if "paper" in axpair: + raise ValueError('Cannot find row and column of "paper" axis reference.') + naxpair = tuple([norm_axis_ref(ax) for ax in axpair]) + nrows, ncols = fig_grid_ref_shape(fig) + row = None + col = None + for r, c in multi_index(nrows, ncols): + for sp in fig._grid_ref[r][c]: + if naxpair == sp.layout_keys: + row = r + 1 + col = c + 1 + if row is None or col is None: + raise ValueError("Could not find subplot containing axes (%s,%s)." % nax) + secondary_y = False + yax = naxpair[1] + if fig.layout[yax]["side"] == "right": + secondary_y = True + return (row, col, secondary_y) + + +def find_subplot_axes(fig, row, col, secondary_y=False): + """ + Returns 2-tuple containing (xaxis,yaxis) at specified row, col and secondary y-axis. + """ + nrows, ncols = fig_grid_ref_shape(fig) + try: + sps = fig._grid_ref[row - 1][col - 1] + except IndexError: + raise IndexError( + "Figure does not have a subplot at the requested row or column." + ) + + def _check_is_secondary_y(sp): + xax, yax = sp.layout_keys + # TODO: It may not be totally accurate to assume if an y-axis' "side" is + # "right" than it is a secondary y axis... + return fig.layout[yax]["side"] == "right" + + # find the secondary y axis + err_msg = ( + "Could not find a y-axis " "at the subplot in the requested row or column." + ) + filter_fun = lambda sp: not _check_is_secondary_y(sp) + if secondary_y: + err_msg = ( + "Could not find a secondary y-axis " + "at the subplot in the requested row or column." + ) + filter_fun = _check_is_secondary_y + try: + sp = list(filter(filter_fun, sps))[0] + except (IndexError, TypeError): + # Catch IndexError if the list is empty, catch TypeError if sps isn't + # iterable (e.g., is None) + raise IndexError(err_msg) + return sp.layout_keys + + +def map_axis_pair(old_fig, new_fig, axpair, make_axis_ref=True): + """ + Find the axes on the new figure that will give the same subplot and + possibly secondary y axis as on the old figure. This can only + work if the axis pair is ("paper","paper") or the axis pair corresponds to a + subplot on the old figure the new figure has corresponding rows, + columns and secondary y-axes. + if make_axis_ref is True, axis is removed from the resulting strings, e.g., xaxis2 -> x2 + """ + if axpair == ("paper", "paper"): + return ax + row, col, secondary_y = axis_pair_to_row_col(old_fig, axpair) + newaxpair = find_subplot_axes(new_fig, row, col, secondary_y) + axpair_extras = [" domain" if ax.endswith("domain") else "" for ax in axpair] + newaxpair = tuple(ax + extra for ax, extra in zip(newaxpair, axpair_extras)) + if make_axis_ref: + newaxpair = tuple(ax.replace("axis", "") for ax in newaxpair) + return newaxpair + + def px_simple_combine(fig0, fig1, fig1_secondary_y=False): """ Combines two figures by just using the layout of the first figure and diff --git a/proto/px_overlay/test_find_subplot_axes.py b/proto/px_overlay/test_find_subplot_axes.py new file mode 100644 index 00000000000..e9b763226bb --- /dev/null +++ b/proto/px_overlay/test_find_subplot_axes.py @@ -0,0 +1,40 @@ +from plotly.subplots import make_subplots +import px_overlay +import pytest + +fig = px_overlay.make_subplots_all_secondary_y(3, 4) +fig_no_sy = px_overlay.make_subplots(3, 4) +fig_custom = make_subplots( + rows=2, + cols=2, + specs=[[{}, {}], [{"colspan": 2}, None]], + subplot_titles=("First Subplot", "Second Subplot", "Third Subplot"), +) + + +def test_bad_row_col(): + with pytest.raises( + IndexError, + match=r"^Figure does not have a subplot at the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig, 4, 2, secondary_y=False) + with pytest.raises( + IndexError, + match=r"^Figure does not have a subplot at the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig, 4, 2, secondary_y=True) + + +def test_no_secondary_y(): + with pytest.raises( + IndexError, + match=r"^Could not find a secondary y-axis at the subplot in the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig_no_sy, 2, 2, secondary_y=True) + with pytest.raises( + IndexError, + match=r"^Could not find a y-axis at the subplot in the requested row or column\.$", + ): + px_overlay.find_subplot_axes(fig_custom, 2, 2, secondary_y=False) + axes = px_overlay.find_subplot_axes(fig_custom, 1, 2, secondary_y=False) + assert axes == ("xaxis2", "yaxis2") From 99195940c0e90806cd46ff759296b507cfcec4b7 Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Mon, 9 Nov 2020 12:26:52 -0500 Subject: [PATCH 7/8] Overlaying 2nd fig on secondary-y has problems It works but it seems the "start_cell" argument of make_subplots has not been taken into consideration. Also sometimes it adds annotations twice? --- proto/px_overlay/multilayered_data_test.py | 6 +- proto/px_overlay/px_overlay.py | 78 +++++++++++++++++++--- 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/proto/px_overlay/multilayered_data_test.py b/proto/px_overlay/multilayered_data_test.py index 59d063aae96..40aec929736 100644 --- a/proto/px_overlay/multilayered_data_test.py +++ b/proto/px_overlay/multilayered_data_test.py @@ -1,7 +1,7 @@ import test_data import numpy as np import plotly.express as px -from px_combine import px_combine_secondary_y, px_simple_combine +from px_overlay import px_simple_overlay df = test_data.multilayered_data(d_divs=[2, 3, 4, 2], rwalk=0.1) print(df) @@ -26,11 +26,11 @@ x=0.25, y=0.5, xref="x domain", yref="y domain", row=2, col=3, text="yo" ) figs[1].add_annotation( - x=0.5, y=0.35, xref="x domain", yref="y domain", row=1, col=2, text="budday" + x=0.5, y=0.35, xref="x domain", yref="y", row=1, col=2, text="budday" ) figs[0].layout.barmode = "group" figs[1].layout.barmode = "relative" -final_fig = px_simple_combine(*figs) +final_fig = px_simple_overlay(*figs, fig1_secondary_y=True) for fig in figs: fig.show() final_fig.show() diff --git a/proto/px_overlay/px_overlay.py b/proto/px_overlay/px_overlay.py index 08b0593401d..afaa5026fa0 100644 --- a/proto/px_overlay/px_overlay.py +++ b/proto/px_overlay/px_overlay.py @@ -1,7 +1,7 @@ -# Prototype for px.combine +# Prototype for px.overlay # Combine 2 figures containing subplots # Run as -# python px_combine.py +# python px_overlay.py import plotly.express as px import plotly.graph_objects as go @@ -111,7 +111,9 @@ def find_subplot_axes(fig, row, col, secondary_y=False): nrows, ncols = fig_grid_ref_shape(fig) try: sps = fig._grid_ref[row - 1][col - 1] - except IndexError: + except (IndexError, TypeError): + # IndexError if fig has _grid_ref but not requested row or column, + # TypeError if fig has no _grid_ref (it is None) raise IndexError( "Figure does not have a subplot at the requested row or column." ) @@ -142,7 +144,15 @@ def _check_is_secondary_y(sp): return sp.layout_keys -def map_axis_pair(old_fig, new_fig, axpair, make_axis_ref=True): +def map_axis_pair( + old_fig, + new_fig, + axpair, + new_row=None, + new_col=None, + new_secondary_y=None, + make_axis_ref=True, +): """ Find the axes on the new figure that will give the same subplot and possibly secondary y axis as on the old figure. This can only @@ -151,9 +161,14 @@ def map_axis_pair(old_fig, new_fig, axpair, make_axis_ref=True): columns and secondary y-axes. if make_axis_ref is True, axis is removed from the resulting strings, e.g., xaxis2 -> x2 """ + if None in axpair: + raise ValueError("Cannot map axis whose value is None.") if axpair == ("paper", "paper"): - return ax + return axpair row, col, secondary_y = axis_pair_to_row_col(old_fig, axpair) + row = new_row if new_row is not None else row + col = new_col if new_col is not None else col + secondary_y = new_secondary_y if new_secondary_y is not None else secondary_y newaxpair = find_subplot_axes(new_fig, row, col, secondary_y) axpair_extras = [" domain" if ax.endswith("domain") else "" for ax in axpair] newaxpair = tuple(ax + extra for ax, extra in zip(newaxpair, axpair_extras)) @@ -162,7 +177,27 @@ def map_axis_pair(old_fig, new_fig, axpair, make_axis_ref=True): return newaxpair -def px_simple_combine(fig0, fig1, fig1_secondary_y=False): +def map_annotation_like_obj_axis(oldfig, newfig, an, force_secondary_y=False): + """ + Take an annotation-like object with xref and yref referring to axes in oldfig + and map them to axes in newfig. This makes it possible to map an annotation + to the same subplot row, column or secondary y in a new plot even if they do + not have matching subplots. + If force_secondary_y is True, attempt is made to map the annotation to a + secondary y axis in the new figure. + Returns the new annotation. Note that it has not been added to newfig, the + caller must then do this if it wants it added to newfig. + """ + oldaxpair = tuple([an[ref] for ref in ["xref", "yref"]]) + newaxpair = map_axis_pair( + oldfig, newfig, oldaxpair, new_secondary_y=force_secondary_y + ) + newan = an.__class__(an) + newan["xref"], newan["yref"] = newaxpair + return newan + + +def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): """ Combines two figures by just using the layout of the first figure and appending the data of the second figure. @@ -177,7 +212,7 @@ def px_simple_combine(fig0, fig1, fig1_secondary_y=False): grid_ref_shape = fig_grid_ref_shape(fig0) if grid_ref_shape != fig_grid_ref_shape(fig1): raise ValueError( - "Only two figures with the same subplot geometry can be combined." + "Only two figures with the same subplot geometry can be overlayd." ) # reflow the colors colorway = fig0.layout.template.layout.colorway @@ -209,7 +244,28 @@ def px_simple_combine(fig0, fig1, fig1_secondary_y=False): tr, row=r + 1, col=c + 1, secondary_y=(fig1_secondary_y and (f == fig1)) ) # TODO: How to preserve axis sizes when adding secondary y? - # TODO: How to put annotations on the correct subplot when using secondary y? + + # Map the axes of the annotation-like objects to the new figure. Map the + # fig1 objects to the secondary-y if requested. + selectors = product( + [fig0, fig1], + [ + go.Figure.select_annotations, + go.Figure.select_shapes, + go.Figure.select_layout_images, + ], + ) + adders = product( + [(fig, False), (fig, fig1_secondary_y)], + [go.Figure.add_annotation, go.Figure.add_shape, go.Figure.add_layout_image], + ) + for (oldfig, selector), ((newfig, secy), adder) in zip(selectors, adders): + for ann in selector(oldfig): + newann = map_annotation_like_obj_axis( + oldfig, newfig, ann, force_secondary_y=secy + ) + adder(newfig, newann) + # fig.update_layout(fig0.layout) # title will be wrong fig.layout.title = None @@ -270,17 +326,17 @@ def get_first_set_barmode(figs): df = test_data.aug_tips() -def simple_combine_example(): +def simple_overlay_example(): fig0 = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker") fig1 = px.histogram( df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker" ) fig1.update_traces(marker_color="red") - fig = px_simple_combine(fig0, fig1) + fig = px_simple_overlay(fig0, fig1) fig.update_layout(title="Simple figure combination") return fig if __name__ == "__main__": - fig_simple = simple_combine_example() + fig_simple = simple_overlay_example() fig_simple.show() From 64d8a196c18dff47f40a7ba48b74bac06a2e648c Mon Sep 17 00:00:00 2001 From: Nicholas Esterer Date: Fri, 13 Nov 2020 18:27:21 -0500 Subject: [PATCH 8/8] Documented functionality and added README look in `proto/px_overlay/README.md` for more information. --- proto/px_overlay/README.md | 15 +++++++ proto/px_overlay/multilayered_data_test.py | 24 ++++++++++- proto/px_overlay/px_overlay.py | 48 +++++++++++++++------- 3 files changed, 70 insertions(+), 17 deletions(-) create mode 100644 proto/px_overlay/README.md diff --git a/proto/px_overlay/README.md b/proto/px_overlay/README.md new file mode 100644 index 00000000000..a64dc2ea733 --- /dev/null +++ b/proto/px_overlay/README.md @@ -0,0 +1,15 @@ +# `px.overlay` prototype + +This demonstrates one possible way of combining two figures into a single +figure. + +To see an example, run (from the root of the `plotly.py` repo): + +```bash +PYTHONPATH=proto/px_overlay python proto/px_overlay/multilayered_data_test.py +``` + +To see the code that does the overlaying, start with the `px_simple_overlay` +function in `proto/px_overlay/px_overlay.py`. In this function there are a few +comments marked with `TODO` that indicate places for improvement in the +functionality. diff --git a/proto/px_overlay/multilayered_data_test.py b/proto/px_overlay/multilayered_data_test.py index 40aec929736..6f59c2dcfaa 100644 --- a/proto/px_overlay/multilayered_data_test.py +++ b/proto/px_overlay/multilayered_data_test.py @@ -3,11 +3,19 @@ import plotly.express as px from px_overlay import px_simple_overlay +# Demonstrates px_overlay prototype. + +# Make some data that can be faceted by row, col and color, and split into 2 +# sets, which will go to the first and second figure respectively. df = test_data.multilayered_data(d_divs=[2, 3, 4, 2], rwalk=0.1) -print(df) + +# The titles of the figures use the last dimension in the data. The title is +# formatted "column_name=column_value", so here we extract the column name. last_cat = df.columns[3] figs = [] for px_call, last_cat_0 in zip([px.line, px.bar], list(set(df[last_cat]))): + # px_call is the chart type to make and last_cat_0 is the column_value for + # that figure which is used in forming the title. df_slice = df.loc[df[last_cat] == last_cat_0] fig = px_call( df_slice, @@ -17,20 +25,32 @@ facet_col=df.columns[1], color=df.columns[2], ) + fig.update_layout(title="%s=%s" % (last_cat, last_cat_0,)) figs.append(fig) +# Add some annotations to make sure they are copied to the final figure properly figs[0].add_hline(y=1, row=1, col="all") -figs[1].add_vline(x=10, row="all", col=2) figs[0].add_annotation( x=0.25, y=0.5, xref="x domain", yref="y domain", row=2, col=3, text="yo" ) +# Note that these annotations should be mapped to a secondary y axis (observe this +# in the final figure by dragging their corresponding secondary y axes). +figs[1].add_vline(x=10, row="all", col=2) figs[1].add_annotation( x=0.5, y=0.35, xref="x domain", yref="y", row=1, col=2, text="budday" ) +# Set the bar modes for both to see that the first figure that the barmode for +# the final figure will be taken from the figure that has bars. figs[0].layout.barmode = "group" figs[1].layout.barmode = "relative" + +# overlay the figures final_fig = px_simple_overlay(*figs, fig1_secondary_y=True) + +# Show the initial figures for fig in figs: fig.show() + +# Show the final figure final_fig.show() diff --git a/proto/px_overlay/px_overlay.py b/proto/px_overlay/px_overlay.py index afaa5026fa0..3ebe32330ed 100644 --- a/proto/px_overlay/px_overlay.py +++ b/proto/px_overlay/px_overlay.py @@ -199,8 +199,16 @@ def map_annotation_like_obj_axis(oldfig, newfig, an, force_secondary_y=False): def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): """ - Combines two figures by just using the layout of the first figure and - appending the data of the second figure. + Combines two figures by putting all the traces from fig0 and fig1 on a new + figure (fig). Then the annotation-like objects are copied to fig (i.e., the + titles are not copied). + The colors are reassigned so each trace has a unique color until all the + colors in the colorway are exhausted and then loops through the colorway to + assign additional colors (this is referred to as "reflowing" below). + In order to differentiate the traces in the legend, if fig0 or fig1 have + titles, they are prepended to the trace name. + If fig1_secondary_y is True, then the yaxes from fig1 are placed on + secondary y axes in the new figure. """ if fig1_secondary_y and ( ("px" not in fig0._aux.keys()) or ("px" not in fig0._aux.keys()) @@ -212,9 +220,9 @@ def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): grid_ref_shape = fig_grid_ref_shape(fig0) if grid_ref_shape != fig_grid_ref_shape(fig1): raise ValueError( - "Only two figures with the same subplot geometry can be overlayd." + "Only two figures with the same subplot geometry can be overlayed." ) - # reflow the colors + # get colors for reflowing colorway = fig0.layout.template.layout.colorway specs = None if fig1_secondary_y: @@ -222,7 +230,11 @@ def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): [dict(secondary_y=True) for __ in range(grid_ref_shape[1])] for _ in range(grid_ref_shape[0]) ] - fig = make_subplots(*fig_grid_ref_shape(fig0), specs=specs) + # TODO: This needs to detect the start_cell of the input figures rather than + # assuming 'bottom-left', which is just the px default start_cell + fig = make_subplots( + *fig_grid_ref_shape(fig0), specs=specs, start_cell="bottom-left" + ) for r, c in multi_index(*fig_grid_ref_shape(fig)): print("row,col", r + 1, c + 1) for (tr, f), color in zip( @@ -232,6 +244,7 @@ def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): for f in [fig0, fig1] ] ), + # reflow the colors cycle(colorway), ): title = f.layout.title.text @@ -261,6 +274,8 @@ def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): ) for (oldfig, selector), ((newfig, secy), adder) in zip(selectors, adders): for ann in selector(oldfig): + # TODO this function needs to eventually take into consideration the + # start_cell arguments of the figures involved in the mapping. newann = map_annotation_like_obj_axis( oldfig, newfig, ann, force_secondary_y=secy ) @@ -270,12 +285,11 @@ def px_simple_overlay(fig0, fig1, fig1_secondary_y=False): # title will be wrong fig.layout.title = None # preserve bar mode - # if both figures have barmode set, the first is taken, otherwise the set one is taken + # if both figures have barmode set, the first is taken from the figure that + # has bars (so just the one from fig0 if both have bars), otherwise the set + # one is taken. # TODO argument to force barmode? or the user can just update it after fig.layout.barmode = get_first_set_barmode([fig0, fig1]) - # also include annotations, shapes and layout images from fig1 - for kw in ["annotations", "shapes", "images"]: - fig.layout[kw] += fig1.layout[kw] return fig @@ -312,21 +326,25 @@ def set_main_trace_color(tr, color): def get_first_set_barmode(figs): + """ Get first bar mode from the figure that has it set and has bar traces. """ + + def _bar_mode_filter(f): + return ( + any([type(tr) == type(go.Bar()) for tr in f.data]) + and f.layout.barmode is not None + ) + barmode = None try: - barmode = list( - filter(lambda x: x is not None, [f.layout.barmode for f in figs]) - )[0] + barmode = [f.layout.barmode for f in filter(_bar_mode_filter, figs)][0] except IndexError: # if no figure sets barmode, then it is not set pass return barmode -df = test_data.aug_tips() - - def simple_overlay_example(): + df = test_data.aug_tips() fig0 = px.scatter(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker") fig1 = px.histogram( df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker"