When you subclass a tf.keras.layers.Layer
, the model will track all tf.Variable
inside as trainable variables. What you then need to do is create a tf.Variable
with the shape of the convolutional filter, and these will adjust to the task (i.e. learn) during training. The filters need this input shape:
(filter_height, filter_width, in_channels, out_channels)
This tf.keras.layers.Layer
object will behave exactly like a Keras convolutional layer in a CNN:
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, filters, kernel_size, padding, strides, activation,
kernel_initializer, bias_initializer, use_bias):
super(CustomLayer, self).__init__()
self.filters = filters
self.kernel_size = kernel_size
self.activation = activation
self.padding = padding
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.strides = strides
self.use_bias = use_bias
self.w = None
self.b = None
def build(self, input_shape):
*_, n_channels = input_shape
self.w = tf.Variable(
initial_value=self.kernel_initializer(shape=(*self.kernel_size,
n_channels,
self.filters),
dtype='float32'), trainable=True)
if self.use_bias:
self.b = tf.Variable(
initial_value=self.bias_initializer(shape=(self.filters,), dtype='float32'),
trainable=True)
def call(self, inputs, training=None):
x = tf.nn.conv2d(inputs, filters=self.w, strides=self.strides, padding=self.padding)
if self.use_bias:
x = x + self.b
x = self.activation(x)
return x
You can see that the weights are the filters of the tf.nn.conv2d
operation, which are tf.Variable
, and so they are weights that will be updated by model training.
If you run this entire script, you will see that it performs the exact same task as the Keras convolutional layers.
import tensorflow as tf
tf.random.set_seed(42)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
rescale = lambda x, y: (tf.divide(tf.expand_dims(x, axis=-1), 255), y)
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.map(rescale).
shuffle(128, reshuffle_each_iteration=False, seed=11).
batch(8).
prefetch(AUTOTUNE)
test_ds = test_ds.map(rescale).
shuffle(128, reshuffle_each_iteration=False, seed=11).
batch(8).
prefetch(AUTOTUNE)
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, filters, kernel_size, padding, strides, activation,
kernel_initializer, bias_initializer, use_bias):
super(CustomLayer, self).__init__()
self.filters = filters
self.kernel_size = kernel_size
self.activation = activation
self.padding = padding
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.strides = strides
self.use_bias = use_bias
self.w = None
self.b = None
def build(self, input_shape):
*_, n_channels = input_shape
self.w = tf.Variable(
initial_value=self.kernel_initializer(shape=(*self.kernel_size,
n_channels,
self.filters),
dtype='float32'), trainable=True)
if self.use_bias:
self.b = tf.Variable(
initial_value=self.bias_initializer(shape=(self.filters,),
dtype='float32'),
trainable=True)
def call(self, inputs, training=None):
x = tf.nn.conv2d(inputs, filters=self.w, strides=self.strides,
padding=self.padding)
if self.use_bias:
x = x + self.b
x = self.activation(x)
return x
class ModelWithCustomConvLayer(tf.keras.Model):
def __init__(self, conv_layer):
super(ModelWithCustomConvLayer, self).__init__()
self.conv1 = conv_layer(filters=16,
kernel_size=(3, 3),
strides=(1, 1),
activation=tf.nn.relu,
padding='VALID',
kernel_initializer=tf.initializers.GlorotUniform(seed=42),
bias_initializer=tf.initializers.Zeros(),
use_bias=True)
self.maxp = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
self.conv2 = conv_layer(filters=32,
kernel_size=(3, 3),
strides=(1, 1),
activation=tf.nn.relu,
padding='VALID',
kernel_initializer=tf.initializers.GlorotUniform(seed=42),
bias_initializer=tf.initializers.Zeros(),
use_bias=True)
self.flat = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(32, activation='relu',
kernel_initializer=tf.initializers.GlorotUniform(seed=42))
self.dense2 = tf.keras.layers.Dense(10, activation='softmax',
kernel_initializer=tf.initializers.GlorotUniform(seed=42))
def call(self, inputs, training=None, mask=None):
x = self.conv1(inputs)
x = self.maxp(x)
x = self.conv2(x)
x = self.maxp(x)
x = self.flat(x)
x = self.dense1(x)
x = self.dense2(x)
return x
custom = ModelWithCustomConvLayer(CustomLayer)
custom.compile(loss=tf.losses.SparseCategoricalCrossentropy(), optimizer='adam',
metrics=tf.metrics.SparseCategoricalAccuracy())
custom.build(input_shape=next(iter(train_ds))[0].shape)
custom.summary()
normal = ModelWithCustomConvLayer(tf.keras.layers.Conv2D)
normal.compile(loss=tf.losses.SparseCategoricalCrossentropy(), optimizer='adam',
metrics=tf.metrics.SparseCategoricalAccuracy())
normal.build(input_shape=next(iter(train_ds))[0].shape)
normal.summary()
history_custom = custom.fit(train_ds, validation_data=test_ds, epochs=25,
steps_per_epoch=10, verbose=0)
history_normal = normal.fit(train_ds, validation_data=test_ds, epochs=25,
steps_per_epoch=10, verbose=0)
import matplotlib.pyplot as plt
plt.plot(history_custom.history['loss'], color='red', alpha=.5, lw=4)
plt.plot(history_custom.history['sparse_categorical_accuracy'],
color='blue', alpha=.5, lw=4)
plt.plot(history_custom.history['val_loss'], color='green', alpha=.5, lw=4)
plt.plot(history_custom.history['val_sparse_categorical_accuracy'],
color='orange', alpha=.5, lw=4)
plt.plot(history_normal.history['loss'], ls=':', color='red')
plt.plot(history_normal.history['sparse_categorical_accuracy'], ls=':',
color='blue')
plt.plot(history_normal.history['val_loss'], ls=':', color='green')
plt.plot(history_normal.history['val_sparse_categorical_accuracy'], ls=':',
color='orange')
plt.legend(list(map(lambda x: 'custom_' + x, list(history_custom.history.keys()))) +
list(map(lambda x: 'keras_' + x, list(history_normal.history.keys()))))
plt.title('Custom Conv Layer vs Keras Conv Layer')
plt.show()
The dotted lines represent the model performance when the Keras layer is used, and the full line is when my custom layer using tf.nn.conv2d
is used. They are the exact same thing when the seed is set.