I used to create the RNN network, in version 0.8 of TensorFlow, using:
from tensorflow.python.ops import rnn
# Define a lstm cell with tensorflow
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Get lstm cell output
outputs, states = rnn.rnn(cell=lstm_cell, inputs=x, dtype=tf.float32)
rnn.rnn()
is not available anymore, and it sounds it has been moved to tf.contrib
. What is the exact code to create RNN network out of a BasicLSTMCell
?
Or, in the case that I have an stacked LSTM,
lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_size, forget_bias=0.0)
stacked_lstm = tf.contrib.rnn.MultiRNNCell([lstm_cell] * num_layers)
outputs, new_state = tf.nn.rnn(stacked_lstm, inputs, initial_state=_initial_state)
What is the replacement for tf.nn.rnn
in new versions of TensorFlow?
See Question&Answers more detail:
os 与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…