We continue to discuss Federated Learning technology and its applications. Check out the first part here where we highlighted FL basics. This part is focused on the mechanics of FL.
Selecting devices in cross device learning
Training data that resides on devices can be sensitive and may need to be kept on device. This creates a challenge for many Google teams that are accustomed to training machine learning models in the cloud.
So, instead of using classical cloud machine learning, an engineer would now use federated learning and get the desired model. At Google the teams aim to build an infrastructure that allows engineers to seamlessly train these models on remote devices without having to worry about the complexity of the orchestration between the company’s servers and devices.
To ensure good user experience, the Google FL system places constraints on which devices can participate when. For example, the system verifies that participating devices have the latest security updates no new surfaces of attack are opened.
Moreover, the system verifies that the devices are plugged into the wall and being charged. They don’t want to kill the battery in this process. The system also ensures the devices are idle so as not to mess with the user experiencer and other device conditions.
When these conditions are satisfied, the device will ping the FL server, indicating all the criteria are met, and the device is ready to participate in federated learning.
Often the server gets millions of requests that can’t all be supported in one round of training. The server can only support about a thousand per round. So through screening, the server finds and accepts the devices that meet the required conditions. That’s when a federated learning training round begins.
Practical Considerations in Federated Learning
How do federated learning rounds proceed in practice? Using traditional machine learning, you start with a model that you’ve either pre-trained on publicly available datasets, or maybe randomly initialized. You send the model to the device. Then on the device using the local data, you update the model and send it back to the servers, using a minimal, focused update that captures the new model. The key here is “minimal, focused update.” The local training examples and labels are never sent back to Google. Only the updated model is sent back to the server. The update is focused on your specific machine learning task. This focus is vitally important.
The minimized focused update often removes a lot of the data that could potentially reveal personal information about the device owner. The server is designed so that these focused updates are only kept for a short time, until all are aggregated. Then the aggregation updates the base model on Google’s server. This usually happens within a few minutes.
Following the base model update, all the individual updates are deleted. So the combined, updated model remains. This is where the engineers can inspect it – to see if the training has converged and they are satisfied with the model quality. If they are not satisfied another round is conducted until they have achieved the desired metrics.
Engineers typically will run hundreds or thousands of rounds until the model converges. Each round takes between one to ten minutes. If you do the math, it takes about a week to train a model. If you want to select hyperparameters, it could take a lot longer.
This process of federated learning is applicable to anything, not just model training. It’s just an orchestration between a central entity (Google’s servers) and a large fleet of devices. You sample a few online devices available. You send them a query of some kind. It could be “hey, train this model for me” or it could be something else, whatever you want.
Then you collect focused minimized updates, aggregate them and run in rounds. But simple averaging of user updates is not always the best method for training models using federated learning.
This is because data on user devices may be imbalanced and may come from very different distributions. For instance, different users may have different preferences and interests, and may use their phone very little or a lot. This creates an imbalance in the amount of training examples each device has. Thus simple averaging of updates isn’t always the best approach. But the technique often results in good outcomes. And when it was first used at Google, the team was simply doing just that, and they were amazed by how well it was working in production.
One way we could improve on simple aggregation schemes is by using adaptive optimization strategies (both on the server or the devices). In addition, one could weigh user updates differently or train multiple models simultaneously to address the data and distribution heterogeneity issues. At Google, they typically would run hundreds or thousands of rounds until the model converged. Each round takes from one to ten minutes.
These are major updates on FL technology and if you want to learn more about it – check out this paper: Advances and Open Problems in Federated Learning.
P.S. This article is based on the Peter Kairouz’s talk, Federated Learning at Google and Beyond. The author thanks Peter for the inspiration.