Skip to content

Commit c95bbcb

Browse files
committed
Refactor mixscape
Signed-off-by: zethson <[email protected]>
1 parent 5c784ac commit c95bbcb

File tree

1 file changed

+70
-69
lines changed

1 file changed

+70
-69
lines changed

pertpy/tools/_mixscape.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ def perturbation_signature(
7474
7575
Returns:
7676
If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
77-
Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
77+
Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
7878
7979
Examples:
8080
Calcutate perturbation signature for each cell in the dataset:
8181
8282
>>> import pertpy as pt
8383
>>> mdata = pt.dt.papalexi_2021()
84-
>>> mixscape_identifier = pt.tl.Mixscape()
85-
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
84+
>>> ms_pt = pt.tl.Mixscape()
85+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
8686
"""
8787
if copy:
8888
adata = adata.copy()
@@ -95,18 +95,17 @@ def perturbation_signature(
9595
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
9696
else:
9797
split_obs = adata.obs[split_by]
98-
cats = split_obs.unique()
99-
split_masks = [split_obs == cat for cat in cats]
98+
split_masks = [split_obs == cat for cat in split_obs.unique()]
10099

101-
R = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
100+
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
102101

103102
for split_mask in split_masks:
104103
control_mask_split = control_mask & split_mask
105104

106-
R_split = R[split_mask]
107-
R_control = R[control_mask_split]
105+
R_split = representation[split_mask]
106+
R_control = representation[control_mask_split]
108107

109-
from pynndescent import NNDescent # saves a lot of import time
108+
from pynndescent import NNDescent
110109

111110
eps = kwargs.pop("epsilon", 0.1)
112111
nn_index = NNDescent(R_control, **kwargs)
@@ -170,7 +169,6 @@ def mixscape(
170169
171170
Args:
172171
adata: The annotated data object.
173-
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
174172
labels: The column of `.obs` with target gene labels.
175173
control: Control category from the `pert_key` column.
176174
new_class_name: Name of mixscape classification to be stored in `.obs`.
@@ -186,26 +184,26 @@ def mixscape(
186184
187185
Returns:
188186
If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
189-
Otherwise writes the results directly to `.obs` of the provided `adata`.
187+
Otherwise, writes the results directly to `.obs` of the provided `adata`.
190188
191-
mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
192-
Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
189+
- mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
190+
Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
193191
194-
mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
195-
Global classification result (perturbed, NP or NT)
192+
- mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
193+
Global classification result (perturbed, NP or NT).
196194
197-
mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
198-
Posterior probabilities used to determine if a cell is KO (default).
199-
Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP
195+
- mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
196+
Posterior probabilities used to determine if a cell is KO (default).
197+
Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
200198
201199
Examples:
202200
Calcutate perturbation signature for each cell in the dataset:
203201
204202
>>> import pertpy as pt
205203
>>> mdata = pt.dt.papalexi_2021()
206-
>>> mixscape_identifier = pt.tl.Mixscape()
207-
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
208-
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
204+
>>> ms_pt = pt.tl.Mixscape()
205+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
206+
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
209207
"""
210208
if copy:
211209
adata = adata.copy()
@@ -229,10 +227,9 @@ def mixscape(
229227
try:
230228
X = adata_comp.layers["X_pert"]
231229
except KeyError:
232-
print(
233-
'[bold yellow]No "X_pert" found in .layers! -- Please run pert_sign first to calculate perturbation signature!'
234-
)
235-
raise
230+
raise KeyError(
231+
"No 'X_pert' found in .layers! Please run pert_sign first to calculate perturbation signature!"
232+
) from None
236233
# initialize return variables
237234
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
238235
adata.obs[new_class_name] = adata.obs[labels].astype(str)
@@ -351,15 +348,17 @@ def lda(
351348
control: Control category from the `pert_key` column. Defaults to 'NT'.
352349
n_comps: Number of principal components to use. Defaults to 10.
353350
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
354-
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. Defaults to 0.25.
351+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
352+
Defaults to 0.25.
355353
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
356354
pval_cutoff: P-value cut-off for selection of significantly DE genes.
357-
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO.
355+
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
356+
Defaults to KO.
358357
copy: Determines whether a copy of the `adata` is returned.
359358
360359
Returns:
361360
If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
362-
Otherwise writes the results directly to `.uns` of the provided `adata`.
361+
Otherwise, writes the results directly to `.uns` of the provided `adata`.
363362
364363
mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
365364
LDA result.
@@ -369,10 +368,10 @@ def lda(
369368
370369
>>> import pertpy as pt
371370
>>> mdata = pt.dt.papalexi_2021()
372-
>>> mixscape_identifier = pt.tl.Mixscape()
373-
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
374-
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
375-
>>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
371+
>>> ms_pt = pt.tl.Mixscape()
372+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
373+
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
374+
>>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
376375
"""
377376
if copy:
378377
adata = adata.copy()
@@ -524,12 +523,12 @@ def plot_barplot( # pragma: no cover
524523
show: bool | None = None,
525524
save: bool | str | None = None,
526525
):
527-
"""Barplot to visualize perturbation scores calculated from RunMixscape function.
526+
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
528527
529528
Args:
530529
adata: The annotated data object.
531530
guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
532-
The format must be <gene_target>g<#>. For example, 'STAT2g1' and 'ATF2g1'.
531+
The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
533532
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
534533
show: Show the plot, do not return axis.
535534
save: If True or a str, save the figure. A string is appended to the default filename.
@@ -541,13 +540,13 @@ def plot_barplot( # pragma: no cover
541540
Examples:
542541
>>> import pertpy as pt
543542
>>> mdata = pt.dt.papalexi_2021()
544-
>>> mixscape_identifier = pt.tl.Mixscape()
545-
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
546-
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
547-
>>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
543+
>>> ms_pt = pt.tl.Mixscape()
544+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
545+
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
546+
>>> ms_pt.plot_barplot(mdata['rna'], guide_rna_column='NT')
548547
"""
549548
if mixscape_class_global not in adata.obs:
550-
raise ValueError("Please run `pt.tl.mixscape` first.")
549+
raise ValueError("Please run the `mixscape` function first.")
551550
count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
552551
all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
553552
KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
@@ -604,7 +603,7 @@ def plot_barplot( # pragma: no cover
604603
)
605604
pl.tight_layout()
606605

607-
return pl.gcf()
606+
return ax
608607

609608
def plot_heatmap( # pragma: no cover
610609
self,
@@ -642,10 +641,10 @@ def plot_heatmap( # pragma: no cover
642641
Examples:
643642
>>> import pertpy as pt
644643
>>> mdata = pt.dt.papalexi_2021()
645-
>>> ms = pt.tl.Mixscape()
646-
>>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
647-
>>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
648-
>>> ms.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
644+
>>> ms_pt = pt.tl.Mixscape()
645+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
646+
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
647+
>>> ms_pt.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
649648
"""
650649
if "mixscape_class" not in adata.obs:
651650
raise ValueError("Please run `pt.tl.mixscape` first.")
@@ -676,8 +675,10 @@ def plot_perturbscore( # pragma: no cover
676675
split_by: str = None,
677676
before_mixscape=False,
678677
perturbation_type: str = "KO",
679-
):
680-
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first.
678+
) -> None:
679+
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
680+
681+
Requires `pt.tl.mixscape` to be run first.
681682
682683
https://satijalab.org/seurat/reference/plotperturbscore
683684
@@ -688,9 +689,10 @@ def plot_perturbscore( # pragma: no cover
688689
mixscape_class: The column of `.obs` with mixscape classifications.
689690
color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
690691
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
691-
the perturbation signature for every replicate separately.
692-
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID.
693-
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Default is KO.
692+
the perturbation signature for every replicate separately.
693+
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
694+
Default is set to NULL and plots cells by original class ID.
695+
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to `KO`.
694696
695697
Returns:
696698
The ggplot object used for drawn.
@@ -700,13 +702,13 @@ def plot_perturbscore( # pragma: no cover
700702
701703
>>> import pertpy as pt
702704
>>> mdata = pt.dt.papalexi_2021()
703-
>>> mixscape_identifier = pt.tl.Mixscape()
704-
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
705-
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
706-
>>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
705+
>>> ms_pt = pt.tl.Mixscape()
706+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
707+
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
708+
>>> ms_pt.plot_perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
707709
"""
708710
if "mixscape" not in adata.uns:
709-
raise ValueError("Please run `pt.tl.mixscape` first.")
711+
raise ValueError("Please run the `mixscape` function first.")
710712
perturbation_score = None
711713
for key in adata.uns["mixscape"][target_gene].keys():
712714
perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
@@ -807,8 +809,6 @@ def plot_perturbscore( # pragma: no cover
807809
pl.legend(title="mixscape class", title_fontsize=14, fontsize=12)
808810
sns.despine()
809811

810-
return pl.gcf()
811-
812812
def plot_violin( # pragma: no cover
813813
self,
814814
adata: AnnData,
@@ -859,10 +859,10 @@ def plot_violin( # pragma: no cover
859859
Examples:
860860
>>> import pertpy as pt
861861
>>> mdata = pt.dt.papalexi_2021()
862-
>>> ms = pt.tl.Mixscape()
863-
>>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
864-
>>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
865-
>>> ms.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
862+
>>> ms_pt = pt.tl.Mixscape()
863+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
864+
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
865+
>>> ms_pt.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
866866
"""
867867
if isinstance(target_gene_idents, str):
868868
mixscape_class_mask = adata.obs[groupby] == target_gene_idents
@@ -1023,12 +1023,11 @@ def plot_lda( # pragma: no cover
10231023
Args:
10241024
adata: The annotated data object.
10251025
control: Control category from the `pert_key` column.
1026-
labels: The column of `.obs` with target gene labels.
10271026
mixscape_class: The column of `.obs` with the mixscape classification result.
10281027
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
10291028
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
10301029
Defaults to 'KO'.
1031-
lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results.
1030+
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
10321031
n_components: The number of dimensions of the embedding.
10331032
show: Show the plot, do not return axis.
10341033
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
@@ -1038,16 +1037,18 @@ def plot_lda( # pragma: no cover
10381037
Examples:
10391038
>>> import pertpy as pt
10401039
>>> mdata = pt.dt.papalexi_2021()
1041-
>>> ms = pt.tl.Mixscape()
1042-
>>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
1043-
>>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
1044-
>>> ms.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
1045-
>>> ms.plot_lda(adata=mdata['rna'], control='NT')
1040+
>>> ms_pt = pt.tl.Mixscape()
1041+
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
1042+
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
1043+
>>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
1044+
>>> ms_pt.plot_lda(adata=mdata['rna'], control='NT')
10461045
"""
10471046
if mixscape_class not in adata.obs:
1048-
raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
1047+
raise ValueError(
1048+
f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first first.'
1049+
)
10491050
if lda_key not in adata.uns:
1050-
raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.')
1051+
raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Run the `lda` function first.')
10511052

10521053
adata_subset = adata[
10531054
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)

0 commit comments

Comments
 (0)