Skip to content

Commit e8942ed

Browse files
committed
BUG: fix handling of color argument for variety of plotting functions
parallel_coordinates - fix reordering of class column (from set) causing possible color/class mismatch - deprecated use of argument colors in favor of color radviz - fix reordering of class column (from set) causing possible color/class mismatch - added explicit color keyword argument (avoids multiple values 'color' being passed to plotting method) andrews_curves - added explicit color keyword argument (avoids multiple values 'color' being passed to plotting method)
1 parent 49aece0 commit e8942ed

File tree

3 files changed

+114
-71
lines changed

3 files changed

+114
-71
lines changed

doc/source/release.rst

+4
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,10 @@ Deprecations
229229
returned if possible, otherwise a copy will be made. Previously the user could think that ``copy=False`` would
230230
ALWAYS return a view. (:issue:`6894`)
231231

232+
- The :func:`parallel_coordinates` function now takes argument ``color``
233+
instead of ``colors``. A ``FutureWarning`` is raised to alert that
234+
the old ``colors`` argument will not be supported in a future release
235+
232236
Prior Version Deprecations/Changes
233237
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
234238

pandas/tests/test_graphics.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -1220,11 +1220,26 @@ def scat2(x, y, by=None, ax=None, figsize=None):
12201220
def test_andrews_curves(self):
12211221
from pandas import read_csv
12221222
from pandas.tools.plotting import andrews_curves
1223-
1223+
from matplotlib import cm
1224+
12241225
path = os.path.join(curpath(), 'data', 'iris.csv')
12251226
df = read_csv(path)
12261227

12271228
_check_plot_works(andrews_curves, df, 'Name')
1229+
_check_plot_works(andrews_curves, df, 'Name',
1230+
color=('#556270', '#4ECDC4', '#C7F464'))
1231+
_check_plot_works(andrews_curves, df, 'Name',
1232+
color=['dodgerblue', 'aquamarine', 'seagreen'])
1233+
_check_plot_works(andrews_curves, df, 'Name', colormap=cm.jet)
1234+
1235+
colors = ['b', 'g', 'r']
1236+
df = DataFrame({"A": [1, 2, 3],
1237+
"B": [1, 2, 3],
1238+
"C": [1, 2, 3],
1239+
"Name": colors})
1240+
ax = andrews_curves(df, 'Name', color=colors)
1241+
legend_colors = [l.get_color() for l in ax.legend().get_lines()]
1242+
self.assertEqual(colors, legend_colors)
12281243

12291244
@slow
12301245
def test_parallel_coordinates(self):
@@ -1235,20 +1250,25 @@ def test_parallel_coordinates(self):
12351250
df = read_csv(path)
12361251
_check_plot_works(parallel_coordinates, df, 'Name')
12371252
_check_plot_works(parallel_coordinates, df, 'Name',
1238-
colors=('#556270', '#4ECDC4', '#C7F464'))
1239-
_check_plot_works(parallel_coordinates, df, 'Name',
1240-
colors=['dodgerblue', 'aquamarine', 'seagreen'])
1253+
color=('#556270', '#4ECDC4', '#C7F464'))
12411254
_check_plot_works(parallel_coordinates, df, 'Name',
1242-
colors=('#556270', '#4ECDC4', '#C7F464'))
1243-
_check_plot_works(parallel_coordinates, df, 'Name',
1244-
colors=['dodgerblue', 'aquamarine', 'seagreen'])
1255+
color=['dodgerblue', 'aquamarine', 'seagreen'])
12451256
_check_plot_works(parallel_coordinates, df, 'Name', colormap=cm.jet)
12461257

12471258
df = read_csv(path, header=None, skiprows=1, names=[1, 2, 4, 8,
12481259
'Name'])
12491260
_check_plot_works(parallel_coordinates, df, 'Name', use_columns=True)
12501261
_check_plot_works(parallel_coordinates, df, 'Name',
12511262
xticks=[1, 5, 25, 125])
1263+
1264+
colors = ['b', 'g', 'r']
1265+
df = DataFrame({"A": [1, 2, 3],
1266+
"B": [1, 2, 3],
1267+
"C": [1, 2, 3],
1268+
"Name": colors})
1269+
ax = parallel_coordinates(df, 'Name', color=colors)
1270+
legend_colors = [l.get_color() for l in ax.legend().get_lines()]
1271+
self.assertEqual(colors, legend_colors)
12521272

12531273
@slow
12541274
def test_radviz(self):
@@ -1259,8 +1279,24 @@ def test_radviz(self):
12591279
path = os.path.join(curpath(), 'data', 'iris.csv')
12601280
df = read_csv(path)
12611281
_check_plot_works(radviz, df, 'Name')
1282+
_check_plot_works(radviz, df, 'Name',
1283+
color=('#556270', '#4ECDC4', '#C7F464'))
1284+
_check_plot_works(radviz, df, 'Name',
1285+
color=['dodgerblue', 'aquamarine', 'seagreen'])
12621286
_check_plot_works(radviz, df, 'Name', colormap=cm.jet)
12631287

1288+
colors = [[0., 0., 1., 1.],
1289+
[0., 0.5, 1., 1.],
1290+
[1., 0., 0., 1.]]
1291+
df = DataFrame({"A": [1, 2, 3],
1292+
"B": [2, 1, 3],
1293+
"C": [3, 2, 1],
1294+
"Name": ['b', 'g', 'r']})
1295+
ax = radviz(df, 'Name', color=colors)
1296+
legend_colors = [c.get_facecolor().squeeze().tolist()
1297+
for c in ax.collections]
1298+
self.assertEqual(colors, legend_colors)
1299+
12641300
@slow
12651301
def test_plot_int_columns(self):
12661302
df = DataFrame(randn(100, 4)).cumsum()

pandas/tools/plotting.py

+67-64
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from pandas.util.decorators import cache_readonly
11+
from pandas.util.decorators import cache_readonly, deprecate_kwarg
1212
import pandas.core.common as com
1313
from pandas.core.index import MultiIndex
1414
from pandas.core.series import Series, remove_na
@@ -355,18 +355,22 @@ def _get_marker_compat(marker):
355355
return marker
356356

357357

358-
def radviz(frame, class_column, ax=None, colormap=None, **kwds):
358+
def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
359359
"""RadViz - a multivariate data visualization algorithm
360360
361361
Parameters:
362362
-----------
363-
frame: DataFrame object
364-
class_column: Column name that contains information about class membership
363+
frame: DataFrame
364+
class_column: str
365+
Column name containing class names
365366
ax: Matplotlib axis object, optional
367+
color: list or tuple, optional
368+
Colors to use for the different classes
366369
colormap : str or matplotlib colormap object, default None
367370
Colormap to select colors from. If string, load colormap with that name
368371
from matplotlib.
369-
kwds: Matplotlib scatter method keyword arguments, optional
372+
kwds: keywords
373+
Options to pass to matplotlib scatter plotting method
370374
371375
Returns:
372376
--------
@@ -380,44 +384,42 @@ def normalize(series):
380384
b = max(series)
381385
return (series - a) / (b - a)
382386

383-
column_names = [column_name for column_name in frame.columns
384-
if column_name != class_column]
385-
386-
df = frame[column_names].apply(normalize)
387+
n = len(frame)
388+
classes = frame[class_column].drop_duplicates()
389+
class_col = frame[class_column]
390+
df = frame.drop(class_column, axis=1).apply(normalize)
387391

388392
if ax is None:
389393
ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
390394

391-
classes = set(frame[class_column])
392395
to_plot = {}
393-
394396
colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
395-
color_type='random', color=kwds.get('color'))
397+
color_type='random', color=color)
396398

397-
for class_ in classes:
398-
to_plot[class_] = [[], []]
399+
for kls in classes:
400+
to_plot[kls] = [[], []]
399401

400402
n = len(frame.columns) - 1
401403
s = np.array([(np.cos(t), np.sin(t))
402404
for t in [2.0 * np.pi * (i / float(n))
403405
for i in range(n)]])
404406

405-
for i in range(len(frame)):
406-
row = df.irow(i).values
407+
for i in range(n):
408+
row = df.iloc[i].values
407409
row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
408410
y = (s * row_).sum(axis=0) / row.sum()
409-
class_name = frame[class_column].iget(i)
410-
to_plot[class_name][0].append(y[0])
411-
to_plot[class_name][1].append(y[1])
411+
kls = class_col.iat[i]
412+
to_plot[kls][0].append(y[0])
413+
to_plot[kls][1].append(y[1])
412414

413-
for i, class_ in enumerate(classes):
414-
ax.scatter(to_plot[class_][0], to_plot[class_][1], color=colors[i],
415-
label=com.pprint_thing(class_), **kwds)
415+
for i, kls in enumerate(classes):
416+
ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
417+
label=com.pprint_thing(kls), **kwds)
416418
ax.legend()
417419

418420
ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
419421

420-
for xy, name in zip(s, column_names):
422+
for xy, name in zip(s, df.columns):
421423

422424
ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
423425

@@ -438,20 +440,23 @@ def normalize(series):
438440
return ax
439441

440442

441-
def andrews_curves(data, class_column, ax=None, samples=200, colormap=None,
442-
**kwds):
443+
def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
444+
colormap=None, **kwds):
443445
"""
444446
Parameters:
445447
-----------
446-
data : DataFrame
448+
frame : DataFrame
447449
Data to be plotted, preferably normalized to (0.0, 1.0)
448450
class_column : Name of the column containing class names
449451
ax : matplotlib axes object, default None
450452
samples : Number of points to plot in each curve
453+
color: list or tuple, optional
454+
Colors to use for the different classes
451455
colormap : str or matplotlib colormap object, default None
452456
Colormap to select colors from. If string, load colormap with that name
453457
from matplotlib.
454-
kwds : Optional plotting arguments to be passed to matplotlib
458+
kwds: keywords
459+
Options to pass to matplotlib plotting method
455460
456461
Returns:
457462
--------
@@ -475,30 +480,31 @@ def f(x):
475480
return result
476481
return f
477482

478-
n = len(data)
479-
class_col = data[class_column]
480-
uniq_class = class_col.drop_duplicates()
481-
columns = [data[col] for col in data.columns if (col != class_column)]
483+
n = len(frame)
484+
class_col = frame[class_column]
485+
classes = frame[class_column].drop_duplicates()
486+
df = frame.drop(class_column, axis=1)
482487
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
483488
used_legends = set([])
484489

485-
colors = _get_standard_colors(num_colors=len(uniq_class), colormap=colormap,
486-
color_type='random', color=kwds.get('color'))
487-
col_dict = dict([(klass, col) for klass, col in zip(uniq_class, colors)])
490+
color_values = _get_standard_colors(num_colors=len(classes),
491+
colormap=colormap, color_type='random',
492+
color=color)
493+
colors = dict(zip(classes, color_values))
488494
if ax is None:
489495
ax = plt.gca(xlim=(-pi, pi))
490496
for i in range(n):
491-
row = [columns[c][i] for c in range(len(columns))]
497+
row = df.iloc[i].values
492498
f = function(row)
493499
y = [f(t) for t in x]
494-
label = None
495-
if com.pprint_thing(class_col[i]) not in used_legends:
496-
label = com.pprint_thing(class_col[i])
500+
kls = class_col.iat[i]
501+
label = com.pprint_thing(kls)
502+
if label not in used_legends:
497503
used_legends.add(label)
498-
ax.plot(x, y, color=col_dict[class_col[i]], label=label, **kwds)
504+
ax.plot(x, y, color=colors[kls], label=label, **kwds)
499505
else:
500-
ax.plot(x, y, color=col_dict[class_col[i]], **kwds)
501-
506+
ax.plot(x, y, color=colors[kls], **kwds)
507+
502508
ax.legend(loc='upper right')
503509
ax.grid()
504510
return ax
@@ -564,31 +570,31 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
564570
plt.setp(axis.get_yticklabels(), fontsize=8)
565571
return fig
566572

567-
568-
def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
569-
use_columns=False, xticks=None, colormap=None, **kwds):
573+
@deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
574+
def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
575+
use_columns=False, xticks=None, colormap=None,
576+
**kwds):
570577
"""Parallel coordinates plotting.
571578
572579
Parameters
573580
----------
574-
data: DataFrame
575-
A DataFrame containing data to be plotted
581+
frame: DataFrame
576582
class_column: str
577583
Column name containing class names
578584
cols: list, optional
579585
A list of column names to use
580586
ax: matplotlib.axis, optional
581587
matplotlib axis object
582-
colors: list or tuple, optional
588+
color: list or tuple, optional
583589
Colors to use for the different classes
584590
use_columns: bool, optional
585591
If true, columns will be used as xticks
586592
xticks: list or tuple, optional
587593
A list of values to use for xticks
588594
colormap: str or matplotlib colormap, default None
589595
Colormap to use for line colors.
590-
kwds: list, optional
591-
A list of keywords for matplotlib plot method
596+
kwds: keywords
597+
Options to pass to matplotlib plotting method
592598
593599
Returns
594600
-------
@@ -600,20 +606,19 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
600606
>>> from pandas.tools.plotting import parallel_coordinates
601607
>>> from matplotlib import pyplot as plt
602608
>>> df = read_csv('https://raw.github.com/pydata/pandas/master/pandas/tests/data/iris.csv')
603-
>>> parallel_coordinates(df, 'Name', colors=('#556270', '#4ECDC4', '#C7F464'))
609+
>>> parallel_coordinates(df, 'Name', color=('#556270', '#4ECDC4', '#C7F464'))
604610
>>> plt.show()
605611
"""
606612
import matplotlib.pyplot as plt
607613

608-
609-
n = len(data)
610-
classes = set(data[class_column])
611-
class_col = data[class_column]
614+
n = len(frame)
615+
classes = frame[class_column].drop_duplicates()
616+
class_col = frame[class_column]
612617

613618
if cols is None:
614-
df = data.drop(class_column, axis=1)
619+
df = frame.drop(class_column, axis=1)
615620
else:
616-
df = data[cols]
621+
df = frame[cols]
617622

618623
used_legends = set([])
619624

@@ -638,19 +643,17 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
638643

639644
color_values = _get_standard_colors(num_colors=len(classes),
640645
colormap=colormap, color_type='random',
641-
color=colors)
646+
color=color)
642647

643648
colors = dict(zip(classes, color_values))
644649

645650
for i in range(n):
646-
row = df.irow(i).values
647-
y = row
648-
kls = class_col.iget_value(i)
649-
if com.pprint_thing(kls) not in used_legends:
650-
label = com.pprint_thing(kls)
651+
y = df.iloc[i].values
652+
kls = class_col.iat[i]
653+
label = com.pprint_thing(kls)
654+
if label not in used_legends:
651655
used_legends.add(label)
652-
ax.plot(x, y, color=colors[kls],
653-
label=label, **kwds)
656+
ax.plot(x, y, color=colors[kls], label=label, **kwds)
654657
else:
655658
ax.plot(x, y, color=colors[kls], **kwds)
656659

0 commit comments

Comments
 (0)