From e6e9645ede24828a75e1cc3171fb531237c750d7 Mon Sep 17 00:00:00 2001 From: Altana Namsaraeva <99650244+namsaraeva@users.noreply.github.com> Date: Wed, 24 Jan 2024 23:15:44 +0100 Subject: [PATCH] harmonise augur --- pertpy/tools/_augur.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index c279b7fa..97989b5e 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -968,7 +968,11 @@ def predict_differential_prioritization( return delta def plot_dp_scatter( - self, results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure: bool = False + self, + results: pd.DataFrame, + top_n: int = None, + ax: Axes = None, + return_figure: bool = False ) -> Figure | Axes: """Plot scatterplot of differential prioritization. @@ -997,6 +1001,9 @@ def plot_dp_scatter( >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \ permuted_results1=results_15_permute, permuted_results2=results_48_permute) >>> ag_rfc.plot_dp_scatter(pvals) + + Preview: + # TODO: add preview """ x = results["mean_augur_score1"] y = results["mean_augur_score2"] @@ -1026,7 +1033,12 @@ def plot_dp_scatter( return fig if return_figure else ax def plot_important_features( - self, data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None, return_figure: bool = False + self, + data: dict[str, Any], + key: str = "augurpy_results", + top_n: int = 10, + ax: Axes = None, + return_figure: bool = False ) -> Figure | Axes: """Plot a lollipop plot of the n features with largest feature importances. @@ -1047,6 +1059,9 @@ def plot_important_features( >>> loaded_data = ag_rfc.load(adata) >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) >>> ag_rfc.plot_important_features(v_results) + + Preview: + # TODO: add preview """ if isinstance(data, AnnData): results = data.uns[key] @@ -1077,7 +1092,11 @@ def plot_important_features( return fig if return_figure else ax def plot_lollipop( - self, data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False + self, + data: dict[str, Any], + key: str = "augurpy_results", + ax: Axes = None, + return_figure: bool = False ) -> Figure | Axes: """Plot a lollipop plot of the mean augur values. @@ -1097,6 +1116,9 @@ def plot_lollipop( >>> loaded_data = ag_rfc.load(adata) >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) >>> ag_rfc.plot_lollipop(v_results) + + Preview: + # TODO: add preview """ if isinstance(data, AnnData): results = data.uns[key] @@ -1124,7 +1146,11 @@ def plot_lollipop( return fig if return_figure else ax def plot_scatterplot( - self, results1: dict[str, Any], results2: dict[str, Any], top_n=None, return_figure: bool = False + self, + results1: dict[str, Any], + results2: dict[str, Any], + top_n: int = None, + return_figure: bool = False ) -> Figure | Axes: """Create scatterplot with two augur results. @@ -1145,6 +1171,9 @@ def plot_scatterplot( >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4) >>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4) >>> ag_rfc.plot_scatterplot(v_results, h_results) + + Preview: + # TODO: add preview """ cell_types = results1["summary_metrics"].columns