Skip to content

Commit 4c4d6c6

Browse files
authored
Merge branch 'main' into plot_example
2 parents 20176bc + 4b20f6d commit 4c4d6c6

File tree

10 files changed

+143
-114
lines changed

10 files changed

+143
-114
lines changed

docs/tutorials/notebooks

pertpy/tools/_augur.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from math import floor, nan
77
from typing import TYPE_CHECKING, Any, Literal
88

9+
import anndata as ad
910
import matplotlib.pyplot as plt
1011
import numpy as np
1112
import pandas as pd
@@ -140,8 +141,8 @@ def load(
140141
# filter samples according to label
141142
if condition_label is not None and treatment_label is not None:
142143
print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
143-
adata = AnnData.concatenate(
144-
adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
144+
adata = ad.concat(
145+
[adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
145146
)
146147
label_encoder = LabelEncoder()
147148
adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
@@ -235,7 +236,7 @@ def sample(self, adata: AnnData, categorical: bool, subsample_size: int, random_
235236
random_state=random_state,
236237
)
237238
)
238-
subsample = AnnData.concatenate(*label_subsamples, index_unique=None)
239+
subsample = ad.concat([*label_subsamples], index_unique=None)
239240
else:
240241
subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
241242

@@ -414,17 +415,17 @@ def set_scorer(
414415
"""
415416
if multiclass:
416417
return {
417-
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
418-
"auc": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
418+
"augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
419+
"auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
419420
"accuracy": make_scorer(accuracy_score),
420421
"precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
421422
"f1": make_scorer(f1_score, average="macro"),
422423
"recall": make_scorer(recall_score, average="macro"),
423424
}
424425
return (
425426
{
426-
"augur_score": make_scorer(roc_auc_score, needs_proba=True),
427-
"auc": make_scorer(roc_auc_score, needs_proba=True),
427+
"augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
428+
"auc": make_scorer(roc_auc_score, response_method="predict_proba"),
428429
"accuracy": make_scorer(accuracy_score),
429430
"precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
430431
"f1": make_scorer(f1_score, average="binary"),

pertpy/tools/_distances/_distances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
796796
label = ["c"] * X.shape[0] + ["p"] * Y_train.shape[0]
797797
train = np.concatenate([X, Y_train])
798798

799-
reg = LogisticRegression() # TODO dynamically pass this?
799+
reg = LogisticRegression()
800800
reg.fit(train, label)
801801
test_labels = reg.predict_proba(Y_test)
802802
return np.mean(test_labels[:, 1])

pertpy/tools/_mixscape.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import copy
4-
import warnings
54
from collections import OrderedDict
65
from typing import TYPE_CHECKING, Literal
76

@@ -10,7 +9,6 @@
109
import scanpy as sc
1110
import seaborn as sns
1211
from matplotlib import pyplot as pl
13-
from rich import print
1412
from scanpy import get
1513
from scanpy._settings import settings
1614
from scanpy._utils import _check_use_raw, sanitize_anndata
@@ -477,15 +475,6 @@ def _get_perturbation_markers(
477475
return perturbation_markers
478476

479477
def _get_column_indices(self, adata, col_names):
480-
"""Fetches the column indices in X for a given list of column names
481-
482-
Args:
483-
adata: :class:`~anndata.AnnData` object
484-
col_names: Column names to extract the indices for
485-
486-
Returns:
487-
Set of column indices
488-
"""
489478
if isinstance(col_names, str): # pragma: no cover
490479
col_names = [col_names]
491480

@@ -827,7 +816,7 @@ def plot_perturbscore( # pragma: no cover
827816
)
828817
pl.xlabel("Perturbation score", fontsize=16)
829818
pl.ylabel("Cell density", fontsize=16)
830-
pl.title("Density Plot using Seaborn and Matplotlib", fontsize=18)
819+
pl.title("Density", fontsize=18)
831820
pl.legend(title="mixscape class", title_fontsize=14, fontsize=12)
832821
sns.despine()
833822

@@ -855,19 +844,21 @@ def plot_violin( # pragma: no cover
855844
ax: Axes | None = None,
856845
**kwargs,
857846
):
858-
"""Violin plot using mixscape results. Requires `pt.tl.mixscape` to be run first.
847+
"""Violin plot using mixscape results.
848+
849+
Requires `pt.tl.mixscape` to be run first.
859850
860851
Args:
861852
adata: The annotated data object.
862-
target_gene: Target gene name to plot.
853+
target_gene_idents: Target gene name to plot.
863854
keys: Keys for accessing variables of `.var_names` or fields of `.obs`. Default is 'mixscape_class_p_ko'.
864855
groupby: The key of the observation grouping to consider. Default is 'mixscape_class'.
865856
log: Plot on logarithmic axis.
866857
use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
867858
stripplot: Add a stripplot on top of the violin plot.
868859
order: Order in which to show the categories.
869-
xlabel: Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
870-
ylabel: Label of the y axis. If `None` and `groupby` is `None`, defaults to `'value'`.
860+
xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
861+
ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
871862
If `None` and `groubpy` is not `None`, defaults to `keys`.
872863
show: Show the plot, do not return axis.
873864
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
@@ -1072,11 +1063,9 @@ def plot_lda( # pragma: no cover
10721063
.. image:: ../_static/docstring_previews/mixscape_lda.png
10731064
"""
10741065
if mixscape_class not in adata.obs:
1075-
raise ValueError(
1076-
f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first first.'
1077-
)
1066+
raise ValueError(f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first.')
10781067
if lda_key not in adata.uns:
1079-
raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Run the `lda` function first.')
1068+
raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Please run the `lda` function first.')
10801069

10811070
adata_subset = adata[
10821071
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)

pertpy/tools/_perturbation_space/_perturbation_space.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from anndata import AnnData
8+
from pynndescent import NNDescent
89
from rich import print
910

1011
if TYPE_CHECKING:
@@ -25,7 +26,7 @@ def __init__(self):
2526
def compute_control_diff( # type: ignore
2627
self,
2728
adata: AnnData,
28-
target_col: str = "perturbations",
29+
target_col: str = "perturbation",
2930
group_col: str = None,
3031
reference_key: str = "control",
3132
layer_key: str = None,
@@ -147,16 +148,16 @@ def add(
147148
perturbations: Iterable[str],
148149
reference_key: str = "control",
149150
ensure_consistency: bool = False,
150-
target_col: str = "perturbations",
151-
) -> AnnData:
151+
target_col: str = "perturbation",
152+
) -> tuple[AnnData, AnnData] | AnnData:
152153
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
153154
154155
Args:
155156
adata: Anndata object of size n_perts x dim.
156157
perturbations: Perturbations to add.
157158
reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
158159
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
159-
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
160+
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbation'.
160161
161162
Returns:
162163
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
@@ -256,8 +257,8 @@ def subtract(
256257
perturbations: Iterable[str],
257258
reference_key: str = "control",
258259
ensure_consistency: bool = False,
259-
target_col: str = "perturbations",
260-
) -> AnnData:
260+
target_col: str = "perturbation",
261+
) -> tuple[AnnData, AnnData] | AnnData:
261262
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
262263
263264
Args:
@@ -358,3 +359,51 @@ def subtract(
358359
return new_perturbation, adata
359360

360361
return new_perturbation
362+
363+
def label_transfer(
364+
self,
365+
adata: AnnData,
366+
column: str = "perturbation",
367+
target_val: str = "unknown",
368+
n_neighbors: int = 5,
369+
use_rep: str = "X_umap",
370+
) -> None:
371+
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
372+
373+
Args:
374+
adata: The AnnData object containing single-cell data.
375+
column: The column name in AnnData object to perform imputation on. Defaults to "perturbation".
376+
target_val: The target value to impute. Defaults to "unknown".
377+
n_neighbors: Number of neighbors to use for imputation. Defaults to 5.
378+
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored. Defaults to 'X_umap'.
379+
380+
Examples:
381+
>>> import pertpy as pt
382+
>>> import scanpy as sc
383+
>>> import numpy as np
384+
>>> adata = sc.datasets.pbmc68k_reduced()
385+
>>> rng = np.random.default_rng()
386+
>>> adata.obs["perturbation"] = rng.choice(["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01])
387+
>>> sc.pp.neighbors(adata)
388+
>>> sc.tl.umap(adata)
389+
>>> ps = pt.tl.PseudobulkSpace()
390+
>>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
391+
"""
392+
if use_rep not in adata.obsm:
393+
raise ValueError(f"Representation {use_rep} not found in the AnnData object.")
394+
395+
embedding = adata.obsm[use_rep]
396+
397+
nnd = NNDescent(embedding, n_neighbors=n_neighbors)
398+
indices, _ = nnd.query(embedding, k=n_neighbors)
399+
400+
perturbations = np.array(adata.obs[column])
401+
missing_mask = perturbations == target_val
402+
403+
for idx in np.where(missing_mask)[0]:
404+
neighbor_indices = indices[idx]
405+
neighbor_categories = perturbations[neighbor_indices]
406+
most_common = pd.Series(neighbor_categories).mode()[0]
407+
perturbations[idx] = most_common
408+
409+
adata.obs[column] = perturbations

pertpy/tools/_perturbation_space/_simple.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class CentroidSpace(PerturbationSpace):
1515
def compute(
1616
self,
1717
adata: AnnData,
18-
target_col: str = "perturbations",
18+
target_col: str = "perturbation",
1919
layer_key: str = None,
2020
embedding_key: str = "X_umap",
2121
keep_obs: bool = True,
@@ -115,7 +115,7 @@ class PseudobulkSpace(PerturbationSpace):
115115
def compute(
116116
self,
117117
adata: AnnData,
118-
target_col: str = "perturbations",
118+
target_col: str = "perturbation",
119119
layer_key: str = None,
120120
embedding_key: str = None,
121121
**kwargs,
@@ -133,13 +133,13 @@ def compute(
133133
AnnData object with one observation per perturbation.
134134
135135
Examples:
136-
>>> import pertpy as pp
136+
>>> import pertpy as pt
137137
>>> mdata = pt.dt.papalexi_2021()
138138
>>> ps = pt.tl.PseudobulkSpace()
139139
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
140140
"""
141141
if "groups_col" not in kwargs:
142-
kwargs["groups_col"] = "perturbations"
142+
kwargs["groups_col"] = "perturbation"
143143

144144
if layer_key is not None and embedding_key is not None:
145145
raise ValueError("Please, select just either layer or embedding for computation.")
@@ -244,7 +244,7 @@ def compute( # type: ignore
244244
copy: bool = True,
245245
return_object: bool = False,
246246
**kwargs,
247-
) -> tuple[AnnData, object | AnnData]:
247+
) -> tuple[AnnData, object] | AnnData:
248248
"""Computes a clustering using Density-based spatial clustering of applications (DBSCAN).
249249
250250
Args:

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import numpy as np
2+
import pytest
3+
4+
5+
@pytest.fixture
6+
def rng():
7+
return np.random.default_rng()

tests/tools/_distances/test_distances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_distance_axioms(self, adata, distance):
6464
assert all(np.diag(df.values) == 0) # distance to self is 0
6565

6666
# (M2) Positivity
67-
assert len(df) == np.sum(df.values == 0) # distance to other is not 0 (TODO)
67+
assert len(df) == np.sum(df.values == 0) # distance to other is not 0
6868
assert all(df.values.flatten() >= 0) # distance is non-negative
6969

7070
# (M3) Symmetry

0 commit comments

Comments
 (0)