Reshaping of the strided array is a little bit costly, for the reasons you've mentioned (copy on non-contiguous array), but not as costly as you think. np.einsum
can actually be a bottleneck in your application, depending on tensor sizes. As mentioned in Convolutional layer in Python using Numpy, np.tensordot
can be a good candidate to replace np.einsum
.
Just to give you a quick example:
x = np.arange(64*221*221*3).reshape((64, 221, 221, 3))
f = np.arange(4*4*3*5).reshape((4, 4, 3, 5))
s = (2, 2)
B, H, W, C = x.shape # e.g. 64, 16, 16, 3
Fh, Fw, C, _ = f.shape # e.g. 4, 4, 3, 3
Sh, Sw = s # e.g. 2, 2
strided_shape = B, 1 + (H - Fh) // Sh, 1 + (W - Fw) // Sw, Fh, Fw, C
print(strided_shape)
# (64, 109, 109, 4, 4, 3)
after initializing the variables, we can test timings of the code parts
%timeit x_strided = as_strided(x, strided_shape, strides=(x.strides[0], Sh * x.strides[1], Sw * x.strides[2], x.strides[1], x.strides[2], x.strides[3]), )
>>> 7.11 μs ± 118 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit f_reshaped = f.reshape(-1, f.shape[-1])
>>> 450 ns ± 7.43 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit x_reshaped = x_strided.reshape(*x_strided.shape[:3], -1) # Bottleneck!
>>> 94.6 ms ± 896 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# einsum without reshape
%timeit np.einsum('wxy...,...d->wxyd', x_strided, f, optimize='optimal')
>>> 809 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# einsum with reshape
%%timeit
f_reshaped = f.reshape(-1, f.shape[-1])
x_reshaped = x_strided.reshape(*x_strided.shape[:3], -1) # Bottleneck!
k = np.einsum('wxyz,zd->wxyd', x_reshaped, f_reshaped, optimize='optimal')
>>> 549 ms ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# tensordot without reshape
%timeit k = np.tensordot(x_strided, f, axes=3)
>>> 271 ms ± 4.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# tensordot with reshape
%%timeit
f_reshaped = f.reshape(-1, f.shape[-1])
x_reshaped = x_strided.reshape(*x_strided.shape[:3], -1) # Bottleneck!
k = np.tensordot(x_reshaped, f_reshaped, axes=(3, 0))
>>> 266 ms ± 3.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I got similar results with the tensor sizes in your code (i.e. 64, 16, 16, 3 and 4, 4, 3, 3).
As you can see, there is an overhead with resize operation, but it makes matrix operations faster because of continuous data. Please, be aware that results would vary depending on cpu speed, cpu architecture/generation etc.