I have a tensor logits
with the dimensions [batch_size, num_rows, num_coordinates]
(i.e. each logit in the batch is a matrix). In my case batch size is 2, there's 4 rows and 4 coordinates.
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
I want to select the first and second row of the first batch and the second and fourth row of the second batch.
indices = tf.constant([[0, 1], [1, 3]])
So the desired output would be
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0]],
[[15.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
How do I do this using TensorFlow? I tried using tf.gather(logits, indices)
but it did not return what I expected. Thanks!
See Question&Answers more detail:
os