Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add knn imputation #517

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0]
train = np.concatenate([X, Y_train])

reg = LogisticRegression() # TODO dynamically pass this?
reg = LogisticRegression()
reg.fit(train, label)
test_labels = reg.predict_proba(Y_test)
return np.mean(test_labels[:, 1])
Expand Down
61 changes: 55 additions & 6 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from pynndescent import NNDescent
from rich import print

if TYPE_CHECKING:
Expand All @@ -25,7 +26,7 @@ def __init__(self):
def compute_control_diff( # type: ignore
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
group_col: str = None,
reference_key: str = "control",
layer_key: str = None,
Expand Down Expand Up @@ -147,16 +148,16 @@ def add(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
) -> AnnData:
target_col: str = "perturbation",
) -> tuple[AnnData, AnnData] | AnnData:
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality

Args:
adata: Anndata object of size n_perts x dim.
perturbations: Perturbations to add.
reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.

Returns:
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
Expand Down Expand Up @@ -256,8 +257,8 @@ def subtract(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
) -> AnnData:
target_col: str = "perturbation",
) -> tuple[AnnData, AnnData] | AnnData:
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality

Args:
Expand Down Expand Up @@ -358,3 +359,51 @@ def subtract(
return new_perturbation, adata

return new_perturbation

def knn_impute(
self,
adata: AnnData,
column: str = "perturbation",
target_val: str = "unknown",
n_neighbors: int = 5,
use_rep: str = "X_umap",
) -> None:
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.

Args:
adata: The AnnData object containing single-cell data.
column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
target_val: The target value to impute. Defaults to "unknown".
n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.

Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> import numpy as np
>>> adata = sc.datasets.pbmc68k_reduced()
>>> rng = np.random.default_rng()
>>> adata.obs["perturbation"] = rng.choice(["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01])
>>> sc.pp.neighbors(adata)
>>> sc.tl.umap(adata)
>>> ps = pt.tl.PseudobulkSpace()
>>> ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap")
"""
if use_rep not in adata.obsm:
raise ValueError(f"Representation {use_rep} not found in the AnnData object.")

embedding = adata.obsm[use_rep]

nnd = NNDescent(embedding, n_neighbors=n_neighbors)
indices, _ = nnd.query(embedding, k=n_neighbors)

perturbations = np.array(adata.obs[column])
missing_mask = perturbations == target_val

for idx in np.where(missing_mask)[0]:
neighbor_indices = indices[idx]
neighbor_categories = perturbations[neighbor_indices]
most_common = pd.Series(neighbor_categories).mode()[0]
perturbations[idx] = most_common

adata.obs[column] = perturbations
8 changes: 4 additions & 4 deletions pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CentroidSpace(PerturbationSpace):
def compute(
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
layer_key: str = None,
embedding_key: str = "X_umap",
keep_obs: bool = True,
Expand Down Expand Up @@ -115,7 +115,7 @@ class PseudobulkSpace(PerturbationSpace):
def compute(
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
layer_key: str = None,
embedding_key: str = None,
**kwargs,
Expand All @@ -133,13 +133,13 @@ def compute(
AnnData object with one observation per perturbation.

Examples:
>>> import pertpy as pp
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ps = pt.tl.PseudobulkSpace()
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
"""
if "groups_col" not in kwargs:
kwargs["groups_col"] = "perturbations"
kwargs["groups_col"] = "perturbation"

if layer_key is not None and embedding_key is not None:
raise ValueError("Please, select just either layer or embedding for computation.")
Expand Down
2 changes: 1 addition & 1 deletion tests/tools/_distances/test_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_distance_axioms(self, adata, distance):
assert all(np.diag(df.values) == 0) # distance to self is 0

# (M2) Positivity
assert len(df) == np.sum(df.values == 0) # distance to other is not 0 (TODO)
assert len(df) == np.sum(df.values == 0) # distance to other is not 0
assert all(df.values.flatten() >= 0) # distance is non-negative

# (M3) Symmetry
Expand Down
Loading
Loading