Here's one way to do it, creating a mask and using tf.where
:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
shape = tf.constant([3,5,4])
A = tf.random_normal(shape)
array_shape = tf.shape(A)
indices = tf.constant([2,1,4])
non_zero_indices = tf.stack((tf.range(array_shape[0]), indices), axis=1)
should_keep_row = tf.scatter_nd(non_zero_indices, tf.ones_like(indices),
shape=[array_shape[0], array_shape[1]])
print("should_keep_row", should_keep_row)
masked = tf.where(tf.cast(tf.tile(should_keep_row[:, :, None],
[1, 1, array_shape[2]]), tf.bool),
A,
tf.zeros_like(A))
print("masked", masked)
Prints:
should_keep_row tf.Tensor(
[[0 0 1 0 0]
[0 1 0 0 0]
[0 0 0 0 1]], shape=(3, 5), dtype=int32)
masked tf.Tensor(
[[[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0.02036316 -0.07163608 -3.16707373 1.31406844]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]
[[ 0. 0. 0. 0. ]
[-0.76696759 -0.28313264 0.87965059 -1.28844094]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]
[[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 1.03188455 0.44305769 0.71291149 1.59758031]]], shape=(3, 5, 4), dtype=float32)
(The example is using eager execution, but the same ops will work with graph execution in a Session)
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…