Skip to content

Commit 5a53238

Browse files
committed
Add key to enrichments
Signed-off-by: zethson <[email protected]>
1 parent 00052a4 commit 5a53238

File tree

1 file changed

+30
-15
lines changed

1 file changed

+30
-15
lines changed

pertpy/tools/_enrichment.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import ChainMap
22
from collections.abc import Sequence
3-
from typing import Any, Literal, Union
3+
from typing import Any, Literal
44

55
import blitzgsea
66
import numpy as np
@@ -64,6 +64,7 @@ def score(
6464
method: Literal["mean", "seurat"] = "mean",
6565
n_bins: int = 25,
6666
ctrl_size: int = 50,
67+
key_added: str = "pertpy_enrichment",
6768
) -> None:
6869
"""Obtain per-cell scoring of gene groups of interest.
6970
@@ -88,6 +89,9 @@ def score(
8889
layer: Specifies which `.layers` of AnnData to use for expression values. Defaults to `.X` if None.
8990
n_bins: The number of expression bins for the `'seurat'` method.
9091
ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
92+
key_added: Prefix key that adds the results to `uns`.
93+
Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
94+
Defaults to `pertpy_enrichment`.
9195
9296
Returns:
9397
An AnnData object with scores.
@@ -144,15 +148,15 @@ def score(
144148
seurat = np.dot(control_profiles, drug_weights)
145149
scores = scores - seurat
146150

147-
adata.uns["pertpy_enrichment_score"] = scores
148-
adata.uns["pertpy_enrichment_variables"] = weights.columns
151+
adata.uns[f"{key_added}_score"] = scores
152+
adata.uns[f"{key_added}_variables"] = weights.columns
149153

150-
adata.uns["pertpy_enrichment_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
151-
adata.uns["pertpy_enrichment_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}
154+
adata.uns[f"{key_added}_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
155+
adata.uns[f"{key_added}_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}
152156

153157
for drug in weights.columns:
154-
adata.uns["pertpy_enrichment_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
155-
adata.uns["pertpy_enrichment_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])
158+
adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
159+
adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])
156160

157161
def hypergeometric(
158162
self,
@@ -172,9 +176,11 @@ def hypergeometric(
172176
If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package.
173177
Accepts two forms:
174178
- A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
175-
- A dictionary of dictionaries defined like above, with names of gene group categories as keys. If passing one of those, specify `nested=True`.
179+
- A dictionary of dictionaries defined like above, with names of gene group categories as keys.
180+
If passing one of those, specify `nested=True`.
176181
nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
177-
categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary). In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
182+
categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
183+
In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
178184
pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
179185
direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
180186
Can be `up`, `down`, or `both` (for no selection).
@@ -235,6 +241,7 @@ def gsea(
235241
nested: bool = False,
236242
categories: str | list[str] | None = None,
237243
absolute: bool = False,
244+
key_added: str = "pertpy_enrichment_gsea",
238245
) -> dict[str, pd.DataFrame] | tuple[dict[str, pd.DataFrame], dict[str, dict]]: # pragma: no cover
239246
"""Perform gene set enrichment analysis on the marker gene scores using blitzgsea.
240247
@@ -251,6 +258,8 @@ def gsea(
251258
applicable if `targets=None` or `nested=True`. Defaults to None.
252259
absolute: If True, passes the absolute values of scores to GSEA, improving
253260
statistical power. Defaults to False.
261+
key_added: Prefix key that adds the results to `uns`.
262+
Defaults to `pertpy_enrichment_gsea`.
254263
255264
Returns:
256265
A dictionary with clusters as keys and data frames of test results sorted on
@@ -272,7 +281,7 @@ def gsea(
272281
enrichment[cluster] = blitzgsea.gsea(df, targets)
273282
plot_gsea_args["scores"][cluster] = df
274283

275-
adata.uns["pertpy_enrichment_gsea"] = plot_gsea_args
284+
adata.uns[key_added] = plot_gsea_args
276285

277286
return enrichment
278287

@@ -282,6 +291,7 @@ def plot_dotplot(
282291
targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
283292
categories: Sequence[str] = None,
284293
groupby: str = None,
294+
key: str = "pertpy_enrichment",
285295
**kwargs,
286296
) -> DotPlot | dict | None:
287297
"""Plots a dotplot by groupby and categories.
@@ -298,6 +308,8 @@ def plot_dotplot(
298308
categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
299309
For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
300310
groupby: dotplot groupby such as clusters or cell types.
311+
key: Prefix key of enrichment results in `uns`.
312+
Defaults to `pertpy_enrichment`.
301313
kwargs: Passed to scanpy dotplot.
302314
303315
Returns:
@@ -330,8 +342,8 @@ def plot_dotplot(
330342
var_group_labels: list[str] = []
331343
start = 0
332344

333-
enrichment_score_adata = AnnData(adata.uns["pertpy_enrichment_score"], obs=adata.obs)
334-
enrichment_score_adata.var_names = adata.uns["pertpy_enrichment_variables"]
345+
enrichment_score_adata = AnnData(adata.uns[f"{key}_score"], obs=adata.obs)
346+
enrichment_score_adata.var_names = adata.uns[f"{key}_variables"]
335347

336348
for group in targets:
337349
targets[group] = list( # type: ignore
@@ -352,7 +364,9 @@ def plot_dotplot(
352364

353365
return sc.pl.dotplot(enrichment_score_adata, groupby=groupby, swap_axes=True, **plot_args, **kwargs)
354366

355-
def plot_gsea(self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int = 10) -> None:
367+
def plot_gsea(
368+
self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int = 10, key: str = "pertpy_enrichment_gsea"
369+
) -> None:
356370
"""Generates a blitzgsea top_table plot.
357371
358372
This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA).
@@ -363,11 +377,12 @@ def plot_gsea(self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int
363377
adata: AnnData object to plot.
364378
enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
365379
n: How many top scores to show for each group. Defaults to 10.
380+
key: GSEA results key in `uns`. Defaults to "pertpy_enrichment_gsea".
366381
"""
367382
for cluster in enrichment:
368383
fig = blitzgsea.plot.top_table(
369-
adata.uns["pertpy_enrichment_gsea"]["scores"][cluster],
370-
adata.uns["pertpy_enrichment_gsea"]["targets"],
384+
adata.uns[key]["scores"][cluster],
385+
adata.uns[key]["targets"],
371386
enrichment[cluster],
372387
n=n,
373388
)

0 commit comments

Comments
 (0)