Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
187 views
in Technique[技术] by (71.8m points)

python - Passing non-tensor parameters to a Keras model during training / using tensors for indexing

I'm trying to train a Keras model that incorporates data augmentation in the model itself. The input to the model are images of different classes, and the model is supposed to generate an augmentation model for each class which should be used for the augmentation process. My code roughly looks like this:

from keras.models import Model
from keras.layers import Input
...further imports...

def get_main_model(input_shape, n_classes):
    encoder_model = get_encoder_model()
    input = Input(input_shape, name="input")
    label_input = Input((1,), name="label_input")
    aug_models = [get_augmentation_model() for i in range(n_classes)]
    augmentation = aug_models[label_input](input)
    x = encoder_model(input)
    y = encoder_model(augmentation)
    model = Model(inputs=[input, label_input], outputs=[x, y])
    model.add_loss(custom_loss_function(x, y))
    return model 

I would then like to pass batches of data through the model which consist of an array of images (passed to input) and a corresponding array of labels (passed to label_input). However, this doesn't work since whatever is input into label_input is converted to a tensor by Tensorflow and can't be used for indexing in the following. What I've tried is the following:

  • augmentation = aug_models[int(label_input)](input) --> doesn't work because label_input is a tensor
  • augmentation = aug_models[tf.make_ndarray(label_input)](input) --> casting doesn't work (I guess because label_input is a symbolic tensor)
  • tf.gather(aug_models, label_input) --> doesn't work because the result of the operation is a Keras model instance that Tensorflow tries to cast into a tensor (which obviously fails)

Is there any kind of trick in Tensorflow that would enable me to pass a parameter to the model during training that is not converted to a tensor or a different way in which I could tell the model which augmentation model to select? Thanks in advance!

question from:https://stackoverflow.com/questions/65829671/passing-non-tensor-parameters-to-a-keras-model-during-training-using-tensors-f

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

To apply a different augmentation to each element of the input tensor (e.g. conditioned on label_input), you will need to:

  1. First, compute each possible augmentation for each element of the batch.
  2. Second, select the desired augmentations according to the label.

Indexing is unfortunately impossible because both the input and label_input tensors are multi-dimensional (e.g. if you were to apply the same augmentation to each element of the batch, it would then be possible to use any conditional tensorflow statement such as tf.case).


Here is a minimal working example showing how you can achieve this:

input = tf.ones((3, 1))  # Shape=(bs, 1)
label_input = tf.constant([3, 2, 1])  # Shape=(bs, 1)
aug_models = [lambda x: x, lambda x: x * 2, lambda x: x * 3, lambda x: x * 4]
nb_classes = len(aug_models)

augmented_data = tf.stack([aug_model(input) for aug_model in aug_models])  # Shape=(nb_classes, bs, 1)
selector = tf.transpose(tf.one_hot(label_input, depth=nb_classes))  # Shape=(nb_classes, bs)
augmentation = tf.reduce_sum(selector[..., None] * augmented_data, axis=0)  # Shape=(bs, 1) 
print(augmentation)

# prints:
# tf.Tensor(
# [[4.]
#  [3.]
#  [2.]], shape=(3, 1), dtype=float32)

NOTE: You might need to wrap these operations into a Keras Lambda layer.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...