Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
442 views
in Technique[技术] by (71.8m points)

python - NumPy Fancy Indexing - Crop different ROIs from different channels

Let assume, we have following numpy 4D array (3x1x4x4):

import numpy as np
n, c, h, w = 3, 1, 4, 4

data = np.arange(n * c * h * w).reshape(n, c, h, w)

>>> array([[[[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11],
             [12, 13, 14, 15]]],

           [[[16, 17, 18, 19],
             [20, 21, 22, 23],
             [24, 25, 26, 27],
             [28, 29, 30, 31]]],

           [[[32, 33, 34, 35],
             [36, 37, 38, 39],
             [40, 41, 42, 43],
             [44, 45, 46, 47]]]])

Now I want to crop each of the n subarray on a different location with a same size:

size = 2
locations = np.array([
    [0, 1],
    [1, 1],
    [0, 2]
])

The slow way to do this is the following:

crops = np.stack([d[:, y:y+size, x:x+size] 
     for d, (y,x) in zip(data, locations)])

>>> array([[[[ 1,  2],
             [ 5,  6]]],

           [[[21, 22],
             [25, 26]]],

           [[[34, 35],
             [38, 39]]]])

Now im looking for a way to achieve this with the numpy's fancy indexing. I have already spent hours to figure out, how to solve this problem. Am I overlooking a simple way to solve this? Are there some numpy indexing experts out there, who can help me?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

We can extend this solution to your 3D case, by leveraging np.lib.stride_tricks.as_strided based sliding-windowed views for efficient patch extraction, like so -

from skimage.util.shape import view_as_windows

def get_patches(data, locations, size):
    # Get 2D sliding windows for each element off data
    w = view_as_windows(data, (1,1,size,size))
    
    # Use fancy/advanced indexing to select the required ones
    return w[np.arange(len(locations)), :, locations[:,0], locations[:,1]][:,:,0,0]

We need those 1,1 as the window parameters to view_as_windows, because it expects the window to have the same number of elements as the number of dims of the input data. We are sliding along the last two axes of data, hence keeping the first two as 1s, basically doing no sliding along the first two axes of data.

Sample runs for one-channel and more than channel data -

In [78]: n, c, h, w = 3, 1, 4, 4 # number of channels = 1
    ...: data = np.arange(n * c * h * w).reshape(n, c, h, w)
    ...: 
    ...: size = 2
    ...: locations = np.array([
    ...:     [0, 1],
    ...:     [1, 1],
    ...:     [0, 2]
    ...: ])
    ...: 
    ...: crops = np.stack([d[:, y:y+size, x:x+size] 
    ...:      for d, (y,x) in zip(data, locations)])

In [79]: print np.allclose(get_patches(data, locations, size), crops)
True

In [80]: n, c, h, w = 3, 5, 4, 4 # number of channels = 5
    ...: data = np.arange(n * c * h * w).reshape(n, c, h, w)
    ...: 
    ...: size = 2
    ...: locations = np.array([
    ...:     [0, 1],
    ...:     [1, 1],
    ...:     [0, 2]
    ...: ])
    ...: 
    ...: crops = np.stack([d[:, y:y+size, x:x+size] 
    ...:      for d, (y,x) in zip(data, locations)])

In [81]: print np.allclose(get_patches(data, locations, size), crops)
True

Benchmarking

Other approaches -

# Original soln
def stack(data, locations, size):
    crops = np.stack([d[:, y:y+size, x:x+size] 
         for d, (y,x) in zip(data, locations)])    
    return crops

# scholi's soln
def allocate_assign(data, locations, size):
    n, c, h, w = data.shape
    crops = np.zeros((n,c,size,size))
    for i, (y,x) in enumerate(locations):
        crops[i,0,:,:] = data[i,0,y:y+size,x:x+size]
    return crops

From the comments, it seems OP is interested in a case with a data of shape (512,1,60,60) and with size as 12,24,48. So, let's setup the data for the same with a function -

# Setup data
def create_inputs(size):
    np.random.seed(0)
    n, c, h, w = 512, 1, 60, 60
    data = np.arange(n * c * h * w).reshape(n, c, h, w)
    locations = np.random.randint(0,3,(n,2))
    return data, locations, size

Timings -

In [186]: data, locations, size = create_inputs(size=12)

In [187]: %timeit stack(data, locations, size)
     ...: %timeit allocate_assign(data, locations, size)
     ...: %timeit get_patches(data, locations, size)
1000 loops, best of 3: 1.26 ms per loop
1000 loops, best of 3: 1.06 ms per loop
10000 loops, best of 3: 124 μs per loop

In [188]: data, locations, size = create_inputs(size=24)

In [189]: %timeit stack(data, locations, size)
     ...: %timeit allocate_assign(data, locations, size)
     ...: %timeit get_patches(data, locations, size)
1000 loops, best of 3: 1.66 ms per loop
1000 loops, best of 3: 1.55 ms per loop
1000 loops, best of 3: 470 μs per loop

In [190]: data, locations, size = create_inputs(size=48)

In [191]: %timeit stack(data, locations, size)
     ...: %timeit allocate_assign(data, locations, size)
     ...: %timeit get_patches(data, locations, size)
100 loops, best of 3: 2.8 ms per loop
100 loops, best of 3: 3.33 ms per loop
1000 loops, best of 3: 1.45 ms per loop

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...