Skip to content

Commit

Permalink
harmonise augur
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva committed Jan 24, 2024
1 parent c7c008a commit e6e9645
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,11 @@ def predict_differential_prioritization(
return delta

def plot_dp_scatter(
self, results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure: bool = False
self,
results: pd.DataFrame,
top_n: int = None,
ax: Axes = None,
return_figure: bool = False
) -> Figure | Axes:
"""Plot scatterplot of differential prioritization.
Expand Down Expand Up @@ -997,6 +1001,9 @@ def plot_dp_scatter(
>>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
permuted_results1=results_15_permute, permuted_results2=results_48_permute)
>>> ag_rfc.plot_dp_scatter(pvals)
Preview:
# TODO: add preview
"""
x = results["mean_augur_score1"]
y = results["mean_augur_score2"]
Expand Down Expand Up @@ -1026,7 +1033,12 @@ def plot_dp_scatter(
return fig if return_figure else ax

def plot_important_features(
self, data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None, return_figure: bool = False
self,
data: dict[str, Any],
key: str = "augurpy_results",
top_n: int = 10,
ax: Axes = None,
return_figure: bool = False
) -> Figure | Axes:
"""Plot a lollipop plot of the n features with largest feature importances.
Expand All @@ -1047,6 +1059,9 @@ def plot_important_features(
>>> loaded_data = ag_rfc.load(adata)
>>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
>>> ag_rfc.plot_important_features(v_results)
Preview:
# TODO: add preview
"""
if isinstance(data, AnnData):
results = data.uns[key]
Expand Down Expand Up @@ -1077,7 +1092,11 @@ def plot_important_features(
return fig if return_figure else ax

def plot_lollipop(
self, data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False
self,
data: dict[str, Any],
key: str = "augurpy_results",
ax: Axes = None,
return_figure: bool = False
) -> Figure | Axes:
"""Plot a lollipop plot of the mean augur values.
Expand All @@ -1097,6 +1116,9 @@ def plot_lollipop(
>>> loaded_data = ag_rfc.load(adata)
>>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
>>> ag_rfc.plot_lollipop(v_results)
Preview:
# TODO: add preview
"""
if isinstance(data, AnnData):
results = data.uns[key]
Expand Down Expand Up @@ -1124,7 +1146,11 @@ def plot_lollipop(
return fig if return_figure else ax

def plot_scatterplot(
self, results1: dict[str, Any], results2: dict[str, Any], top_n=None, return_figure: bool = False
self,
results1: dict[str, Any],
results2: dict[str, Any],
top_n: int = None,
return_figure: bool = False
) -> Figure | Axes:
"""Create scatterplot with two augur results.
Expand All @@ -1145,6 +1171,9 @@ def plot_scatterplot(
>>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
>>> v_adata, v_results = ag_rfc.predict(loaded_data, subsample_size=20, select_variance_features=True, n_threads=4)
>>> ag_rfc.plot_scatterplot(v_results, h_results)
Preview:
# TODO: add preview
"""
cell_types = results1["summary_metrics"].columns

Expand Down

0 comments on commit e6e9645

Please sign in to comment.