Skip to content

Commit 3c06675

Browse files
committed
Adapt palette for mixscape plot
Signed-off-by: zethson <[email protected]>
1 parent cc3be2f commit 3c06675

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

pertpy/tools/_mixscape.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ def plot_perturbscore( # pragma: no cover
672672
target_gene: str,
673673
mixscape_class="mixscape_class",
674674
color="orange",
675+
palette: dict[str, str] = None,
675676
split_by: str = None,
676677
before_mixscape=False,
677678
perturbation_type: str = "KO",
@@ -688,11 +689,13 @@ def plot_perturbscore( # pragma: no cover
688689
target_gene: Target gene name to visualize perturbation scores for.
689690
mixscape_class: The column of `.obs` with mixscape classifications.
690691
color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
692+
palette: Optional full color palette to overwrite all colors.
691693
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
692694
the perturbation signature for every replicate separately.
693695
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
694696
Default is set to NULL and plots cells by original class ID.
695-
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to `KO`.
697+
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
698+
Defaults to `KO`.
696699
697700
Returns:
698701
The ggplot object used for drawn.
@@ -721,7 +724,7 @@ def plot_perturbscore( # pragma: no cover
721724
gd = list(set(perturbation_score[labels]).difference({target_gene}))[0]
722725
# If before_mixscape is True, split densities based on original target gene classification
723726
if before_mixscape is True:
724-
cols = {gd: "#7d7d7d", target_gene: color}
727+
palette = {gd: "#7d7d7d", target_gene: color}
725728
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
726729
top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
727730
pl.close()
@@ -737,10 +740,10 @@ def plot_perturbscore( # pragma: no cover
737740
if split_by is not None:
738741
sns.set(style="whitegrid")
739742
g = sns.FacetGrid(
740-
data=perturbation_score, col=split_by, hue=split_by, palette=cols, height=5, sharey=False
743+
data=perturbation_score, col=split_by, hue=split_by, palette=palette, height=5, sharey=False
741744
)
742-
g.map(sns.kdeplot, "pvec", fill=True, common_norm=False)
743-
g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5)
745+
g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, palette=palette)
746+
g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5, palette=palette)
744747
g.set_axis_labels("Perturbation score", "Cell density")
745748
g.add_legend(title=split_by, fontsize=14, title_fontsize=16)
746749
g.despine(left=True)
@@ -749,10 +752,10 @@ def plot_perturbscore( # pragma: no cover
749752
else:
750753
sns.set(style="whitegrid")
751754
sns.kdeplot(
752-
data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=cols
755+
data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=palette
753756
)
754757
sns.scatterplot(
755-
data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=cols, s=10, alpha=0.5
758+
data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=palette, s=10, alpha=0.5
756759
)
757760
pl.xlabel("Perturbation score", fontsize=16)
758761
pl.ylabel("Cell density", fontsize=16)
@@ -762,7 +765,8 @@ def plot_perturbscore( # pragma: no cover
762765

763766
# If before_mixscape is False, split densities based on mixscape classifications
764767
else:
765-
cols = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
768+
if palette is None:
769+
palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
766770
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False)
767771
top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
768772
pl.close()
@@ -786,7 +790,7 @@ def plot_perturbscore( # pragma: no cover
786790
if split_by is not None:
787791
sns.set(style="whitegrid")
788792
g = sns.FacetGrid(
789-
data=perturbation_score, col=split_by, hue="mix", palette=cols, height=5, sharey=False
793+
data=perturbation_score, col=split_by, hue="mix", palette=palette, height=5, sharey=False
790794
)
791795
g.map(sns.kdeplot, "pvec", fill=True, common_norm=False, alpha=0.7)
792796
g.map(sns.scatterplot, "pvec", "y_jitter", s=10, alpha=0.5)
@@ -798,10 +802,16 @@ def plot_perturbscore( # pragma: no cover
798802
else:
799803
sns.set(style="whitegrid")
800804
sns.kdeplot(
801-
data=perturbation_score, x="pvec", hue="mix", fill=True, common_norm=False, palette=cols, alpha=0.7
805+
data=perturbation_score,
806+
x="pvec",
807+
hue="mix",
808+
fill=True,
809+
common_norm=False,
810+
palette=palette,
811+
alpha=0.7,
802812
)
803813
sns.scatterplot(
804-
data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=cols, s=10, alpha=0.5
814+
data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=palette, s=10, alpha=0.5
805815
)
806816
pl.xlabel("Perturbation score", fontsize=16)
807817
pl.ylabel("Cell density", fontsize=16)

0 commit comments

Comments
 (0)