There are two options:
First, heatmap
is an Axes level figure, so you could set up a main large main heatmap axes for the correlation matrix and flank it with heatmaps that you then pass class colors to yourself. This will be a little bit of work, but gives you lots of control over how everything works.
This is more or less an option in clustermap
though, so I'm going to demonstrate how to do it that way here. It's a bit of a hack, but it will work.
First, we'll load the sample data and do a bit of roundabout transformations to get colors for the class labels.
networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
network_labels = networks.columns.get_level_values("network")
network_pal = sns.cubehelix_palette(network_labels.unique().size,
light=.9, dark=.1, reverse=True,
start=1, rot=-2)
network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
network_colors = pd.Series(network_labels).map(network_lut)
Next we call clustermap
to make the main plot.
g = sns.clustermap(networks.corr(),
# Turn off the clustering
row_cluster=False, col_cluster=False,
# Add colored class labels
row_colors=network_colors, col_colors=network_colors,
# Make the plot look better when many rows/cols
linewidths=0, xticklabels=False, yticklabels=False)
The side colors are drawn with a heatmap, which matplotlib thinks of as quantitative data and thus there's not a straightforward way to get a legend directly from it. Instead of that, we'll add an invisible barplot with the right colors and labels, then add a legend for that.
for label in network_labels.unique():
g.ax_col_dendrogram.bar(0, 0, color=network_lut[label],
label=label, linewidth=0)
g.ax_col_dendrogram.legend(loc="center", ncol=6)
Finally, let's move the colorbar to take up the empty space where the row dendrogram would normally be and save the figure.
g.cax.set_position([.15, .2, .03, .45])
g.savefig("clustermap.png")