The key observation here is that the TensorFlow sampled softmax function returns actual losses, not a set of predictions over the set of possible labels to compare with the ground truth data to then compute losses as a separate step. This makes the model setup a little bit weird.
First, we add a second input layer to the model that encodes the target (training) data a second time as an input, in addition to being the target output. This is used for the labels
argument of the sampled_softmax_loss
function. It needs to be a Keras input, because it's treated as an input when we go to instantiate and set up the model.
Second, we construct a new custom Keras layer that calls the sampled_softmax_loss
function with two Keras layers as its inputs: the output of the dense layer that predicts our classes, and then the second input that contains a copy of the training data. Note that we're doing some serious hackery accessing the _keras_history
instance variable to fetch the weight and bias tensors from the output tensor of the original fully-connected layer.
Finally, we have to construct a new "dumb" loss function that ignores the training data and just uses the loss reported by the sampled_softmax_loss
function.
Note that because the sampled softmax function returns losses, not class predictions, you can't use this model specification for validation or inference. You'll need to re-use the trained layers from this "training version" in a new specification that applies a standard softmax function to the original dense layer which has the default activation function applied.
There is definitely a more elegant way to do this, but I believe this works, so I figured I'd post it here now as-is rather than wait until I have something that's a little bit neater. For example, you'd probably want to make the number of classes an argument of the SampledSoftmax
layer, or better yet, condense this all into the loss function as in the original question and avoid passing in the training data twice.
from keras.models import Model
from keras.layers import Input, Dense, Layer
from keras import backend as K
class SampledSoftmax(Layer):
def __init__(self, **kwargs):
super(SampledSoftmax, self).__init__(**kwargs)
def call(self, inputs):
"""
The first input should be the model as it were, and the second the
target (i.e., a repeat of the training data) to compute the labels
argument
"""
# the labels input to this function is batch size by 1, where the
# value at position (i, 1) is the index that is true (not zero)
# e.g., (0, 0, 1) => (2) or (0, 1, 0, 0) => (1)
return K.tf.nn.sampled_softmax_loss(weights=inputs[0]._keras_history[0].weights[0],
biases=inputs[0]._keras_history[0].bias,
inputs=inputs[0],
labels=K.tf.reshape(K.tf.argmax(inputs[1], 1), [-1, 1]),
num_sampled=1000,
num_classes=200000)
def custom_loss(y_true, y_pred):
return K.tf.reduce_mean(y_pred)
num_classes = 200000
input = Input(shape=(300,))
target_input = Input(shape=(num_classes,))
dense = Dense(num_classes)
outputs = dense(input)
outputs = SampledSoftmax()([outputs, target_input])
model = Model([input, target_input], outputs)
model.compile(optimizer=u'adam', loss=custom_loss)
# train as desired