similar to this question I was running an asynchronous reinforcement learning algorithm and need to run model prediction in multiple threads to get training data more quickly. My code is based on DDPG-keras on GitHub, whose Neural Network was build on top of Keras & Tensorflow. Pieces of my code are shown below:
Asynchronous Thread creation and join:
for roundNo in xrange(self.param['max_round']):
AgentPool = [AgentThread(self.getEnv(), self.actor, self.critic, eps, self.param['n_step'], self.param['gamma'])]
for agent in AgentPool:
agent.start()
for agent in AgentPool:
agent.join()
Agent Thread Code
"""Agent Thread for collecting data"""
def __init__(self, env_, actor_, critic_, eps_, n_step_, gamma_):
super(AgentThread, self).__init__()
self.env = env_ # type: Environment
self.actor = actor_ # type: ActorNetwork
# TODO: use Q(s,a)
self.critic = critic_ # type: CriticNetwork
self.eps = eps_ # type: float
self.n_step = n_step_ # type: int
self.gamma = gamma_
self.data = {}
def run(self):
"""run behavior policy self.actor to collect experience data in self.data"""
state = self.env.get_state()
action = self.actor.model.predict(state[np.newaxis, :])[0]
action = np.maximum(np.random.normal(action, self.eps, action.shape), np.ones_like(action) * 1e-3)
While running these codes, I encountered a Tensorflow Exception:
Using TensorFlow backend.
create_actor_network
Exception in thread Thread-1:
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/Users/niyan/code/routerRL/A3C.py", line 26, in run
action = self.actor.model.predict(state[np.newaxis, :])[0]
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/keras/engine/training.py", line 1269, in predict
self._make_predict_function()
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/keras/engine/training.py", line 798, in _make_predict_function
**kwargs)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 1961, in function
return Function(inputs, outputs, updates=updates)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 1919, in __init__
with tf.control_dependencies(self.outputs):
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3583, in control_dependencies
return get_default_graph().control_dependencies(control_inputs)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3314, in control_dependencies
c = self.as_graph_element(c)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2405, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2484, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("concat:0", shape=(?, 4), dtype=float32) is not an element of this graph.
So how can I use a trained Keras model (using Tensorflow as backend) to concurrently predict in multiple threads?
Update on April 2nd:
I tried coping model over weight, but didn't work:
for roundNo in xrange(self.param['max_round']):
for agent in self.AgentPool:
agent.syncModel(self.getEnv(), self.actor, self.critic, eps)
agent.start()
for agent in self.AgentPool:
agent.join()
def syncModel(self, env_, actor_, critic_, eps_):
"""synchronize A-C models before collecting data"""
# TODO copy env, actor, critic
self.env = env_ # shallow copy
self.actor.model.set_weights(actor_.model.get_weights()) # deep copy, by weights
self.critic.model.set_weights(critic_.model.get_weights()) # deep copy, by weights
self.eps = eps_ # shallow copy
self.data = {}
EDIT:
see this jaara/AI-blog on Github, seems
model._make_predict_function() # have to initialize before threading
works.
The author explained a little on this issue. For further discussion, see this issue on Keras
See Question&Answers more detail:
os