We can extend this solution
to your 3D
case, by leveraging np.lib.stride_tricks.as_strided
based sliding-windowed views
for efficient patch extraction, like so -
from skimage.util.shape import view_as_windows
def get_patches(data, locations, size):
# Get 2D sliding windows for each element off data
w = view_as_windows(data, (1,1,size,size))
# Use fancy/advanced indexing to select the required ones
return w[np.arange(len(locations)), :, locations[:,0], locations[:,1]][:,:,0,0]
We need those 1,1
as the window parameters to view_as_windows
, because it expects the window to have the same number of elements as the number of dims of the input data. We are sliding along the last two axes of data
, hence keeping the first two as 1s
, basically doing no sliding along the first two axes of data
.
Sample runs for one-channel and more than channel data -
In [78]: n, c, h, w = 3, 1, 4, 4 # number of channels = 1
...: data = np.arange(n * c * h * w).reshape(n, c, h, w)
...:
...: size = 2
...: locations = np.array([
...: [0, 1],
...: [1, 1],
...: [0, 2]
...: ])
...:
...: crops = np.stack([d[:, y:y+size, x:x+size]
...: for d, (y,x) in zip(data, locations)])
In [79]: print np.allclose(get_patches(data, locations, size), crops)
True
In [80]: n, c, h, w = 3, 5, 4, 4 # number of channels = 5
...: data = np.arange(n * c * h * w).reshape(n, c, h, w)
...:
...: size = 2
...: locations = np.array([
...: [0, 1],
...: [1, 1],
...: [0, 2]
...: ])
...:
...: crops = np.stack([d[:, y:y+size, x:x+size]
...: for d, (y,x) in zip(data, locations)])
In [81]: print np.allclose(get_patches(data, locations, size), crops)
True
Benchmarking
Other approaches -
# Original soln
def stack(data, locations, size):
crops = np.stack([d[:, y:y+size, x:x+size]
for d, (y,x) in zip(data, locations)])
return crops
# scholi's soln
def allocate_assign(data, locations, size):
n, c, h, w = data.shape
crops = np.zeros((n,c,size,size))
for i, (y,x) in enumerate(locations):
crops[i,0,:,:] = data[i,0,y:y+size,x:x+size]
return crops
From the comments, it seems OP is interested in a case with a data of shape (512,1,60,60)
and with size
as 12,24,48
. So, let's setup the data for the same with a function -
# Setup data
def create_inputs(size):
np.random.seed(0)
n, c, h, w = 512, 1, 60, 60
data = np.arange(n * c * h * w).reshape(n, c, h, w)
locations = np.random.randint(0,3,(n,2))
return data, locations, size
Timings -
In [186]: data, locations, size = create_inputs(size=12)
In [187]: %timeit stack(data, locations, size)
...: %timeit allocate_assign(data, locations, size)
...: %timeit get_patches(data, locations, size)
1000 loops, best of 3: 1.26 ms per loop
1000 loops, best of 3: 1.06 ms per loop
10000 loops, best of 3: 124 μs per loop
In [188]: data, locations, size = create_inputs(size=24)
In [189]: %timeit stack(data, locations, size)
...: %timeit allocate_assign(data, locations, size)
...: %timeit get_patches(data, locations, size)
1000 loops, best of 3: 1.66 ms per loop
1000 loops, best of 3: 1.55 ms per loop
1000 loops, best of 3: 470 μs per loop
In [190]: data, locations, size = create_inputs(size=48)
In [191]: %timeit stack(data, locations, size)
...: %timeit allocate_assign(data, locations, size)
...: %timeit get_patches(data, locations, size)
100 loops, best of 3: 2.8 ms per loop
100 loops, best of 3: 3.33 ms per loop
1000 loops, best of 3: 1.45 ms per loop