Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
250 views
in Technique[技术] by (71.8m points)

tensorflow - Scatter operation for middle dimension of a tensor

I have a 3d tensor where I need to preserve vectors at certain positions in the second dimension, and zero out the remaining vectors. The positions are specified as a 1d array. I'm thinking the best way to do this is to multiply the tensor with a binary mask.

Here's a simple Numpy version:

A.shape: (b, n, m) 
indices.shape: (b)

mask = np.zeros(A.shape)
for i in range(b):
  mask[i][indices[i]] = 1
result = A*mask

So for each nxm matrix in A, I need to preserve rows specified by indices, and zero out the rest.

I'm trying to do this in TensorFlow using tf.scatter_nd op, but I can't figure out the correct shape of indices:

shape = tf.constant([3,5,4])
A = tf.random_normal(shape)       
indices = tf.constant([2,1,4])   #???   
updates = tf.ones((3,4))           
mask = tf.scatter_nd(indices, updates, shape) 
result = A*mask
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Here's one way to do it, creating a mask and using tf.where:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

shape = tf.constant([3,5,4])
A = tf.random_normal(shape)

array_shape = tf.shape(A)
indices = tf.constant([2,1,4])
non_zero_indices = tf.stack((tf.range(array_shape[0]), indices), axis=1)
should_keep_row = tf.scatter_nd(non_zero_indices, tf.ones_like(indices),
                                shape=[array_shape[0], array_shape[1]])
print("should_keep_row", should_keep_row)
masked = tf.where(tf.cast(tf.tile(should_keep_row[:, :, None],
                                  [1, 1, array_shape[2]]), tf.bool),
                   A,
                   tf.zeros_like(A))
print("masked", masked)

Prints:

should_keep_row tf.Tensor(
[[0 0 1 0 0]
 [0 1 0 0 0]
 [0 0 0 0 1]], shape=(3, 5), dtype=int32)
masked tf.Tensor(
[[[ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.02036316 -0.07163608 -3.16707373  1.31406844]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]]

 [[ 0.          0.          0.          0.        ]
  [-0.76696759 -0.28313264  0.87965059 -1.28844094]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]]

 [[ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.        ]
  [ 1.03188455  0.44305769  0.71291149  1.59758031]]], shape=(3, 5, 4), dtype=float32)

(The example is using eager execution, but the same ops will work with graph execution in a Session)


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...