diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index f0416e94..19bac9bb 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -41,6 +41,7 @@ def perturbation_signature( adata: AnnData, pert_key: str, control: str, + *, ref_selection_mode: Literal["nn", "split_by"] = "nn", split_by: str | None = None, n_neighbors: int = 20, @@ -53,10 +54,10 @@ def perturbation_signature( ): """Calculate perturbation signature. - The perturbation signature is calculated by subtracting the averaged mRNA expression profile of the control - cells (selected according to `ref_selection_mode`) from the mRNA expression profile of each cell. + The perturbation signature is calculated by subtracting the mRNA expression profile of each cell from the averaged + mRNA expression profile of the control cells (selected according to `ref_selection_mode`). The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the - perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same. + perturbation signature is calculated on unscaled data by default, and we therefore recommend to do the same. Args: adata: The annotated data object. @@ -184,11 +185,15 @@ def mixscape( adata: AnnData, labels: str, control: str, + *, new_class_name: str | None = "mixscape_class", - min_de_genes: int | None = 5, layer: str | None = None, + min_de_genes: int | None = 5, logfc_threshold: float | None = 0.25, + de_layer: str | None = None, + test_method: str | None = "wilcoxon", iter_num: int | None = 10, + scale: bool | None = True, split_by: str | None = None, pval_cutoff: float | None = 5e-2, perturbation_type: str | None = "KO", @@ -202,14 +207,16 @@ def mixscape( Args: adata: The annotated data object. labels: The column of `.obs` with target gene labels. - control: Control category from the `pert_key` column. + control: Control category from the `labels` column. new_class_name: Name of mixscape classification to be stored in `.obs`. - min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells. layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`. + min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells. logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells (default: 0.25). + de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used. + test_method: Method to use for differential expression testing. iter_num: Number of normalmixEM iterations to run if convergence does not occur. - split_by: Provide the column `.obs` if multiple biological replicates exist to calculate - the perturbation signature for every replicate separately. + scale: Scale the data specified in `layer` before running the GaussianMixture model on it. + split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific. pval_cutoff: P-value cut-off for selection of significantly DE genes. perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. random_state: Random seed for the GaussianMixture model. @@ -235,8 +242,8 @@ def mixscape( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") - >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert") """ if copy: adata = adata.copy() @@ -250,7 +257,16 @@ def mixscape( split_masks = [split_obs == category for category in categories] perturbation_markers = self._get_perturbation_markers( - adata, split_masks, categories, labels, control, layer, pval_cutoff, min_de_genes, logfc_threshold + adata=adata, + split_masks=split_masks, + categories=categories, + labels=labels, + control=control, + layer=de_layer, + pval_cutoff=pval_cutoff, + min_de_genes=min_de_genes, + logfc_threshold=logfc_threshold, + test_method=test_method, ) adata_comp = adata @@ -278,8 +294,8 @@ def mixscape( adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0 for split, split_mask in enumerate(split_masks): category = categories[split] - genes = list(set(adata[split_mask].obs[labels]).difference([control])) - for gene in genes: + gene_targets = list(set(adata[split_mask].obs[labels]).difference([control])) + for gene in gene_targets: post_prob = 0 orig_guide_cells = (adata.obs[labels] == gene) & split_mask orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells]) @@ -288,28 +304,38 @@ def mixscape( if len(perturbation_markers[(category, gene)]) == 0: adata.obs.loc[orig_guide_cells, new_class_name] = f"{gene} NP" + else: de_genes = perturbation_markers[(category, gene)] de_genes_indices = self._get_column_indices(adata, list(de_genes)) + dat = X[np.asarray(all_cells)][:, de_genes_indices] + dat_cells = all_cells[all_cells].index + if scale: + dat = sc.pp.scale(dat) + converged = False n_iter = 0 - old_classes = adata.obs[labels][all_cells] + old_classes = adata.obs[new_class_name][all_cells] + while not converged and n_iter < iter_num: # Get all cells in current split&Gene - guide_cells = (adata.obs[labels] == gene) & split_mask + guide_cells = (adata.obs[new_class_name] == gene) & split_mask + # get average value for each gene over all selected cells # all cells in current split&Gene minus all NT cells in current split # Each row is for each cell, each column is for each gene, get mean for each column - vec = np.mean(X[np.asarray(guide_cells)][:, de_genes_indices], axis=0) - np.mean( - X[np.asarray(nt_cells)][:, de_genes_indices], axis=0 - ) + guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index) + nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index) + vec = np.mean(dat[guide_cells_dat_idx], axis=0) - np.mean(dat[nt_cells_dat_idx], axis=0) + # project cells onto the perturbation vector if isinstance(dat, spmatrix): - pvec = np.sum(np.multiply(dat.toarray(), vec), axis=1) / np.sum(np.multiply(vec, vec)) + pvec = np.dot(dat.toarray(), vec) / np.dot(vec, vec) else: - pvec = np.sum(np.multiply(dat, vec), axis=1) / np.sum(np.multiply(vec, vec)) + pvec = np.dot(dat, vec) / np.dot(vec, vec) pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells])) + if n_iter == 0: gv = pd.DataFrame(columns=["pvec", labels]) gv["pvec"] = pvec @@ -319,20 +345,22 @@ def mixscape( gv_list[gene] = {} gv_list[gene][category] = gv - guide_norm = self._define_normal_mixscape(pvec[guide_cells]) - nt_norm = self._define_normal_mixscape(pvec[nt_cells]) - means_init = np.array([[nt_norm[0]], [guide_norm[0]]]) - precisions_init = np.array([nt_norm[1], guide_norm[1]]) - mm = GaussianMixture( + means_init = np.array([[pvec[nt_cells].mean()], [pvec[guide_cells].mean()]]) + std_init = np.array([pvec[nt_cells].std(), pvec[guide_cells].std()]) + mm = MixscapeGaussianMixture( n_components=2, covariance_type="spherical", means_init=means_init, - precisions_init=precisions_init, + precisions_init=1 / (std_init ** 2), random_state=random_state, + max_iter=5000, + fixed_means=[pvec[nt_cells].mean(), None], + fixed_covariances=[pvec[nt_cells].std() ** 2, None], ).fit(np.asarray(pvec).reshape(-1, 1)) probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1)) lik_ratio = probabilities[:, 0] / probabilities[:, 1] post_prob = 1 / (1 + lik_ratio) + # based on the posterior probability, assign cells to the two classes adata.obs.loc[ [orig_guide_cells_index[cell] for cell in np.where(post_prob > 0.5)[0]], new_class_name @@ -340,11 +368,13 @@ def mixscape( adata.obs.loc[ [orig_guide_cells_index[cell] for cell in np.where(post_prob <= 0.5)[0]], new_class_name ] = f"{gene} NP" + if sum(adata.obs[new_class_name][split_mask] == gene) < min_de_genes: adata.obs.loc[guide_cells, new_class_name] = "NP" converged = True if adata.obs[new_class_name][all_cells].equals(old_classes): converged = True + old_classes = adata.obs[new_class_name][all_cells] n_iter += 1 @@ -364,11 +394,13 @@ def lda( adata: AnnData, labels: str, control: str, + *, mixscape_class_global: str | None = "mixscape_class_global", layer: str | None = None, n_comps: int | None = 10, min_de_genes: int | None = 5, logfc_threshold: float | None = 0.25, + test_method: str | None = "wilcoxon", split_by: str | None = None, pval_cutoff: float | None = 5e-2, perturbation_type: str | None = "KO", @@ -381,12 +413,13 @@ def lda( labels: The column of `.obs` with target gene labels. control: Control category from the `pert_key` column. mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT). - layer: Key from `adata.layers` whose value will be used to perform tests on. + layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used. control: Control category from the `pert_key` column. n_comps: Number of principal components to use. min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells. logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. - split_by: Provide the column `.obs` if multiple biological replicates exist to calculate + test_method: Method to use for differential expression testing. + split_by: Provide `.obs` column with experimental condition/cell type annotation, if perturbations are condition/cell type-specific. pval_cutoff: P-value cut-off for selection of significantly DE genes. perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications. copy: Determines whether a copy of the `adata` is returned. @@ -404,9 +437,9 @@ def lda( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") - >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") - >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert") + >>> ms_pt.lda(mdata["rna"], "gene_target", "NT") """ if copy: adata = adata.copy() @@ -422,9 +455,8 @@ def lda( categories = split_obs.unique() split_masks = [split_obs == category for category in categories] - mixscape_identifier = pt.tl.Mixscape() # determine gene sets across all splits/groups through differential gene expression - perturbation_markers = mixscape_identifier._get_perturbation_markers( + perturbation_markers = self._get_perturbation_markers( adata=adata, split_masks=split_masks, categories=categories, @@ -434,6 +466,7 @@ def lda( pval_cutoff=pval_cutoff, min_de_genes=min_de_genes, logfc_threshold=logfc_threshold, + test_method=test_method, ) adata_subset = adata[ (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control) @@ -479,12 +512,21 @@ def _get_perturbation_markers( pval_cutoff: float, min_de_genes: float, logfc_threshold: float, + test_method: str, ) -> dict[tuple, np.ndarray]: """Determine gene sets across all splits/groups through differential gene expression Args: adata: :class:`~anndata.AnnData` object - col_names: Column names to extract the indices for + split_masks: List of boolean masks for each split/group. + categories: List of split/group names. + labels: The column of `.obs` with target gene labels. + control: Control category from the `labels` column. + layer: Key from adata.layers whose value will be used to compare gene expression. + pval_cutoff: P-value cut-off for selection of significantly DE genes. + min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells. + logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. + test_method: Method to use for differential expression testing. Returns: Set of column indices. @@ -493,21 +535,21 @@ def _get_perturbation_markers( for split, split_mask in enumerate(split_masks): category = categories[split] # get gene sets for each split - genes = list(set(adata[split_mask].obs[labels]).difference([control])) + gene_targets = list(set(adata[split_mask].obs[labels]).difference([control])) adata_split = adata[split_mask].copy() # find top DE genes between cells with targeting and non-targeting gRNAs sc.tl.rank_genes_groups( adata_split, layer=layer, groupby=labels, - groups=genes, + groups=gene_targets, reference=control, - method="t-test", + method=test_method, use_raw=False, ) - # get DE genes for each gene - for gene in genes: - logfc_threshold_mask = adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene] >= logfc_threshold + # get DE genes for each target gene + for gene in gene_targets: + logfc_threshold_mask = np.abs(adata_split.uns["rank_genes_groups"]["logfoldchanges"][gene]) >= logfc_threshold de_genes = adata_split.uns["rank_genes_groups"]["names"][gene][logfc_threshold_mask] pvals_adj = adata_split.uns["rank_genes_groups"]["pvals_adj"][gene][logfc_threshold_mask] de_genes = de_genes[pvals_adj < pval_cutoff] @@ -528,20 +570,6 @@ def _get_column_indices(self, adata, col_names): return indices - def _define_normal_mixscape(self, X: np.ndarray | sparse.spmatrix | pd.DataFrame | None) -> list[float]: - """Calculates the mean and standard deviation of a matrix. - - Args: - X: The matrix to calculate the properties for. - - Returns: - Mean and standard deviation of the matrix. - """ - mu = X.mean() - sd = X.std() - - return [mu, sd] - @_doc_params(common_plot_args=doc_common_plot_args) def plot_barplot( # pragma: no cover self, @@ -581,8 +609,8 @@ def plot_barplot( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") - >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert") >>> ms_pt.plot_barplot(mdata["rna"], guide_rna_column="NT") Preview: @@ -687,8 +715,8 @@ def plot_heatmap( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") - >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert") >>> ms_pt.plot_heatmap( ... adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", layer="X_pert", control="NT" ... ) @@ -763,8 +791,8 @@ def plot_perturbscore( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") - >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert") >>> ms_pt.plot_perturbscore(adata=mdata["rna"], labels="gene_target", target_gene="IFNGR2", color="orange") Preview: @@ -937,8 +965,8 @@ def plot_violin( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") - >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert") >>> ms_pt.plot_violin( ... adata=mdata["rna"], target_gene_idents=["NT", "IFNGR2 NP", "IFNGR2 KO"], groupby="mixscape_class" ... ) @@ -1121,9 +1149,9 @@ def plot_lda( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ms_pt = pt.tl.Mixscape() - >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", "replicate") - >>> ms_pt.mixscape(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") - >>> ms_pt.lda(adata=mdata["rna"], control="NT", labels="gene_target", layer="X_pert") + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscape(mdata["rna"], "gene_target", "NT", layer="X_pert") + >>> ms_pt.lda(mdata["rna"], "gene_target", "NT", split_by="replicate") >>> ms_pt.plot_lda(adata=mdata["rna"], control="NT") Preview: @@ -1157,3 +1185,40 @@ def plot_lda( # pragma: no cover return fig plt.show() return None + +class MixscapeGaussianMixture(GaussianMixture): + def __init__( + self, + n_components: int, + fixed_means: Sequence[float] | None = None, + fixed_covariances: Sequence[float] | None = None, + **kwargs + ): + """Custom Gaussian Mixture Model where means and covariances can be fixed for specific components. + + Args: + n_components: Number of Gaussian components + fixed_means: Means to fix (use None for those that should be estimated) + fixed_covariances: Covariances to fix (use None for those that should be estimated) + kwargs: Additional arguments passed to scikit-learn's GaussianMixture + """ + super().__init__(n_components=n_components, **kwargs) + self.fixed_means = fixed_means + self.fixed_covariances = fixed_covariances + + def _m_step(self, X: np.ndarray, log_resp: np.ndarray): + """Modified M-step to respect fixed means and covariances.""" + super()._m_step(X, log_resp) + + if self.fixed_means is not None: + for i in range(self.n_components): + if self.fixed_means[i] is not None: + self.means_[i] = self.fixed_means[i] + + if self.fixed_covariances is not None: + for i in range(self.n_components): + if self.fixed_covariances[i] is not None: + self.covariances_[i] = self.fixed_covariances[i] + + return self + diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py index bb50e2e1..02b71f91 100644 --- a/tests/tools/test_mixscape.py +++ b/tests/tools/test_mixscape.py @@ -7,6 +7,8 @@ import pytest from scipy import sparse +from pertpy.tools._mixscape import MixscapeGaussianMixture + CWD = Path(__file__).parent.resolve() # Random generate data settings @@ -49,7 +51,7 @@ def adata(): # obs for random AnnData gene_target = {"gene_target": ["NT"] * num_cells_per_group + ["target_gene_a"] * num_cells_per_group * 2} gene_target = pd.DataFrame(gene_target) - label = {"label": ["control", "treatment", "treatment"] * num_cells_per_group} + label = {"label": ["control"] * num_cells_per_group + ["treatment"] * num_cells_per_group* 2 } label = pd.DataFrame(label) obs = pd.concat([gene_target, label], axis=1) obs = obs.set_index(np.arange(num_cells_per_group * 3)) @@ -68,9 +70,9 @@ def adata(): def test_mixscape(adata): - mixscape_identifier = pt.tl.Mixscape() adata.layers["X_pert"] = adata.X - mixscape_identifier.mixscape(adata=adata, control="NT", labels="gene_target") + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.mixscape(adata=adata, labels="gene_target", control="NT", test_method="t-test") np_result = adata.obs["mixscape_class_global"] == "NP" np_result_correct = np_result[num_cells_per_group : num_cells_per_group * 2] @@ -94,8 +96,8 @@ def test_perturbation_signature(adata): def test_lda(adata): adata.layers["X_pert"] = adata.X mixscape_identifier = pt.tl.Mixscape() - mixscape_identifier.mixscape(adata=adata, control="NT", labels="gene_target") - mixscape_identifier.lda(adata=adata, labels="gene_target", control="NT") + mixscape_identifier.mixscape(adata=adata, labels="gene_target", control="NT", test_method="t-test") + mixscape_identifier.lda(adata=adata, labels="gene_target", control="NT", test_method="t-test") assert "mixscape_lda" in adata.uns @@ -159,3 +161,16 @@ def test_deterministic_perturbation_signature(): assert np.allclose( adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0) ) + + +def test_mixscape_gaussian_mixture(): + X = np.random.rand(100) + + fixed_means = [0.2, None] + fixed_covariances = [None, 0.1] + + model = MixscapeGaussianMixture(n_components=2, fixed_means=fixed_means, fixed_covariances=fixed_covariances) + model.fit(X.reshape(-1, 1)) + + assert np.allclose(model.means_[0], fixed_means[0]) + assert np.allclose(model.covariances_[1], fixed_covariances[1])