Skip to content

Commit 6408ee0

Browse files
author
Alexander Ororbia
committed
minor tweak to dim-reduce in utils
1 parent 73e5aa1 commit 6408ee0

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

ngclearn/utils/viz/dim_reduce.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import matplotlib
22
import matplotlib.pyplot as plt
3-
cmap = plt.cm.jet
3+
default_cmap = plt.cm.jet
44

55
import numpy as np
66
from sklearn.decomposition import IncrementalPCA
@@ -66,7 +66,8 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32): ## tSNE mapping
6666
z_2D = vectors
6767
return z_2D
6868

69-
def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.):
69+
def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.,
70+
cmap=None):
7071
"""
7172
Produces a label-overlaid (label map to distinct colors) scatterplot for
7273
visualizing two-dimensional latent codes (produced by either PCA or t-SNE).
@@ -80,7 +81,9 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.):
8081
8182
plot_fname: /path/to/plot_fname.<suffix> for saving the plot to disk
8283
83-
alpha:
84+
alpha: alpha intensity level to present colors in scatterplot
85+
86+
cmap: custom color-map to provide
8487
"""
8588
curr_backend = plt.rcParams["backend"]
8689
matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting
@@ -92,7 +95,11 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.):
9295
if lab.shape[1] > 1: ## extract integer class labels from a one-hot matrix
9396
lab = np.argmax(lab, 1)
9497
plt.figure(figsize=(8, 6))
95-
plt.scatter(code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=cmap, alpha=alpha)
98+
_cmap = cmap
99+
if _cmap is None:
100+
_cmap = default_cmap
101+
#print("> USING DEFAULT CMAP!")
102+
plt.scatter(code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha)
96103
colorbar = plt.colorbar()
97104
#colorbar.set_alpha(1)
98105
#plt.draw_all()

0 commit comments

Comments
 (0)