Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add perturbation signature calculation from replicate control cells #695

Merged
merged 10 commits into from
Feb 7, 2025
8 changes: 6 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ Discussions <https://github.com/scverse/pertpy/discussions>
references
```

- Consider citing [scanpy Genome Biology (2018)] along with original {doc}`references <references>`.
- A paper for pertpy is in the works.
## Citation

[Lukas Heumos, Yuge Ji, Lilly May, Tessa Green, Xinyue Zhang, Xichen Wu, Johannes Ostner, Stefan Peidli, Antonia Schumacher, Karin Hrovatin, Michaela Mueller, Faye Chong, Gregor Sturm, Alejandro Tejada, Emma Dann, Mingze Dong, Mojtaba Bahrami, Ilan Gold, Sergei Rybakov, Altana Namsaraeva, Amir Ali Moinfar, Zihe Zheng, Eljas Roellin, Isra Mekki, Chris Sander, Mohammad Lotfollahi, Herbert B. Schiller, Fabian J. Theis
bioRxiv 2024.08.04.606516; doi: https://doi.org/10.1101/2024.08.04.606516](https://www.biorxiv.org/content/10.1101/2024.08.04.606516v1)

Consider citing [scanpy Genome Biology (2018)] along with the original {doc}`references <references>`.

# Indices and tables

Expand Down
114 changes: 65 additions & 49 deletions pertpy/tools/_mixscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def perturbation_signature(
adata: AnnData,
pert_key: str,
control: str,
ref_selection_mode: Literal["nn", "manual"] = "nn",
split_by: str | None = None,
n_neighbors: int = 20,
use_rep: str | None = None,
Expand All @@ -52,14 +53,18 @@ def perturbation_signature(
):
"""Calculate perturbation signature.

For each cell, we identify `n_neighbors` cells from the control pool with the most similar mRNA expression profiles.
The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control
neighbors from the mRNA expression profile of each cell.
cells (selected according to `ref_selection_mode`) from the mRNA expression profile of each cell.
The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same.

Args:
adata: The annotated data object.
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
control: Control category from the `pert_key` column.
control: Name of the control category from the `pert_key` column.
ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`,
the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `manual`,
the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature.
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
the perturbation signature for every replicate separately.
n_neighbors: Number of neighbors from the control to use for the perturbation signature.
Expand Down Expand Up @@ -87,72 +92,84 @@ def perturbation_signature(
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate")
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
"""
if ref_selection_mode not in ["nn", "manual"]:
raise ValueError("ref_selection_mode must be either 'nn' or 'manual'.")
if ref_selection_mode == "manual" and split_by is None:
raise ValueError("split_by must be provided if ref_selection_mode is 'manual'.")

if copy:
adata = adata.copy()

adata.layers["X_pert"] = adata.X.copy()

control_mask = adata.obs[pert_key] == control

if split_by is None:
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
if ref_selection_mode == "manual":
for split in adata.obs[split_by].unique():
split_mask = adata.obs[split_by] == split
control_mask_group = control_mask & split_mask
control_mean_expr = adata.X[control_mask_group].mean(0)
adata.layers["X_pert"][split_mask] = np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - adata.layers["X_pert"][split_mask]
else:
split_obs = adata.obs[split_by]
split_masks = [split_obs == cat for cat in split_obs.unique()]
if split_by is None:
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
else:
split_obs = adata.obs[split_by]
split_masks = [split_obs == cat for cat in split_obs.unique()]

representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
if n_dims is not None and n_dims < representation.shape[1]:
representation = representation[:, :n_dims]
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
if n_dims is not None and n_dims < representation.shape[1]:
representation = representation[:, :n_dims]

for split_mask in split_masks:
control_mask_split = control_mask & split_mask
for split_mask in split_masks:
control_mask_split = control_mask & split_mask

R_split = representation[split_mask]
R_control = representation[np.asarray(control_mask_split)]
R_split = representation[split_mask]
R_control = representation[np.asarray(control_mask_split)]

from pynndescent import NNDescent
from pynndescent import NNDescent

eps = kwargs.pop("epsilon", 0.1)
nn_index = NNDescent(R_control, **kwargs)
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
eps = kwargs.pop("epsilon", 0.1)
nn_index = NNDescent(R_control, **kwargs)
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)

X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])

n_split = split_mask.sum()
n_control = X_control.shape[0]
n_split = split_mask.sum()
n_control = X_control.shape[0]

if batch_size is None:
col_indices = np.ravel(indices)
row_indices = np.repeat(np.arange(n_split), n_neighbors)
if batch_size is None:
col_indices = np.ravel(indices)
row_indices = np.repeat(np.arange(n_split), n_neighbors)

neigh_matrix = csr_matrix(
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
shape=(n_split, n_control),
)
neigh_matrix /= n_neighbors
adata.layers["X_pert"][split_mask] = (
np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
)
else:
is_sparse = issparse(X_control)
split_indices = np.where(split_mask)[0]
for i in range(0, n_split, batch_size):
size = min(i + batch_size, n_split)
select = slice(i, size)
neigh_matrix = csr_matrix(
(np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)),
shape=(n_split, n_control),
)
neigh_matrix /= n_neighbors
adata.layers["X_pert"][split_mask] = (
np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
)
else:
is_sparse = issparse(X_control)
split_indices = np.where(split_mask)[0]
for i in range(0, n_split, batch_size):
size = min(i + batch_size, n_split)
select = slice(i, size)

batch = np.ravel(indices[select])
split_batch = split_indices[select]
batch = np.ravel(indices[select])
split_batch = split_indices[select]

size = size - i
size = size - i

# sparse is very slow
means_batch = X_control[batch]
means_batch = means_batch.toarray() if is_sparse else means_batch
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)
# sparse is very slow
means_batch = X_control[batch]
means_batch = means_batch.toarray() if is_sparse else means_batch
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)

adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch]
adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch]

if copy:
return adata
Expand All @@ -175,8 +192,7 @@ def mixscape(
):
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.

The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same.
The implementation resembles https://satijalab.org/seurat/reference/runmixscape.

Args:
adata: The annotated data object.
Expand Down
11 changes: 11 additions & 0 deletions tests/tools/test_mixscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,14 @@ def test_deterministic_perturbation_signature():
assert np.allclose(
adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0)
)

del adata.layers["X_pert"]

mixscape_identifier = pt.tl.Mixscape()
mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", ref_selection_mode="manual", split_by="group")

assert "X_pert" in adata.layers
assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0)
assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0)
assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0))

Loading