Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
123 views
in Technique[技术] by (71.8m points)

python - Issue implementing q-rnn in tf-agents

I have been trying to build a rl agent using tf-agents in tensorflow. I experienced the issue in a custom built environment but reproduced it using an official tf colab example. The problem occurs whenever I try to use QRnnNetwork as the network for the DqnAgent. The agent works fine with a regular qnetwork but there is a reshaping of the policy_state_spec when using qrnn. How would I remedy this?

This is the shape the policy_state_spec gets converted to, but the original shape is ()

ListWrapper([TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_0'), TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_1')])

q_net = q_rnn_network.QRnnNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    lstm_size=(16,),
    )
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)
agent.initialize()

collect_policy = agent.collect_policy
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()


collect_policy.action(time_step)

I get this error:

TypeError: policy_state and policy_state_spec structures do not match:
  ()
vs.
  ListWrapper([., .])
question from:https://stackoverflow.com/questions/65949867/issue-implementing-q-rnn-in-tf-agents

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
Waitting for answers

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...