Using the functional API brings you all possibilities.
When using the functional API, you need to keep track of inputs and outputs, instead of just defining layers.
You define a layer, then you call the layer with an input tensor to get the output tensor. Models and layers can be called exactly the same way.
For the merge layer, I prefer using other merge layers that are more intuitive, such as Add()
, Multiply()
and Concatenate()
for instance.
from keras.layers import *
mergedOut = Add()([model1.output,model2.output])
#Add() -> creates a merge layer that sums the inputs
#The second parentheses "calls" the layer with the output tensors of the two models
#it will demand that both model1 and model2 have the same output shape
This same idea apply to all the following layers. We keep updating the output tensor giving it to each layer and getting a new output (if we were interested in creating branches, we would use a different var for each output of interest to keep track of them):
mergedOut = Flatten()(mergedOut)
mergedOut = Dense(256, activation='relu')(mergedOut)
mergedOut = Dropout(.5)(mergedOut)
mergedOut = Dense(128, activation='relu')(mergedOut)
mergedOut = Dropout(.35)(mergedOut)
# output layer
mergedOut = Dense(5, activation='softmax')(mergedOut)
Now that we created the "path", it's time to create the Model
. Creating the model is just like telling at which input tensors it starts and where it ends:
from keras.models import Model
newModel = Model([model1.input,model2.input], mergedOut)
#use lists if you want more than one input or output
Notice that since this model has two inputs, you have to train it with two different X_training
vars in a list:
newModel.fit([X_train_1, X_train_2], Y_train, ....)
Now, suppose you wanted only one input, and both model1 and model2 would take the same input.
The functional API allows that quite easily by creating an input tensor and feeding it to the models (we call the models as if they were layers):
commonInput = Input(input_shape)
out1 = model1(commonInput)
out2 = model2(commonInput)
mergedOut = Add()([out1,out2])
In this case, the Model would consider this input:
oneInputModel = Model(commonInput,mergedOut)