diff --git a/.gitignore b/.gitignore index c9e7058fe9..2a08b48c17 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ htmlcov/ .DS_Store notes/ notebooks/ +mytest.py +*.png diff --git a/seaborn/categorical.py b/seaborn/categorical.py index ee8aa0908b..393f83d9e6 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -535,6 +535,11 @@ def plot_swarms( if "marker" in plot_kws and not MarkerStyle(plot_kws["marker"]).is_filled(): plot_kws.pop("edgecolor", None) + + keep_gutters = True + if "keep_gutters" in plot_kws: + keep_gutters = plot_kws.pop("keep_gutters") + plot_kws.pop("keep_gutters", None) for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, @@ -557,7 +562,7 @@ def plot_swarms( if not sub_data.empty: point_collections[(ax, sub_data[self.orient].iloc[0])] = points - beeswarm = Beeswarm(width=width, orient=self.orient, warn_thresh=warn_thresh) + beeswarm = Beeswarm(width=width, orient=self.orient, warn_thresh=warn_thresh, keep_gutters=keep_gutters) for (ax, center), points in point_collections.items(): if points.get_offsets().shape[0] > 1: @@ -3223,11 +3228,14 @@ def catplot( class Beeswarm: """Modifies a scatterplot artist to show a beeswarm plot.""" - def __init__(self, orient="x", width=0.8, warn_thresh=.05): + def __init__(self, orient="x", width=0.8, warn_thresh=.05, keep_gutters=True): self.orient = orient self.width = width self.warn_thresh = warn_thresh + self.gutters = False #BetterBeeswarm modification + self.shrink_factor = 0.9 #BetterBeeswarm modification + self.keep_gutters = keep_gutters #BetterBeeswarm modification def __call__(self, points, center): """Swarm `points`, a PathCollection, around the `center` position.""" @@ -3259,35 +3267,58 @@ def __call__(self, points, center): sizes = np.repeat(sizes, orig_xy.shape[0]) edge = points.get_linewidth().item() radii = (np.sqrt(sizes) + edge) / 2 * (dpi / 72) - orig_xy = np.c_[orig_xy, radii] - # Sort along the value axis to facilitate the beeswarm - sorter = np.argsort(orig_xy[:, 1]) - orig_xyr = orig_xy[sorter] + #BetterBeeswarm modified added while loop to check for gutters + checking_gutters = True + radius_shrink_factor = self.shrink_factor + original_radius_fraction = 1.0 + orig_xy_copy = orig_xy.copy() - # Adjust points along the categorical axis to prevent overlaps - new_xyr = np.empty_like(orig_xyr) - new_xyr[sorter] = self.beeswarm(orig_xyr) + while checking_gutters: - # Transform the point coordinates back to data coordinates - if self.orient == "y": - new_xy = new_xyr[:, [1, 0]] - else: - new_xy = new_xyr[:, :2] - new_x_data, new_y_data = ax.transData.inverted().transform(new_xy).T + orig_xy = np.c_[orig_xy_copy, radii] #BetterBeeswarm modified by changing orig_xy to orig_xy_copy - # Add gutters - t_fwd, t_inv = _get_transform_functions(ax, self.orient) - if self.orient == "y": - self.add_gutters(new_y_data, center, t_fwd, t_inv) - else: - self.add_gutters(new_x_data, center, t_fwd, t_inv) + # Sort along the value axis to facilitate the beeswarm + sorter = np.argsort(orig_xy[:, 1]) + orig_xyr = orig_xy[sorter] - # Reposition the points so they do not overlap - if self.orient == "y": - points.set_offsets(np.c_[orig_x_data, new_y_data]) - else: - points.set_offsets(np.c_[new_x_data, orig_y_data]) + # Adjust points along the categorical axis to prevent overlaps + new_xyr = np.empty_like(orig_xyr) + new_xyr[sorter] = self.beeswarm(orig_xyr) + + # Transform the point coordinates back to data coordinates + if self.orient == "y": + new_xy = new_xyr[:, [1, 0]] + else: + new_xy = new_xyr[:, :2] + new_x_data, new_y_data = ax.transData.inverted().transform(new_xy).T + + # Add gutters + t_fwd, t_inv = _get_transform_functions(ax, self.orient) + if self.orient == "y": + self.add_gutters(new_y_data, center, t_fwd, t_inv) + else: + self.add_gutters(new_x_data, center, t_fwd, t_inv) + + # Reposition the points so they do not overlap + if self.orient == "y": + points.set_offsets(np.c_[orig_x_data, new_y_data]) + else: + points.set_offsets(np.c_[new_x_data, orig_y_data]) + + # if keep_gutters is True then we will keep the gutters and stop the while loop + if self.keep_gutters: + break + + else: + # Check if gutters were added + if self.gutters: + # Shrink the radii and try again + radii = radii * radius_shrink_factor + original_radius_fraction = original_radius_fraction * radius_shrink_factor + print(f"Shrinking radii to the {original_radius_fraction:.1%} of the original point size.") + else: + checking_gutters = False def beeswarm(self, orig_xyr): """Adjust x position of points to avoid overlaps.""" @@ -3398,12 +3429,30 @@ def add_gutters(self, points, center, trans_fwd, trans_inv): points[off_high] = high_gutter gutter_prop = (off_high + off_low).sum() / len(points) + + if gutter_prop > self.warn_thresh: - msg = ( - "{:.1%} of the points cannot be placed; you may want " - "to decrease the size of the markers or use stripplot." - ).format(gutter_prop) - warnings.warn(msg, UserWarning) + + self.gutters = True + + if self.keep_gutters: + msg = ( + "{:.1%} of the points cannot be placed; you may want " + "to decrease the size of the markers, use stripplot, or " + "set 'keep_gutters=False' to increase point density." + ).format(gutter_prop) + warnings.warn(msg, UserWarning) + + else: + print(f"{gutter_prop:.1%} of the points cannot be placed. Iteratively shrinking radii distance by {(1 - self.shrink_factor):.1%}") + # msg = ( + # "{:.1%} of the points cannot be placed. " + # f"Iteratively shrinking radii distance by {(1 - self.shrink_factor) * 100:.1f}%" + # ).format(gutter_prop) + # warnings.warn(msg, UserWarning) + + else: + self.gutters = False return points