Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(chore): Simplify BasePlot’s var_groups #3462

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/scanpy/neighbors/_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def umap(
warnings.filterwarnings("ignore", message=r"Tensorflow not installed")
from umap.umap_ import fuzzy_simplicial_set

X = coo_matrix(([], ([], [])), shape=(n_obs, 1))
X = coo_matrix((n_obs, 1))
connectivities, _sigmas, _rhos = fuzzy_simplicial_set(
X,
n_neighbors,
Expand Down
115 changes: 61 additions & 54 deletions src/scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
norm: Normalize | None


class VarGroups(NamedTuple):
labels: Sequence[str]
positions: Sequence[tuple[int, int]]


doc_common_groupby_plot_args = """\
title
Title for the figure
Expand Down Expand Up @@ -87,6 +92,8 @@

MAX_NUM_CATEGORIES = 500 # maximum number of categories allowed to be plotted

var_groups: VarGroups | None

@old_positionals(
"use_raw",
"log",
Expand Down Expand Up @@ -129,18 +136,24 @@
norm: Normalize | None = None,
**kwds,
):
self.var_names = var_names
self.var_group_labels = var_group_labels
self.var_group_positions = var_group_positions
self.var_names, self.var_groups = _var_groups(var_names, ref=adata.var_names)
match (var_group_labels, var_group_positions, self.var_groups):
case (None, None, _):
pass # inferred from `var_names`
case (None, _, _) | (_, None, _):
msg = "both or none of var_group_labels and var_group_positions must be set"
raise TypeError(msg)

Check warning on line 145 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L144-L145

Added lines #L144 - L145 were not covered by tests
case (_, _, None):
if len(var_group_labels) != len(var_group_positions):
msg = "var_group_labels and var_group_positions must have the same length"
raise ValueError(msg)

Check warning on line 149 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L148-L149

Added lines #L148 - L149 were not covered by tests
self.var_groups = VarGroups(var_group_labels, var_group_positions)
case (_, _, _):
msg = "var_group_labels and var_group_positions cannot be set if var_names is a dict"
raise TypeError(msg)

Check warning on line 153 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L151-L153

Added lines #L151 - L153 were not covered by tests
self.var_group_rotation = var_group_rotation
self.width, self.height = figsize if figsize is not None else (None, None)

self.has_var_groups = (
var_group_positions is not None and len(var_group_positions) > 0
)

self._update_var_groups()

self.categories, self.obs_tidy = _prepare_dataframe(
adata,
self.var_names,
Expand Down Expand Up @@ -702,7 +715,7 @@
width_ratios=[mainplot_width + self.group_extra_size, self.legends_width],
)

if self.has_var_groups:
if self.var_groups:
# add some space in case 'brackets' want to be plotted on top of the image
if self.are_axes_swapped:
var_groups_height = category_height
Expand Down Expand Up @@ -754,14 +767,14 @@
if self.plot_group_extra is not None:
group_extra_ax = self.fig.add_subplot(mainplot_gs[2, 1], sharey=main_ax)
group_extra_orientation = "right"
if self.has_var_groups:
if self.var_groups:
gene_groups_ax = self.fig.add_subplot(mainplot_gs[1, 0], sharex=main_ax)
var_group_orientation = "top"
else:
if self.plot_group_extra:
group_extra_ax = self.fig.add_subplot(mainplot_gs[1, 0], sharex=main_ax)
group_extra_orientation = "top"
if self.has_var_groups:
if self.var_groups:
gene_groups_ax = self.fig.add_subplot(mainplot_gs[2, 1], sharey=main_ax)
var_group_orientation = "right"

Expand All @@ -781,11 +794,11 @@
return_ax_dict["group_extra_ax"] = group_extra_ax

# plot group legends on top or left of main_ax (if given)
if self.has_var_groups:
if self.var_groups:
self._plot_var_groups_brackets(
gene_groups_ax,
group_positions=self.var_group_positions,
group_labels=self.var_group_labels,
group_positions=self.var_groups.positions,
group_labels=self.var_groups.labels,
rotation=self.var_group_rotation,
left_adjustment=0.2,
right_adjustment=0.7,
Expand Down Expand Up @@ -924,31 +937,30 @@
if self.var_names is not None:
var_names_idx_ordered = list(range(len(self.var_names)))

if self.has_var_groups:
if set(self.var_group_labels) == set(self.categories):
if self.var_groups:
if set(self.var_groups.labels) == set(self.categories):
positions_ordered = []
labels_ordered = []
position_start = 0
var_names_idx_ordered = []
for cat_name in categories_ordered:
idx = self.var_group_labels.index(cat_name)
position = self.var_group_positions[idx]
idx = self.var_groups.labels.index(cat_name)
position = self.var_groups.positions[idx]
_var_names = self.var_names[position[0] : position[1] + 1]
var_names_idx_ordered.extend(range(position[0], position[1] + 1))
positions_ordered.append(
(position_start, position_start + len(_var_names) - 1)
)
position_start += len(_var_names)
labels_ordered.append(self.var_group_labels[idx])
self.var_group_labels = labels_ordered
self.var_group_positions = positions_ordered
labels_ordered.append(self.var_groups.labels[idx])
self.var_groups = VarGroups(labels_ordered, positions_ordered)
else:
logg.warning(
"Groups are not reordered because the `groupby` categories "
"and the `var_group_labels` are different.\n"
f"categories: {_format_first_three_categories(self.categories)}\n"
"var_group_labels: "
f"{_format_first_three_categories(self.var_group_labels)}"
f"{_format_first_three_categories(self.var_groups.labels)}"
)

if var_names_idx_ordered is not None:
Expand Down Expand Up @@ -1082,35 +1094,30 @@
axis="x", bottom=False, labelbottom=False, labeltop=False
)

def _update_var_groups(self) -> None:
"""
checks if var_names is a dict. Is this is the cases, then set the
correct values for var_group_labels and var_group_positions

updates var_names, var_group_labels, var_group_positions
"""
if isinstance(self.var_names, Mapping):
if self.has_var_groups:
logg.warning(
"`var_names` is a dictionary. This will reset the current "
"values of `var_group_labels` and `var_group_positions`."
)
var_group_labels = []
_var_names = []
var_group_positions = []
start = 0
for label, vars_list in self.var_names.items():
if isinstance(vars_list, str):
vars_list = [vars_list]
# use list() in case var_list is a numpy array or pandas series
_var_names.extend(list(vars_list))
var_group_labels.append(label)
var_group_positions.append((start, start + len(vars_list) - 1))
start += len(vars_list)
self.var_names = _var_names
self.var_group_labels = var_group_labels
self.var_group_positions = var_group_positions
self.has_var_groups = True

elif isinstance(self.var_names, str):
self.var_names = [self.var_names]
def _var_groups(
var_names: _VarNames | Mapping[str, _VarNames], *, ref: pd.Index[str]
) -> tuple[Sequence[str], VarGroups | None]:
"""
Normalize var_names.
If it’s a mapping, also return var_group_labels and var_group_positions.
"""

if not isinstance(var_names, Mapping):
var_names = [var_names] if isinstance(var_names, str) else var_names
return var_names, None

var_group_labels: list[str] = []
var_names_seq: list[str] = []
var_group_positions: list[tuple[int, int]] = []
for label, vars_list in var_names.items():
vars_list = [vars_list] if isinstance(vars_list, str) else vars_list
start = len(var_names_seq)
# use list() in case var_list is a numpy array or pandas series
var_names_seq.extend(list(vars_list))
var_group_labels.append(label)
var_group_positions.append((start, start + len(vars_list) - 1))
if not var_names_seq:
msg = "No valid var_names were passed."
raise ValueError(msg)

Check warning on line 1122 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L1121-L1122

Added lines #L1121 - L1122 were not covered by tests
return var_names_seq, VarGroups(var_group_labels, var_group_positions)
Loading