All right, the solution I found is following:
1) Create in your Iterator a method to retrieve weights' matrix (with shape = mask shape). The output must contain [image, mask, weights]
2) Create a Lambda layer containing loss function
3) Create an Identity loss function
def weighted_binary_loss(X):
import keras.backend as K
import keras.layers.merge as merge
y_pred, weights, y_true = X
loss = K.binary_crossentropy(y_pred, y_true)
loss = merge([loss, weights], mode='mul')
return loss
def identity_loss(y_true, y_pred):
return y_pred
def get_unet_w_lambda_loss(input_shape=(1024, 1024, 3), mask_shape=(1024, 1024, 1)):
images = Input(input_shape)
mask_weights = Input(mask_shape)
true_masks = Input(mask_shape)
y_pred = Conv2D(1, (1, 1), activation='sigmoid')(up1) #output of original unet
loss = Lambda(weighted_binary_loss, output_shape=(1024, 1024, 1))([y_pred, mask_weights, true_masks])
model = Model(inputs=[images, mask_weights, true_masks], outputs=loss)