During my research on this I came across some information answering my questions.
Note: As updated in the question in newer tensorflow
/keras
-versions (tf
> 2) fit_generator()
is deprecated. Instead, it is recommended to use fit()
with the generator. However, the answer still applies to fit()
using a generator as well.
1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?
Taken from Keras' gitRepo (training_generators.py) I found in lines 46-52
the following:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the `keras.utils.Sequence'
' class.'))
The definition of is_sequence()
taken from training_utils.py in lines 624-635
is:
def is_sequence(seq):
"""Determine if an object follows the Sequence API.
# Arguments
seq: a possible Sequence object
# Returns
boolean, whether the object follows the Sequence API.
"""
# TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
return (getattr(seq, 'use_sequence_api', False)
or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))
Regarding this piece of code Keras only checks if a passed generator is a Keras-sequence (or rather uses Keras' sequence API) and does not check if a generator is threadsafe in general.
2. Is using the approach I choosed as threadsafe as using the generatorClass(Sequence)-version from the Keras-docs?
As Omer Zohar has shown on gitHub his decorator is threadsafe - I don't see any reason why it shouldn't be as threadsafe for Keras (even though Keras will warn as shown in 1.).
The implementation of thread.Lock()
can be concidered as threadsafe according to the docs:
A factory function that returns a new primitive lock object. Once a thread has acquired it, subsequent attempts to acquire it block, until it is released; any thread may release it.
The generator is also picklable, which can be tested like (see this SO-Q&A here for further information):
#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
for yielded_data in generator(data):
pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)
Resuming this, I would even suggest to implement thread.Lock()
when you extend Keras' Sequence()
like:
import threading
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.lock = threading.Lock() #Set self.lock
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
with self.lock: #Use self.lock
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return ...
Edit 24/04/2020:
By using self.lock = threading.Lock()
you might run into the following error:
TypeError: can't pickle _thread.lock objects
In case this happens try to replace with self.lock:
inside __getitem__
with with threading.Lock():
and comment out / delete the self.lock = threading.Lock()
inside the __init__
.
It seems there are some problems when storing the lock
-object inside a class (see for example this Q&A).
3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?
During my research I did not encounter any other method.
Of course I cannot say this with 100% certainty.