The Story That Explains Federated Learning
But the chefs refuse. Their recipes are their livelihood — the raw data never leaves their kitchen.
So the company tries something clever. Instead of collecting recipes, it sends each chef the same draft cookbook. Each chef cooks from it for a week, scribbles small improvements in the margins based only on their own kitchen, and mails back just the margin notes — never the recipes themselves. Head office averages all ten sets of notes into a better draft, and mails the improved book out again. Round after round, the shared cookbook gets brilliant — yet not a single secret recipe ever left a kitchen.
That shared cookbook is the global model. Each chef's week of private cooking is local training on device. The mailed margin notes are model updates. And the averaging at head office is aggregation. That is federated learning.
What Is Federated Learning & On-Device Training?
Federated Learning (FL) is a way to train one shared machine-learning model across many devices — phones, cars, hospitals, IoT sensors — without ever collecting their raw data in one place. The classic phrase is: bring the model to the data, not the data to the model.
In traditional machine learning, all data is copied to a central server and the model trains there. In federated learning, the model is shipped down to each device, trained locally on data that physically stays put, and only the resulting weight updates travel back. This on-device step — the part where each phone quietly improves the model using your own typing, photos, or sensor readings — is called local training on device, and it is the heart of this tutorial.
Left: every device ships raw data to one server. Right: the model travels to each device, trains locally behind a lock, and only learned updates return.
The Federated Training Round — Step by Step
Federated learning runs in repeated communication rounds. Each round has five stages. Stages 2 and 3 — highlighted below — are the on-device local training that this tutorial focuses on.
Blue = global model broadcast down · Amber pulse = on-device local training · Green = weight updates returning. The loop repeats for many rounds until the model converges.
Inside a Single Device — What "Local Training" Actually Does
When the global model lands on a device, the device behaves like a tiny, self-contained training loop. It does not train forever — that would drain the battery and overfit to one user. Instead it runs a small, fixed amount of work: typically E local epochs over the device's data in mini-batches, using standard gradient descent.
The amber dot is the device descending its local loss curve over a few epochs — a miniature training run happening entirely on the hardware in your pocket.
The Math Behind It — Federated Averaging (FedAvg)
The algorithm that ties all the local updates together is FedAvg, introduced by Google in 2017. The intuition is simple: after local training, the server takes a weighted average of every device's model, where each device is weighted by how much data it holds.
Python Implementation — Local Training + FedAvg from Scratch
Below is a minimal, dependency-light simulation. It defines what happens on a single device (local_train), then the server loop that broadcasts, collects updates, and aggregates with FedAvg.
import numpy as np
# ── A tiny linear model: predict y from x with weight w and bias b ──
def local_train(global_w, global_b, X, y, epochs=3, lr=0.05):
"""Runs ON the device. Raw data X, y never leave this function."""
w, b = global_w, global_b # start from the shared global model
n = len(X)
for _ in range(epochs): # a few LOCAL epochs only
y_pred = w * X + b
error = y_pred - y
grad_w = (2/n) * np.dot(error, X)
grad_b = (2/n) * np.sum(error)
w -= lr * grad_w # local SGD step
b -= lr * grad_b
return w, b, n # send back weights + sample count ONLY
def fed_avg(updates):
"""Server side: weighted average by each client's data size."""
total = sum(n for _, _, n in updates)
new_w = sum(w * n for w, _, n in updates) / total
new_b = sum(b * n for _, b, n in updates) / total
return new_w, new_b
# ── Simulate 3 devices, each with its own private (non-IID) data ──
np.random.seed(0)
clients = [
(np.random.rand(600), # Device A — lots of data
lambda x: 3.0*x + 1.0),
(np.random.rand(300), # Device B
lambda x: 3.0*x + 1.0),
(np.random.rand(100), # Device C — little data
lambda x: 3.0*x + 1.0),
]
client_data = [(X, f(X)) for X, f in clients]
# ── The federated training loop (server orchestration) ──
global_w, global_b = 0.0, 0.0 # start untrained
for rnd in range(20): # 20 communication rounds
updates = []
for X, y in client_data: # BROADCAST + LOCAL TRAIN
upd = local_train(global_w, global_b, X, y)
updates.append(upd) # only Δ-weights come back
global_w, global_b = fed_avg(updates) # AGGREGATE
print(f"Learned w={global_w:.3f}, b={global_b:.3f}")
print("True target was w=3.000, b=1.000")
True target was w=3.000, b=1.000
Notice the design: local_train receives X, y but never returns them — it returns only weights and a sample count. The server in fed_avg never touches a single raw data point, yet the model converges to the true target. That is on-device federated learning in 30 lines.
Where On-Device Federated Learning Is Used Today
| Application | What trains on-device | Why federated |
|---|---|---|
| Mobile keyboards (Gboard) | Next-word & emoji prediction, query suggestions | Your typing is deeply private — it must never reach a server |
| Voice assistants | "Hey ___" wake-word / keyword spotting | Real on-device audio beats proxy data; raw audio stays local |
| Healthcare | Diagnostic models across hospitals | Patient records can't leave the hospital (HIPAA / GDPR) |
| Finance | Fraud-detection across banks | Transaction data is regulated and competitively sensitive |
| Connected cars / IoT | Driving behaviour, sensor anomaly models | Bandwidth is scarce; data volume is huge; latency matters |
The Hard Parts — Challenges of Local On-Device Training
| Challenge | What goes wrong | Common fix |
|---|---|---|
| Non-IID data | Each device's data is unrepresentative; one phone types in French, another in code | FedProx, adaptive aggregation, fewer local epochs |
| Communication cost | Sending full model updates over mobile networks is expensive | Update compression, quantization, fewer participating clients |
| Client drift | Too much local training pulls devices apart, hurting the global model | Limit local epochs (E=1–5); proximal regularization |
| Stragglers / dropout | Slow or offline devices delay or skip rounds | Asynchronous FL; client selection; deadlines |
| Privacy leakage | Even updates can leak data via inference attacks | Differential privacy + secure aggregation |