Skip to content

Commit 049d837

Browse files
Merge pull request #1838 from plotly/facet_wrap
initial build-out of facet wrapping
2 parents be1a182 + ee48cca commit 049d837

File tree

5 files changed

+93
-33
lines changed

5 files changed

+93
-33
lines changed

packages/python/plotly/plotly/express/_chart_types.py

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def scatter(
1616
text=None,
1717
facet_row=None,
1818
facet_col=None,
19+
facet_col_wrap=0,
1920
error_x=None,
2021
error_x_minus=None,
2122
error_y=None,
@@ -65,6 +66,7 @@ def density_contour(
6566
color=None,
6667
facet_row=None,
6768
facet_col=None,
69+
facet_col_wrap=0,
6870
hover_name=None,
6971
hover_data=None,
7072
animation_frame=None,
@@ -120,6 +122,7 @@ def density_heatmap(
120122
z=None,
121123
facet_row=None,
122124
facet_col=None,
125+
facet_col_wrap=0,
123126
hover_name=None,
124127
hover_data=None,
125128
animation_frame=None,
@@ -180,6 +183,7 @@ def line(
180183
text=None,
181184
facet_row=None,
182185
facet_col=None,
186+
facet_col_wrap=0,
183187
error_x=None,
184188
error_x_minus=None,
185189
error_y=None,
@@ -225,6 +229,7 @@ def area(
225229
text=None,
226230
facet_row=None,
227231
facet_col=None,
232+
facet_col_wrap=0,
228233
animation_frame=None,
229234
animation_group=None,
230235
category_orders={},
@@ -267,6 +272,7 @@ def bar(
267272
color=None,
268273
facet_row=None,
269274
facet_col=None,
275+
facet_col_wrap=0,
270276
hover_name=None,
271277
hover_data=None,
272278
custom_data=None,
@@ -318,6 +324,7 @@ def histogram(
318324
color=None,
319325
facet_row=None,
320326
facet_col=None,
327+
facet_col_wrap=0,
321328
hover_name=None,
322329
hover_data=None,
323330
animation_frame=None,
@@ -376,6 +383,7 @@ def violin(
376383
color=None,
377384
facet_row=None,
378385
facet_col=None,
386+
facet_col_wrap=0,
379387
hover_name=None,
380388
hover_data=None,
381389
custom_data=None,
@@ -427,6 +435,7 @@ def box(
427435
color=None,
428436
facet_row=None,
429437
facet_col=None,
438+
facet_col_wrap=0,
430439
hover_name=None,
431440
hover_data=None,
432441
custom_data=None,
@@ -473,6 +482,7 @@ def strip(
473482
color=None,
474483
facet_row=None,
475484
facet_col=None,
485+
facet_col_wrap=0,
476486
hover_name=None,
477487
hover_data=None,
478488
custom_data=None,

packages/python/plotly/plotly/express/_core.py

+47-33
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
233233
result["y"] = trendline[:, 1]
234234
hover_header = "<b>LOWESS trendline</b><br><br>"
235235
elif v == "ols":
236-
fit_results = sm.OLS(y, sm.add_constant(x)).fit()
236+
fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
237237
result["y"] = fit_results.predict()
238238
hover_header = "<b>OLS trendline</b><br>"
239239
hover_header += "%s = %f * %s + %f<br>" % (
@@ -747,10 +747,10 @@ def apply_default_cascade(args):
747747
]
748748

749749
# If both marginals and faceting are specified, faceting wins
750-
if args.get("facet_col", None) and args.get("marginal_y", None):
750+
if args.get("facet_col", None) is not None and args.get("marginal_y", None):
751751
args["marginal_y"] = None
752752

753-
if args.get("facet_row", None) and args.get("marginal_x", None):
753+
if args.get("facet_row", None) is not None and args.get("marginal_x", None):
754754
args["marginal_x"] = None
755755

756756

@@ -874,7 +874,7 @@ def build_dataframe(args, attrables, array_attrables):
874874
"pandas MultiIndex is not supported by plotly express "
875875
"at the moment." % field
876876
)
877-
## ----------------- argument is a col name ----------------------
877+
# ----------------- argument is a col name ----------------------
878878
if isinstance(argument, str) or isinstance(
879879
argument, int
880880
): # just a column name given as str or int
@@ -1042,6 +1042,13 @@ def infer_config(args, constructor, trace_patch):
10421042
args[position] = args["marginal"]
10431043
args[other_position] = None
10441044

1045+
if (
1046+
args.get("marginal_x", None) is not None
1047+
or args.get("marginal_y", None) is not None
1048+
or args.get("facet_row", None) is not None
1049+
):
1050+
args["facet_col_wrap"] = 0
1051+
10451052
# Compute applicable grouping attributes
10461053
for k in group_attrables:
10471054
if k in args:
@@ -1098,15 +1105,14 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
10981105

10991106
orders, sorted_group_names = get_orderings(args, grouper, grouped)
11001107

1101-
has_marginal_x = bool(args.get("marginal_x", False))
1102-
has_marginal_y = bool(args.get("marginal_y", False))
1103-
11041108
subplot_type = _subplot_type_for_trace_type(constructor().type)
11051109

11061110
trace_names_by_frame = {}
11071111
frames = OrderedDict()
11081112
trendline_rows = []
11091113
nrows = ncols = 1
1114+
col_labels = []
1115+
row_labels = []
11101116
for group_name in sorted_group_names:
11111117
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
11121118
mapping_labels = OrderedDict()
@@ -1188,27 +1194,36 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
11881194
# Find row for trace, handling facet_row and marginal_x
11891195
if m.facet == "row":
11901196
row = m.val_map[val]
1191-
trace._subplot_row_val = val
1197+
if args["facet_row"] and len(row_labels) < row:
1198+
row_labels.append(args["facet_row"] + "=" + str(val))
11921199
else:
1193-
if has_marginal_x and trace_spec.marginal != "x":
1200+
if (
1201+
bool(args.get("marginal_x", False))
1202+
and trace_spec.marginal != "x"
1203+
):
11941204
row = 2
11951205
else:
11961206
row = 1
11971207

1198-
nrows = max(nrows, row)
1199-
if row > 1:
1200-
trace._subplot_row = row
1201-
1208+
facet_col_wrap = args.get("facet_col_wrap", 0)
12021209
# Find col for trace, handling facet_col and marginal_y
12031210
if m.facet == "col":
12041211
col = m.val_map[val]
1205-
trace._subplot_col_val = val
1212+
if args["facet_col"] and len(col_labels) < col:
1213+
col_labels.append(args["facet_col"] + "=" + str(val))
1214+
if facet_col_wrap: # assumes no facet_row, no marginals
1215+
row = 1 + ((col - 1) // facet_col_wrap)
1216+
col = 1 + ((col - 1) % facet_col_wrap)
12061217
else:
12071218
if trace_spec.marginal == "y":
12081219
col = 2
12091220
else:
12101221
col = 1
12111222

1223+
nrows = max(nrows, row)
1224+
if row > 1:
1225+
trace._subplot_row = row
1226+
12121227
ncols = max(ncols, col)
12131228
if col > 1:
12141229
trace._subplot_col = col
@@ -1238,7 +1253,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12381253
if show_colorbar:
12391254
colorvar = "z" if constructor == go.Histogram2d else "color"
12401255
range_color = args["range_color"] or [None, None]
1241-
d = len(args["color_continuous_scale"]) - 1
12421256

12431257
colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
12441258
layout_patch["coloraxis1"] = dict(
@@ -1260,7 +1274,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12601274
layout_patch["legend"]["itemsizing"] = "constant"
12611275

12621276
fig = init_figure(
1263-
args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
1277+
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
12641278
)
12651279

12661280
# Position traces in subplots
@@ -1290,49 +1304,39 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12901304
return fig
12911305

12921306

1293-
def init_figure(
1294-
args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
1295-
):
1307+
def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
12961308
# Build subplot specs
12971309
specs = [[{}] * ncols for _ in range(nrows)]
1298-
column_titles = [None] * ncols
1299-
row_titles = [None] * nrows
13001310
for frame in frame_list:
13011311
for trace in frame["data"]:
13021312
row0 = trace._subplot_row - 1
13031313
col0 = trace._subplot_col - 1
1304-
13051314
if isinstance(trace, go.Splom):
13061315
# Splom not compatible with make_subplots, treat as domain
13071316
specs[row0][col0] = {"type": "domain"}
13081317
else:
13091318
specs[row0][col0] = {"type": trace.type}
1310-
if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"):
1311-
row_titles[row0] = args["facet_row"] + "=" + str(trace._subplot_row_val)
1312-
1313-
if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"):
1314-
column_titles[col0] = (
1315-
args["facet_col"] + "=" + str(trace._subplot_col_val)
1316-
)
13171319

13181320
# Default row/column widths uniform
13191321
column_widths = [1.0] * ncols
13201322
row_heights = [1.0] * nrows
13211323

13221324
# Build column_widths/row_heights
13231325
if subplot_type == "xy":
1324-
if has_marginal_x:
1326+
if bool(args.get("marginal_x", False)):
13251327
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
13261328
main_size = 0.74
13271329
else:
13281330
main_size = 0.84
13291331

13301332
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
13311333
vertical_spacing = 0.01
1334+
elif args.get("facet_col_wrap", 0):
1335+
vertical_spacing = 0.07
13321336
else:
13331337
vertical_spacing = 0.03
13341338

1335-
if has_marginal_y:
1339+
if bool(args.get("marginal_y", False)):
13361340
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
13371341
main_size = 0.74
13381342
else:
@@ -1351,15 +1355,25 @@ def init_figure(
13511355
vertical_spacing = 0.1
13521356
horizontal_spacing = 0.1
13531357

1358+
facet_col_wrap = args.get("facet_col_wrap", 0)
1359+
if facet_col_wrap:
1360+
subplot_labels = [None] * nrows * ncols
1361+
while len(col_labels) < nrows * ncols:
1362+
col_labels.append(None)
1363+
for i in range(nrows):
1364+
for j in range(ncols):
1365+
subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]
1366+
13541367
# Create figure with subplots
13551368
fig = make_subplots(
13561369
rows=nrows,
13571370
cols=ncols,
13581371
specs=specs,
13591372
shared_xaxes="all",
13601373
shared_yaxes="all",
1361-
row_titles=list(reversed(row_titles)),
1362-
column_titles=column_titles,
1374+
row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
1375+
column_titles=[] if facet_col_wrap else col_labels,
1376+
subplot_titles=subplot_labels if facet_col_wrap else [],
13631377
horizontal_spacing=horizontal_spacing,
13641378
vertical_spacing=vertical_spacing,
13651379
row_heights=row_heights,

packages/python/plotly/plotly/express/_doc.py

+6
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@
183183
colref_desc,
184184
"Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.",
185185
],
186+
facet_col_wrap=[
187+
"int",
188+
"Maximum number of facet columns.",
189+
"Wraps the column variable at this width, so that the column facets span multiple rows.",
190+
"Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.",
191+
],
186192
animation_frame=[
187193
colref_type,
188194
colref_desc,

packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

+3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def test_pandas_series():
6161
assert fig.data[0].hovertemplate == "day=%{x}<br>y=%{y}"
6262
fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"})
6363
assert fig.data[0].hovertemplate == "day=%{x}<br>bill=%{y}"
64+
# lock down that we can pass df.col to facet_*
65+
fig = px.bar(tips, x="day", y="tip", facet_row=tips.day, facet_col=tips.day)
66+
assert fig.data[0].hovertemplate == "day=%{x}<br>tip=%{y}"
6467

6568

6669
def test_several_dataframes():

test/percy/plotly-express.py

+27
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,33 @@
184184

185185
import plotly.express as px
186186

187+
tips = px.data.tips()
188+
fig = px.scatter(
189+
tips,
190+
x="day",
191+
y="tip",
192+
facet_col="day",
193+
facet_col_wrap=2,
194+
category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
195+
)
196+
fig.write_html(os.path.join(dir_name, "facet_wrap_neat.html"))
197+
198+
import plotly.express as px
199+
200+
tips = px.data.tips()
201+
fig = px.scatter(
202+
tips,
203+
x="day",
204+
y="tip",
205+
color="sex",
206+
facet_col="day",
207+
facet_col_wrap=3,
208+
category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
209+
)
210+
fig.write_html(os.path.join(dir_name, "facet_wrap_ragged.html"))
211+
212+
import plotly.express as px
213+
187214
gapminder = px.data.gapminder()
188215
fig = px.area(gapminder, x="year", y="pop", color="continent", line_group="country")
189216
fig.write_html(os.path.join(dir_name, "area.html"))

0 commit comments

Comments
 (0)