Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
730 views
in Technique[技术] by (71.8m points)

tensorflow federated - Is there a way for TFF clients to have internal states?

The code in the TFF tutorials and in the research projects I see generally only keep track of server states. I’d like there to be internal client states (for instance, additional client internal neural networks which are completely decentralized and don’t update in a federated manner) that would influence the federated client computations.

However, in the client computations I have seen, they are only functions of the server states and the data. Is it possible to accomplish the above?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Yup, this is easy to express in TFF, and will execution just fine in the default execution stacks.

As you've noticed, the TFF repository generally has examples of cross-device Federated Learning (Kairouz et. al 2019). Generally we talk about the state have tff.SERVER placement, and the function signature for one "round" of federated learning has the structure (for details about TFF's type shorthand, see the Federated data section of the tutorials):

(<State@SERVER, {Dataset}@CLIENTS> -> State@Server)

We can represent stateful client by simply extending the signature:

(<State@SERVER, {State}@Clients, {Dataset}@CLIENTS> -> <State@Server, {State}@Clients>)

Implementing a version of Federated Averaging (McMahan et. al 2016) that includes a client state object might look something like:

@tff.tf_computation(
  model_type,
  client_state_type,  # additional state parameter
  client_data_type)
def client_training_fn(model, state, dataset):
  model_update, new_state = # do some local training
  return model_update, new_state # return a tuple including updated state

@tff.federated_computation(
  tff.FederatedType(server_state_type, tff.SERVER),
  tff.FederatedType(client_state_type , tff.CLIENTS),  # new parameter for state
  tff.FederatedType(client_data_type , tff.CIENTS))
def run_fed_avg(server_state, client_states, client_datasets):
  client_initial_models = tff.federated_broadcast(server_state.model)
  client_updates, new_client_state = tff.federated_map(client_training_fn,
    # Pass the client states as an argument.
    (client_initial_models, client_states, client_datasets))
  average_update = tff.federated_mean(client_updates)
  new_server_state = tff.federated_map(server_update_fn, (server_state, average_update))
  # Make sure to return the client states so they can be used in later rounds.
  return new_server_state, new_client_states

The invocation of run_fed_avg would require passing a Python list of tensors/structures for each client participating in a round, and the result fo the method invocation will be the server state, and a list of client states.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...