I adapted the solution of the answer by ImportanceOfBeingErnest to "Create a matplotlib mpatches with a rectangle bi-colored for figure legend" to this case. As linked there, the instructions in the section on Implementing a custom legend handler in the matplotlib legend guide were particularly helpful.
Result:
Solution:
I created the class HandlerColormap
derived from matplotlib's base class for legend handlers HandlerBase
. HandlerColormap
takes a colormap and a number of stripes as arguments.
For the argument cmap
a matplotlib colormap instance should be given.
The argument num_stripes
determines how (non-)continuous the color gradient in the legend patch will be.
As instructed in HandlerBase
I override its create_artist
method using the given dimension such that the code should be (automatically) scaled by fontsize. In this new create_artist
method I create multiple stripes (slim matplotlib Rectangles
) colored according to the input colormap.
Code:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.legend_handler import HandlerBase
class HandlerColormap(HandlerBase):
def __init__(self, cmap, num_stripes=8, **kw):
HandlerBase.__init__(self, **kw)
self.cmap = cmap
self.num_stripes = num_stripes
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width, height, fontsize, trans):
stripes = []
for i in range(self.num_stripes):
s = Rectangle([xdescent + i * width / self.num_stripes, ydescent],
width / self.num_stripes,
height,
fc=self.cmap((2 * i + 1) / (2 * self.num_stripes)),
transform=trans)
stripes.append(s)
return stripes
x_array = np.linspace(1, 10, 10)
y_array = x_array
param_max = x_array.size
cmaps = [plt.cm.spring, plt.cm.winter] # set of colormaps
# (as many as there are groups of lines)
plt.figure()
for param, (x, y) in enumerate(zip(x_array, y_array)):
x_line1 = np.linspace(x, 1.5 * x, 10)
y_line1 = np.linspace(y**2, y**2 - x, 10)
x_line2 = np.linspace(1.2 * x, 1.5 * x, 10)
y_line2 = np.linspace(2 * y, 2 * y - x, 10)
# plot lines with color depending on param using different colormaps:
plt.plot(x_line1, y_line1, c=cmaps[0](param / param_max))
plt.plot(x_line2, y_line2, c=cmaps[1](param / param_max))
cmap_labels = ["parameter 1 $in$ [0, 10]", "parameter 2 $in$ [-1, 1]"]
# create proxy artists as handles:
cmap_handles = [Rectangle((0, 0), 1, 1) for _ in cmaps]
handler_map = dict(zip(cmap_handles,
[HandlerColormap(cm, num_stripes=8) for cm in cmaps]))
plt.legend(handles=cmap_handles,
labels=cmap_labels,
handler_map=handler_map,
fontsize=12)
plt.show()