You are approaching the problem from a wrong direction.
First, download data using tfds.load
, cifar10
for example (for simplicity we will use default TRAIN
and TEST
splits):
import tensorflow_datasets as tfds
dataloader = tfds.load("cifar10", as_supervised=True)
train, test = dataloader["train"], dataloader["test"]
(you can use custom tfds.Split
objects to create validations datasets or other, see documentation)
train
and test
are tf.data.Dataset
objects so you can use map
, apply
, batch
and similar functions to each of those.
Below is an example, where I will (using tf.image
mostly):
- convert each image to
tf.float64
in the 0-1
range (don't use this stupid snippet from official docs, this way ensures correct image format)
cache()
results as those can be re-used after each repeat
- randomly flip
left_to_right
each image
- randomly change contrast of image
- shuffle data and batch
- IMPORTANT: repeat all the steps when dataset is exhausted. This means that after one epoch all of the above transformations are applied again (except for the ones which were cached).
Here is the code doing the above (you can change lambda
s to functors or functions):
train = train.map(
lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache().map(
lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
100
).batch(
64
).repeat()
Such tf.data.Dataset
can be passed directly to Keras's fit
, evaluate
and predict
methods.
Verifying it actually works like that
I see you are highly suspicious of my explanation, let's go through an example:
1. Get small subset of data
Here is one way to take a single element, admittedly unreadable and unintuitive, but you should be fine with it if you do anything with Tensorflow
:
# Horrible API is horrible
element = tfds.load(
# Take one percent of test and take 1 element from it
"cifar10",
as_supervised=True,
split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
).take(1)
2. Repeat data and check whether it is the same:
Using Tensorflow 2.0
one can actually do it without stupid workarounds (almost):
element = element.repeat(2)
# You can iterate through tf.data.Dataset now, finally...
images = [image[0] for image in element]
print(f"Are the same: {tf.reduce_all(tf.equal(images[0], images[1]))}")
And it unsurprisingly returns:
Are the same: True
3. Check whether data differs after each repeat with random augmentation
Below snippet repeat
s single element 5 times and checks which are equal and which are different.
element = (
tfds.load(
# Take one percent of test and take 1 element
"cifar10",
as_supervised=True,
split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
)
.take(1)
.map(lambda image, label: (tf.image.random_flip_left_right(image), label))
.repeat(5)
)
images = [image[0] for image in element]
for i in range(len(images)):
for j in range(i, len(images)):
print(
f"{i} same as {j}: {tf.reduce_all(tf.equal(images[i], images[j]))}"
)
Output (in mine case, each run would be different):
0 same as 0: True
0 same as 1: False
0 same as 2: True
0 same as 3: False
0 same as 4: False
1 same as 1: True
1 same as 2: False
1 same as 3: True
1 same as 4: True
2 same as 2: True
2 same as 3: False
2 same as 4: False
3 same as 3: True
3 same as 4: True
4 same as 4: True
You could cast each of those images to numpy
as well and see the images for yourself using skimage.io.imshow
, matplotlib.pyplot.imshow
or other alternatives.
Another example of visualization of real-time data augmentation
This answer provides a more comprehensive and readable view on data augmentation using Tensorboard
and MNIST
, might want to check that one out (yeah, shameless plug, but useful I guess).