To apply a different augmentation to each element of the input
tensor (e.g. conditioned on label_input
), you will need to:
- First, compute each possible augmentation for each element of the batch.
- Second, select the desired augmentations according to the label.
Indexing is unfortunately impossible because both the input
and label_input
tensors are multi-dimensional (e.g. if you were to apply the same augmentation to each element of the batch, it would then be possible to use any conditional tensorflow statement such as tf.case).
Here is a minimal working example showing how you can achieve this:
input = tf.ones((3, 1)) # Shape=(bs, 1)
label_input = tf.constant([3, 2, 1]) # Shape=(bs, 1)
aug_models = [lambda x: x, lambda x: x * 2, lambda x: x * 3, lambda x: x * 4]
nb_classes = len(aug_models)
augmented_data = tf.stack([aug_model(input) for aug_model in aug_models]) # Shape=(nb_classes, bs, 1)
selector = tf.transpose(tf.one_hot(label_input, depth=nb_classes)) # Shape=(nb_classes, bs)
augmentation = tf.reduce_sum(selector[..., None] * augmented_data, axis=0) # Shape=(bs, 1)
print(augmentation)
# prints:
# tf.Tensor(
# [[4.]
# [3.]
# [2.]], shape=(3, 1), dtype=float32)
NOTE: You might need to wrap these operations into a Keras Lambda layer.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…