It's because of the batch normalization layers.
In training phase, the batch is normalized w.r.t. its mean and variance. However, in testing phase, the batch is normalized w.r.t. the moving average of previously observed mean and variance.
Now this is a problem when the number of observed batches is small (e.g., 5 in your example) because in the BatchNormalization
layer, by default moving_mean
is initialized to be 0 and moving_variance
is initialized to be 1.
Given also that the default momentum
is 0.99, you'll need to update the moving averages quite a lot of times before they converge to the "real" mean and variance.
That's why the prediction is wrong in the early stage, but is correct after 1000 epochs.
You can verify it by forcing the BatchNormalization
layers to operate in "training mode".
During training, the accuracy is 1 and the loss is close to zero:
model.fit(imgs,y,epochs=5,shuffle=True)
Epoch 1/5
3/3 [==============================] - 19s 6s/step - loss: 1.4624 - acc: 0.3333
Epoch 2/5
3/3 [==============================] - 0s 63ms/step - loss: 0.6051 - acc: 0.6667
Epoch 3/5
3/3 [==============================] - 0s 57ms/step - loss: 0.2168 - acc: 1.0000
Epoch 4/5
3/3 [==============================] - 0s 56ms/step - loss: 1.1921e-07 - acc: 1.0000
Epoch 5/5
3/3 [==============================] - 0s 53ms/step - loss: 1.1921e-07 - acc: 1.0000
Now if we evaluate the model, we'll observe high loss and low accuracy because after 5 updates, the moving averages are still pretty close to the initial values:
model.evaluate(imgs,y)
3/3 [==============================] - 3s 890ms/step
[10.745396614074707, 0.3333333432674408]
However, if we manually specify the "learning phase" variable and let the BatchNormalization
layers use the "real" batch mean and variance, the result becomes the same as what's observed in fit()
.
sample_weights = np.ones(3)
learning_phase = 1 # 1 means "training"
ins = [imgs, y, sample_weights, learning_phase]
model.test_function(ins)
[1.192093e-07, 1.0]
It's also possible to verify it by changing the momentum to a smaller value.
For example, by adding momentum=0.01
to all the batch norm layers in ResNet50
, the prediction after 20 epochs is:
model.predict(imgs)
array([[ 1.00000000e+00, 1.34882026e-08, 3.92139575e-22],
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
[ 8.70998792e-06, 5.31159838e-10, 9.99991298e-01]], dtype=float32)