Skip to content

Commit

Permalink
Add key to enrichments
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Jan 6, 2024
1 parent 00052a4 commit 5a53238
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions pertpy/tools/_enrichment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import ChainMap
from collections.abc import Sequence
from typing import Any, Literal, Union
from typing import Any, Literal

import blitzgsea
import numpy as np
Expand Down Expand Up @@ -64,6 +64,7 @@ def score(
method: Literal["mean", "seurat"] = "mean",
n_bins: int = 25,
ctrl_size: int = 50,
key_added: str = "pertpy_enrichment",
) -> None:
"""Obtain per-cell scoring of gene groups of interest.
Expand All @@ -88,6 +89,9 @@ def score(
layer: Specifies which `.layers` of AnnData to use for expression values. Defaults to `.X` if None.
n_bins: The number of expression bins for the `'seurat'` method.
ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
key_added: Prefix key that adds the results to `uns`.
Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
Defaults to `pertpy_enrichment`.
Returns:
An AnnData object with scores.
Expand Down Expand Up @@ -144,15 +148,15 @@ def score(
seurat = np.dot(control_profiles, drug_weights)
scores = scores - seurat

adata.uns["pertpy_enrichment_score"] = scores
adata.uns["pertpy_enrichment_variables"] = weights.columns
adata.uns[f"{key_added}_score"] = scores
adata.uns[f"{key_added}_variables"] = weights.columns

adata.uns["pertpy_enrichment_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
adata.uns["pertpy_enrichment_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}
adata.uns[f"{key_added}_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
adata.uns[f"{key_added}_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}

for drug in weights.columns:
adata.uns["pertpy_enrichment_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
adata.uns["pertpy_enrichment_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])
adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])

def hypergeometric(
self,
Expand All @@ -172,9 +176,11 @@ def hypergeometric(
If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package.
Accepts two forms:
- A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
- A dictionary of dictionaries defined like above, with names of gene group categories as keys. If passing one of those, specify `nested=True`.
- A dictionary of dictionaries defined like above, with names of gene group categories as keys.
If passing one of those, specify `nested=True`.
nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary). In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
Can be `up`, `down`, or `both` (for no selection).
Expand Down Expand Up @@ -235,6 +241,7 @@ def gsea(
nested: bool = False,
categories: str | list[str] | None = None,
absolute: bool = False,
key_added: str = "pertpy_enrichment_gsea",
) -> dict[str, pd.DataFrame] | tuple[dict[str, pd.DataFrame], dict[str, dict]]: # pragma: no cover
"""Perform gene set enrichment analysis on the marker gene scores using blitzgsea.
Expand All @@ -251,6 +258,8 @@ def gsea(
applicable if `targets=None` or `nested=True`. Defaults to None.
absolute: If True, passes the absolute values of scores to GSEA, improving
statistical power. Defaults to False.
key_added: Prefix key that adds the results to `uns`.
Defaults to `pertpy_enrichment_gsea`.
Returns:
A dictionary with clusters as keys and data frames of test results sorted on
Expand All @@ -272,7 +281,7 @@ def gsea(
enrichment[cluster] = blitzgsea.gsea(df, targets)
plot_gsea_args["scores"][cluster] = df

adata.uns["pertpy_enrichment_gsea"] = plot_gsea_args
adata.uns[key_added] = plot_gsea_args

return enrichment

Expand All @@ -282,6 +291,7 @@ def plot_dotplot(
targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
categories: Sequence[str] = None,
groupby: str = None,
key: str = "pertpy_enrichment",
**kwargs,
) -> DotPlot | dict | None:
"""Plots a dotplot by groupby and categories.
Expand All @@ -298,6 +308,8 @@ def plot_dotplot(
categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
groupby: dotplot groupby such as clusters or cell types.
key: Prefix key of enrichment results in `uns`.
Defaults to `pertpy_enrichment`.
kwargs: Passed to scanpy dotplot.
Returns:
Expand Down Expand Up @@ -330,8 +342,8 @@ def plot_dotplot(
var_group_labels: list[str] = []
start = 0

enrichment_score_adata = AnnData(adata.uns["pertpy_enrichment_score"], obs=adata.obs)
enrichment_score_adata.var_names = adata.uns["pertpy_enrichment_variables"]
enrichment_score_adata = AnnData(adata.uns[f"{key}_score"], obs=adata.obs)
enrichment_score_adata.var_names = adata.uns[f"{key}_variables"]

for group in targets:
targets[group] = list( # type: ignore
Expand All @@ -352,7 +364,9 @@ def plot_dotplot(

return sc.pl.dotplot(enrichment_score_adata, groupby=groupby, swap_axes=True, **plot_args, **kwargs)

def plot_gsea(self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int = 10) -> None:
def plot_gsea(
self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int = 10, key: str = "pertpy_enrichment_gsea"
) -> None:
"""Generates a blitzgsea top_table plot.
This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA).
Expand All @@ -363,11 +377,12 @@ def plot_gsea(self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int
adata: AnnData object to plot.
enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
n: How many top scores to show for each group. Defaults to 10.
key: GSEA results key in `uns`. Defaults to "pertpy_enrichment_gsea".
"""
for cluster in enrichment:
fig = blitzgsea.plot.top_table(
adata.uns["pertpy_enrichment_gsea"]["scores"][cluster],
adata.uns["pertpy_enrichment_gsea"]["targets"],
adata.uns[key]["scores"][cluster],
adata.uns[key]["targets"],
enrichment[cluster],
n=n,
)
Expand Down

0 comments on commit 5a53238

Please sign in to comment.