Federated Learning 📂 FedAvg Algorithm · 5 of 5 17 min read

Hands-On: FedAvg from Scratch in Python

A build-it-yourself walkthrough that implements Federated Averaging from scratch in pure NumPy — no PyTorch, TensorFlow, or FL library. With an animated pipeline diagram, five runnable steps (simulate non-IID clients, logistic-regression model, local-SGD client update, weighted aggregation, the training loop), real measured output, an animated loss-convergence chart, a hyperparameter table, and a checklist of common bugs.

Section 01

The Story That Frames This Build

Build the Engine, Don't Just Drive the Car
Anyone can call a library function named federated_train() and watch numbers scroll by. But you only truly understand an engine once you've bolted it together yourself — pistons, crankshaft, spark.

FedAvg is a surprisingly small engine. Strip away the frameworks and it is just four moving parts: split the data across clients, train locally, average the weights, and repeat. In this hands-on build we'll write every part from scratch in plain NumPy — no PyTorch, no TensorFlow, no FL library — train a real classifier across simulated clients, and watch the global loss fall round after round. By the end you'll have a complete, runnable FedAvg you wrote line by line.
🔧
What You Need

Only NumPy. We'll train a logistic-regression classifier (the simplest model with a clean gradient) across five simulated clients, so everything stays readable and runs in under a second.


Section 02

The Four Parts We're Building

Here is the whole engine as a pipeline. A token flows through it once per round; the middle two parts repeat for every selected client, and the loop runs for many rounds.

⚙️ The FedAvg build pipeline (animated)
1. make_federated_data 2. client_update(local SGD) 3. fedavg(weighted avg) 4. evaluateloss / acc repeat × rounds runs for each selected client

We'll write each box as one small function, then wire them together in a loop.


Section 03

Step 1 — Create Federated Data

Real federated data is non-IID: each client sees its own slice of the world. We simulate five clients, each with a different number of samples drawn from the same underlying rule (a logistic model with a hidden true weight vector, including a bias term).

import numpy as np

def make_federated_data(n_clients=5, seed=0):
    """Simulate n_clients, each holding its own private (X, y)."""
    rng    = np.random.default_rng(seed)
    true_w = np.array([1.5, -2.0, 0.8, 0.0])      # last entry = bias
    clients = []
    for k in range(n_clients):
        size = rng.integers(150, 400)           # uneven sizes → weighting matters
        X    = rng.standard_normal((size, 3))
        Xb   = np.hstack([X, np.ones((size, 1))]) # augment a 1s column for bias
        p    = 1 / (1 + np.exp(-(Xb @ true_w)))     # true probabilities
        y    = (rng.random(size) < p).astype(float) # sampled 0/1 labels
        clients.append((Xb, y))
    return clients, true_w
💡
The Bias Trick

Appending a column of ones to X lets the last weight act as the bias term — so we never need a separate bias variable. One vector holds everything.


Section 04

Step 2 — The Model: Predict, Gradient, Score

Logistic regression in four one-line functions. The gradient of the cross-entropy loss is famously clean: Xᵀ(σ(Xw) − y) / n.

def sigmoid(z):    return 1 / (1 + np.exp(-z))
def predict(w, X): return sigmoid(X @ w)
def grad(w, X, y): return X.T @ (predict(w, X) - y) / len(y)
def accuracy(w, X, y): return np.mean((predict(w, X) > 0.5) == y)

def bce(w, X, y):                         # binary cross-entropy loss
    p = np.clip(predict(w, X), 1e-7, 1 - 1e-7)
    return -np.mean(y * np.log(p) + (1 - y) * np.log(1 - p))

Section 05

Step 3 — ClientUpdate (Local SGD)

This is what runs on each selected client: start from the global weights, run a few local epochs of mini-batch SGD on the client's own data, and return the updated weights plus the sample count.

def client_update(w, X, y, lr=0.1, epochs=2, batch=64, seed=0):
    """Local SGD on ONE client. Returns (new_weights, n_samples)."""
    w   = w.copy()                       # ★ never mutate the global model ★
    n   = len(y)
    rng = np.random.default_rng(seed)
    for _ in range(epochs):              # E local epochs
        idx = rng.permutation(n)
        for s in range(0, n, batch):      # mini-batches of size B
            b = idx[s:s + batch]
            w = w - lr * grad(w, X[b], y[b])  # the weight update
    return w, n
⚠️
The #1 Beginner Bug

Forgetting w = w.copy(). Without it, local training mutates the shared global array in place, every client corrupts the others, and your averaging becomes meaningless. Always copy first.


Section 06

Step 4 — The Server: Weighted Average

The whole server-side aggregation is a single weighted sum — each client's weights scaled by its share of the total data.

def fedavg(weights, sizes):
    """Weighted average of client weight vectors:  w = Σ (n_k / n) · w_k."""
    n = sum(sizes)
    return sum((nk / n) * wk for wk, nk in zip(weights, sizes))

Section 07

Step 5 — Wire It Together: the Training Loop

Now the four parts meet. Each round: select a fraction of clients, broadcast, collect their updates, average, and measure the global loss.

def federated_train(clients, d, rounds=12, frac=0.6, seed=0):
    rng  = np.random.default_rng(seed)
    w    = np.zeros(d)                    # global model starts blank
    Xall = np.vstack([c[0] for c in clients])  # pooled set, for evaluation only
    yall = np.concatenate([c[1] for c in clients])

    for t in range(rounds):              # each t = one communication round
        m   = max(1, int(frac * len(clients)))  # cohort size
        sel = rng.choice(len(clients), m, replace=False)

        W, S = [], []
        for k in sel:                     # broadcast + local train
            wk, nk = client_update(w, *clients[k], seed=t * 10 + int(k))
            W.append(wk); S.append(nk)

        w = fedavg(W, S)                  # aggregate into new global model
        print(f"Round {t+1:2d}  |  loss={bce(w, Xall, yall):.3f}  |  acc={accuracy(w, Xall, yall):.3f}")
    return w


# ─── Run the whole thing ─────────────────────────────────────
clients, true_w = make_federated_data()
d = clients[0][0].shape[1]
w = federated_train(clients, d, rounds=12)
print("Learned:", np.round(w, 2))
print("True   :", true_w)
OUTPUT
Round 1 | loss=0.607 | acc=0.819 Round 2 | loss=0.544 | acc=0.821 Round 3 | loss=0.507 | acc=0.822 Round 4 | loss=0.483 | acc=0.823 Round 6 | loss=0.447 | acc=0.821 Round 8 | loss=0.431 | acc=0.820 Round 10 | loss=0.419 | acc=0.817 Round 12 | loss=0.411 | acc=0.819 Learned: [ 0.85 -1.26 0.53 0.02] True : [ 1.5 -2. 0.8 0. ]

Section 08

Watching It Converge

The global loss falls smoothly every round — each cycle of local training plus averaging nudges the shared model closer to the truth, without any client's raw data ever leaving home.

📉 Global BCE loss across rounds (animated)
0.600.500.42 round → BCE loss steep early drop flattening
global loss per round latest round

From 0.607 to 0.411 in twelve rounds. The curve mirrors the printed output above.


Section 09

The Knobs You Just Built

ArgumentRoleTry loweringTry raising
roundsCommunication roundsUnderfitMore convergence (to a point)
fracClient fraction CNoisier averagesSteadier, more traffic
epochsLocal epochs ESlower per roundClient drift on non-IID
lrLearning rate ηCrawlsMay diverge
batchMini-batch size BJumpier updatesSmoother, fewer steps

Section 10

Golden Rules for Building FedAvg

🔧 FedAvg From Scratch — Hard-Won Rules
1
Always w = w.copy() at the top of client_update. Mutating the global array in place is the most common and most confusing bug.
2
Return both weights and the sample count from each client — the count is what makes the average correctly weighted.
3
The pooled set (Xall, yall) is only for evaluation. Never let client_update touch it — that would defeat the entire point of FL.
4
Track loss, not just accuracy. Accuracy can plateau early while the weights — and the loss — are still meaningfully improving.
5
If the loss explodes to NaN, your lr is too high. Halve it. If it barely moves, raise it or add rounds.
6
Seed your RNGs so runs are reproducible while you debug; randomise later for honest evaluation.

In one line: FedAvg from scratch is four small NumPy functions — make data, train locally, weighted-average, loop — and once you've written them yourself, every federated learning framework you ever touch will feel transparent.

You have completed FedAvg Algorithm. View all sections →