Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
225 views
in Technique[技术] by (71.8m points)

How to select rows from a 3-D Tensor in TensorFlow?

I have a tensor logits with the dimensions [batch_size, num_rows, num_coordinates] (i.e. each logit in the batch is a matrix). In my case batch size is 2, there's 4 rows and 4 coordinates.

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

I want to select the first and second row of the first batch and the second and fourth row of the second batch.

indices = tf.constant([[0, 1], [1, 3]])

So the desired output would be

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0]],
                     [[15.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

How do I do this using TensorFlow? I tried using tf.gather(logits, indices) but it did not return what I expected. Thanks!

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

This is possible in TensorFlow, but slightly inconvenient, because tf.gather() currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather():

logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])

# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))

selected_rows = tf.gather(flattened_logits, flattened_indices)

result = tf.reshape(selected_rows,
                    tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                                  tf.shape(logits)[2:]]))

Note that, since this uses tf.reshape() and not tf.transpose(), it doesn't need to modify the (potentially large) data in the logits tensor, so it should be fairly efficient.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...