I stumbled across this question while dealing with a similar issue. I came up with a solution based on using a Python generator, together with the TF dataset construction method from_generator
. Because we use a generator, the HDF5 file should be opened for reading only once and kept open as long as there are entries to read. So it will not be opened, read, and then closed for every single call to get the next data element.
Generator definition
To allow the user to pass in the HDF5 filename as an argument, I generated a class that has a __call__
method since from_generator
specifies that the generator has to be callable. This is the generator:
import h5py
import tensorflow as tf
class generator:
def __init__(self, file):
self.file = file
def __call__(self):
with h5py.File(self.file, 'r') as hf:
for im in hf["train_img"]:
yield im
By using a generator, the code should pick up from where it left off at each call from the last time it returned a result, instead of running everything from the beginning again. In this case it is on the next iteration of the inner for
loop. So this should skip opening the file again for reading, keeping it open as long as there is data to yield
. For more on generators, see this excellent Q&A.
Of course, you will have to replace anything inside the with
block to match how your dataset is constructed and what outputs you want to obtain.
Usage example
ds = tf.data.Dataset.from_generator(
generator(hdf5_path),
tf.uint8,
tf.TensorShape([427,561,3]))
value = ds.make_one_shot_iterator().get_next()
# Example on how to read elements
while True:
try:
data = sess.run(value)
print(data.shape)
except tf.errors.OutOfRangeError:
print('done.')
break
Again, in my case I had stored uint8
images of height 427
, width 561
, and 3
color channels in my dataset, so you will need to modify these in the above call to match your use case.
Handling multiple files
I have a proposed solution for handling multiple HDF5 files. The basic idea is to construct a Dataset
from the filenames as usual, and then use the interleave
method to process many input files concurrently, getting samples from each of them to form a batch, for example.
The idea is as follows:
ds = tf.data.Dataset.from_tensor_slices(filenames)
# You might want to shuffle() the filenames here depending on the application
ds = ds.interleave(lambda filename: tf.data.Dataset.from_generator(
generator(filename),
tf.uint8,
tf.TensorShape([427,561,3])),
cycle_length, block_length)
What this does is open cycle_length
files concurrently, and produce block_length
items from each before moving to the next file - see interleave
documentation for details. You can set the values here to match what is appropriate for your application: e.g., do you need to process one file at a time or several concurrently, do you only want to have a single sample at a time from each file, and so on.
Edit: for a parallel version, take a look at tf.contrib.data.parallel_interleave
!
Possible caveats
Be aware of the peculiarities of using from_generator
if you decide to go with the solution. For Tensorflow 1.6.0, the documentation of from_generator
mentions these two notes.
It may be challenging to apply this across different environments or with distributed training:
NOTE: The current implementation of Dataset.from_generator() uses
tf.py_func and inherits the same constraints. In particular, it
requires the Dataset- and Iterator-related operations to be placed on
a device in the same process as the Python program that called
Dataset.from_generator(). The body of generator will not be serialized
in a GraphDef, and you should not use this method if you need to
serialize your model and restore it in a different environment.
Be careful if the generator depends on external state:
NOTE: If generator depends on mutable global variables or other
external state, be aware that the runtime may invoke generator
multiple times (in order to support repeating the Dataset) and at any
time between the call to Dataset.from_generator() and the production
of the first element from the generator. Mutating global variables or
external state can cause undefined behavior, and we recommend that you
explicitly cache any external state in generator before calling
Dataset.from_generator().