You can actually do that now with tf.gather_nd
. Let's say you have a matrix m
like the following:
| 1 2 3 4 |
| 5 6 7 8 |
And you want to build a matrix r
of size, let's say, 3x2, built from elements of m
, like this:
| 3 6 |
| 2 7 |
| 5 3 |
| 1 1 |
Each element of r
corresponds to a row and column of m
, and you can have matrices rows
and cols
with these indices (zero-based, since we are programming, not doing math!):
| 0 1 | | 2 1 |
rows = | 0 1 | cols = | 1 2 |
| 1 0 | | 0 2 |
| 0 0 | | 0 0 |
Which you can stack into a 3-dimensional tensor like this:
| | 0 2 | | 1 1 | |
| | 0 1 | | 1 2 | |
| | 1 0 | | 2 0 | |
| | 0 0 | | 0 0 | |
This way, you can get from m
to r
through rows
and cols
as follows:
import numpy as np
import tensorflow as tf
m = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
rows = np.array([[0, 1], [0, 1], [1, 0], [0, 0]])
cols = np.array([[2, 1], [1, 2], [0, 2], [0, 0]])
x = tf.placeholder('float32', (None, None))
idx1 = tf.placeholder('int32', (None, None))
idx2 = tf.placeholder('int32', (None, None))
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
with tf.Session() as sess:
r = sess.run(result, feed_dict={
x: m,
idx1: rows,
idx2: cols,
})
print(r)
Output:
[[ 3. 6.]
[ 2. 7.]
[ 5. 3.]
[ 1. 1.]]