The Story That Explains Federated Learning
So they hire a neutral coordinator. The coordinator never sees a single scan. Instead, it mails each hospital the same blank model. Each hospital trains it privately on its own scans, then mails back only the adjusted dials — the model weights, not the data. The coordinator averages the three sets of dials into one improved model and mails that back out. Repeat a few dozen times.
After a month, all three hospitals own a model that learned from everyone's scans — yet not one image ever left a hospital. That is Federated Learning: the model travels to the data, the data never travels.
Federated Learning (FL) is a way to train one shared machine-learning model across many devices or organisations without centralising their data. A coordinating server sends out the current global model; each client trains it locally on private data; clients send back only model updates; the server aggregates those updates into a better global model. The thing that flows over the network is math, not records.
Classic machine learning moves data to the model (copy everything into one datacentre, then train). Federated learning inverts this: it moves the model to the data. Because raw data stays put, FL unlocks training on sensitive, regulated, or simply too-large-to-move datasets — phones, hospitals, banks, factories — that could never be pooled centrally.
The Communication Flow — Animated
Everything in federated learning revolves around a repeating communication loop between one server and many clients. The diagram below animates a single round: amber packets flow down (the server broadcasts the global model) and green packets flow up (clients return their local updates).
Notice what is missing from the wire: the data. Only model parameters move in either direction. This single property is what makes FL privacy-preserving by design.
Anatomy of One Communication Round
A round is the heartbeat of federated learning. Steps 1–5 repeat until the global model converges. Each round is one full up-and-down cycle of the diagram above.
The pulse never stops: broadcast → train → upload → aggregate → broadcast again. Convergence usually takes anywhere from tens to thousands of rounds depending on data heterogeneity.
FedAvg — The Aggregation Math
The default aggregation algorithm is Federated Averaging (FedAvg), introduced by McMahan et al. in 2017. The key idea: don't just average the client models equally — weight each client's contribution by how much data it trained on, so a client with 100 samples influences the result ten times more than one with 10 samples.
With K clients, where client k holds nₕ samples and
produced local weights wₕ, and total samples n = Σ nₕ:
wglobal = Σₖ (nₕ / n) · wₕ
Each client's weight vector is scaled by its data share nₕ/n, then summed.
That is the entire algorithm — deceptively simple, remarkably effective.
Centralized vs Federated
| Step | What Happens |
|---|---|
| 1 | Copy all raw data to server |
| 2 | Data leaves user devices |
| 3 | Train one model centrally |
| Privacy | Weak — data pooled |
| Bandwidth | Huge upfront data transfer |
| Step | What Happens |
|---|---|
| 1 | Send model to data |
| 2 | Data stays on device |
| 3 | Aggregate weight updates |
| Privacy | Strong — data never moves |
| Bandwidth | Only weights, repeatedly |
Python Implementation — FedAvg From Scratch
This self-contained simulation builds a small federation of 5 clients with non-IID
data, runs local logistic-regression training on each, and aggregates with weighted FedAvg.
Watch the global accuracy climb — while no client's X, y is ever shared.
import numpy as np
np.random.seed(42)
NUM_CLIENTS = 5
ROUNDS = 10
LOCAL_EPOCHS = 3
LR = 0.1
# Build a federation with non-IID local datasets
def make_client(n):
X = np.random.randn(n, 4)
w_true = np.array([1.5, -2.0, 0.8, 1.0])
y = (X @ w_true + 0.3 * np.random.randn(n) > 0).astype(float)
return X, y
clients = [make_client(np.random.randint(80, 200)) for _ in range(NUM_CLIENTS)]
def sigmoid(z):
return 1.0 / (1.0 + np.exp(-z))
# One client trains locally on its OWN data only
def local_update(w, X, y, epochs):
w = w.copy()
for _ in range(epochs):
grad = X.T @ (sigmoid(X @ w) - y) / len(y)
w -= LR * grad
return w
def accuracy(w):
correct = total = 0
for X, y in clients:
correct += ((sigmoid(X @ w) > 0.5) == y).sum()
total += len(y)
return correct / total
# Server orchestrates the communication rounds
global_w = np.zeros(4)
for rnd in range(1, ROUNDS + 1):
updates, sizes = [], []
for X, y in clients: # steps 2-3: broadcast + local train
updates.append(local_update(global_w, X, y, LOCAL_EPOCHS))
sizes.append(len(y))
n = sum(sizes) # steps 4-5: weighted FedAvg
global_w = sum((s / n) * u for u, s in zip(updates, sizes))
print(f"Round {rnd:2d} | global accuracy: {accuracy(global_w):.3f}")
print("Done — raw data never left any client.")
What Actually Crosses the Network
The payloads make the privacy story concrete — weights go out and back, data never does.
# server -> client (downstream broadcast)
payload_down = {"global_weights": global_w}
# client -> server (upstream update)
payload_up = {"weights": w_k, "num_samples": n_k}
# The raw training data (X, y) is NEVER part of any payload.
In production FL frameworks like Flower, TensorFlow Federated,
and OpenFL, you implement exactly these two payloads. The server-side
aggregation strategy (FedAvg by default) and the client-side local_update are the
two pieces you customise — the communication scaffolding is handled for you.
Communication Cost — The Real Bottleneck
In FL, computation is cheap (it's spread across thousands of devices) but communication is expensive. Every round ships a full model both ways, often over slow, metered, unreliable mobile links. Reducing rounds and shrinking payloads is where most FL engineering effort goes.
| Technique | What It Does | Effect on Communication |
|---|---|---|
| More local epochs | Train longer per round before uploading | Fewer rounds needed → less total traffic |
| Client subsampling | Only a fraction of clients join each round | Lower per-round bandwidth |
| Gradient compression / quantization | Send low-precision or sparse updates | Smaller payloads (often 10–100×) |
| Knowledge distillation | Send predictions, not full weights | Much smaller, but adds complexity |
| Naive FedSGD (upload every step) | Communicate after each mini-batch | Extremely chatty — avoid |
In Snips' federated wake-word study, upstream communication was estimated at roughly 8 MB per user to reach target accuracy — reasonable for a smart-home device. The headline finding: an adaptive averaging strategy cut the number of communication rounds dramatically, which matters far more than raw compute.
Where Federated Learning Shines
The Three Hard Problems
Keeping raw data on-device is necessary but not sufficient. Secure aggregation ensures the server only ever sees the sum of updates (never any single client's), and differential privacy adds calibrated noise so no individual's contribution can be singled out. Production FL almost always combines all three.
FedAvg vs Its Successors
| Algorithm | Key Idea | Best For |
|---|---|---|
FedAvg | Weighted average of client weights | The default baseline; near-IID data |
FedProx | Adds a proximal term keeping locals near the global | Non-IID data & stragglers |
SCAFFOLD | Uses control variates to correct client drift | Heavy heterogeneity, fewer rounds |
FedNova | Normalizes for differing local step counts | Clients doing unequal local work |
FedAdam / FedYogi | Adaptive optimizer on the server side | Faster, more stable convergence |
Start with FedAvg — it is the universally supported baseline and works well when client data is reasonably similar. Only reach for FedProx, SCAFFOLD, or a server-side adaptive optimizer once you actually observe slow or unstable convergence caused by non-IID data. Don't pay the complexity tax until your data demands it.
When To Use It — And When Not To
Golden Rules
nₕ/n). Plain unweighted
averaging lets small, noisy clients distort the global model out of proportion.