Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
215 views
in Technique[技术] by (71.8m points)

python - How to draw a line through a scatter graph with no overflow

So I am currently plotting a scatter graph with many x and ys in matplotlib:

plt.scatter(x, y)

I want to draw a line on this scatter graph that crosses through the whole graph (i.e hits two 'borders') I know the gradient and the intercept - m and the c in the equation y = mx +c.

I have thought about acquiring the 4 points of the plot (calculating the min and max scatter x and ys) and from that calculating the min and max coords for the line and then plotting but that seems very convoluted. Is there any better way to do this bearing in mind the line may not even be 'within' the 'plot'?


Example of scatter graph: enter image description here

as identified visually in the plot the four bordering coordinates are ruffly:

  • bottom left: -1,-2
  • top left: -1,2
  • bottom right: 6,-2
  • top right 6,2

I now have a line that I need to plot that must not exceed these boundaries but if it enters the plot must touch two of the boundary points.

So I could check what y equals when x = -1 and then check if that value is between -1 and 6 and if it is the line must cross the left border, so plot it, and so on and so fourth.


Ideally though I would create a line from -infinity to infinity and then crop it to fit the plot.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

The idea here is to draw a line of some equation y=m*x+y0 into the plot. This can be achieved by transforming a horizontal line, originally given in axes coordinates, into data coordinates, applying the Affine2D transform according to the line equation and transforming back to screen coordinates.

The advantage here is that you do not need to know the axes limits at all. You may also freely zoom or pan your plot; the line will always stay within the axes boundaries. It hence effectively implements a line ranging from -infinity to + inifinty.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms

def axaline(m,y0, ax=None, **kwargs):
    if not ax:
        ax = plt.gca()
    tr = mtransforms.BboxTransformTo(
            mtransforms.TransformedBbox(ax.viewLim, ax.transScale))  + 
         ax.transScale.inverted()
    aff = mtransforms.Affine2D.from_values(1,m,0,0,0,y0)
    trinv = ax.transData
    line = plt.Line2D([0,1],[0,0],transform=tr+aff+trinv, **kwargs)
    ax.add_line(line)

x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)

fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")

# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
axaline(m,y0, ax=ax, color="limegreen", linewidth=5)

plt.show()

enter image description here

While this solution looks a bit complicated on first sight, one does not need to fully understand it. Just copy the axaline function to your code and use it as it is.


In order to get the automatic updating working without the transforms doing this, one may add callbacks which would reset the transform every time something changes in the plot.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import transforms

class axaline():
    def __init__(self, m,y0, ax=None, **kwargs):
        if not ax: ax = plt.gca()
        self.ax = ax
        self.aff = transforms.Affine2D.from_values(1,m,0,0,0,y0)
        self.line = plt.Line2D([0,1],[0,0], **kwargs)
        self.update()
        self.ax.add_line(self.line)
        self.ax.callbacks.connect('xlim_changed', self.update)
        self.ax.callbacks.connect('ylim_changed', self.update)

    def update(self, evt=None):
        tr = ax.transAxes - ax.transData
        trinv = ax.transData
        self.line.set_transform(tr+self.aff+trinv)

x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)

fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")

# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
al = axaline(m,y0, ax=ax, color="limegreen", linewidth=5)

plt.show()

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...