What is Federated Machine Learning?
In the most broad or vanilla definition, Federated Learning (FL) or Federated Machine Learning is simply a Machine Learning (ML) setting where many clients collaboratively train a model under the orchestration of a central server while keeping the training data decentralised. That’s all!
The term was introduced in a paper by McMahan et al (2016) [1] from Google and has taken off since then as a way to facilitate collaboration whilst ensuring privacy. The technology was initially introduced specifically for mobile and edge device applications where we have millions or potentially even billions of unreliable clients (reliability here refers to the likelihood of dropping out due to a variety of different reasons e.g. network issues, battery status, compute limitations, etc.), and we want to train a global model that we ship to all devices.
For example, let’s say you want to predict the next word a user is going to type into their mobile phone. Previously, in order to do this, each mobile phone would have to upload its data to the cloud where a central orchestrator (e.g. Google) could train a model on all users’ typing data in one place. In practice, this would likely be across numerous data-centres and compute nodes, but the key is that all the data in this historical scenario is visible in its entirety to the orchestrator.
What Federated Learning does is turn this idea on its head. Instead of sending the data to the model, we send the model to the data.
‘If the mountain will not come to Muhammad, then Muhammad must go to the mountain’ – Francis Bacon, 1625
How does it work?
The algorithm works as follows:
1. Client Selection: The central server orchestrating the training process samples from a set of clients meeting eligibility requirements.
2. Broadcast: The selected clients download the current model weights and training program (assuming training is not being done from scratch).
3. Client computation: Each selected device locally computes an update to the model parameters.
4. Aggregation: The central server collects all the model updates from the devices and aggregates them.
5. Model update: The server locally updates the shared model based on the aggregated update.
Steps 2-5 are repeated until our model has converged
This procedure is demonstrated in the figure below:
Instead of the orchestrator being able to see everyone’s data, they just receive each user’s update to the model parameters. The user data never leaves the device, but the orchestrator can still receive the same results. This method is now in use by both Google and Apple for a variety of different use cases such as speech recognition and language modelling and is termed Cross-device Federated Learning.
But I don’t have access to millions of devices…
Unless you are Google or Apple, chances are you don’t have ready access to millions of devices which are happy to compute things for you using their data. Don’t fear! This doesn’t mean Federated Machine Learning is within the purview of just large tech conglomerates. We can still apply the same technique to a smaller number of clients, which is known as cross-silo FL.
Where cross-device Federated Learning is characterised by a very large number of unreliable devices, cross-silo Federated Machine Learning is characterised by a small number of reliable organisations/entities. Organisations which are actively collaborating with one another to train a global model from which they will all benefit. Cross-silo FL therefore essentially does away with a lot of the communication and computation bottlenecks that are present in cross-device FL. Computation can be performed on powerful GPUs rather than small mobile devices and communication is only between a handful of devices, typically just 2-100. This extra flexibility allows for a whole host of new opportunities such as data alliances between historically competitive parties, collaborations between industry and academia, and other similar data collaboration configurations.
This is what makes FL truly exciting. It is the prospect of finally removing the lock on the vast troves of confidential data that have been siloed for competitive, regulatory, or privacy reasons. ML techniques generally require heaps of data to be able to produce good predictions, but in many cases we’ve only had access to small, fragmented and sequestered datasets. These have been a big barrier to the adoption (or at least full leverage) of ML methods in areas like Healthcare and Finance in particular.
What about privacy?
So far, as we’ve described it, Federated Machine Learning seems like a great way to collaborate with other parties to jointly train a model without actually sharing your data. So, is this sufficient to also protect privacy? Unfortunately, with no additional protections, the central orchestrator to whom each party sends their model parameter updates is able to infer a surprising amount of information about the underlying data. For instance, in cross-device language modelling, it has been demonstrated that it’s possible to learn individual words that a user has typed by inspecting that user’s most recent parameter update [2]. Further research has shown that this even extends to rare words/numbers typed only once e.g. sensitive data such as bank card details or social security numbers [3]. Not only is this possible but memorisation in fact actually peaks when the test set loss is lowest! Naturally this poses significant privacy concerns despite the user’s data having never left the user’s device.
In FL scenarios, we often think of the central orchestrator as being honest-but-curious (you can think of this as being one rung below malicious). An honest-but-curious adversary is one who will not deviate from the defined protocol but will attempt to learn all possible information that they can from the messages they receive. In most cases this is a reasonable assumption and typical of how we can expect FL orchestrators to act. They won’t do anything they shouldn’t do as far as the protocol is concerned, but if we send them any kind of data, they will use it. Clearly, FL alone in this paradigm is not a suitable privacy preserving framework. Therefore, it is almost always paired with one or more of the techniques listed below:
- Differential Privacy
- Secure Multi-Party Computation
- Homomorphic Encryption
- Trusted Execution Environments
Resources
If you are interested in a more comprehensive assessment of FL, take a look at Advances and Open Problems in Federated Learning [4]. Or to get up to date with the most recent research, follow https://github.com/innovation-cat/Awesome-Federated-Machine-Learning.
References