The idea here is to draw a line of some equation y=m*x+y0
into the plot. This can be achieved by transforming a horizontal line, originally given in axes coordinates, into data coordinates, applying the Affine2D transform according to the line equation and transforming back to screen coordinates.
The advantage here is that you do not need to know the axes limits at all. You may also freely zoom or pan your plot; the line will always stay within the axes boundaries. It hence effectively implements a line ranging from -infinity to + inifinty.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
def axaline(m,y0, ax=None, **kwargs):
if not ax:
ax = plt.gca()
tr = mtransforms.BboxTransformTo(
mtransforms.TransformedBbox(ax.viewLim, ax.transScale)) +
ax.transScale.inverted()
aff = mtransforms.Affine2D.from_values(1,m,0,0,0,y0)
trinv = ax.transData
line = plt.Line2D([0,1],[0,0],transform=tr+aff+trinv, **kwargs)
ax.add_line(line)
x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)
fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")
# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
axaline(m,y0, ax=ax, color="limegreen", linewidth=5)
plt.show()
While this solution looks a bit complicated on first sight, one does not need to fully understand it. Just copy the axaline
function to your code and use it as it is.
In order to get the automatic updating working without the transforms doing this, one may add callbacks which would reset the transform every time something changes in the plot.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import transforms
class axaline():
def __init__(self, m,y0, ax=None, **kwargs):
if not ax: ax = plt.gca()
self.ax = ax
self.aff = transforms.Affine2D.from_values(1,m,0,0,0,y0)
self.line = plt.Line2D([0,1],[0,0], **kwargs)
self.update()
self.ax.add_line(self.line)
self.ax.callbacks.connect('xlim_changed', self.update)
self.ax.callbacks.connect('ylim_changed', self.update)
def update(self, evt=None):
tr = ax.transAxes - ax.transData
trinv = ax.transData
self.line.set_transform(tr+self.aff+trinv)
x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)
fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")
# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
al = axaline(m,y0, ax=ax, color="limegreen", linewidth=5)
plt.show()