Skip to content

Commit

Permalink
Merge branch 'main' into plot_example
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva authored Feb 4, 2024
2 parents 20176bc + 4b20f6d commit 4c4d6c6
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 114 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
15 changes: 8 additions & 7 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from math import floor, nan
from typing import TYPE_CHECKING, Any, Literal

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -140,8 +141,8 @@ def load(
# filter samples according to label
if condition_label is not None and treatment_label is not None:
print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
adata = AnnData.concatenate(
adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
adata = ad.concat(
[adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
)
label_encoder = LabelEncoder()
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
Expand Down Expand Up @@ -235,7 +236,7 @@ def sample(self, adata: AnnData, categorical: bool, subsample_size: int, random_
random_state=random_state,
)
)
subsample = AnnData.concatenate(*label_subsamples, index_unique=None)
subsample = ad.concat([*label_subsamples], index_unique=None)
else:
subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)

Expand Down Expand Up @@ -414,17 +415,17 @@ def set_scorer(
"""
if multiclass:
return {
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
"auc": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
"auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
"accuracy": make_scorer(accuracy_score),
"precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
"f1": make_scorer(f1_score, average="macro"),
"recall": make_scorer(recall_score, average="macro"),
}
return (
{
"augur_score": make_scorer(roc_auc_score, needs_proba=True),
"auc": make_scorer(roc_auc_score, needs_proba=True),
"augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
"auc": make_scorer(roc_auc_score, response_method="predict_proba"),
"accuracy": make_scorer(accuracy_score),
"precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
"f1": make_scorer(f1_score, average="binary"),
Expand Down
2 changes: 1 addition & 1 deletion pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0]
train = np.concatenate([X, Y_train])

reg = LogisticRegression() # TODO dynamically pass this?
reg = LogisticRegression()
reg.fit(train, label)
test_labels = reg.predict_proba(Y_test)
return np.mean(test_labels[:, 1])
Expand Down
29 changes: 9 additions & 20 deletions pertpy/tools/_mixscape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import copy
import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, Literal

Expand All @@ -10,7 +9,6 @@
import scanpy as sc
import seaborn as sns
from matplotlib import pyplot as pl
from rich import print
from scanpy import get
from scanpy._settings import settings
from scanpy._utils import _check_use_raw, sanitize_anndata
Expand Down Expand Up @@ -477,15 +475,6 @@ def _get_perturbation_markers(
return perturbation_markers

def _get_column_indices(self, adata, col_names):
"""Fetches the column indices in X for a given list of column names
Args:
adata: :class:`~anndata.AnnData` object
col_names: Column names to extract the indices for
Returns:
Set of column indices
"""
if isinstance(col_names, str): # pragma: no cover
col_names = [col_names]

Expand Down Expand Up @@ -827,7 +816,7 @@ def plot_perturbscore( # pragma: no cover
)
pl.xlabel("Perturbation score", fontsize=16)
pl.ylabel("Cell density", fontsize=16)
pl.title("Density Plot using Seaborn and Matplotlib", fontsize=18)
pl.title("Density", fontsize=18)
pl.legend(title="mixscape class", title_fontsize=14, fontsize=12)
sns.despine()

Expand Down Expand Up @@ -855,19 +844,21 @@ def plot_violin( # pragma: no cover
ax: Axes | None = None,
**kwargs,
):
"""Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first.
"""Violin plot using mixscape results.
Requires `pt.tl.mixscape` to be run first.
Args:
adata: The annotated data object.
target_gene: Target gene name to plot.
target_gene_idents: Target gene name to plot.
keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
log: Plot on logarithmic axis.
use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
stripplot: Add a stripplot on top of the violin plot.
order: Order in which to show the categories.
xlabel: Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
ylabel: Label of the y axis. If `None` and `groupby` is `None`, defaults to `'value'`.
xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
If `None` and `groubpy` is not `None`, defaults to `keys`.
show: Show the plot, do not return axis.
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
Expand Down Expand Up @@ -1072,11 +1063,9 @@ def plot_lda( # pragma: no cover
.. image:: ../_static/docstring_previews/mixscape_lda.png
"""
if mixscape_class not in adata.obs:
raise ValueError(
f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first first.'
)
raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.')
if lda_key not in adata.uns:
raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Run the `lda` function first.')
raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.')

adata_subset = adata[
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
Expand Down
61 changes: 55 additions & 6 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from pynndescent import NNDescent
from rich import print

if TYPE_CHECKING:
Expand All @@ -25,7 +26,7 @@ def __init__(self):
def compute_control_diff( # type: ignore
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
group_col: str = None,
reference_key: str = "control",
layer_key: str = None,
Expand Down Expand Up @@ -147,16 +148,16 @@ def add(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
) -> AnnData:
target_col: str = "perturbation",
) -> tuple[AnnData, AnnData] | AnnData:
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
Args:
adata: Anndata object of size n_perts x dim.
perturbations: Perturbations to add.
reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.
Returns:
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
Expand Down Expand Up @@ -256,8 +257,8 @@ def subtract(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
) -> AnnData:
target_col: str = "perturbation",
) -> tuple[AnnData, AnnData] | AnnData:
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
Args:
Expand Down Expand Up @@ -358,3 +359,51 @@ def subtract(
return new_perturbation, adata

return new_perturbation

def label_transfer(
self,
adata: AnnData,
column: str = "perturbation",
target_val: str = "unknown",
n_neighbors: int = 5,
use_rep: str = "X_umap",
) -> None:
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
Args:
adata: The AnnData object containing single-cell data.
column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
target_val: The target value to impute. Defaults to "unknown".
n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.
Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> import numpy as np
>>> adata = sc.datasets.pbmc68k_reduced()
>>> rng = np.random.default_rng()
>>> adata.obs["perturbation"] = rng.choice(["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01])
>>> sc.pp.neighbors(adata)
>>> sc.tl.umap(adata)
>>> ps = pt.tl.PseudobulkSpace()
>>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
"""
if use_rep not in adata.obsm:
raise ValueError(f"Representation {use_rep} not found in the AnnData object.")

embedding = adata.obsm[use_rep]

nnd = NNDescent(embedding, n_neighbors=n_neighbors)
indices, _ = nnd.query(embedding, k=n_neighbors)

perturbations = np.array(adata.obs[column])
missing_mask = perturbations == target_val

for idx in np.where(missing_mask)[0]:
neighbor_indices = indices[idx]
neighbor_categories = perturbations[neighbor_indices]
most_common = pd.Series(neighbor_categories).mode()[0]
perturbations[idx] = most_common

adata.obs[column] = perturbations
10 changes: 5 additions & 5 deletions pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CentroidSpace(PerturbationSpace):
def compute(
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
layer_key: str = None,
embedding_key: str = "X_umap",
keep_obs: bool = True,
Expand Down Expand Up @@ -115,7 +115,7 @@ class PseudobulkSpace(PerturbationSpace):
def compute(
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
layer_key: str = None,
embedding_key: str = None,
**kwargs,
Expand All @@ -133,13 +133,13 @@ def compute(
AnnData object with one observation per perturbation.
Examples:
>>> import pertpy as pp
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ps = pt.tl.PseudobulkSpace()
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
"""
if "groups_col" not in kwargs:
kwargs["groups_col"] = "perturbations"
kwargs["groups_col"] = "perturbation"

if layer_key is not None and embedding_key is not None:
raise ValueError("Please, select just either layer or embedding for computation.")
Expand Down Expand Up @@ -244,7 +244,7 @@ def compute( # type: ignore
copy: bool = True,
return_object: bool = False,
**kwargs,
) -> tuple[AnnData, object | AnnData]:
) -> tuple[AnnData, object] | AnnData:
"""Computes a clustering using Density-based spatial clustering of applications (DBSCAN).
Args:
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np
import pytest


@pytest.fixture
def rng():
return np.random.default_rng()
2 changes: 1 addition & 1 deletion tests/tools/_distances/test_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_distance_axioms(self, adata, distance):
assert all(np.diag(df.values) == 0) # distance to self is 0

# (M2) Positivity
assert len(df) == np.sum(df.values == 0) # distance to other is not 0 (TODO)
assert len(df) == np.sum(df.values == 0) # distance to other is not 0
assert all(df.values.flatten() >= 0) # distance is non-negative

# (M3) Symmetry
Expand Down
Loading

0 comments on commit 4c4d6c6

Please sign in to comment.