@@ -116,6 +116,7 @@ def compute(
116116        self ,
117117        adata : AnnData ,
118118        target_col : str  =  "perturbation" ,
119+         groups_col : str  =  None ,
119120        layer_key : str  =  None ,
120121        embedding_key : str  =  None ,
121122        ** kwargs ,
@@ -125,6 +126,8 @@ def compute(
125126        Args: 
126127            adata: Anndata object of size cells x genes 
127128            target_col: .obs column that stores the label of the perturbation applied to each cell. 
129+             groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation. 
130+                 The summarized expression per perturbation (target_col) and group (groups_col) is computed. Defaults to None. 
128131            layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X 
129132            embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise. 
130133            **kwargs: Are passed to decoupler's get_pseuobulk. 
@@ -136,11 +139,8 @@ def compute(
136139            >>> import pertpy as pt 
137140            >>> mdata = pt.dt.papalexi_2021() 
138141            >>> ps = pt.tl.PseudobulkSpace() 
139-             >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target" ) 
142+             >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target") 
140143        """ 
141-         if  "groups_col"  not  in kwargs :
142-             kwargs ["groups_col" ] =  "perturbation" 
143- 
144144        if  layer_key  is  not None  and  embedding_key  is  not None :
145145            raise  ValueError ("Please, select just either layer or embedding for computation." )
146146
@@ -159,7 +159,8 @@ def compute(
159159                adata_emb .obs  =  adata .obs 
160160                adata  =  adata_emb 
161161
162-         ps_adata  =  dc .get_pseudobulk (adata , sample_col = target_col , layer = layer_key , ** kwargs )  # type: ignore 
162+         adata .obs [target_col ] =  adata .obs [target_col ].astype ("category" )
163+         ps_adata  =  dc .get_pseudobulk (adata , sample_col = target_col , layer = layer_key , groups_col = groups_col , ** kwargs )  # type: ignore 
163164
164165        ps_adata .obs [target_col ] =  ps_adata .obs [target_col ].astype ("category" )
165166
0 commit comments