Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
929 views
in Technique[技术] by (71.8m points)

tensorflow2.0 - Tensorflow: binary mask of max values along tensor axis

If I have a N-dimensional tensor, I would like to create another tensor (with the same shape) of values 0 and 1, where 1 is in the same position as the maximum element in original tensor across some dimension.
One constraint I have is that I want to get only the first maximum element along that axis, in case there are duplicates.

For simplification, I will use fewer dimensions.

>>> x = tf.constant([[7, 2, 3], 
                     [5, 0, 1], 
                     [3, 8, 2]], dtype=tf.float32)

>>> tf.reduce_max(x, axis=-1)
tf.Tensor([7. 5. 8.], shape=(3,), dtype=float32)

What I want is:

tf.Tensor([1. 0. 0.], 
          [1. 0. 0.],
          [0. 1. 0.], shape=(3,3), dtype=float32)
          

What I've tried (and realized was wrong):

>>> tf.cast(tf.equal(x, tf.reduce_max(x, axis=-1, keepdims=True)), dtype=tf.float32)

# works fine when there are no duplicates
tf.Tensor([[1. 0. 0.]
           [1. 0. 0.]
           [0. 1. 0.]], shape=(3, 3), dtype=float32)


>>> y = tf.zeros([3,3])
>>> tf.cast(tf.equal(y, tf.reduce_max(y, axis=-1, keepdims=True)), dtype=tf.float32)

# fails when there are multiple identical values across dimension
tf.Tensor([[1. 1. 1.]
           [1. 1. 1.]
           [1. 1. 1.]], shape=(3, 3), dtype=float32)

Edit: Solved

tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1, direction='DESCENDING'), 1), 0), tf.float32)
question from:https://stackoverflow.com/questions/66048913/tensorflow-binary-mask-of-max-values-along-tensor-axis

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You could use double tf.argsort() to get rank order of elements along axis 1 and get max rank. This gives the last instance of the max value as the top rank. Let's take an example with duplicate elements -

x = tf.constant([[7, 2, 3],  #max is 7
                 [5, 0, 5],  #max is 5 but duplicate in same row
                 [7, 8, 7]]) #max is 8 but shares 7 with first row too

tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1), 1), x.shape[0]-1), tf.int64)
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 0, 0],
       [0, 0, 1],
       [0, 1, 0]], dtype=int32)>

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...