diff --git a/pertpy/tools/_enrichment.py b/pertpy/tools/_enrichment.py index 52345dd8..4ea53e33 100644 --- a/pertpy/tools/_enrichment.py +++ b/pertpy/tools/_enrichment.py @@ -365,7 +365,12 @@ def plot_dotplot( return sc.pl.dotplot(enrichment_score_adata, groupby=groupby, swap_axes=True, **plot_args, **kwargs) def plot_gsea( - self, adata: AnnData, enrichment: dict[str, pd.DataFrame], n: int = 10, key: str = "pertpy_enrichment_gsea" + self, + adata: AnnData, + enrichment: dict[str, pd.DataFrame], + n: int = 10, + key: str = "pertpy_enrichment_gsea", + interactive_plot: bool = False, ) -> None: """Generates a blitzgsea top_table plot. @@ -378,6 +383,7 @@ def plot_gsea( enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values. n: How many top scores to show for each group. Defaults to 10. key: GSEA results key in `uns`. Defaults to "pertpy_enrichment_gsea". + interactive_plot: Whether to plot interactively or not. Defaults to False. """ for cluster in enrichment: fig = blitzgsea.plot.top_table( @@ -385,6 +391,7 @@ def plot_gsea( adata.uns[key]["targets"], enrichment[cluster], n=n, + interactive_plot=interactive_plot, ) fig.suptitle(cluster) fig.show()