From 0c08bf5104339a997b5d3f86c206633053a244b1 Mon Sep 17 00:00:00 2001 From: Lilly Date: Mon, 9 Dec 2024 16:34:27 +0100 Subject: [PATCH 1/8] Added Mixscape seeds and test --- pertpy/tools/_mixscape.py | 16 +++++++++++---- tests/tools/test_mixscape.py | 38 ++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index fddc54c2..30f466b9 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -44,6 +44,7 @@ def perturbation_signature( split_by: str | None = None, n_neighbors: int = 20, use_rep: str | None = None, + n_dims: int | None = 15, n_pcs: int | None = None, batch_size: int | None = None, copy: bool = False, @@ -66,7 +67,8 @@ def perturbation_signature( If `None`, the representation is chosen automatically: For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used. If 'X_pca' is not present, it’s computed with default parameters. - n_pcs: Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`. + n_dims: Number of dimensions to use from the representation to calculate the perturbation signature. If `None`, use all dimensions. + n_pcs: If PCA representation is used, the number of principal components to compute. If `n_pcs==0` use `.X` if `use_rep is None`. batch_size: Size of batch to calculate the perturbation signature. If 'None', the perturbation signature is calcuated in the full mode, requiring more memory. The batched mode is very inefficient for sparse data. @@ -99,6 +101,8 @@ def perturbation_signature( split_masks = [split_obs == cat for cat in split_obs.unique()] representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + if n_dims is not None and n_dims < representation.shape[1]: + representation = representation[:, :n_dims] for split_mask in split_masks: control_mask_split = control_mask & split_mask @@ -126,7 +130,7 @@ def perturbation_signature( shape=(n_split, n_control), ) neigh_matrix /= n_neighbors - adata.layers["X_pert"][split_mask] -= np.log1p(neigh_matrix @ X_control) + adata.layers["X_pert"][split_mask] = np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] else: is_sparse = issparse(X_control) split_indices = np.where(split_mask)[0] @@ -144,7 +148,7 @@ def perturbation_signature( means_batch = means_batch.toarray() if is_sparse else means_batch means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) - adata.layers["X_pert"][split_batch] -= np.log1p(means_batch) + adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] if copy: return adata @@ -162,11 +166,13 @@ def mixscape( split_by: str | None = None, pval_cutoff: float | None = 5e-2, perturbation_type: str | None = "KO", + random_state: int | None = 0, copy: bool | None = False, ): """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations. - The implementation resembles https://satijalab.org/seurat/reference/runmixscape + The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the + perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same. Args: adata: The annotated data object. @@ -181,6 +187,7 @@ def mixscape( the perturbation signature for every replicate separately. pval_cutoff: P-value cut-off for selection of significantly DE genes. perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. + random_state: Random seed for the GaussianMixture model. copy: Determines whether a copy of the `adata` is returned. Returns: @@ -293,6 +300,7 @@ def mixscape( covariance_type="spherical", means_init=means_init, precisions_init=precisions_init, + random_state=random_state, ).fit(np.asarray(pvec).reshape(-1, 1)) probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1)) lik_ratio = probabilities[:, 0] / probabilities[:, 1] diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py index ec9387c1..8f160669 100644 --- a/tests/tools/test_mixscape.py +++ b/tests/tools/test_mixscape.py @@ -98,3 +98,41 @@ def test_lda(adata): mixscape_identifier.lda(adata=adata, labels="gene_target", control="NT") assert "mixscape_lda" in adata.uns + +def test_deterministic_perturbation_signature(): + n_genes = 5 + n_cells_per_class = 50 + cell_classes = ["NT", "KO", "NP"] + groups = ["Group1", "Group2"] + + cell_classes_array = np.repeat(cell_classes, n_cells_per_class) + groups_array = np.tile(np.repeat(groups, n_cells_per_class // 2), len(cell_classes)) + obs = pd.DataFrame({"cell_class": cell_classes_array, "group": groups_array, + "perturbation": ["control" if cell_class == "NT" else "pert1" for cell_class in cell_classes_array]}) + + data = np.zeros((len(obs), n_genes)) + pert_effect = np.random.uniform(-1, 1, size=(n_cells_per_class//len(groups), n_genes)) + for group_idx, group in enumerate(groups): + baseline_expr = 2 if group == "Group1" else 10 + group_mask = obs["group"] == group + + nt_mask = (obs["cell_class"] == "NT") & group_mask + data[nt_mask] = baseline_expr + + ko_mask = (obs["cell_class"] == "KO") & group_mask + data[ko_mask] = baseline_expr + pert_effect + + np_mask = (obs["cell_class"] == "NP") & group_mask + data[np_mask] = baseline_expr + + var = pd.DataFrame(index=[f"Gene{i + 1}" for i in range(n_genes)]) + adata = anndata.AnnData(X=data, obs=obs, var=var) + + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", n_neighbors=5, split_by="group") + + assert "X_pert" in adata.layers + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0)) + From 97083b4ea34ca429bfd6fe414e6a6c51d70ae87a Mon Sep 17 00:00:00 2001 From: Lilly Date: Sat, 4 Jan 2025 17:44:18 +0100 Subject: [PATCH 2/8] Added manual reference cell selection method --- pertpy/tools/_mixscape.py | 105 ++++++++++++++++++++--------------- tests/tools/test_mixscape.py | 11 ++++ 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 30f466b9..04ebf7ca 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -41,7 +41,9 @@ def perturbation_signature( adata: AnnData, pert_key: str, control: str, + ref_selection_mode: Literal["neighest_neighbors", "manual"] = "neighest_neighbors", split_by: str | None = None, + group_key: str | None = None, n_neighbors: int = 20, use_rep: str | None = None, n_dims: int | None = 15, @@ -52,14 +54,16 @@ def perturbation_signature( ): """Calculate perturbation signature. - For each cell, we identify `n_neighbors` cells from the control pool with the most similar mRNA expression profiles. The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control - neighbors from the mRNA expression profile of each cell. + cells (selected according to `ref_selection_mode`) from the mRNA expression profile of each cell. Args: adata: The annotated data object. pert_key: The column of `.obs` with perturbation categories, should also contain `control`. - control: Control category from the `pert_key` column. + control: Name of the control category from the `pert_key` column. + ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `neighest_neighbors`, + the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `manual`, + the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature. split_by: Provide the column `.obs` if multiple biological replicates exist to calculate the perturbation signature for every replicate separately. n_neighbors: Number of neighbors from the control to use for the perturbation signature. @@ -87,6 +91,11 @@ def perturbation_signature( >>> ms_pt = pt.tl.Mixscape() >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") """ + if ref_selection_mode not in ["neighest_neighbors", "manual"]: + raise ValueError("ref_selection_mode must be either 'neighest_neighbors' or 'manual'.") + if ref_selection_mode == "manual" and split_by is None: + raise ValueError("split_by must be provided if ref_selection_mode is 'manual'.") + if copy: adata = adata.copy() @@ -94,61 +103,69 @@ def perturbation_signature( control_mask = adata.obs[pert_key] == control - if split_by is None: - split_masks = [np.full(adata.n_obs, True, dtype=bool)] + if ref_selection_mode == "manual": + for split in adata.obs[split_by].unique(): + split_mask = adata.obs[split_by] == split + control_mask_group = control_mask & split_mask + control_mean_expr = adata.X[control_mask_group].mean(0) + adata.layers["X_pert"][split_mask] = np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - adata.layers["X_pert"][split_mask] + else: - split_obs = adata.obs[split_by] - split_masks = [split_obs == cat for cat in split_obs.unique()] + if split_by is None: + split_masks = [np.full(adata.n_obs, True, dtype=bool)] + else: + split_obs = adata.obs[split_by] + split_masks = [split_obs == cat for cat in split_obs.unique()] - representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) - if n_dims is not None and n_dims < representation.shape[1]: - representation = representation[:, :n_dims] + representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + if n_dims is not None and n_dims < representation.shape[1]: + representation = representation[:, :n_dims] - for split_mask in split_masks: - control_mask_split = control_mask & split_mask + for split_mask in split_masks: + control_mask_split = control_mask & split_mask - R_split = representation[split_mask] - R_control = representation[control_mask_split] + R_split = representation[split_mask] + R_control = representation[control_mask_split] - from pynndescent import NNDescent + from pynndescent import NNDescent - eps = kwargs.pop("epsilon", 0.1) - nn_index = NNDescent(R_control, **kwargs) - indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps) + eps = kwargs.pop("epsilon", 0.1) + nn_index = NNDescent(R_control, **kwargs) + indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps) - X_control = np.expm1(adata.X[control_mask_split]) + X_control = np.expm1(adata.X[control_mask_split]) - n_split = split_mask.sum() - n_control = X_control.shape[0] + n_split = split_mask.sum() + n_control = X_control.shape[0] - if batch_size is None: - col_indices = np.ravel(indices) - row_indices = np.repeat(np.arange(n_split), n_neighbors) + if batch_size is None: + col_indices = np.ravel(indices) + row_indices = np.repeat(np.arange(n_split), n_neighbors) - neigh_matrix = csr_matrix( - (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), - shape=(n_split, n_control), - ) - neigh_matrix /= n_neighbors - adata.layers["X_pert"][split_mask] = np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] - else: - is_sparse = issparse(X_control) - split_indices = np.where(split_mask)[0] - for i in range(0, n_split, batch_size): - size = min(i + batch_size, n_split) - select = slice(i, size) + neigh_matrix = csr_matrix( + (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), + shape=(n_split, n_control), + ) + neigh_matrix /= n_neighbors + adata.layers["X_pert"][split_mask] = np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] + else: + is_sparse = issparse(X_control) + split_indices = np.where(split_mask)[0] + for i in range(0, n_split, batch_size): + size = min(i + batch_size, n_split) + select = slice(i, size) - batch = np.ravel(indices[select]) - split_batch = split_indices[select] + batch = np.ravel(indices[select]) + split_batch = split_indices[select] - size = size - i + size = size - i - # sparse is very slow - means_batch = X_control[batch] - means_batch = means_batch.toarray() if is_sparse else means_batch - means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) + # sparse is very slow + means_batch = X_control[batch] + means_batch = means_batch.toarray() if is_sparse else means_batch + means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) - adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] + adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] if copy: return adata diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py index 8f160669..c18fff12 100644 --- a/tests/tools/test_mixscape.py +++ b/tests/tools/test_mixscape.py @@ -99,6 +99,7 @@ def test_lda(adata): assert "mixscape_lda" in adata.uns + def test_deterministic_perturbation_signature(): n_genes = 5 n_cells_per_class = 50 @@ -136,3 +137,13 @@ def test_deterministic_perturbation_signature(): assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0)) + del adata.layers["X_pert"] + + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", ref_selection_mode="manual", split_by="group") + + assert "X_pert" in adata.layers + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0)) + From ff8c894b201d424b35baff5a099f3c2bfcf96c06 Mon Sep 17 00:00:00 2001 From: Lilly Date: Fri, 10 Jan 2025 15:26:09 +0100 Subject: [PATCH 3/8] Updated docs method descriptions --- pertpy/tools/_mixscape.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 04ebf7ca..36109ee8 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -41,9 +41,8 @@ def perturbation_signature( adata: AnnData, pert_key: str, control: str, - ref_selection_mode: Literal["neighest_neighbors", "manual"] = "neighest_neighbors", + ref_selection_mode: Literal["nn", "manual"] = "nn", split_by: str | None = None, - group_key: str | None = None, n_neighbors: int = 20, use_rep: str | None = None, n_dims: int | None = 15, @@ -56,12 +55,14 @@ def perturbation_signature( The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control cells (selected according to `ref_selection_mode`) from the mRNA expression profile of each cell. + The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the + perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same. Args: adata: The annotated data object. pert_key: The column of `.obs` with perturbation categories, should also contain `control`. control: Name of the control category from the `pert_key` column. - ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `neighest_neighbors`, + ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`, the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `manual`, the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature. split_by: Provide the column `.obs` if multiple biological replicates exist to calculate @@ -89,10 +90,10 @@ def perturbation_signature( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") """ - if ref_selection_mode not in ["neighest_neighbors", "manual"]: - raise ValueError("ref_selection_mode must be either 'neighest_neighbors' or 'manual'.") + if ref_selection_mode not in ["nn", "manual"]: + raise ValueError("ref_selection_mode must be either 'nn' or 'manual'.") if ref_selection_mode == "manual" and split_by is None: raise ValueError("split_by must be provided if ref_selection_mode is 'manual'.") @@ -188,8 +189,7 @@ def mixscape( ): """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations. - The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the - perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same. + The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Args: adata: The annotated data object. From 4495864de1cacf0df2fe24fbff67c2ea1092d0f2 Mon Sep 17 00:00:00 2001 From: Lilly Date: Fri, 10 Jan 2025 15:54:29 +0100 Subject: [PATCH 4/8] Merged main into branch --- pertpy/tools/_mixscape.py | 21 +++++++++---------- tests/tools/test_mixscape.py | 40 ------------------------------------ 2 files changed, 10 insertions(+), 51 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 48512595..3c6df1a5 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -110,7 +110,6 @@ def perturbation_signature( control_mask_group = control_mask & split_mask control_mean_expr = adata.X[control_mask_group].mean(0) adata.layers["X_pert"][split_mask] = np.repeat(control_mean_expr.reshape(1, -1), split_mask.sum(), axis=0) - adata.layers["X_pert"][split_mask] - else: if split_by is None: split_masks = [np.full(adata.n_obs, True, dtype=bool)] @@ -125,8 +124,8 @@ def perturbation_signature( for split_mask in split_masks: control_mask_split = control_mask & split_mask - R_split = representation[split_mask] - R_control = representation[np.asarray(control_mask_split)] + R_split = representation[split_mask] + R_control = representation[np.asarray(control_mask_split)] from pynndescent import NNDescent @@ -134,7 +133,7 @@ def perturbation_signature( nn_index = NNDescent(R_control, **kwargs) indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps) - X_control = np.expm1(adata.X[np.asarray(control_mask_split)]) + X_control = np.expm1(adata.X[np.asarray(control_mask_split)]) n_split = split_mask.sum() n_control = X_control.shape[0] @@ -156,15 +155,15 @@ def perturbation_signature( size = min(i + batch_size, n_split) select = slice(i, size) - batch = np.ravel(indices[select]) - split_batch = split_indices[select] + batch = np.ravel(indices[select]) + split_batch = split_indices[select] - size = size - i + size = size - i - # sparse is very slow - means_batch = X_control[batch] - means_batch = means_batch.toarray() if is_sparse else means_batch - means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) + # sparse is very slow + means_batch = X_control[batch] + means_batch = means_batch.toarray() if is_sparse else means_batch + means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py index 2f1defe0..c43c5ebd 100644 --- a/tests/tools/test_mixscape.py +++ b/tests/tools/test_mixscape.py @@ -98,46 +98,6 @@ def test_lda(adata): mixscape_identifier.lda(adata=adata, labels="gene_target", control="NT") assert "mixscape_lda" in adata.uns - -def test_deterministic_perturbation_signature(): - n_genes = 5 - n_cells_per_class = 50 - cell_classes = ["NT", "KO", "NP"] - groups = ["Group1", "Group2"] - - cell_classes_array = np.repeat(cell_classes, n_cells_per_class) - groups_array = np.tile(np.repeat(groups, n_cells_per_class // 2), len(cell_classes)) - obs = pd.DataFrame({"cell_class": cell_classes_array, "group": groups_array, - "perturbation": ["control" if cell_class == "NT" else "pert1" for cell_class in cell_classes_array]}) - - data = np.zeros((len(obs), n_genes)) - pert_effect = np.random.uniform(-1, 1, size=(n_cells_per_class//len(groups), n_genes)) - for group_idx, group in enumerate(groups): - baseline_expr = 2 if group == "Group1" else 10 - group_mask = obs["group"] == group - - nt_mask = (obs["cell_class"] == "NT") & group_mask - data[nt_mask] = baseline_expr - - ko_mask = (obs["cell_class"] == "KO") & group_mask - data[ko_mask] = baseline_expr + pert_effect - - np_mask = (obs["cell_class"] == "NP") & group_mask - data[np_mask] = baseline_expr - - var = pd.DataFrame(index=[f"Gene{i + 1}" for i in range(n_genes)]) - adata = anndata.AnnData(X=data, obs=obs, var=var) - - mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", n_neighbors=5, split_by="group") - - assert "X_pert" in adata.layers - assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) - assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) - assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0)) - - - def test_deterministic_perturbation_signature(): n_genes = 5 n_cells_per_class = 50 From 0087dcdbf0fd6e99eef2b7fa1128923acdc5febe Mon Sep 17 00:00:00 2001 From: Lilly Date: Fri, 10 Jan 2025 16:08:26 +0100 Subject: [PATCH 5/8] Fixed merge conflict --- pertpy/tools/_mixscape.py | 46 +++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 3c6df1a5..adc363e2 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -117,9 +117,9 @@ def perturbation_signature( split_obs = adata.obs[split_by] split_masks = [split_obs == cat for cat in split_obs.unique()] - representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) - if n_dims is not None and n_dims < representation.shape[1]: - representation = representation[:, :n_dims] + representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + if n_dims is not None and n_dims < representation.shape[1]: + representation = representation[:, :n_dims] for split_mask in split_masks: control_mask_split = control_mask & split_mask @@ -142,30 +142,30 @@ def perturbation_signature( col_indices = np.ravel(indices) row_indices = np.repeat(np.arange(n_split), n_neighbors) - neigh_matrix = csr_matrix( - (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), - shape=(n_split, n_control), - ) - neigh_matrix /= n_neighbors - adata.layers["X_pert"][split_mask] = np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] - else: - is_sparse = issparse(X_control) - split_indices = np.where(split_mask)[0] - for i in range(0, n_split, batch_size): - size = min(i + batch_size, n_split) - select = slice(i, size) + neigh_matrix = csr_matrix( + (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), + shape=(n_split, n_control), + ) + neigh_matrix /= n_neighbors + adata.layers["X_pert"][split_mask] = np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] + else: + is_sparse = issparse(X_control) + split_indices = np.where(split_mask)[0] + for i in range(0, n_split, batch_size): + size = min(i + batch_size, n_split) + select = slice(i, size) - batch = np.ravel(indices[select]) - split_batch = split_indices[select] + batch = np.ravel(indices[select]) + split_batch = split_indices[select] - size = size - i + size = size - i - # sparse is very slow - means_batch = X_control[batch] - means_batch = means_batch.toarray() if is_sparse else means_batch - means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) + # sparse is very slow + means_batch = X_control[batch] + means_batch = means_batch.toarray() if is_sparse else means_batch + means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) - adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] + adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] if copy: return adata From 232f24394e514a98b0555519bf31f78d0577576e Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 5 Feb 2025 17:12:39 +0100 Subject: [PATCH 6/8] Merged main --- pertpy/tools/_mixscape.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 24904ba3..6a5ba010 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -144,20 +144,20 @@ def perturbation_signature( col_indices = np.ravel(indices) row_indices = np.repeat(np.arange(n_split), n_neighbors) - neigh_matrix = csr_matrix( - (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), - shape=(n_split, n_control), - ) - neigh_matrix /= n_neighbors - adata.layers["X_pert"][split_mask] = ( - np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] - ) - else: - is_sparse = issparse(X_control) - split_indices = np.where(split_mask)[0] - for i in range(0, n_split, batch_size): - size = min(i + batch_size, n_split) - select = slice(i, size) + neigh_matrix = csr_matrix( + (np.ones_like(col_indices, dtype=np.float64), (row_indices, col_indices)), + shape=(n_split, n_control), + ) + neigh_matrix /= n_neighbors + adata.layers["X_pert"][split_mask] = ( + np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] + ) + else: + is_sparse = issparse(X_control) + split_indices = np.where(split_mask)[0] + for i in range(0, n_split, batch_size): + size = min(i + batch_size, n_split) + select = slice(i, size) batch = np.ravel(indices[select]) split_batch = split_indices[select] From 9923dc8b65adcdf2c4c32b8fd6acfc463030668e Mon Sep 17 00:00:00 2001 From: Lilly Date: Wed, 5 Feb 2025 17:14:01 +0100 Subject: [PATCH 7/8] Added citation to docs --- docs/index.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index 6f530605..58dd23b4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -54,8 +54,12 @@ Discussions references ``` -- Consider citing [scanpy Genome Biology (2018)] along with original {doc}`references `. -- A paper for pertpy is in the works. +## Citation + +[Lukas Heumos, Yuge Ji, Lilly May, Tessa Green, Xinyue Zhang, Xichen Wu, Johannes Ostner, Stefan Peidli, Antonia Schumacher, Karin Hrovatin, Michaela Mueller, Faye Chong, Gregor Sturm, Alejandro Tejada, Emma Dann, Mingze Dong, Mojtaba Bahrami, Ilan Gold, Sergei Rybakov, Altana Namsaraeva, Amir Ali Moinfar, Zihe Zheng, Eljas Roellin, Isra Mekki, Chris Sander, Mohammad Lotfollahi, Herbert B. Schiller, Fabian J. Theis +bioRxiv 2024.08.04.606516; doi: https://doi.org/10.1101/2024.08.04.606516](https://www.biorxiv.org/content/10.1101/2024.08.04.606516v1) + +Consider citing [scanpy Genome Biology (2018)] along with the original {doc}`references `. # Indices and tables From 3adafe46106a4727298074034f2de4373b759873 Mon Sep 17 00:00:00 2001 From: Lilly Date: Fri, 7 Feb 2025 16:34:12 +0100 Subject: [PATCH 8/8] Renamed 'manual' selection mode to 'split_by' --- pertpy/tools/_mixscape.py | 14 +++++++------- tests/tools/test_mixscape.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 6a5ba010..625a9bac 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -41,7 +41,7 @@ def perturbation_signature( adata: AnnData, pert_key: str, control: str, - ref_selection_mode: Literal["nn", "manual"] = "nn", + ref_selection_mode: Literal["nn", "split_by"] = "nn", split_by: str | None = None, n_neighbors: int = 20, use_rep: str | None = None, @@ -63,7 +63,7 @@ def perturbation_signature( pert_key: The column of `.obs` with perturbation categories, should also contain `control`. control: Name of the control category from the `pert_key` column. ref_selection_mode: Method to select reference cells for the perturbation signature calculation. If `nn`, - the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `manual`, + the `n_neighbors` cells from the control pool with the most similar mRNA expression profiles are selected. If `split_by`, the control cells from the same split in `split_by` (e.g. indicating biological replicates) are used to calculate the perturbation signature. split_by: Provide the column `.obs` if multiple biological replicates exist to calculate the perturbation signature for every replicate separately. @@ -94,10 +94,10 @@ def perturbation_signature( >>> ms_pt = pt.tl.Mixscape() >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") """ - if ref_selection_mode not in ["nn", "manual"]: - raise ValueError("ref_selection_mode must be either 'nn' or 'manual'.") - if ref_selection_mode == "manual" and split_by is None: - raise ValueError("split_by must be provided if ref_selection_mode is 'manual'.") + if ref_selection_mode not in ["nn", "split_by"]: + raise ValueError("ref_selection_mode must be either 'nn' or 'split_by'.") + if ref_selection_mode == "split_by" and split_by is None: + raise ValueError("split_by must be provided if ref_selection_mode is 'split_by'.") if copy: adata = adata.copy() @@ -106,7 +106,7 @@ def perturbation_signature( control_mask = adata.obs[pert_key] == control - if ref_selection_mode == "manual": + if ref_selection_mode == "split_by": for split in adata.obs[split_by].unique(): split_mask = adata.obs[split_by] == split control_mask_group = control_mask & split_mask diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py index 40512c35..adff733f 100644 --- a/tests/tools/test_mixscape.py +++ b/tests/tools/test_mixscape.py @@ -149,7 +149,7 @@ def test_deterministic_perturbation_signature(): del adata.layers["X_pert"] mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", ref_selection_mode="manual", split_by="group") + mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", ref_selection_mode="split_by", split_by="group") assert "X_pert" in adata.layers assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0)