Skip to content

Commit

Permalink
Add knn imputation
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Jan 30, 2024
1 parent 02032df commit be64f1f
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 81 deletions.
2 changes: 1 addition & 1 deletion pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0]
train = np.concatenate([X, Y_train])

reg = LogisticRegression() # TODO dynamically pass this?
reg = LogisticRegression()
reg.fit(train, label)
test_labels = reg.predict_proba(Y_test)
return np.mean(test_labels[:, 1])
Expand Down
61 changes: 55 additions & 6 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from pynndescent import NNDescent
from rich import print

if TYPE_CHECKING:
Expand All @@ -25,7 +26,7 @@ def __init__(self):
def compute_control_diff( # type: ignore
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
group_col: str = None,
reference_key: str = "control",
layer_key: str = None,
Expand Down Expand Up @@ -147,16 +148,16 @@ def add(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
) -> AnnData:
target_col: str = "perturbation",
) -> tuple[AnnData, AnnData] | AnnData:
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
Args:
adata: Anndata object of size n_perts x dim.
perturbations: Perturbations to add.
reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.
Returns:
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
Expand Down Expand Up @@ -256,8 +257,8 @@ def subtract(
perturbations: Iterable[str],
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
) -> AnnData:
target_col: str = "perturbation",
) -> tuple[AnnData, AnnData] | AnnData:
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
Args:
Expand Down Expand Up @@ -358,3 +359,51 @@ def subtract(
return new_perturbation, adata

return new_perturbation

def knn_impute(
self,
adata: AnnData,
column: str = "perturbation",
target_val: str = "unknown",
n_neighbors: int = 5,
use_rep: str = "X_umap",
) -> None:
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
Args:
adata: The AnnData object containing single-cell data.
column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
target_val: The target value to impute. Defaults to "unknown".
n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.
Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> import numpy as np
>>> adata = sc.datasets.pbmc68k_reduced()
>>> rng = np.random.default_rng()
>>> adata.obs["perturbation"] = rng.choice(["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01])
>>> sc.pp.neighbors(adata)
>>> sc.tl.umap(adata)
>>> ps = pt.tl.PseudobulkSpace()
>>> ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap")
"""
if use_rep not in adata.obsm:
raise ValueError(f"Representation {use_rep} not found in the AnnData object.")

embedding = adata.obsm[use_rep]

nnd = NNDescent(embedding, n_neighbors=n_neighbors)
indices, _ = nnd.query(embedding, k=n_neighbors)

perturbations = np.array(adata.obs[column])
missing_mask = perturbations == target_val

for idx in np.where(missing_mask)[0]:
neighbor_indices = indices[idx]
neighbor_categories = perturbations[neighbor_indices]
most_common = pd.Series(neighbor_categories).mode()[0]
perturbations[idx] = most_common

adata.obs[column] = perturbations
8 changes: 4 additions & 4 deletions pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CentroidSpace(PerturbationSpace):
def compute(
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
layer_key: str = None,
embedding_key: str = "X_umap",
keep_obs: bool = True,
Expand Down Expand Up @@ -115,7 +115,7 @@ class PseudobulkSpace(PerturbationSpace):
def compute(
self,
adata: AnnData,
target_col: str = "perturbations",
target_col: str = "perturbation",
layer_key: str = None,
embedding_key: str = None,
**kwargs,
Expand All @@ -133,13 +133,13 @@ def compute(
AnnData object with one observation per perturbation.
Examples:
>>> import pertpy as pp
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ps = pt.tl.PseudobulkSpace()
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
"""
if "groups_col" not in kwargs:
kwargs["groups_col"] = "perturbations"
kwargs["groups_col"] = "perturbation"

if layer_key is not None and embedding_key is not None:
raise ValueError("Please, select just either layer or embedding for computation.")
Expand Down
2 changes: 1 addition & 1 deletion tests/tools/_distances/test_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_distance_axioms(self, adata, distance):
assert all(np.diag(df.values) == 0) # distance to self is 0

# (M2) Positivity
assert len(df) == np.sum(df.values == 0) # distance to other is not 0 (TODO)
assert len(df) == np.sum(df.values == 0) # distance to other is not 0
assert all(df.values.flatten() >= 0) # distance is non-negative

# (M3) Symmetry
Expand Down
Loading

0 comments on commit be64f1f

Please sign in to comment.