It might also be worth to look into numba.jit
; without it, the vectorized version will likely beat a straight-forward pure-Python search in most scenarios, but after compiling the code, the ordinary search will take the lead, at least in my testing:
In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [70]: %paste
import numba
def naive(a):
for i in range(len(a)):
if np.isnan(a[i]):
return i
def short(a):
return np.isnan(a).argmax()
@numba.jit
def naive_jit(a):
for i in range(len(a)):
if np.isnan(a[i]):
return i
@numba.jit
def short_jit(a):
return np.isnan(a).argmax()
## -- End pasted text --
In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop
In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 μs per loop
In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 μs per loop
In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 μs per loop
Edit: As pointed out by @hpaulj in their answer, numpy
actually ships with an optimized short-circuited search whose performance is comparable with the JITted search above:
In [26]: %paste
def plain(a):
return a.argmax()
@numba.jit
def plain_jit(a):
return a.argmax()
## -- End pasted text --
In [35]: %timeit naive(a)
100 loops, best of 3: 7.13 ms per loop
In [36]: %timeit plain(a)
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.04 μs per loop
In [37]: %timeit naive_jit(a)
100000 loops, best of 3: 6.91 μs per loop
In [38]: %timeit plain_jit(a)
10000 loops, best of 3: 125 μs per loop
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…