Is there a simply way to specify bar colors by column name using Pandas DataFrame.plot(kind='bar')
method?
I have a script that generates multiple DataFrames from several different data files in a directory. For example it does something like this:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pds
data_files = ['a', 'b', 'c', 'd']
df1 = pds.DataFrame(np.random.rand(4,3), columns=data_files[:-1])
df2 = pds.DataFrame(np.random.rand(4,3), columns=data_files[1:])
df1.plot(kind='bar', ax=plt.subplot(121))
df2.plot(kind='bar', ax=plt.subplot(122))
plt.show()
With the following output:
Unfortunately, the column colors aren't consistent for each label in the different plots. Is it possible to pass in a dictionary of (filenames:colors), so that any particular column always has the same color. For example, I could imagine creating this by zipping up the filenames with the Matplotlib color_cycle:
data_files = ['a', 'b', 'c', 'd']
colors = plt.rcParams['axes.color_cycle']
print zip(data_files, colors)
[('a', u'b'), ('b', u'g'), ('c', u'r'), ('d', u'c')]
I could figure out how to do this directly with Matplotlib: I just thought there might be a simpler, built-in solution.
Edit:
Below is a partial solution that works in pure Matplotlib. However, I'm using this in an IPython notebook that will be distributed to non-programmer colleagues, and I'd like to minimize the amount of excessive plotting code.
import numpy as np
import matplotlib.pyplot as plt
import pandas as pds
data_files = ['a', 'b', 'c', 'd']
mpl_colors = plt.rcParams['axes.color_cycle']
colors = dict(zip(data_files, mpl_colors))
def bar_plotter(df, colors, sub):
ncols = df.shape[1]
width = 1./(ncols+2.)
starts = df.index.values - width*ncols/2.
plt.subplot(120+sub)
for n, col in enumerate(df):
plt.bar(starts + width*n, df[col].values, color=colors[col],
width=width, label=col)
plt.xticks(df.index.values)
plt.grid()
plt.legend()
df1 = pds.DataFrame(np.random.rand(4,3), columns=data_files[:-1])
df2 = pds.DataFrame(np.random.rand(4,3), columns=data_files[1:])
bar_plotter(df1, colors, 1)
bar_plotter(df2, colors, 2)
plt.show()
See Question&Answers more detail:
os