Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
382 views
in Technique[技术] by (71.8m points)

python - Predicting next word using the language model tensorflow example

The tensorflow tutorial on language model allows to compute the probability of sentences :

probabilities = tf.nn.softmax(logits)

in the comments below it also specifies a way of predicting the next word instead of probabilities but does not specify how this can be done. So how to output a word instead of probability using this example?

lstm = rnn_cell.BasicLSTMCell(lstm_size)
# Initial state of the LSTM memory.
state = tf.zeros([batch_size, lstm.state_size])

loss = 0.0
for current_batch_of_words in words_in_dataset:
    # The value of state is updated after processing each batch of words.
    output, state = lstm(current_batch_of_words, state)

    # The LSTM output can be used to make next word predictions
    logits = tf.matmul(output, softmax_w) + softmax_b
    probabilities = tf.nn.softmax(logits)
    loss += loss_function(probabilities, target_words)
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You need to find the argmax of the probabilities, and translate the index back to a word by reversing the word_to_id map. To get this to work, you must save the probabilities in the model and then fetch them from the run_epoch function (you could also save just the argmax itself). Here's a snippet:

inverseDictionary = dict(zip(word_to_id.values(), word_to_id.keys()))

def run_epoch(...):
  decodedWordId = int(np.argmax(logits))
  print (" ".join([inverseDictionary[int(x1)] for x1 in np.nditer(x)])  
    + " got" + inverseDictionary[decodedWordId] + 
    + " expected:" + inverseDictionary[int(y)])

See full implementation here: https://github.com/nelken/tf


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...