@@ -74,15 +74,15 @@ def perturbation_signature(
7474
7575 Returns:
7676 If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
77- Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
77+ Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
7878
7979 Examples:
8080 Calcutate perturbation signature for each cell in the dataset:
8181
8282 >>> import pertpy as pt
8383 >>> mdata = pt.dt.papalexi_2021()
84- >>> mixscape_identifier = pt.tl.Mixscape()
85- >>> mixscape_identifier .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
84+ >>> ms_pt = pt.tl.Mixscape()
85+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
8686 """
8787 if copy :
8888 adata = adata .copy ()
@@ -95,18 +95,17 @@ def perturbation_signature(
9595 split_masks = [np .full (adata .n_obs , True , dtype = bool )]
9696 else :
9797 split_obs = adata .obs [split_by ]
98- cats = split_obs .unique ()
99- split_masks = [split_obs == cat for cat in cats ]
98+ split_masks = [split_obs == cat for cat in split_obs .unique ()]
10099
101- R = _choose_representation (adata , use_rep = use_rep , n_pcs = n_pcs )
100+ representation = _choose_representation (adata , use_rep = use_rep , n_pcs = n_pcs )
102101
103102 for split_mask in split_masks :
104103 control_mask_split = control_mask & split_mask
105104
106- R_split = R [split_mask ]
107- R_control = R [control_mask_split ]
105+ R_split = representation [split_mask ]
106+ R_control = representation [control_mask_split ]
108107
109- from pynndescent import NNDescent # saves a lot of import time
108+ from pynndescent import NNDescent
110109
111110 eps = kwargs .pop ("epsilon" , 0.1 )
112111 nn_index = NNDescent (R_control , ** kwargs )
@@ -170,7 +169,6 @@ def mixscape(
170169
171170 Args:
172171 adata: The annotated data object.
173- pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
174172 labels: The column of `.obs` with target gene labels.
175173 control: Control category from the `pert_key` column.
176174 new_class_name: Name of mixscape classification to be stored in `.obs`.
@@ -186,26 +184,26 @@ def mixscape(
186184
187185 Returns:
188186 If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
189- Otherwise writes the results directly to `.obs` of the provided `adata`.
187+ Otherwise, writes the results directly to `.obs` of the provided `adata`.
190188
191- mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
192- Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
189+ - mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
190+ Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
193191
194- mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
195- Global classification result (perturbed, NP or NT)
192+ - mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
193+ Global classification result (perturbed, NP or NT).
196194
197- mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
198- Posterior probabilities used to determine if a cell is KO (default).
199- Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP
195+ - mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
196+ Posterior probabilities used to determine if a cell is KO (default).
197+ Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
200198
201199 Examples:
202200 Calcutate perturbation signature for each cell in the dataset:
203201
204202 >>> import pertpy as pt
205203 >>> mdata = pt.dt.papalexi_2021()
206- >>> mixscape_identifier = pt.tl.Mixscape()
207- >>> mixscape_identifier .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
208- >>> mixscape_identifier .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
204+ >>> ms_pt = pt.tl.Mixscape()
205+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
206+ >>> ms_pt .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
209207 """
210208 if copy :
211209 adata = adata .copy ()
@@ -229,10 +227,9 @@ def mixscape(
229227 try :
230228 X = adata_comp .layers ["X_pert" ]
231229 except KeyError :
232- print (
233- '[bold yellow]No "X_pert" found in .layers! -- Please run pert_sign first to calculate perturbation signature!'
234- )
235- raise
230+ raise KeyError (
231+ "No 'X_pert' found in .layers! Please run pert_sign first to calculate perturbation signature!"
232+ ) from None
236233 # initialize return variables
237234 adata .obs [f"{ new_class_name } _p_{ perturbation_type .lower ()} " ] = 0
238235 adata .obs [new_class_name ] = adata .obs [labels ].astype (str )
@@ -351,15 +348,17 @@ def lda(
351348 control: Control category from the `pert_key` column. Defaults to 'NT'.
352349 n_comps: Number of principal components to use. Defaults to 10.
353350 min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
354- logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. Defaults to 0.25.
351+ logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
352+ Defaults to 0.25.
355353 split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
356354 pval_cutoff: P-value cut-off for selection of significantly DE genes.
357- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO.
355+ perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
356+ Defaults to KO.
358357 copy: Determines whether a copy of the `adata` is returned.
359358
360359 Returns:
361360 If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
362- Otherwise writes the results directly to `.uns` of the provided `adata`.
361+ Otherwise, writes the results directly to `.uns` of the provided `adata`.
363362
364363 mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
365364 LDA result.
@@ -369,10 +368,10 @@ def lda(
369368
370369 >>> import pertpy as pt
371370 >>> mdata = pt.dt.papalexi_2021()
372- >>> mixscape_identifier = pt.tl.Mixscape()
373- >>> mixscape_identifier .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
374- >>> mixscape_identifier .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
375- >>> mixscape_identifier .lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
371+ >>> ms_pt = pt.tl.Mixscape()
372+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
373+ >>> ms_pt .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
374+ >>> ms_pt .lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
376375 """
377376 if copy :
378377 adata = adata .copy ()
@@ -524,12 +523,12 @@ def plot_barplot( # pragma: no cover
524523 show : bool | None = None ,
525524 save : bool | str | None = None ,
526525 ):
527- """Barplot to visualize perturbation scores calculated from RunMixscape function.
526+ """Barplot to visualize perturbation scores calculated by the `mixscape` function.
528527
529528 Args:
530529 adata: The annotated data object.
531530 guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
532- The format must be <gene_target>g<#>. For example, 'STAT2g1' and 'ATF2g1'.
531+ The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
533532 mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
534533 show: Show the plot, do not return axis.
535534 save: If True or a str, save the figure. A string is appended to the default filename.
@@ -541,13 +540,13 @@ def plot_barplot( # pragma: no cover
541540 Examples:
542541 >>> import pertpy as pt
543542 >>> mdata = pt.dt.papalexi_2021()
544- >>> mixscape_identifier = pt.tl.Mixscape()
545- >>> mixscape_identifier .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
546- >>> mixscape_identifier .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
547- >>> pt.pl.ms.barplot (mdata['rna'], guide_rna_column='NT')
543+ >>> ms_pt = pt.tl.Mixscape()
544+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
545+ >>> ms_pt .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
546+ >>> ms_pt.plot_barplot (mdata['rna'], guide_rna_column='NT')
548547 """
549548 if mixscape_class_global not in adata .obs :
550- raise ValueError ("Please run `pt.tl. mixscape` first." )
549+ raise ValueError ("Please run the ` mixscape` function first." )
551550 count = pd .crosstab (index = adata .obs [mixscape_class_global ], columns = adata .obs [guide_rna_column ])
552551 all_cells_percentage = pd .melt (count / count .sum (), ignore_index = False ).reset_index ()
553552 KO_cells_percentage = all_cells_percentage [all_cells_percentage [mixscape_class_global ] == "KO" ]
@@ -604,7 +603,7 @@ def plot_barplot( # pragma: no cover
604603 )
605604 pl .tight_layout ()
606605
607- return pl . gcf ()
606+ return ax
608607
609608 def plot_heatmap ( # pragma: no cover
610609 self ,
@@ -642,10 +641,10 @@ def plot_heatmap( # pragma: no cover
642641 Examples:
643642 >>> import pertpy as pt
644643 >>> mdata = pt.dt.papalexi_2021()
645- >>> ms = pt.tl.Mixscape()
646- >>> ms .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
647- >>> ms .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
648- >>> ms .plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
644+ >>> ms_pt = pt.tl.Mixscape()
645+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
646+ >>> ms_pt .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
647+ >>> ms_pt .plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
649648 """
650649 if "mixscape_class" not in adata .obs :
651650 raise ValueError ("Please run `pt.tl.mixscape` first." )
@@ -676,8 +675,10 @@ def plot_perturbscore( # pragma: no cover
676675 split_by : str = None ,
677676 before_mixscape = False ,
678677 perturbation_type : str = "KO" ,
679- ):
680- """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first.
678+ ) -> None :
679+ """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
680+
681+ Requires `pt.tl.mixscape` to be run first.
681682
682683 https://satijalab.org/seurat/reference/plotperturbscore
683684
@@ -688,9 +689,10 @@ def plot_perturbscore( # pragma: no cover
688689 mixscape_class: The column of `.obs` with mixscape classifications.
689690 color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
690691 split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
691- the perturbation signature for every replicate separately.
692- before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID.
693- perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Default is KO.
692+ the perturbation signature for every replicate separately.
693+ before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
694+ Default is set to NULL and plots cells by original class ID.
695+ perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to `KO`.
694696
695697 Returns:
696698 The ggplot object used for drawn.
@@ -700,13 +702,13 @@ def plot_perturbscore( # pragma: no cover
700702
701703 >>> import pertpy as pt
702704 >>> mdata = pt.dt.papalexi_2021()
703- >>> mixscape_identifier = pt.tl.Mixscape()
704- >>> mixscape_identifier .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
705- >>> mixscape_identifier .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
706- >>> pt.pl.ms.perturbscore (adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
705+ >>> ms_pt = pt.tl.Mixscape()
706+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
707+ >>> ms_pt .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
708+ >>> ms_pt.plot_perturbscore (adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
707709 """
708710 if "mixscape" not in adata .uns :
709- raise ValueError ("Please run `pt.tl. mixscape` first." )
711+ raise ValueError ("Please run the ` mixscape` function first." )
710712 perturbation_score = None
711713 for key in adata .uns ["mixscape" ][target_gene ].keys ():
712714 perturbation_score_temp = adata .uns ["mixscape" ][target_gene ][key ]
@@ -807,8 +809,6 @@ def plot_perturbscore( # pragma: no cover
807809 pl .legend (title = "mixscape class" , title_fontsize = 14 , fontsize = 12 )
808810 sns .despine ()
809811
810- return pl .gcf ()
811-
812812 def plot_violin ( # pragma: no cover
813813 self ,
814814 adata : AnnData ,
@@ -859,10 +859,10 @@ def plot_violin( # pragma: no cover
859859 Examples:
860860 >>> import pertpy as pt
861861 >>> mdata = pt.dt.papalexi_2021()
862- >>> ms = pt.tl.Mixscape()
863- >>> ms .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
864- >>> ms .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
865- >>> ms .plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
862+ >>> ms_pt = pt.tl.Mixscape()
863+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
864+ >>> ms_pt .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
865+ >>> ms_pt .plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
866866 """
867867 if isinstance (target_gene_idents , str ):
868868 mixscape_class_mask = adata .obs [groupby ] == target_gene_idents
@@ -1023,12 +1023,11 @@ def plot_lda( # pragma: no cover
10231023 Args:
10241024 adata: The annotated data object.
10251025 control: Control category from the `pert_key` column.
1026- labels: The column of `.obs` with target gene labels.
10271026 mixscape_class: The column of `.obs` with the mixscape classification result.
10281027 mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
10291028 perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
10301029 Defaults to 'KO'.
1031- lda_key: If not speficied , lda looks .uns["mixscape_lda"] for the LDA results.
1030+ lda_key: If not specified , lda looks .uns["mixscape_lda"] for the LDA results.
10321031 n_components: The number of dimensions of the embedding.
10331032 show: Show the plot, do not return axis.
10341033 save: If `True` or a `str`, save the figure. A string is appended to the default filename.
@@ -1038,16 +1037,18 @@ def plot_lda( # pragma: no cover
10381037 Examples:
10391038 >>> import pertpy as pt
10401039 >>> mdata = pt.dt.papalexi_2021()
1041- >>> ms = pt.tl.Mixscape()
1042- >>> ms .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
1043- >>> ms .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
1044- >>> ms .lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
1045- >>> ms .plot_lda(adata=mdata['rna'], control='NT')
1040+ >>> ms_pt = pt.tl.Mixscape()
1041+ >>> ms_pt .perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
1042+ >>> ms_pt .mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
1043+ >>> ms_pt .lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
1044+ >>> ms_pt .plot_lda(adata=mdata['rna'], control='NT')
10461045 """
10471046 if mixscape_class not in adata .obs :
1048- raise ValueError (f'Did not find .obs["{ mixscape_class !r} "]. Please run `pt.tl.mixscape` first.' )
1047+ raise ValueError (
1048+ f'Did not find `.obs["{ mixscape_class !r} "]`. Please run the `mixscape` function first first.'
1049+ )
10491050 if lda_key not in adata .uns :
1050- raise ValueError (f'Did not find .uns["{ lda_key !r} "]. Run `pt.tl.neighbors` first.' )
1051+ raise ValueError (f'Did not find ` .uns["{ lda_key !r} "]` . Run the `lda` function first.' )
10511052
10521053 adata_subset = adata [
10531054 (adata .obs [mixscape_class_global ] == perturbation_type ) | (adata .obs [mixscape_class_global ] == control )
0 commit comments