In Tensorflow Federated (TFF), you can pass to the tff.learning.build_federated_averaging_process
a broadcast_process
and an aggregation_process
, which can embed customized encoders e.g. to apply custom compressions.
Getting to the point of my question, I am trying to implement an encoder to sparsify model updates/model weights.
I am trying to build such an encoder by implementing the EncodingStageInterface
, from tensorflow_model_optimization.python.core.internal
.
However, I am struggling to implement a (local) state to accumulate the zeroed-out coordinates of model updates/model weights round by round. Note that this state should not be communicated, and just need to be maintained locally (so the AdaptiveEncodingStageInterface
should not be helpful). In general, the question is how to maintain a local state inside an Encoder to be then passed to the fedavg process.
I attach the code of my encoder implementation (that, besides the state I would like to add, works fine as stateless as expected).
I then attach the excerpt of my code where I use the encoder implementation.
If I decomment the commented parts in stateful_encoding_stage_topk.py the code does not work: I can't figure out how manage the state (that is a Tensor) in TF non eager mode.
stateful_encoding_stage_topk.py
import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te
@te.core.tf_style_encoding_stage
class StatefulTopKEncodingStage(te.core.EncodingStageInterface):
ENCODED_VALUES_KEY = 'stateful_topk_values'
INDICES_KEY = 'indices'
def __init__(self):
super().__init__()
# Here I would like to init my state
#self.A = tf.zeros([800], dtype=tf.float32)
@property
def name(self):
"""See base class."""
return 'stateful_topk'
@property
def compressible_tensors_keys(self):
"""See base class."""
return [self.ENCODED_VALUES_KEY]
@property
def commutes_with_sum(self):
"""See base class."""
return True
@property
def decode_needs_input_shape(self):
"""See base class."""
return True
def get_params(self):
"""See base class."""
return {}, {}
def encode(self, x, encode_params):
"""See base class."""
del encode_params # Unused.
dW = tf.reshape(x, [-1])
# Here I would like to retrieve the state
A = tf.zeros([800], dtype=tf.float32)
#A = self.residual
dW_and_A = tf.math.add(A, dW)
percentage = tf.constant(0.4, dtype=tf.float32)
k_float = tf.multiply(percentage, tf.cast(tf.size(dW), tf.float32))
k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)
values, indices = tf.math.top_k(tf.math.abs(dW_and_A), k = k_int, sorted = False)
indices = tf.expand_dims(indices, 1)
sparse_dW = tf.scatter_nd(indices, values, tf.shape(dW_and_A))
# Here I would like to update the state
A_updated = tf.math.subtract(dW_and_A, sparse_dW)
#self.A = A_updated
encoded_x = {self.ENCODED_VALUES_KEY: values,
self.INDICES_KEY: indices}
return encoded_x
def decode(self,
encoded_tensors,
decode_params,
num_summands=None,
shape=None):
"""See base class."""
del decode_params, num_summands # Unused.
indices = encoded_tensors[self.INDICES_KEY]
values = encoded_tensors[self.ENCODED_VALUES_KEY]
tensor = tf.fill([800], 0.0)
decoded_values = tf.tensor_scatter_nd_update(tensor, indices, values)
return tf.reshape(decoded_values, shape)
def sparse_quantizing_encoder():
encoder = te.core.EncoderComposer(
StatefulTopKEncodingStage() )
return encoder.make()
fedavg_with_sparsification.py
[...]
def sparsification_broadcast_encoder_fn(value):
spec = tf.TensorSpec(value.shape, value.dtype)
return te.encoders.as_simple_encoder(te.encoders.identity(), spec)
def sparsification_mean_encoder_fn(value):
spec = tf.TensorSpec(value.shape, value.dtype)
if value.shape.num_elements() == 800:
return te.encoders.as_gather_encoder(
stateful_encoding_stage_topk.sparse_quantizing_encoder(), spec)
else:
return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
encoded_broadcast_process = (
tff.learning.framework.build_encoded_broadcast_process_from_model(
model_fn, sparsification_broadcast_encoder_fn))
encoded_mean_process = (
tff.learning.framework.build_encoded_mean_process_from_model(
model_fn, sparsification_mean_encoder_fn))
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.004),
client_weight_fn=lambda _: tf.constant(1.0),
broadcast_process=encoded_broadcast_process,
aggregation_process=encoded_mean_process)
[...]
I am using:
- tensorflow 2.4.0
- tensorflow-federated 0.17.0
question from:
https://stackoverflow.com/questions/65830370/tensorflow-federated-compression-how-to-implement-a-stateful-encoder-to-be-used