From 5a53238c8bfa64ce36039b3bb27609ea32b73eba Mon Sep 17 00:00:00 2001 From: zethson Date: Sat, 6 Jan 2024 17:13:18 +0100 Subject: [PATCH] Add key to enrichments Signed-off-by: zethson --- pertpy/tools/_enrichment.py | 45 ++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/pertpy/tools/_enrichment.py b/pertpy/tools/_enrichment.py index f2e1e076..52345dd8 100644 --- a/pertpy/tools/_enrichment.py +++ b/pertpy/tools/_enrichment.py @@ -1,6 +1,6 @@ from collections import ChainMap from collections.abc import Sequence -from typing import Any, Literal, Union +from typing import Any, Literal import blitzgsea import numpy as np @@ -64,6 +64,7 @@ def score( method: Literal["mean", "seurat"] = "mean", n_bins: int = 25, ctrl_size: int = 50, + key_added: str = "pertpy_enrichment", ) -> None: """Obtain per-cell scoring of gene groups of interest. @@ -88,6 +89,9 @@ def score( layer: Specifies which `.layers` of AnnData to use for expression values. Defaults to `.X` if None. n_bins: The number of expression bins for the `'seurat'` method. ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method. + key_added: Prefix key that adds the results to `uns`. + Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`. + Defaults to `pertpy_enrichment`. Returns: An AnnData object with scores. @@ -144,15 +148,15 @@ def score( seurat = np.dot(control_profiles, drug_weights) scores = scores - seurat - adata.uns["pertpy_enrichment_score"] = scores - adata.uns["pertpy_enrichment_variables"] = weights.columns + adata.uns[f"{key_added}_score"] = scores + adata.uns[f"{key_added}_variables"] = weights.columns - adata.uns["pertpy_enrichment_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)} - adata.uns["pertpy_enrichment_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)} + adata.uns[f"{key_added}_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)} + adata.uns[f"{key_added}_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)} for drug in weights.columns: - adata.uns["pertpy_enrichment_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]]) - adata.uns["pertpy_enrichment_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug]) + adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]]) + adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug]) def hypergeometric( self, @@ -172,9 +176,11 @@ def hypergeometric( If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package. Accepts two forms: - A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists. - - A dictionary of dictionaries defined like above, with names of gene group categories as keys. If passing one of those, specify `nested=True`. + - A dictionary of dictionaries defined like above, with names of gene group categories as keys. + If passing one of those, specify `nested=True`. nested: Whether `targets` is a dictionary of dictionaries with group categories as keys. - categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary). In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes. + categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary). + In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes. pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers. direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`. Can be `up`, `down`, or `both` (for no selection). @@ -235,6 +241,7 @@ def gsea( nested: bool = False, categories: str | list[str] | None = None, absolute: bool = False, + key_added: str = "pertpy_enrichment_gsea", ) -> dict[str, pd.DataFrame] | tuple[dict[str, pd.DataFrame], dict[str, dict]]: # pragma: no cover """Perform gene set enrichment analysis on the marker gene scores using blitzgsea. @@ -251,6 +258,8 @@ def gsea( applicable if `targets=None` or `nested=True`. Defaults to None. absolute: If True, passes the absolute values of scores to GSEA, improving statistical power. Defaults to False. + key_added: Prefix key that adds the results to `uns`. + Defaults to `pertpy_enrichment_gsea`. Returns: A dictionary with clusters as keys and data frames of test results sorted on @@ -272,7 +281,7 @@ def gsea( enrichment[cluster] = blitzgsea.gsea(df, targets) plot_gsea_args["scores"][cluster] = df - adata.uns["pertpy_enrichment_gsea"] = plot_gsea_args + adata.uns[key_added] = plot_gsea_args return enrichment @@ -282,6 +291,7 @@ def plot_dotplot( targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None, categories: Sequence[str] = None, groupby: str = None, + key: str = "pertpy_enrichment", **kwargs, ) -> DotPlot | dict | None: """Plots a dotplot by groupby and categories. @@ -298,6 +308,8 @@ def plot_dotplot( categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`. For ChEMBL drug targets, these are ATC level 1/level 2 category codes. groupby: dotplot groupby such as clusters or cell types. + key: Prefix key of enrichment results in `uns`. + Defaults to `pertpy_enrichment`. kwargs: Passed to scanpy dotplot. Returns: @@ -330,8 +342,8 @@ def plot_dotplot( var_group_labels: list[str] = [] start = 0 - enrichment_score_adata = AnnData(adata.uns["pertpy_enrichment_score"], obs=adata.obs) - enrichment_score_adata.var_names = adata.uns["pertpy_enrichment_variables"] + enrichment_score_adata = AnnData(adata.uns[f"{key}_score"], obs=adata.obs) + enrichment_score_adata.var_names = adata.uns[f"{key}_variables"] for group in targets: targets[group] = list( # type: ignore @@ -352,7 +364,9 @@ def plot_dotplot( return sc.pl.dotplot(enrichment_score_adata, groupby=groupby, swap_axes=True, **plot_args, **kwargs) - def plot_gsea(self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int = 10) -> None: + def plot_gsea( + self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int = 10, key: str = "pertpy_enrichment_gsea" + ) -> None: """Generates a blitzgsea top_table plot. This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA). @@ -363,11 +377,12 @@ def plot_gsea(self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int adata: AnnData object to plot. enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values. n: How many top scores to show for each group. Defaults to 10. + key: GSEA results key in `uns`. Defaults to "pertpy_enrichment_gsea". """ for cluster in enrichment: fig = blitzgsea.plot.top_table( - adata.uns["pertpy_enrichment_gsea"]["scores"][cluster], - adata.uns["pertpy_enrichment_gsea"]["targets"], + adata.uns[key]["scores"][cluster], + adata.uns[key]["targets"], enrichment[cluster], n=n, )