What you are trying to do is frequently done with tf.scatter_nd_update
. However, that is most times not the right way to do it, you should not need a variable, just another tensor produced from the original tensor with some replaced values. Unfortunately, there is no straightforward way to do this in general. If your original tensor is really all zeros, then you can simply use tf.scatter_nd
:
import tensorflow as tf
idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.stack([row_idx, idx], axis=1)
a = tf.scatter_nd(indices, row_idx, (5, 2))
with tf.Session() as sess:
print(sess.run(a))
# [[0 0]
# [0 1]
# [0 2]
# [3 0]
# [0 4]]
However, if the initial tensor is not all zeros, it is more complicated. One way to do that is do the same as above, then make a mask for the updated, and select between the original and the update according to the mask:
import tensorflow as tf
a = tf.ones((5, 2), dtype=tf.int32)
idx = tf.constant([0, 1, 1, 0, 1])
row_idx = tf.range(5)
indices = tf.stack([row_idx, idx], axis=1)
a_update = tf.scatter_nd(indices, row_idx, (5, 2))
update_mask = tf.scatter_nd(indices, tf.ones_like(row_idx, dtype=tf.bool), (5, 2))
a = tf.where(update_mask, a_update, a)
with tf.Session() as sess:
print(sess.run(a))
# [[0 1]
# [1 1]
# [1 2]
# [3 1]
# [1 4]]
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…