diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index 42649496..c0085b4b 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -116,6 +116,7 @@ def compute( self, adata: AnnData, target_col: str = "perturbation", + groups_col: str = None, layer_key: str = None, embedding_key: str = None, **kwargs, @@ -125,6 +126,8 @@ def compute( Args: adata: Anndata object of size cells x genes target_col: .obs column that stores the label of the perturbation applied to each cell. + groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation. + The summarized expression per perturbation (target_col) and group (groups_col) is computed. Defaults to None. layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise. **kwargs: Are passed to decoupler's get_pseuobulk. @@ -136,11 +139,8 @@ def compute( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() - >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") + >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target") """ - if "groups_col" not in kwargs: - kwargs["groups_col"] = "perturbation" - if layer_key is not None and embedding_key is not None: raise ValueError("Please, select just either layer or embedding for computation.") @@ -159,7 +159,8 @@ def compute( adata_emb.obs = adata.obs adata = adata_emb - ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, **kwargs) # type: ignore + adata.obs[target_col] = adata.obs[target_col].astype("category") + ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")