Here is my understanding of swapaxes
Suppose you have an array
In [1]: arr = np.arange(16).reshape((2, 2, 4))
In [2]: arr
Out[2]:
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]])
And the shape of arr
is (2, 2, 4)
, for the value 7
, you can get the value by
In [3]: arr[0, 1, 3]
Out[3]: 7
There are 3 axes 0, 1 and 2, now, we swap axis 0 and 2
In [4]: arr_swap = arr.swapaxes(0, 2)
In [5]: arr_swap
Out[5]:
array([[[ 0, 8],
[ 4, 12]],
[[ 1, 9],
[ 5, 13]],
[[ 2, 10],
[ 6, 14]],
[[ 3, 11],
[ 7, 15]]])
And as you can guess, the index of 7
is (3, 1, 0)
, with axis 1
unchanged,
In [6]: arr_swap[3, 1, 0]
Out[6]: 7
So, now from the perspective of the index, swapping axis is just change the index of values. For example
In [7]: arr[0, 0, 1]
Out[7]: 1
In [8]: arr_swap[1, 0, 0]
Out[8]: 1
In [9]: arr[0, 1, 2]
Out[9]: 6
In [9]: arr_swap[2, 1, 0]
Out[9]: 6
So, if you feel difficult to get the swapped-axis array, just change the index, say arr_swap[2, 1, 0] = arr[0, 1, 2]
.