🌼🌸FLOWER 🌼🌸
The idea behind Federated Learning is to train a model between multiple clients and a server without having to share any data. This is done by letting each client train the model locally on its data and send its parameters back to the server, which then aggregates all the clientsโ parameters together using a predefined strategy. This process is made very simple by using the Flower framework 😀.Flower: A Friendly Federated Learning Research Framework on MNIST Data.
We naturally first need to import torch and torchvision and loading the MNIST dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import flwr as fl
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "distilbert-base-uncased"
from torchvision.transforms import ToTensor, Normalize, Compose
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
from flwr.common import Metrics
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
print("flwr", fl.__version__)
print("numpy", np.__version__)
print("torch", torch.__version__)
print("torchvision", torchvision.__version__)
DEVICE = torch.device("cpu")
# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")
1
2
3
model = Net(num_classes=10)
num_parameters = sum(value.numel() for value in model.state_dict().values())
print(f"{num_parameters = }")
Visualise 64 images from the dataset ๐ค ๐ค ๐ค
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import random
import numpy as np
def visualise_n_random_examples(trainset_, n: int, verbose: bool = True):
# take n examples at random
idx = list(range(len(trainset_.data)))
random.shuffle(idx)
idx = idx[:n]
if verbose:
print(f"will display images with idx: {idx}")
# construct canvas
num_cols = 16
num_rows = int(np.ceil(len(idx) / num_cols))
fig, axs = plt.subplots(figsize=(64, num_rows * 4), nrows=num_rows, ncols=num_cols)
# display images on canvas
for c_i, i in enumerate(idx):
axs.flat[c_i].imshow(trainset_.data[i], cmap="gray")
1
2
3
# it is likely that the plot this function will generate looks familiar to other plots you might have generated before
# or you might have encountered in other tutorials. So far, we aren't doing anything new, Federated Learning will start soon!
visualise_n_random_examples(trainset, n=64)
The training loop
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def train(net, trainloader, optimizer, epochs):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
net.train()
for _ in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
return net
def test(net, testloader):
"""Validate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for images, labels in testloader:
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy
def run_centralised(epochs: int, lr: float, momentum: float = 0.9):
"""A minimal (but complete) training loop"""
# instantiate the model
model = Net(num_classes=10)
# define optimiser with hyperparameters supplied
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# get dataset and construct a dataloaders
trainset, testset = get_mnist()
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=128)
# train for the specified number of epochs
trained_model = train(model, trainloader, optim, epochs)
# training is completed, then evaluate model on the test set
loss, accuracy = test(trained_model, testloader)
print(f"{loss = }")
print(f"{accuracy = }")
Create 2 partitions and extract some statistics from one partition ๐๐๐
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from torch.utils.data import random_split
def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):
"""This function partitions the training set into N disjoint
subsets, each will become the local dataset of a client. This
function also subsequently partitions each traininset partition
into train and validation. The test set is left intact and will
be used by the central server to asses the performance of the
global model."""
# get the MNIST dataset
trainset, testset = get_mnist()
# split trainset into `num_partitions` trainsets
num_images = len(trainset) // num_partitions
partition_len = [num_images] * num_partitions
trainsets = random_split(
trainset, partition_len, torch.Generator().manual_seed(2023)
)
# create dataloaders with train+val support
trainloaders = []
valloaders = []
for trainset_ in trainsets:
num_total = len(trainset_)
num_val = int(val_ratio * num_total)
num_train = num_total - num_val
for_train, for_val = random_split(
trainset_, [num_train, num_val], torch.Generator().manual_seed(2023)
)
trainloaders.append(
DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2)
)
valloaders.append(
DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2)
)
# create dataloader for the test set
testloader = DataLoader(testset, batch_size=128)
return trainloaders, valloaders, testloader
Create 100 partitions and extract some statistics from one partition ๐๐๐
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import matplotlib.pyplot as plt
trainloaders, valloaders, testloader = prepare_dataset(
num_partitions=100, batch_size=32
)
# first partition
train_partition = trainloaders[0].dataset
# count data points
partition_indices = train_partition.indices
print(f"number of images: {len(partition_indices)}")
# visualise histogram
plt.hist(train_partition.dataset.dataset.targets[partition_indices], bins=10, color = "red", ec="green")
plt.grid()
plt.xticks(range(10))
plt.xlabel("Label")
plt.ylabel("Number of images")
plt.title("Class labels distribution for MNIST")
Defining a Flower Client
We can think of a client in FL as an entity that owns some data and trains a model using this data. The caveat is that the model is being trained collaboratively in Federation by multiple clients (sometimes up to hundreds of thousands) and, in most instances of FL, is sent by a central server.
A Flower Client is a simple Python class with four distinct methods:
-
fit(): With this method, the client does on-device training for a number of epochs using its own data. At the end, the resulting model is sent back to the server for aggregation.
-
evaluate(): With this method, the server can evaluate the performance of the global model on the local validation set of a client. This can be used for instance when there is no centralised dataset on the server for validation/test. Also, this method can be use to asses the degree of personalisation of the model being federated.
-
set_parameters(): This method takes the parameters sent by the server and uses them to initialise the parameters of the local model that is ML framework specific (e.g. TF, Pytorch, etc).
-
get_parameters(): It extract the parameters from the local model and transforms them into a list of NumPy arrays. This ML framework-agnostic representation of the model will be sent to the server.
Start by importing Flower! ๐๐๐
1
2
3
4
5
6
import flwr as fl
import torch
DEVICE = torch.device("cpu")
# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")
Then we need to define our Flower Client class ๐๐๐
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from collections import OrderedDict
from typing import Dict, Tuple
import torch
from flwr.common import NDArrays, Scalar
class FlowerClient(fl.client.NumPyClient):
def __init__(self, trainloader, vallodaer) -> None:
super().__init__()
self.trainloader = trainloader
self.valloader = vallodaer
self.model = Net(num_classes=10)
def set_parameters(self, parameters):
"""With the model parameters received from the server,
overwrite the uninitialise model in this class with them."""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
# now replace the parameters
self.model.load_state_dict(state_dict, strict=True)
def get_parameters(self, config: Dict[str, Scalar]):
"""Extract all model parameters and convert them to a list of
NumPy arrays. The server doesn't work with PyTorch/TF/etc."""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def fit(self, parameters, config):
"""This method train the model using the parameters sent by the
server on the dataset of this client. At then end, the parameters
of the locally trained model are communicated back to the server"""
# copy parameters sent by the server into client's local model
self.set_parameters(parameters)
# Define the optimizer -------------------------------------------------------------- Essentially the same as in the centralised example above
optim = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
# do local training -------------------------------------------------------------- Essentially the same as in the centralised example above (but now using the client's data instead of the whole dataset)
train(self.model, self.trainloader, optim, epochs=1)
# return the model parameters to the server as well as extra info (number of training examples in this case)
return self.get_parameters({}), len(self.trainloader), {}
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
"""Evaluate the model sent by the server on this client's
local validation set. Then return performance metrics."""
self.set_parameters(parameters)
loss, accuracy = test(
self.model, self.valloader
) # <-------------------------- calls the `test` function, just what we did in the centralised setting (but this time using the client's local validation set)
# send statistics back to the server
return float(loss), len(self.valloader), {"accuracy": accuracy}
Spend a few minutes to inspect the FlowerClient class above. Please ask questions if there is something unclear !
Then keen-eyed among you might have realised that if we were to fuse the clientโs fit() and evaluate() methods, weโll end up with essentially the same as in the run_centralised() function we used in the Centralised Training part of this tutorial. And it is true!! In Federated Learning, the way clients perform local training makes use of the same principles as more traditional centralised setup. The key difference is that the dataset now is much smaller and itโs never โseenโ by the entity running the FL workload (i.e. the central server).
Talking about the central serverโฆ we should define what strategy we want to make use of so the updated models sent from the clients back to the server at the end of the fit() method are aggregate.
Choosing a Flower Strategy 😀😀😀
A strategy sits at the core of the Federated Learning experiment. It is involved in all stages of a FL pipeline: sampling clients; sending the global model to the clients so they can do fit(); receive the updated models from the clients and aggregate these to construct a new global model; define and execute global or federated evaluation; and more.
The way FedAvg works is simple but performs surprisingly well in practice. It is therefore one good strategy to start your experimentation. FedAvg, as its name implies, derives a new version of the global model by taking the average of all the models sent by clients participating in the round.
Letโs see how we can define FedAvg using Flower. We use one of the callbacks called evaluate_fn so we can easily evaluate the state of the global model using a small centralised testset. Note this functionality is user-defined since it requires a choice in terms of ML-framework. (if you recall, Flower is framework agnostic).
This being said, centralised evaluation of the global model is only possible if there exists a centralised dataset that somewhat follows a similar distribution as the data thatโs spread across clients. In some cases having such centralised dataset for validation is not possible, so the only solution is to federate the evaluation of the global model. This is the default behaviour in Flower. If you donโt specify the evaluate_fn argument in your strategy, then, centralised global evaluation wonโt be performed.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def get_evalulate_fn(testloader):
"""This is a function that returns a function. The returned
function (i.e. `evaluate_fn`) will be executed by the strategy
at the end of each round to evaluate the stat of the global
model."""
def evaluate_fn(server_round: int, parameters, config):
"""This function is executed by the strategy it will instantiate
a model and replace its parameters with those from the global model.
The, the model will be evaluate on the test set (recall this is the
whole MNIST test set)."""
model = Net(num_classes=10)
# set parameters to the model
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
# call test
loss, accuracy = test(
model, testloader
) # <-------------------------- calls the `test` function, just what we did in the centralised setting
return loss, {"accuracy": accuracy}
return evaluate_fn
# now we can define the strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.1, # let's sample 10% of the client each round to do local training
fraction_evaluate=0.1, # after each round, let's sample 20% of the clients to asses how well the global model is doing
min_available_clients=50, # total number of clients available in the experiment
evaluate_fn=get_evalulate_fn(testloader),
) # a callback to a function that the strategy can execute to evaluate the state of the global model on a centralised dataset
So far we have:
- created the dataset partitions (one for each client)
- defined the client class
- decided on a strategy to use ๐๐๐
Now we just need to launch the Flower FL experimentโฆ not so fast! just one final function: letโs create another callback that the Simulation Engine will use in order to span VirtualClients. As you can see this is really simple: construct a FlowerClient object, assigning each their own data partition.
1
2
3
4
5
6
7
8
9
10
11
12
def generate_client_fn(trainloaders, valloaders):
def client_fn(cid: str):
"""Returns a FlowerClient containing the cid-th data partition"""
return FlowerClient(
trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]
)
return client_fn
client_fn_callback = generate_client_fn(trainloaders, valloaders)
Launch the FL experiment using Flower simulation ๐๐๐
1
2
3
4
5
6
history = fl.simulation.start_simulation(
client_fn=client_fn_callback, # a callback to construct a client
num_clients=100, # total number of clients in the experiment
config=fl.server.ServerConfig(num_rounds=10), # let's run for 10 rounds
strategy=strategy, # the strategy that will orchestrate the whole FL pipeline
)
Doing 10 rounds should take less than 2 minutes on a CPU-only Colab instance <โ Flower Simulation is fast! ๐ Doing 20 rounds should take less than 15 minutes on a CPU-only Colab instance <โ Flower Simulation is fast! ๐
We can then use the returned History object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtained at the end of each round (including at the very beginning of the experiment) for the global model. This is want the function evaluate_fn() that we passed to the strategy reports.
1
2
3
4
5
6
7
8
9
10
print(f"{history.metrics_centralized = }")
global_accuracy_centralised = history.metrics_centralized["accuracy"]
round = [data[0] for data in global_accuracy_centralised]
acc = [100.0 * data[1] for data in global_accuracy_centralised]
plt.plot(round, acc)
plt.grid()
plt.ylabel("Accuracy (%)")
plt.xlabel("Round")
plt.title("MNIST - IID - 100 clients with 10 clients per round")
Soโฆ. we can modify the provided codes above to train 200, 400, 600, or even thousand billions of clients in federated environment using 🌼🌸FLOWER 🌼🌸. Anyway, I just tried until 200 clients using this data of MNIST.