From c51217706bf1d022a6893942ab9a41377a733653 Mon Sep 17 00:00:00 2001 From: zethson Date: Mon, 29 Jan 2024 15:43:36 +0100 Subject: [PATCH 1/5] submodule Signed-off-by: zethson --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index d9fdf0b8..3e8186f7 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit d9fdf0b8c050755990990c940289366886414fb8 +Subproject commit 3e8186f75765c558863aa08155afc34fb0155c0e From 02032df04378c0131110f64ab5bda0064eb9e90e Mon Sep 17 00:00:00 2001 From: zethson Date: Tue, 30 Jan 2024 14:25:10 +0100 Subject: [PATCH 2/5] Remove Future Warnings in Augur Signed-off-by: zethson --- pertpy/tools/_augur.py | 15 ++++++++------- tests/tools/{test_jax_scgen.py => test_scgen.py} | 0 2 files changed, 8 insertions(+), 7 deletions(-) rename tests/tools/{test_jax_scgen.py => test_scgen.py} (100%) diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index c279b7fa..27b43dd3 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -6,6 +6,7 @@ from math import floor, nan from typing import TYPE_CHECKING, Any, Literal +import anndata as ad import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -140,8 +141,8 @@ def load( # filter samples according to label if condition_label is not None and treatment_label is not None: print(f"Filtering samples with {condition_label} and {treatment_label} labels.") - adata = AnnData.concatenate( - adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label] + adata = ad.concat( + [adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]] ) label_encoder = LabelEncoder() adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"]) @@ -235,7 +236,7 @@ def sample(self, adata: AnnData, categorical: bool, subsample_size: int, random_ random_state=random_state, ) ) - subsample = AnnData.concatenate(*label_subsamples, index_unique=None) + subsample = ad.concat([*label_subsamples], index_unique=None) else: subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state) @@ -414,8 +415,8 @@ def set_scorer( """ if multiclass: return { - "augur_score": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True), - "auc": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True), + "augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"), + "auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"), "accuracy": make_scorer(accuracy_score), "precision": make_scorer(precision_score, average="macro", zero_division=zero_division), "f1": make_scorer(f1_score, average="macro"), @@ -423,8 +424,8 @@ def set_scorer( } return ( { - "augur_score": make_scorer(roc_auc_score, needs_proba=True), - "auc": make_scorer(roc_auc_score, needs_proba=True), + "augur_score": make_scorer(roc_auc_score, response_method="predict_proba"), + "auc": make_scorer(roc_auc_score, response_method="predict_proba"), "accuracy": make_scorer(accuracy_score), "precision": make_scorer(precision_score, average="binary", zero_division=zero_division), "f1": make_scorer(f1_score, average="binary"), diff --git a/tests/tools/test_jax_scgen.py b/tests/tools/test_scgen.py similarity index 100% rename from tests/tools/test_jax_scgen.py rename to tests/tools/test_scgen.py From 1e857ad497b41b4bd987060ef9a8d4d786f06785 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Tue, 30 Jan 2024 14:30:07 -0800 Subject: [PATCH 3/5] Add knn imputation (#517) Signed-off-by: zethson --- pertpy/tools/_distances/_distances.py | 2 +- .../_perturbation_space.py | 61 ++++++++- pertpy/tools/_perturbation_space/_simple.py | 8 +- tests/tools/_distances/test_distances.py | 2 +- .../test_simple_perturbation_space.py | 129 ++++++++---------- 5 files changed, 121 insertions(+), 81 deletions(-) diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index 9a2c3620..a33a4b14 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -796,7 +796,7 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0] train = np.concatenate([X, Y_train]) - reg = LogisticRegression() # TODO dynamically pass this? + reg = LogisticRegression() reg.fit(train, label) test_labels = reg.predict_proba(Y_test) return np.mean(test_labels[:, 1]) diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index bb533c22..5da60bfa 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd from anndata import AnnData +from pynndescent import NNDescent from rich import print if TYPE_CHECKING: @@ -25,7 +26,7 @@ def __init__(self): def compute_control_diff( # type: ignore self, adata: AnnData, - target_col: str = "perturbations", + target_col: str = "perturbation", group_col: str = None, reference_key: str = "control", layer_key: str = None, @@ -147,8 +148,8 @@ def add( perturbations: Iterable[str], reference_key: str = "control", ensure_consistency: bool = False, - target_col: str = "perturbations", - ) -> AnnData: + target_col: str = "perturbation", + ) -> tuple[AnnData, AnnData] | AnnData: """Add perturbations linearly. Assumes input of size n_perts x dimensionality Args: @@ -156,7 +157,7 @@ def add( perturbations: Perturbations to add. reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'. ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space. - target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'. + target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'. Returns: Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations. @@ -256,8 +257,8 @@ def subtract( perturbations: Iterable[str], reference_key: str = "control", ensure_consistency: bool = False, - target_col: str = "perturbations", - ) -> AnnData: + target_col: str = "perturbation", + ) -> tuple[AnnData, AnnData] | AnnData: """Subtract perturbations linearly. Assumes input of size n_perts x dimensionality Args: @@ -358,3 +359,51 @@ def subtract( return new_perturbation, adata return new_perturbation + + def knn_impute( + self, + adata: AnnData, + column: str = "perturbation", + target_val: str = "unknown", + n_neighbors: int = 5, + use_rep: str = "X_umap", + ) -> None: + """Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`. + + Args: + adata: The AnnData object containing single-cell data. + column: The column name in AnnData object to perform imputation on. Defaults to "perturbation". + target_val: The target value to impute. Defaults to "unknown". + n_neighbors: Number of neighbors to use for imputation. Defaults to 5. + use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> import numpy as np + >>> adata = sc.datasets.pbmc68k_reduced() + >>> rng = np.random.default_rng() + >>> adata.obs["perturbation"] = rng.choice(["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]) + >>> sc.pp.neighbors(adata) + >>> sc.tl.umap(adata) + >>> ps = pt.tl.PseudobulkSpace() + >>> ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap") + """ + if use_rep not in adata.obsm: + raise ValueError(f"Representation {use_rep} not found in the AnnData object.") + + embedding = adata.obsm[use_rep] + + nnd = NNDescent(embedding, n_neighbors=n_neighbors) + indices, _ = nnd.query(embedding, k=n_neighbors) + + perturbations = np.array(adata.obs[column]) + missing_mask = perturbations == target_val + + for idx in np.where(missing_mask)[0]: + neighbor_indices = indices[idx] + neighbor_categories = perturbations[neighbor_indices] + most_common = pd.Series(neighbor_categories).mode()[0] + perturbations[idx] = most_common + + adata.obs[column] = perturbations diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index 49a2dec3..dc245bb0 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -15,7 +15,7 @@ class CentroidSpace(PerturbationSpace): def compute( self, adata: AnnData, - target_col: str = "perturbations", + target_col: str = "perturbation", layer_key: str = None, embedding_key: str = "X_umap", keep_obs: bool = True, @@ -115,7 +115,7 @@ class PseudobulkSpace(PerturbationSpace): def compute( self, adata: AnnData, - target_col: str = "perturbations", + target_col: str = "perturbation", layer_key: str = None, embedding_key: str = None, **kwargs, @@ -133,13 +133,13 @@ def compute( AnnData object with one observation per perturbation. Examples: - >>> import pertpy as pp + >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") """ if "groups_col" not in kwargs: - kwargs["groups_col"] = "perturbations" + kwargs["groups_col"] = "perturbation" if layer_key is not None and embedding_key is not None: raise ValueError("Please, select just either layer or embedding for computation.") diff --git a/tests/tools/_distances/test_distances.py b/tests/tools/_distances/test_distances.py index 73f367cc..a9ac44e2 100644 --- a/tests/tools/_distances/test_distances.py +++ b/tests/tools/_distances/test_distances.py @@ -64,7 +64,7 @@ def test_distance_axioms(self, adata, distance): assert all(np.diag(df.values) == 0) # distance to self is 0 # (M2) Positivity - assert len(df) == np.sum(df.values == 0) # distance to other is not 0 (TODO) + assert len(df) == np.sum(df.values == 0) # distance to other is not 0 assert all(df.values.flatten() >= 0) # distance is non-negative # (M3) Symmetry diff --git a/tests/tools/_perturbation_space/test_simple_perturbation_space.py b/tests/tools/_perturbation_space/test_simple_perturbation_space.py index 142c8b83..eef0dd9b 100644 --- a/tests/tools/_perturbation_space/test_simple_perturbation_space.py +++ b/tests/tools/_perturbation_space/test_simple_perturbation_space.py @@ -2,15 +2,34 @@ import pandas as pd import pertpy as pt import pytest +import scanpy as sc from anndata import AnnData -def test_differential_response(): - rng = np.random.default_rng() +@pytest.fixture +def rng(): + return np.random.default_rng() + + +@pytest.fixture +def adata(rng): + X = rng.random((69, 50)) + adata = AnnData(X) + perturbations = np.array(["control", "target1", "target2"] * 22 + ["unknown"] * 3) + adata.obs["perturbation"] = perturbations + sc.pp.pca(adata) + sc.pp.neighbors(adata, use_rep="X") + sc.tl.umap(adata) + + return adata + + +@pytest.fixture +def adata_simple(rng): X = rng.random(size=(10, 5)) obs = pd.DataFrame( { - "perturbations": [ + "perturbation": [ "control", "target1", "target1", @@ -26,19 +45,20 @@ def test_differential_response(): ) adata = AnnData(X, obs=obs) - # Compute the differential response + return adata + + +def test_differential_response(adata_simple): ps = pt.tl.PseudobulkSpace() - ps_adata = ps.compute_control_diff(adata, copy=True) + ps_adata = ps.compute_control_diff(adata_simple, target_col="perturbation", copy=True) - # Test that the differential response was computed correctly - expected_diff_matrix = adata.X - adata.X[0, :] + expected_diff_matrix = adata_simple.X - adata_simple.X[0, :] np.testing.assert_allclose(ps_adata.X, expected_diff_matrix, rtol=1e-4) - # Check that the function raises an error if the reference key is not found with pytest.raises(ValueError): ps.compute_control_diff( - adata, - target_col="perturbations", + adata_simple, + target_col="perturbation", reference_key="not_found", layer_key="counts", new_layer_key="counts_diff", @@ -48,59 +68,32 @@ def test_differential_response(): ) -def test_pseudobulk_response(): - rng = np.random.default_rng() - X = rng.random(size=(10, 5)) - obs = pd.DataFrame( - { - "perturbations": [ - "control", - "target1", - "target1", - "target2", - "target2", - "target1", - "target1", - "target2", - "target2", - "target2", - ] - } - ) - adata = AnnData(X, obs=obs) - - # Compute the pseudobulk +def test_pseudobulk_response(adata_simple): ps = pt.tl.PseudobulkSpace() - psadata = ps.compute(adata, mode="mean", min_cells=0, min_counts=0) + psadata = ps.compute(adata_simple, mode="mean", min_cells=0, min_counts=0) - # Test that the pseudobulk response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].X.mean(0) + adata_target1 = adata_simple[adata_simple.obs.perturbation == "target1"].X.mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) - # Test in UMAP space - adata.obsm["X_umap"] = X + adata_simple.obsm["X_umap"] = adata_simple.X - # Compute the pseudobulk ps = pt.tl.PseudobulkSpace() - psadata = ps.compute(adata, embedding_key="X_umap", mode="mean", min_cells=0, min_counts=0) + psadata = ps.compute(adata_simple, embedding_key="X_umap", mode="mean", min_cells=0, min_counts=0) - # Test that the pseudobulk response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].obsm["X_umap"].mean(0) + adata_target1 = adata_simple[adata_simple.obs.perturbation == "target1"].obsm["X_umap"].mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) - # Check that the function raises an error if the layer key is not found with pytest.raises(ValueError): ps.compute( - adata, - target_col="perturbations", + adata_simple, + target_col="perturbation", layer_key="not_found", ) - # Check that the function raises an error if the layer key and embedding key are used at the same time with pytest.raises(ValueError): ps.compute( - adata, - target_col="perturbations", + adata_simple, + target_col="perturbation", embedding_key="not_found", layer_key="not_found", ) @@ -130,39 +123,34 @@ def test_centroid_umap_response(): elif value == "target2": X[i, :] = 30 - obs = pd.DataFrame({"perturbations": pert_index}) + obs = pd.DataFrame({"perturbation": pert_index}) adata = AnnData(X, obs=obs) adata.obsm["X_umap"] = X - # Compute the centroids ps = pt.tl.CentroidSpace() psadata = ps.compute(adata, embedding_key="X_umap") - # Test that the centroids response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].obsm["X_umap"].mean(0) + adata_target1 = adata[adata.obs.perturbation == "target1"].obsm["X_umap"].mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) ps = pt.tl.CentroidSpace() psadata = ps.compute(adata) # if nothing specific, compute with X, and X and X_umap are the same - # Test that the centroids response was computed correctly - adata_target1 = adata[adata.obs.perturbations == "target1"].obsm["X_umap"].mean(0) + adata_target1 = adata[adata.obs.perturbation == "target1"].obsm["X_umap"].mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) - # Check that the function raises an error if the embedding key is not found with pytest.raises(ValueError): ps.compute( adata, - target_col="perturbations", + target_col="perturbation", embedding_key="not_found", ) - # Check that the function raises an error if the layer key and embedding key are used at the same time with pytest.raises(ValueError): ps.compute( adata, - target_col="perturbations", + target_col="perturbation", embedding_key="not_found", layer_key="not_found", ) @@ -174,7 +162,7 @@ def test_linear_operations(): X = rng.random(size=(10, 5)) obs = pd.DataFrame( { - "perturbations": [ + "perturbation": [ "control", "target1", "target1", @@ -191,21 +179,17 @@ def test_linear_operations(): adata = AnnData(X, obs=obs) adata.obsm["X_umap"] = X - # Compute pseudobulk ps = pt.tl.PseudobulkSpace() psadata = ps.compute(adata, mode="mean", min_cells=0, min_counts=0) psadata_umap = ps.compute(adata, mode="mean", min_cells=0, min_counts=0, embedding_key="X_umap") psadata.obsm["X_umap"] = psadata_umap.X - # Perform summation ps_adata, data_compare = ps.add(psadata, perturbations=["target1", "target2"], ensure_consistency=True) - # Test in X test = data_compare["control"].X + data_compare["target1"].X + data_compare["target2"].X np.testing.assert_allclose(test, ps_adata["target1+target2"].X, rtol=1e-4) - # Test in UMAP embedding test = ( data_compare["control"].obsm["X_umap_control_diff"] + data_compare["target1"].obsm["X_umap_control_diff"] @@ -213,26 +197,20 @@ def test_linear_operations(): ) np.testing.assert_allclose(test, ps_adata["target1+target2"].obsm["X_umap"], rtol=1e-4) - # Perform subtraction ps_adata, data_compare = ps.subtract( psadata, reference_key="target1", perturbations=["target2"], ensure_consistency=True ) - # Test in X test = data_compare["target1"].X - data_compare["target2"].X np.testing.assert_allclose(test, ps_adata["target1-target2"].X, rtol=1e-4) - # Operations after control diff, do the results match? ps_adata = ps.compute_control_diff(psadata, copy=True) - # Do summation ps_adata2 = ps.add(ps_adata, perturbations=["target1", "target2"]) - # Test in X test = ps_adata["control"].X + ps_adata["target1"].X + ps_adata["target2"].X np.testing.assert_allclose(test, ps_adata2["target1+target2"].X, rtol=1e-4) - # Do subtract ps_adata2 = ps.subtract(ps_adata, reference_key="target1", perturbations=["target1"]) ps_vector = ps_adata2["target1-target1"].X np.testing.assert_allclose(ps_adata2["control"].X, ps_adata2["target1-target1"].X, rtol=1e-4) @@ -245,7 +223,6 @@ def test_linear_operations(): # Compare process data vs pseudobulk before, should be the same np.testing.assert_allclose(ps_inner_vector, ps_vector, rtol=1e-4) - # Check result in UMAP np.testing.assert_allclose( data_compare["control"].obsm["X_umap_control_diff"], ps_adata2["target1-target1"].obsm["X_umap"], rtol=1e-4 ) @@ -263,3 +240,17 @@ def test_linear_operations(): ps_adata, perturbations=["target1", "target3"], ) + + +def test_knn_impute(): + rng = np.random.default_rng() + X = rng.standard_normal((69, 50)) + adata = AnnData(X) + perturbations = np.array(["A", "B", "C"] * 22 + ["unknown"] * 3) + adata.obs["perturbation"] = perturbations + sc.pp.neighbors(adata, use_rep="X") + sc.tl.umap(adata) + + ps = pt.tl.PseudobulkSpace() + ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap") + assert "unknown" not in adata.obs["perturbation"] From 604dfeffdfcacc73701c7a72a2940bb1fe779de7 Mon Sep 17 00:00:00 2001 From: zethson Date: Wed, 31 Jan 2024 22:04:33 +0100 Subject: [PATCH 4/5] Rename knn_impute to label_transfer Signed-off-by: zethson --- .../_perturbation_space/_perturbation_space.py | 4 ++-- pertpy/tools/_perturbation_space/_simple.py | 2 +- tests/conftest.py | 7 +++++++ .../test_simple_perturbation_space.py | 14 +++----------- 4 files changed, 13 insertions(+), 14 deletions(-) create mode 100644 tests/conftest.py diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index 5da60bfa..8f832589 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -360,7 +360,7 @@ def subtract( return new_perturbation - def knn_impute( + def label_transfer( self, adata: AnnData, column: str = "perturbation", @@ -387,7 +387,7 @@ def knn_impute( >>> sc.pp.neighbors(adata) >>> sc.tl.umap(adata) >>> ps = pt.tl.PseudobulkSpace() - >>> ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap") + >>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap") """ if use_rep not in adata.obsm: raise ValueError(f"Representation {use_rep} not found in the AnnData object.") diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index dc245bb0..ae4f467a 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -244,7 +244,7 @@ def compute( # type: ignore copy: bool = True, return_object: bool = False, **kwargs, - ) -> tuple[AnnData, object | AnnData]: + ) -> tuple[AnnData, object] | AnnData: """Computes a clustering using Density-based spatial clustering of applications (DBSCAN). Args: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..eee0c858 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import numpy as np +import pytest + + +@pytest.fixture +def rng(): + return np.random.default_rng() diff --git a/tests/tools/_perturbation_space/test_simple_perturbation_space.py b/tests/tools/_perturbation_space/test_simple_perturbation_space.py index eef0dd9b..8fd04aeb 100644 --- a/tests/tools/_perturbation_space/test_simple_perturbation_space.py +++ b/tests/tools/_perturbation_space/test_simple_perturbation_space.py @@ -6,11 +6,6 @@ from anndata import AnnData -@pytest.fixture -def rng(): - return np.random.default_rng() - - @pytest.fixture def adata(rng): X = rng.random((69, 50)) @@ -135,7 +130,7 @@ def test_centroid_umap_response(): np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) ps = pt.tl.CentroidSpace() - psadata = ps.compute(adata) # if nothing specific, compute with X, and X and X_umap are the same + psadata = ps.compute(adata) adata_target1 = adata[adata.obs.perturbation == "target1"].obsm["X_umap"].mean(0) np.testing.assert_allclose(adata_target1, psadata["target1"].X[0], rtol=1e-4) @@ -220,21 +215,18 @@ def test_linear_operations(): ) ps_inner_vector = ps_adata2["target1-target1"].X - # Compare process data vs pseudobulk before, should be the same np.testing.assert_allclose(ps_inner_vector, ps_vector, rtol=1e-4) np.testing.assert_allclose( data_compare["control"].obsm["X_umap_control_diff"], ps_adata2["target1-target1"].obsm["X_umap"], rtol=1e-4 ) - # Check that the function raises an error if the perturbation is not found with pytest.raises(ValueError): ps.add( ps_adata, perturbations=["target1", "target3"], ) - # Check that the function raises an error if some key is not found with pytest.raises(ValueError): ps.add( ps_adata, @@ -242,7 +234,7 @@ def test_linear_operations(): ) -def test_knn_impute(): +def test_label_transfer(): rng = np.random.default_rng() X = rng.standard_normal((69, 50)) adata = AnnData(X) @@ -252,5 +244,5 @@ def test_knn_impute(): sc.tl.umap(adata) ps = pt.tl.PseudobulkSpace() - ps.knn_impute(adata, n_neighbors=5, use_rep="X_umap") + ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap") assert "unknown" not in adata.obs["perturbation"] From 4b20f6d1084b1a7489473550c1c8598904435d18 Mon Sep 17 00:00:00 2001 From: zethson Date: Wed, 31 Jan 2024 22:11:49 +0100 Subject: [PATCH 5/5] Fix typos Signed-off-by: zethson --- pertpy/tools/_mixscape.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 3e741d7c..32d07729 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import warnings from collections import OrderedDict from typing import TYPE_CHECKING, Literal @@ -10,7 +9,6 @@ import scanpy as sc import seaborn as sns from matplotlib import pyplot as pl -from rich import print from scanpy import get from scanpy._settings import settings from scanpy._utils import _check_use_raw, sanitize_anndata @@ -477,15 +475,6 @@ def _get_perturbation_markers( return perturbation_markers def _get_column_indices(self, adata, col_names): - """Fetches the column indices in X for a given list of column names - - Args: - adata: :class:`~anndata.AnnData` object - col_names: Column names to extract the indices for - - Returns: - Set of column indices - """ if isinstance(col_names, str): # pragma: no cover col_names = [col_names] @@ -815,7 +804,7 @@ def plot_perturbscore( # pragma: no cover ) pl.xlabel("Perturbation score", fontsize=16) pl.ylabel("Cell density", fontsize=16) - pl.title("Density Plot using Seaborn and Matplotlib", fontsize=18) + pl.title("Density", fontsize=18) pl.legend(title="mixscape class", title_fontsize=14, fontsize=12) sns.despine() @@ -843,19 +832,21 @@ def plot_violin( # pragma: no cover ax: Axes | None = None, **kwargs, ): - """Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first. + """Violin plot using mixscape results. + + Requires `pt.tl.mixscape` to be run first. Args: adata: The annotated data object. - target_gene: Target gene name to plot. + target_gene_idents: Target gene name to plot. keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'. groupby: The key of the observation grouping to consider. Default is 'mixscape_class'. log: Plot on logarithmic axis. use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present. stripplot: Add a stripplot on top of the violin plot. order: Order in which to show the categories. - xlabel: Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown. - ylabel: Label of the y axis. If `None` and `groupby` is `None`, defaults to `'value'`. + xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown. + ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`. If `None` and `groubpy` is not `None`, defaults to `keys`. show: Show the plot, do not return axis. save: If `True` or a `str`, save the figure. A string is appended to the default filename. @@ -1054,11 +1045,9 @@ def plot_lda( # pragma: no cover >>> ms_pt.plot_lda(adata=mdata['rna'], control='NT') """ if mixscape_class not in adata.obs: - raise ValueError( - f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first first.' - ) + raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.') if lda_key not in adata.uns: - raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Run the `lda` function first.') + raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.') adata_subset = adata[ (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)