Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
183 views
in Technique[技术] by (71.8m points)

python - TensorFlow Federated Compression: How to implement a stateful encoder to be used in in TFF's build_federated_averaging_process?

In Tensorflow Federated (TFF), you can pass to the tff.learning.build_federated_averaging_process a broadcast_process and an aggregation_process, which can embed customized encoders e.g. to apply custom compressions.

Getting to the point of my question, I am trying to implement an encoder to sparsify model updates/model weights.

I am trying to build such an encoder by implementing the EncodingStageInterface, from tensorflow_model_optimization.python.core.internal. However, I am struggling to implement a (local) state to accumulate the zeroed-out coordinates of model updates/model weights round by round. Note that this state should not be communicated, and just need to be maintained locally (so the AdaptiveEncodingStageInterface should not be helpful). In general, the question is how to maintain a local state inside an Encoder to be then passed to the fedavg process.

I attach the code of my encoder implementation (that, besides the state I would like to add, works fine as stateless as expected). I then attach the excerpt of my code where I use the encoder implementation. If I decomment the commented parts in stateful_encoding_stage_topk.py the code does not work: I can't figure out how manage the state (that is a Tensor) in TF non eager mode.

stateful_encoding_stage_topk.py

import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te


@te.core.tf_style_encoding_stage
class StatefulTopKEncodingStage(te.core.EncodingStageInterface):

  ENCODED_VALUES_KEY = 'stateful_topk_values'
  INDICES_KEY = 'indices'
  
  
  def __init__(self):
    super().__init__()
    # Here I would like to init my state
    #self.A = tf.zeros([800], dtype=tf.float32)

  @property
  def name(self):
    """See base class."""
    return 'stateful_topk'

  @property
  def compressible_tensors_keys(self):
    """See base class."""
    return [self.ENCODED_VALUES_KEY]

  @property
  def commutes_with_sum(self):
    """See base class."""
    return True

  @property
  def decode_needs_input_shape(self):
    """See base class."""
    return True

  def get_params(self):
    """See base class."""
    return {}, {}

  def encode(self, x, encode_params):
    """See base class."""
    del encode_params  # Unused.

    dW = tf.reshape(x, [-1])
    # Here I would like to retrieve the state
    A = tf.zeros([800], dtype=tf.float32)
    #A = self.residual
    
    dW_and_A = tf.math.add(A, dW)

    percentage = tf.constant(0.4, dtype=tf.float32)
    k_float = tf.multiply(percentage, tf.cast(tf.size(dW), tf.float32))
    k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)

    values, indices = tf.math.top_k(tf.math.abs(dW_and_A), k = k_int, sorted = False)
    indices = tf.expand_dims(indices, 1)
    sparse_dW = tf.scatter_nd(indices, values, tf.shape(dW_and_A))
    
    # Here I would like to update the state
    A_updated = tf.math.subtract(dW_and_A, sparse_dW)
    #self.A = A_updated
    
    encoded_x = {self.ENCODED_VALUES_KEY: values,
                 self.INDICES_KEY: indices}

    return encoded_x

  def decode(self,
             encoded_tensors,
             decode_params,
             num_summands=None,
             shape=None):
    """See base class."""
    del decode_params, num_summands  # Unused.
    
    indices = encoded_tensors[self.INDICES_KEY]
    values = encoded_tensors[self.ENCODED_VALUES_KEY]
    tensor = tf.fill([800], 0.0)
    decoded_values = tf.tensor_scatter_nd_update(tensor, indices, values)
    
    return tf.reshape(decoded_values, shape)



def sparse_quantizing_encoder():
  encoder = te.core.EncoderComposer(
      StatefulTopKEncodingStage() )  
  return encoder.make()

fedavg_with_sparsification.py

[...]

def sparsification_broadcast_encoder_fn(value):
  spec = tf.TensorSpec(value.shape, value.dtype)
  return te.encoders.as_simple_encoder(te.encoders.identity(), spec)

def sparsification_mean_encoder_fn(value):
  spec = tf.TensorSpec(value.shape, value.dtype)
  
  if value.shape.num_elements() == 800:
    return te.encoders.as_gather_encoder(
        stateful_encoding_stage_topk.sparse_quantizing_encoder(), spec)

  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
  
encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        model_fn, sparsification_broadcast_encoder_fn))

encoded_mean_process = (
    tff.learning.framework.build_encoded_mean_process_from_model(
        model_fn, sparsification_mean_encoder_fn))


iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.004),
    client_weight_fn=lambda _: tf.constant(1.0),
    broadcast_process=encoded_broadcast_process,
    aggregation_process=encoded_mean_process)

[...]

I am using:

  • tensorflow 2.4.0
  • tensorflow-federated 0.17.0
question from:https://stackoverflow.com/questions/65830370/tensorflow-federated-compression-how-to-implement-a-stateful-encoder-to-be-used

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

I'll try to answer in two parts; (1) top_k encoder without state and (2) realizing the stateful idea you seem to want in TFF.

(1)

To get the TopKEncodingStage working without state, I see a few details to change.

The commutes_with_sum property should be set to False. In pseudo-code, its meaning is whether sum_x(decode(encode(x))) == decode(sum_x(encode(x))) . This is not true for the representation your encode method returns -- summing the indices would not work well. I think implementation of the decode method can be simplified to

return tf.scatter_nd(
    indices=encoded_tensors[self.INDICES_KEY],
    updates=encoded_tensors[self.ENCODED_VALUES_KEY],
    shape=shape)

(2)

What you refer to cannot be achieved in this manner using tff.learning.build_federated_averaging_process. The process returned by this method does not have any mechanism for maintaining client/local state. Whatever is the state expressed in your StatefulTopKEncodingStage would end up being the server state, not local state.

To work with the client/local state, you may need to write more custom code. For a starter, see examples/stateful_clients which you can adapt to store the state you refer to.

Keep in mind that in TFF, this will need to be represented as functional transformations. Storing values in attributes of a class and use them elsewhere can lead to surprising errors.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...