11from collections import ChainMap
22from collections .abc import Sequence
3- from typing import Any , Literal , Union
3+ from typing import Any , Literal
44
55import blitzgsea
66import numpy as np
@@ -64,6 +64,7 @@ def score(
6464 method : Literal ["mean" , "seurat" ] = "mean" ,
6565 n_bins : int = 25 ,
6666 ctrl_size : int = 50 ,
67+ key_added : str = "pertpy_enrichment" ,
6768 ) -> None :
6869 """Obtain per-cell scoring of gene groups of interest.
6970
@@ -88,6 +89,9 @@ def score(
8889 layer: Specifies which `.layers` of AnnData to use for expression values. Defaults to `.X` if None.
8990 n_bins: The number of expression bins for the `'seurat'` method.
9091 ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
92+ key_added: Prefix key that adds the results to `uns`.
93+ Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
94+ Defaults to `pertpy_enrichment`.
9195
9296 Returns:
9397 An AnnData object with scores.
@@ -144,15 +148,15 @@ def score(
144148 seurat = np .dot (control_profiles , drug_weights )
145149 scores = scores - seurat
146150
147- adata .uns ["pertpy_enrichment_score " ] = scores
148- adata .uns ["pertpy_enrichment_variables " ] = weights .columns
151+ adata .uns [f" { key_added } _score " ] = scores
152+ adata .uns [f" { key_added } _variables " ] = weights .columns
149153
150- adata .uns ["pertpy_enrichment_genes " ] = {"var" : pd .DataFrame (columns = ["genes" ]).astype (object )}
151- adata .uns ["pertpy_enrichment_all_genes " ] = {"var" : pd .DataFrame (columns = ["all_genes" ]).astype (object )}
154+ adata .uns [f" { key_added } _genes " ] = {"var" : pd .DataFrame (columns = ["genes" ]).astype (object )}
155+ adata .uns [f" { key_added } _all_genes " ] = {"var" : pd .DataFrame (columns = ["all_genes" ]).astype (object )}
152156
153157 for drug in weights .columns :
154- adata .uns ["pertpy_enrichment_genes " ]["var" ].loc [drug , "genes" ] = "|" .join (adata .var_names [targets [drug ]])
155- adata .uns ["pertpy_enrichment_all_genes " ]["var" ].loc [drug , "all_genes" ] = "|" .join (full_targets [drug ])
158+ adata .uns [f" { key_added } _genes " ]["var" ].loc [drug , "genes" ] = "|" .join (adata .var_names [targets [drug ]])
159+ adata .uns [f" { key_added } _all_genes " ]["var" ].loc [drug , "all_genes" ] = "|" .join (full_targets [drug ])
156160
157161 def hypergeometric (
158162 self ,
@@ -172,9 +176,11 @@ def hypergeometric(
172176 If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package.
173177 Accepts two forms:
174178 - A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
175- - A dictionary of dictionaries defined like above, with names of gene group categories as keys. If passing one of those, specify `nested=True`.
179+ - A dictionary of dictionaries defined like above, with names of gene group categories as keys.
180+ If passing one of those, specify `nested=True`.
176181 nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
177- categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary). In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
182+ categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
183+ In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
178184 pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
179185 direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
180186 Can be `up`, `down`, or `both` (for no selection).
@@ -235,6 +241,7 @@ def gsea(
235241 nested : bool = False ,
236242 categories : str | list [str ] | None = None ,
237243 absolute : bool = False ,
244+ key_added : str = "pertpy_enrichment_gsea" ,
238245 ) -> dict [str , pd .DataFrame ] | tuple [dict [str , pd .DataFrame ], dict [str , dict ]]: # pragma: no cover
239246 """Perform gene set enrichment analysis on the marker gene scores using blitzgsea.
240247
@@ -251,6 +258,8 @@ def gsea(
251258 applicable if `targets=None` or `nested=True`. Defaults to None.
252259 absolute: If True, passes the absolute values of scores to GSEA, improving
253260 statistical power. Defaults to False.
261+ key_added: Prefix key that adds the results to `uns`.
262+ Defaults to `pertpy_enrichment_gsea`.
254263
255264 Returns:
256265 A dictionary with clusters as keys and data frames of test results sorted on
@@ -272,7 +281,7 @@ def gsea(
272281 enrichment [cluster ] = blitzgsea .gsea (df , targets )
273282 plot_gsea_args ["scores" ][cluster ] = df
274283
275- adata .uns ["pertpy_enrichment_gsea" ] = plot_gsea_args
284+ adata .uns [key_added ] = plot_gsea_args
276285
277286 return enrichment
278287
@@ -282,6 +291,7 @@ def plot_dotplot(
282291 targets : dict [str , list [str ]] | dict [str , dict [str , list [str ]]] = None ,
283292 categories : Sequence [str ] = None ,
284293 groupby : str = None ,
294+ key : str = "pertpy_enrichment" ,
285295 ** kwargs ,
286296 ) -> DotPlot | dict | None :
287297 """Plots a dotplot by groupby and categories.
@@ -298,6 +308,8 @@ def plot_dotplot(
298308 categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
299309 For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
300310 groupby: dotplot groupby such as clusters or cell types.
311+ key: Prefix key of enrichment results in `uns`.
312+ Defaults to `pertpy_enrichment`.
301313 kwargs: Passed to scanpy dotplot.
302314
303315 Returns:
@@ -330,8 +342,8 @@ def plot_dotplot(
330342 var_group_labels : list [str ] = []
331343 start = 0
332344
333- enrichment_score_adata = AnnData (adata .uns ["pertpy_enrichment_score " ], obs = adata .obs )
334- enrichment_score_adata .var_names = adata .uns ["pertpy_enrichment_variables " ]
345+ enrichment_score_adata = AnnData (adata .uns [f" { key } _score " ], obs = adata .obs )
346+ enrichment_score_adata .var_names = adata .uns [f" { key } _variables " ]
335347
336348 for group in targets :
337349 targets [group ] = list ( # type: ignore
@@ -352,7 +364,9 @@ def plot_dotplot(
352364
353365 return sc .pl .dotplot (enrichment_score_adata , groupby = groupby , swap_axes = True , ** plot_args , ** kwargs )
354366
355- def plot_gsea (self , adata : AnnData , enrichment : dict [str , pd .DataFrame ], n : int = 10 ) -> None :
367+ def plot_gsea (
368+ self , adata : AnnData , enrichment : dict [str , pd .DataFrame ], n : int = 10 , key : str = "pertpy_enrichment_gsea"
369+ ) -> None :
356370 """Generates a blitzgsea top_table plot.
357371
358372 This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA).
@@ -363,11 +377,12 @@ def plot_gsea(self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int
363377 adata: AnnData object to plot.
364378 enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
365379 n: How many top scores to show for each group. Defaults to 10.
380+ key: GSEA results key in `uns`. Defaults to "pertpy_enrichment_gsea".
366381 """
367382 for cluster in enrichment :
368383 fig = blitzgsea .plot .top_table (
369- adata .uns ["pertpy_enrichment_gsea" ]["scores" ][cluster ],
370- adata .uns ["pertpy_enrichment_gsea" ]["targets" ],
384+ adata .uns [key ]["scores" ][cluster ],
385+ adata .uns [key ]["targets" ],
371386 enrichment [cluster ],
372387 n = n ,
373388 )
0 commit comments