Skip to content

Commit 423063f

Browse files
authored
Add parameter explanation for groups_col (#529)
1 parent b8f7359 commit 423063f

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

pertpy/tools/_perturbation_space/_simple.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def compute(
116116
self,
117117
adata: AnnData,
118118
target_col: str = "perturbation",
119+
groups_col: str = None,
119120
layer_key: str = None,
120121
embedding_key: str = None,
121122
**kwargs,
@@ -125,6 +126,8 @@ def compute(
125126
Args:
126127
adata: Anndata object of size cells x genes
127128
target_col: .obs column that stores the label of the perturbation applied to each cell.
129+
groups_col: Optional .obs column that stores a grouping label to consider for pseudobulk computation.
130+
The summarized expression per perturbation (target_col) and group (groups_col) is computed. Defaults to None.
128131
layer_key: If specified pseudobulk computation is done by using the specified layer. Otherwise, computation is done with .X
129132
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
130133
**kwargs: Are passed to decoupler's get_pseuobulk.
@@ -136,11 +139,8 @@ def compute(
136139
>>> import pertpy as pt
137140
>>> mdata = pt.dt.papalexi_2021()
138141
>>> ps = pt.tl.PseudobulkSpace()
139-
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
142+
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target")
140143
"""
141-
if "groups_col" not in kwargs:
142-
kwargs["groups_col"] = "perturbation"
143-
144144
if layer_key is not None and embedding_key is not None:
145145
raise ValueError("Please, select just either layer or embedding for computation.")
146146

@@ -159,7 +159,8 @@ def compute(
159159
adata_emb.obs = adata.obs
160160
adata = adata_emb
161161

162-
ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, **kwargs) # type: ignore
162+
adata.obs[target_col] = adata.obs[target_col].astype("category")
163+
ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, groups_col=groups_col, **kwargs) # type: ignore
163164

164165
ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
165166

0 commit comments

Comments
 (0)