From be64f1f53ebd676d8f718e605f4d84cbdc17e958 Mon Sep 17 00:00:00 2001 From: zethson Date: Tue, 30 Jan 2024 22:44:59 +0100 Subject: [PATCH] Add knn imputation Signed-off-by: zethson --- pertpy/tools/_distances/_distances.py | 2 +- .../_perturbation_space.py | 61 ++++++++- pertpy/tools/_perturbation_space/_simple.py | 8 +- tests/tools/_distances/test_distances.py | 2 +- .../test_simple_perturbation_space.py | 129 ++++++++---------- 5 files changed, 121 insertions(+), 81 deletions(-) diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index 9a2c3620..a33a4b14 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -796,7 +796,7 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0] train = np.concatenate([X, Y_train]) - reg = LogisticRegression() # TODO dynamically pass this? + reg = LogisticRegression() reg.fit(train, label) test_labels = reg.predict_proba(Y_test) return np.mean(test_labels[:, 1]) diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index bb533c22..5da60bfa 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd from anndata import AnnData +from pynndescent import NNDescent from rich import print if TYPE_CHECKING: @@ -25,7 +26,7 @@ def __init__(self): def compute_control_diff( # type: ignore self, adata: AnnData, - target_col: str = "perturbations", + target_col: str = "perturbation", group_col: str = None, reference_key: str = "control", layer_key: str = None, @@ -147,8 +148,8 @@ def add( perturbations: Iterable[str], reference_key: str = "control", ensure_consistency: bool = False, - target_col: str = "perturbations", - ) -> AnnData: + target_col: str = "perturbation", + ) -> tuple[AnnData, AnnData] | AnnData: """Add perturbations linearly. Assumes input of size n_perts x dimensionality Args: @@ -156,7 +157,7 @@ def add( perturbations: Perturbations to add. reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'. ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space. - target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'. + target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'. Returns: Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations. @@ -256,8 +257,8 @@ def subtract( perturbations: Iterable[str], reference_key: str = "control", ensure_consistency: bool = False, - target_col: str = "perturbations", - ) -> AnnData: + target_col: str = "perturbation", + ) -> tuple[AnnData, AnnData] | AnnData: """Subtract perturbations linearly. Assumes input of size n_perts x dimensionality Args: @@ -358,3 +359,51 @@ def subtract( return new_perturbation, adata return new_perturbation + + def knn_impute( + self, + adata: AnnData, + column: str = "perturbation", + target_val: str = "unknown", + n_neighbors: int = 5, + use_rep: str = "X_umap", + ) -> None: + """Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`. + + Args: + adata: The AnnData object containing single-cell data. + column: The column name in AnnData object to perform imputation on. Defaults to "perturbation". + target_val: The target value to impute. Defaults to "unknown". + n_neighbors: Number of neighbors to use for imputation. Defaults to 5. + use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> import numpy as np + >>> adata = sc.datasets.pbmc68k_reduced() + >>> rng = np.random.default_rng() + >>> adata.obs["perturbation"] = rng.choice(["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]) + >>> sc.pp.neighbors(adata) + >>> sc.tl.umap(adata) + >>> ps = pt.tl.PseudobulkSpace() + >>> ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap") + """ + if use_rep not in adata.obsm: + raise ValueError(f"Representation {use_rep} not found in the AnnData object.") + + embedding = adata.obsm[use_rep] + + nnd = NNDescent(embedding, n_neighbors=n_neighbors) + indices, _ = nnd.query(embedding, k=n_neighbors) + + perturbations = np.array(adata.obs[column]) + missing_mask = perturbations == target_val + + for idx in np.where(missing_mask)[0]: + neighbor_indices = indices[idx] + neighbor_categories = perturbations[neighbor_indices] + most_common = pd.Series(neighbor_categories).mode()[0] + perturbations[idx] = most_common + + adata.obs[column] = perturbations diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index 49a2dec3..dc245bb0 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -15,7 +15,7 @@ class CentroidSpace(PerturbationSpace): def compute( self, adata: AnnData, - target_col: str = "perturbations", + target_col: str = "perturbation", layer_key: str = None, embedding_key: str = "X_umap", keep_obs: bool = True, @@ -115,7 +115,7 @@ class PseudobulkSpace(PerturbationSpace): def compute( self, adata: AnnData, - target_col: str = "perturbations", + target_col: str = "perturbation", layer_key: str = None, embedding_key: str = None, **kwargs, @@ -133,13 +133,13 @@ def compute( AnnData object with one observation per perturbation. Examples: - >>> import pertpy as pp + >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") """ if "groups_col" not in kwargs: - kwargs["groups_col"] = "perturbations" + kwargs["groups_col"] = "perturbation" if layer_key is not None and embedding_key is not None: raise ValueError("Please, select just either layer or embedding for computation.") diff --git a/tests/tools/_distances/test_distances.py b/tests/tools/_distances/test_distances.py index 73f367cc..a9ac44e2 100644 --- a/tests/tools/_distances/test_distances.py +++ b/tests/tools/_distances/test_distances.py @@ -64,7 +64,7 @@ def test_distance_axioms(self, adata, distance): assert all(np.diag(df.values) == 0) # distance to self is 0 # (M2) Positivity - assert len(df) == np.sum(df.values == 0) # distance to other is not 0 (TODO) + assert len(df) == np.sum(df.values == 0) # distance to other is not 0 assert all(df.values.flatten() >= 0) # distance is non-negative # (M3) Symmetry diff --git a/tests/tools/_perturbation_space/test_simple_perturbation_space.py b/tests/tools/_perturbation_space/test_simple_perturbation_space.py index 142c8b83..eef0dd9b 100644 --- a/tests/tools/_perturbation_space/test_simple_perturbation_space.py +++ b/tests/tools/_perturbation_space/test_simple_perturbation_space.py @@ -2,15 +2,34 @@ import pandas as pd import pertpy as pt import pytest +import scanpy as sc from anndata import AnnData -def test_differential_response(): - rng = np.random.default_rng() +@pytest.fixture +def rng(): + return np.random.default_rng() + + +@pytest.fixture +def adata(rng): + X = rng.random((69, 50)) + adata = AnnData(X) + perturbations = np.array(["control", "target1", "target2"] * 22 + ["unknown"] * 3) + adata.obs["perturbation"] = perturbations + sc.pp.pca(adata) + sc.pp.neighbors(adata, use_rep="X") + sc.tl.umap(adata) + + return adata + + +@pytest.fixture +def adata_simple(rng): X = rng.random(size=(10, 5)) obs = pd.DataFrame( { - "perturbations": [ + "perturbation": [ "control", "target1", "target1", @@ -26,19 +45,20 @@ def test_differential_response(): ) adata = AnnData(X, obs=obs) - # Compute the differential response + return adata + + +def test_differential_response(adata_simple): ps = pt.tl.PseudobulkSpace() - ps_adata = ps.compute_control_diff(adata, copy=True) + ps_adata = ps.compute_control_diff(adata_simple, target_col="perturbation", copy=True) - # Test that the differential response was computed correctly - expected_diff_matrix = adata.X - adata.X[0, :] + expected_diff_matrix = adata_simple.X - adata_simple.X[0, :] np.testing.assert_allclose(ps_adata.X, expected_diff_matrix, rtol=1e-4) - # Check that the function raises an error if the reference key is not found with pytest.raises(ValueError): ps.compute_control_diff( - adata, - target_col="perturbations", + adata_simple, + target_col="perturbation", reference_key="not_found", layer_key="counts", new_layer_key="counts_diff", @@ -48,59 +68,32 @@ def test_differential_response(): ) -def test_pseudobulk_response(): - rng = np.random.default_rng() - X = rng.random(size=(10, 5)) - obs = pd.DataFrame( - { - "perturbations": [ - "control", - "target1", - "target1", - "target2", - "target2", - "target1", - "target1", - "target2", - "target2", - "target2", - ] - } - ) - adata = AnnData(X, obs=obs) - - # Compute the pseudobulk +def test_pseudobulk_response(adata_simple): ps = pt.tl.PseudobulkSpace() - psadata = ps.compute(adata, mode="mean", min_cells=0, min_counts=0) + psadata = ps.compute(adata_simple, mode="mean", min_cells=0, min_counts=0) - # Test that the pseudobulk response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].X.mean(0) + adata_target1 = adata_simple[adata_simple.obs.perturbation == "target1"].X.mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) - # Test in UMAP space - adata.obsm["X_umap"] = X + adata_simple.obsm["X_umap"] = adata_simple.X - # Compute the pseudobulk ps = pt.tl.PseudobulkSpace() - psadata = ps.compute(adata, embedding_key="X_umap", mode="mean", min_cells=0, min_counts=0) + psadata = ps.compute(adata_simple, embedding_key="X_umap", mode="mean", min_cells=0, min_counts=0) - # Test that the pseudobulk response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].obsm["X_umap"].mean(0) + adata_target1 = adata_simple[adata_simple.obs.perturbation == "target1"].obsm["X_umap"].mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) - # Check that the function raises an error if the layer key is not found with pytest.raises(ValueError): ps.compute( - adata, - target_col="perturbations", + adata_simple, + target_col="perturbation", layer_key="not_found", ) - # Check that the function raises an error if the layer key and embedding key are used at the same time with pytest.raises(ValueError): ps.compute( - adata, - target_col="perturbations", + adata_simple, + target_col="perturbation", embedding_key="not_found", layer_key="not_found", ) @@ -130,39 +123,34 @@ def test_centroid_umap_response(): elif value == "target2": X[i, :] = 30 - obs = pd.DataFrame({"perturbations": pert_index}) + obs = pd.DataFrame({"perturbation": pert_index}) adata = AnnData(X, obs=obs) adata.obsm["X_umap"] = X - # Compute the centroids ps = pt.tl.CentroidSpace() psadata = ps.compute(adata, embedding_key="X_umap") - # Test that the centroids response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].obsm["X_umap"].mean(0) + adata_target1 = adata[adata.obs.perturbation == "target1"].obsm["X_umap"].mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) ps = pt.tl.CentroidSpace() psadata = ps.compute(adata) # if nothing specific, compute with X, and X and X_umap are the same - # Test that the centroids response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].obsm["X_umap"].mean(0) + adata_target1 = adata[adata.obs.perturbation == "target1"].obsm["X_umap"].mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) - # Check that the function raises an error if the embedding key is not found with pytest.raises(ValueError): ps.compute( adata, - target_col="perturbations", + target_col="perturbation", embedding_key="not_found", ) - # Check that the function raises an error if the layer key and embedding key are used at the same time with pytest.raises(ValueError): ps.compute( adata, - target_col="perturbations", + target_col="perturbation", embedding_key="not_found", layer_key="not_found", ) @@ -174,7 +162,7 @@ def test_linear_operations(): X = rng.random(size=(10, 5)) obs = pd.DataFrame( { - "perturbations": [ + "perturbation": [ "control", "target1", "target1", @@ -191,21 +179,17 @@ def test_linear_operations(): adata = AnnData(X, obs=obs) adata.obsm["X_umap"] = X - # Compute pseudobulk ps = pt.tl.PseudobulkSpace() psadata = ps.compute(adata, mode="mean", min_cells=0, min_counts=0) psadata_umap = ps.compute(adata, mode="mean", min_cells=0, min_counts=0, embedding_key="X_umap") psadata.obsm["X_umap"] = psadata_umap.X - # Perform summation ps_adata, data_compare = ps.add(psadata, perturbations=["target1", "target2"], ensure_consistency=True) - # Test in X test = data_compare["control"].X + data_compare["target1"].X + data_compare["target2"].X np.testing.assert_allclose(test, ps_adata["target1+target2"].X, rtol=1e-4) - # Test in UMAP embedding test = ( data_compare["control"].obsm["X_umap_control_diff"] + data_compare["target1"].obsm["X_umap_control_diff"] @@ -213,26 +197,20 @@ def test_linear_operations(): ) np.testing.assert_allclose(test, ps_adata["target1+target2"].obsm["X_umap"], rtol=1e-4) - # Perform subtraction ps_adata, data_compare = ps.subtract( psadata, reference_key="target1", perturbations=["target2"], ensure_consistency=True ) - # Test in X test = data_compare["target1"].X - data_compare["target2"].X np.testing.assert_allclose(test, ps_adata["target1-target2"].X, rtol=1e-4) - # Operations after control diff, do the results match? ps_adata = ps.compute_control_diff(psadata, copy=True) - # Do summation ps_adata2 = ps.add(ps_adata, perturbations=["target1", "target2"]) - # Test in X test = ps_adata["control"].X + ps_adata["target1"].X + ps_adata["target2"].X np.testing.assert_allclose(test, ps_adata2["target1+target2"].X, rtol=1e-4) - # Do subtract ps_adata2 = ps.subtract(ps_adata, reference_key="target1", perturbations=["target1"]) ps_vector = ps_adata2["target1-target1"].X np.testing.assert_allclose(ps_adata2["control"].X, ps_adata2["target1-target1"].X, rtol=1e-4) @@ -245,7 +223,6 @@ def test_linear_operations(): # Compare process data vs pseudobulk before, should be the same np.testing.assert_allclose(ps_inner_vector, ps_vector, rtol=1e-4) - # Check result in UMAP np.testing.assert_allclose( data_compare["control"].obsm["X_umap_control_diff"], ps_adata2["target1-target1"].obsm["X_umap"], rtol=1e-4 ) @@ -263,3 +240,17 @@ def test_linear_operations(): ps_adata, perturbations=["target1", "target3"], ) + + +def test_knn_impute(): + rng = np.random.default_rng() + X = rng.standard_normal((69, 50)) + adata = AnnData(X) + perturbations = np.array(["A", "B", "C"] * 22 + ["unknown"] * 3) + adata.obs["perturbation"] = perturbations + sc.pp.neighbors(adata, use_rep="X") + sc.tl.umap(adata) + + ps = pt.tl.PseudobulkSpace() + ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap") + assert "unknown" not in adata.obs["perturbation"]