You can ignore weights by setting them to 0. For this, you can directly get a weight W
and do W.assign(tf.mul(W,0))
. I know that you care about speeding up inference but unless you rewrite your code to use sparse representations, you will probably not be speeding up inference since weights can't be removed fully.
What you can alternatively do, is look at existing solutions for pruning in custom layers:
class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):
def get_prunable_weights(self):
# Prune bias also, though that usually harms model accuracy too much.
return [self.kernel, self.bias]
# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
tf.keras.layers.Flatten()
])
You can e.g. use ConstantSparsity
(see here) and set the parameters such that your layers are fully pruned.
Another alternative is to construct a second, smaller model that you only use for inference. You can then save the required weights separately (instead of saving the entire model) after training and load them in the second model.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…