diff --git a/docs/index.md b/docs/index.md index 6f530605..58dd23b4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -54,8 +54,12 @@ Discussions references ``` -- Consider citing [scanpy Genome Biology (2018)] along with original {doc}`references `. -- A paper for pertpy is in the works. +## Citation + +[Lukas Heumos, Yuge Ji, Lilly May, Tessa Green, Xinyue Zhang, Xichen Wu, Johannes Ostner, Stefan Peidli, Antonia Schumacher, Karin Hrovatin, Michaela Mueller, Faye Chong, Gregor Sturm, Alejandro Tejada, Emma Dann, Mingze Dong, Mojtaba Bahrami, Ilan Gold, Sergei Rybakov, Altana Namsaraeva, Amir Ali Moinfar, Zihe Zheng, Eljas Roellin, Isra Mekki, Chris Sander, Mohammad Lotfollahi, Herbert B. Schiller, Fabian J. Theis +bioRxiv 2024.08.04.606516; doi: https://doi.org/10.1101/2024.08.04.606516](https://www.biorxiv.org/content/10.1101/2024.08.04.606516v1) + +Consider citing [scanpy Genome Biology (2018)] along with the original {doc}`references `. # Indices and tables diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index e906a034..625a9bac 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -41,6 +41,7 @@ def perturbation_signature( adata: AnnData, pert_key: str, control: str, + ref_selection_mode: Literal["nn", "split_by"] = "nn", split_by: str | None = None, n_neighbors: int = 20, use_rep: str | None = None, @@ -52,14 +53,18 @@ def perturbation_signature( ): """Calculate perturbation signature. - For each cell, we identify `n_neighbors` cells from the control pool with the most similar mRNA expression profiles. The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control - neighbors from the mRNA expression profile of each cell. + cells (selected according to `ref_selection_mode`) from the mRNA expression profile of each cell. + The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the + perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same. Args: adata: The annotated data object. pert_key: The column of `.obs` with perturbation categories, should also contain `control`. - control: Control category from the `pert_key` column. + control: Name of the control category from the `pert_key` column. + ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`, + the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`, + the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature. split_by: Provide the column `.obs` if multiple biological replicates exist to calculate the perturbation signature for every replicate separately. n_neighbors: Number of neighbors from the control to use for the perturbation signature. @@ -87,8 +92,13 @@ def perturbation_signature( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") """ + if ref_selection_mode not in ["nn", "split_by"]: + raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.") + if ref_selection_mode == "split_by" and split_by is None: + raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.") + if copy: adata = adata.copy() @@ -96,63 +106,70 @@ def perturbation_signature( control_mask = adata.obs[pert_key] == control - if split_by is None: - split_masks = [np.full(adata.n_obs, True, dtype=bool)] + if ref_selection_mode == "split_by": + for split in adata.obs[split_by].unique(): + split_mask = adata.obs[split_by] == split + control_mask_group = control_mask & split_mask + control_mean_expr = adata.X[control_mask_group].mean(0) + adata.layers["X_pert"][split_mask] = np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - adata.layers["X_pert"][split_mask] else: - split_obs = adata.obs[split_by] - split_masks = [split_obs == cat for cat in split_obs.unique()] + if split_by is None: + split_masks = [np.full(adata.n_obs, True, dtype=bool)] + else: + split_obs = adata.obs[split_by] + split_masks = [split_obs == cat for cat in split_obs.unique()] - representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) - if n_dims is not None and n_dims < representation.shape[1]: - representation = representation[:, :n_dims] + representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + if n_dims is not None and n_dims < representation.shape[1]: + representation = representation[:, :n_dims] - for split_mask in split_masks: - control_mask_split = control_mask & split_mask + for split_mask in split_masks: + control_mask_split = control_mask & split_mask - R_split = representation[split_mask] - R_control = representation[np.asarray(control_mask_split)] + R_split = representation[split_mask] + R_control = representation[np.asarray(control_mask_split)] - from pynndescent import NNDescent + from pynndescent import NNDescent - eps = kwargs.pop("epsilon", 0.1) - nn_index = NNDescent(R_control, **kwargs) - indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps) + eps = kwargs.pop("epsilon", 0.1) + nn_index = NNDescent(R_control, **kwargs) + indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps) - X_control = np.expm1(adata.X[np.asarray(control_mask_split)]) + X_control = np.expm1(adata.X[np.asarray(control_mask_split)]) - n_split = split_mask.sum() - n_control = X_control.shape[0] + n_split = split_mask.sum() + n_control = X_control.shape[0] - if batch_size is None: - col_indices = np.ravel(indices) - row_indices = np.repeat(np.arange(n_split), n_neighbors) + if batch_size is None: + col_indices = np.ravel(indices) + row_indices = np.repeat(np.arange(n_split), n_neighbors) - neigh_matrix = csr_matrix( - (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), - shape=(n_split, n_control), - ) - neigh_matrix /= n_neighbors - adata.layers["X_pert"][split_mask] = ( - np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] - ) - else: - is_sparse = issparse(X_control) - split_indices = np.where(split_mask)[0] - for i in range(0, n_split, batch_size): - size = min(i + batch_size, n_split) - select = slice(i, size) + neigh_matrix = csr_matrix( + (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), + shape=(n_split, n_control), + ) + neigh_matrix /= n_neighbors + adata.layers["X_pert"][split_mask] = ( + np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] + ) + else: + is_sparse = issparse(X_control) + split_indices = np.where(split_mask)[0] + for i in range(0, n_split, batch_size): + size = min(i + batch_size, n_split) + select = slice(i, size) - batch = np.ravel(indices[select]) - split_batch = split_indices[select] + batch = np.ravel(indices[select]) + split_batch = split_indices[select] - size = size - i + size = size - i - # sparse is very slow - means_batch = X_control[batch] - means_batch = means_batch.toarray() if is_sparse else means_batch - means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) + # sparse is very slow + means_batch = X_control[batch] + means_batch = means_batch.toarray() if is_sparse else means_batch + means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) - adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] + adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] if copy: return adata @@ -175,8 +192,7 @@ def mixscape( ): """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations. - The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the - perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same. + The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Args: adata: The annotated data object. diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py index 83616729..adff733f 100644 --- a/tests/tools/test_mixscape.py +++ b/tests/tools/test_mixscape.py @@ -145,3 +145,14 @@ def test_deterministic_perturbation_signature(): assert np.allclose( adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0) ) + + del adata.layers["X_pert"] + + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", ref_selection_mode="split_by", split_by="group") + + assert "X_pert" in adata.layers + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0)) +