Yup, this is easy to express in TFF, and will execution just fine in the default execution stacks.
As you've noticed, the TFF repository generally has examples of cross-device Federated Learning (Kairouz et. al 2019). Generally we talk about the state have tff.SERVER
placement, and the function signature for one "round" of federated learning has the structure (for details about TFF's type shorthand, see the Federated data section of the tutorials):
(<State@SERVER, {Dataset}@CLIENTS> -> State@Server)
We can represent stateful client by simply extending the signature:
(<State@SERVER, {State}@Clients, {Dataset}@CLIENTS> -> <State@Server, {State}@Clients>)
Implementing a version of Federated Averaging (McMahan et. al 2016) that includes a client state object might look something like:
@tff.tf_computation(
model_type,
client_state_type, # additional state parameter
client_data_type)
def client_training_fn(model, state, dataset):
model_update, new_state = # do some local training
return model_update, new_state # return a tuple including updated state
@tff.federated_computation(
tff.FederatedType(server_state_type, tff.SERVER),
tff.FederatedType(client_state_type , tff.CLIENTS), # new parameter for state
tff.FederatedType(client_data_type , tff.CIENTS))
def run_fed_avg(server_state, client_states, client_datasets):
client_initial_models = tff.federated_broadcast(server_state.model)
client_updates, new_client_state = tff.federated_map(client_training_fn,
# Pass the client states as an argument.
(client_initial_models, client_states, client_datasets))
average_update = tff.federated_mean(client_updates)
new_server_state = tff.federated_map(server_update_fn, (server_state, average_update))
# Make sure to return the client states so they can be used in later rounds.
return new_server_state, new_client_states
The invocation of run_fed_avg
would require passing a Python list
of tensors/structures for each client participating in a round, and the result fo the method invocation will be the server state, and a list of client states.