Yes, this function is hard to understand, until you get the point.
In its simplest form, it is similar to tf.gather
. It returns the elements of params
according to the indexes specified by ids
.
For example (assuming you are inside tf.InteractiveSession()
)
params = tf.constant([10,20,30,40])
ids = tf.constant([0,1,2,3])
print tf.nn.embedding_lookup(params,ids).eval()
would return [10 20 30 40]
, because the first element (index 0) of params is 10
, the second element of params (index 1) is 20
, etc.
Similarly,
params = tf.constant([10,20,30,40])
ids = tf.constant([1,1,3])
print tf.nn.embedding_lookup(params,ids).eval()
would return [20 20 40]
.
But embedding_lookup
is more than that. The params
argument can be a list of tensors, rather than a single tensor.
params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)
In such a case, the indexes, specified in ids
, correspond to elements of tensors according to a partition strategy, where the default partition strategy is 'mod'.
In the 'mod' strategy, index 0 corresponds to the first element of the first tensor in the list. Index 1 corresponds to the first element of the second tensor. Index 2 corresponds to the first element of the third tensor, and so on. Simply index i
corresponds to the first element of the (i+1)th tensor , for all the indexes 0..(n-1)
, assuming params is a list of n
tensors.
Now, index n
cannot correspond to tensor n+1, because the list params
contains only n
tensors. So index n
corresponds to the second element of the first tensor. Similarly, index n+1
corresponds to the second element of the second tensor, etc.
So, in the code
params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)
index 0 corresponds to the first element of the first tensor: 1
index 1 corresponds to the first element of the second tensor: 10
index 2 corresponds to the second element of the first tensor: 2
index 3 corresponds to the second element of the second tensor: 20
Thus, the result would be:
[ 2 1 2 10 2 20]