Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
336 views
in Technique[技术] by (71.8m points)

python - Filter rows of a numpy array?

I am looking to apply a function to each row of a numpy array. If this function evaluates to true I will keep the row, otherwise I will discard it. For example, my function might be:

def f(row):
    if sum(row)>10: return True
    else: return False

I was wondering if there was something similar to:

np.apply_over_axes()

which applies a function to each row of a numpy array and returns the result. I was hoping for something like:

np.filter_over_axes()

which would apply a function to each row of an numpy array and only return rows for which the function returned true. Is there anything like this? Or should I just use a for loop?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Ideally, you would be able to implement a vectorized version of your function and use that to do boolean indexing. For the vast majority of problems this is the right solution. Numpy provides quite a few functions that can act over various axes as well as all the basic operations and comparisons, so most useful conditions should be vectorizable.

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

If you are absolutely sure that you can't do the above, I would suggest using a list comprehension (or np.apply_along_axis) to create an array of bools to index with.

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

This will get the job done in a relatively clean way, but will be significantly slower than a vectorized version. An example:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...