From 98e2bdbbde143b1e6c855ed2ea494d7367f22a17 Mon Sep 17 00:00:00 2001 From: Stefan Peidli Date: Mon, 30 Sep 2024 18:35:08 +0200 Subject: [PATCH] Improve KNN label_transfer in PerturbationSpace (#658) * Add uncertainty score in KNN label_transfer in PerturbationSpace Certainty is quantified as the fraction of nearest neighbors belonging to the classified (i.e. the most abundant) label compared to the total number of nearest neighbors. * Update pre-commit-config.yaml Replaces yanked dependency of mypy "types-pkg-resources" with "types-setuptools" as recommended: https://pypi.org/project/types-pkg-resources/ * Improve label imputation in PerturbationSpace class Key changes: - Now uses KNN graph in adata: saves cost and increases consistency - Vectorized operations instead of expensive for loop - Distance weighting for KNN imputation - Quantifies uncertainty as local KNN label entropy --- .pre-commit-config.yaml | 2 +- .../_perturbation_space.py | 69 +++++++++++-------- .../test_simple_perturbation_space.py | 12 +++- 3 files changed, 50 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3dc84253..65184420 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] additional_dependencies: - ["types-pkg-resources", "types-requests", "types-attrs"] + ["types-setuptools", "types-requests", "types-attrs"] - repo: local hooks: - id: forbid-to-commit diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index 9a5a475c..04f58d6b 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -7,6 +7,7 @@ from anndata import AnnData from lamin_utils import logger from rich import print +from scipy.stats import entropy if TYPE_CHECKING: from collections.abc import Iterable @@ -364,50 +365,58 @@ def label_transfer( self, adata: AnnData, column: str = "perturbation", + column_uncertainty_score_key: str = "perturbation_transfer_uncertainty", target_val: str = "unknown", - n_neighbors: int = 5, - use_rep: str = "X_umap", + neighbors_key: str = "neighbors", ) -> None: """Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`. + Uncertainty is calculated as the entropy of the label distribution in the neighborhood of the target cell. + In other words, a cell where all neighbors have the same set of labels will have an uncertainty of 0, whereas a cell + where all neighbors have many different labels will have high uncertainty. + Args: adata: The AnnData object containing single-cell data. - column: The column name in AnnData object to perform imputation on. + column: The column name in adata.obs to perform imputation on. + column_uncertainty_score_key: The column name in adata.obs to store the uncertainty score of the label transfer. target_val: The target value to impute. - n_neighbors: Number of neighbors to use for imputation. - use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. + neighbors_key: The key in adata.uns where the neighbors are stored. Examples: >>> import pertpy as pt >>> import scanpy as sc >>> import numpy as np >>> adata = sc.datasets.pbmc68k_reduced() - >>> rng = np.random.default_rng() - >>> adata.obs["perturbation"] = rng.choice( - ... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01] - ... ) + >>> # randomly dropout 10% of the data annotations + >>> adata.obs["perturbation"] = adata.obs["louvain"].astype(str).copy() + >>> random_cells = np.random.choice(adata.obs.index, int(adata.obs.shape[0] * 0.1), replace=False) + >>> adata.obs.loc[random_cells, "perturbation"] = "unknown" >>> sc.pp.neighbors(adata) >>> sc.tl.umap(adata) >>> ps = pt.tl.PseudobulkSpace() - >>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap") + >>> ps.label_transfer(adata) """ - if use_rep not in adata.obsm: - raise ValueError(f"Representation {use_rep} not found in the AnnData object.") - - embedding = adata.obsm[use_rep] - - from pynndescent import NNDescent - - nnd = NNDescent(embedding, n_neighbors=n_neighbors) - indices, _ = nnd.query(embedding, k=n_neighbors) - - perturbations = np.array(adata.obs[column]) - missing_mask = perturbations == target_val - - for idx in np.where(missing_mask)[0]: - neighbor_indices = indices[idx] - neighbor_categories = perturbations[neighbor_indices] - most_common = pd.Series(neighbor_categories).mode()[0] - perturbations[idx] = most_common - - adata.obs[column] = perturbations + if neighbors_key not in adata.uns: + raise ValueError(f"Key {neighbors_key} not found in adata.uns. Please run `sc.pp.neighbors` first.") + + labels = adata.obs[column].astype(str) + target_cells = labels == target_val + + connectivities = adata.obsp[adata.uns[neighbors_key]["connectivities_key"]] + # convert labels to an incidence matrix + one_hot_encoded_labels = adata.obs[column].astype(str).str.get_dummies() + # convert to distance-weighted neighborhood incidence matrix + weighted_label_occurence = pd.DataFrame( + (one_hot_encoded_labels.values.T * connectivities).T, + index=adata.obs_names, + columns=one_hot_encoded_labels.columns, + ) + # choose best label for each target cell + best_labels = weighted_label_occurence.drop(target_val, axis=1)[target_cells].idxmax(axis=1) + adata.obs[column] = labels + adata.obs.loc[target_cells, column] = best_labels + + # calculate uncertainty + uncertainty = np.zeros(adata.n_obs) + uncertainty[target_cells] = entropy(weighted_label_occurence.drop(target_val, axis=1)[target_cells], axis=1) + adata.obs[column_uncertainty_score_key] = uncertainty diff --git a/tests/tools/_perturbation_space/test_simple_perturbation_space.py b/tests/tools/_perturbation_space/test_simple_perturbation_space.py index 8fd04aeb..bc2559f8 100644 --- a/tests/tools/_perturbation_space/test_simple_perturbation_space.py +++ b/tests/tools/_perturbation_space/test_simple_perturbation_space.py @@ -240,9 +240,17 @@ def test_label_transfer(): adata = AnnData(X) perturbations = np.array(["A", "B", "C"] * 22 + ["unknown"] * 3) adata.obs["perturbation"] = perturbations + + with pytest.raises(ValueError): + ps = pt.tl.PseudobulkSpace() + ps.label_transfer(adata) + sc.pp.neighbors(adata, use_rep="X") sc.tl.umap(adata) - ps = pt.tl.PseudobulkSpace() - ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap") + ps.label_transfer(adata) assert "unknown" not in adata.obs["perturbation"] + assert all(adata.obs["perturbation_transfer_uncertainty"] >= 0) + assert not all(adata.obs["perturbation_transfer_uncertainty"] == 0) + is_known = perturbations != "unknown" + assert all(adata.obs.loc[is_known, "perturbation_transfer_uncertainty"] == 0)