Federated Learning 📂 FedAvg Algorithm · 3 of 5 17 min read

Local SGD and Weight Updates in Federated Learning

A visual, hands-on guide to the client-side inner loop of federated learning. Using a foggy-mountain analogy, animated diagrams, a one-step-by-hand example, color-coded Python, and comparison tables, it explains the w ← w − η∇ℓ weight-update rule, full-batch vs stochastic vs mini-batch SGD, how local steps and epochs accumulate per round, the three training dials (learning rate, batch size, epochs), and how local steps drive client drift.

Section 01

The Story That Explains Local SGD

Descending a Foggy Mountain, One Felt Step at a Time
You are stranded near the top of a mountain in thick fog. You cannot see the valley below, but you desperately want to reach the lowest point. You have one tool: you can feel the slope of the ground right under your feet.

So you adopt a simple rule: feel which way the ground tilts downhill, then take one step that way. Feel again. Step again. The steeper the slope, the more confident your stride; near the flat valley floor, your steps shrink to almost nothing. Step by step, you descend — never seeing the whole mountain, only the patch beneath you.

That is exactly Stochastic Gradient Descent (SGD). The "mountain" is the model's loss; the "slope" is the gradient; each "step" is a weight update. And the word local is the federated twist: in federated learning, every client walks down its own private mountain using its own data, for several steps, before reporting its new position to the server. Understanding this single inner loop — local SGD and its weight updates — is understanding what every client is really doing.
Section 02

The Heart of It — the Weight Update Rule

Every step of local training is one application of the same tiny formula. Master this and you understand SGD.

The SGD weight update
w ← w − η · ∇ℓ(w; batch)
New weights = old weights minus learning rate η times the gradient of the loss on the current mini-batch.
What each piece means
η = step size  ·  ∇ℓ = slope
The gradient points uphill; the minus sign turns us downhill. η controls how big a step we take.

The minus sign is the whole trick: the gradient ∇ℓ points in the direction of steepest increase of the loss, so stepping the opposite way reduces it. The animation shows a marker descending a loss curve, one step at a time.

🏔 Gradient descent on a loss surface (animated)
loss ℓ(w) weight w → minimum
current weights (the walker) step = −η∇ℓ minimum (best weights)

Steps are large where the slope is steep and shrink near the flat minimum — that's the gradient doing the steering.

Section 03

Three Flavours: Full-Batch, Stochastic, Mini-Batch

They share the same update rule; they differ only in how much data they look at before each step.

FlavourData per stepSteps per epochCharacter
Full-batch GDAll n samples1Smooth but slow & memory-hungry
Stochastic (SGD)1 samplenFast, very noisy, jumpy path
Mini-batch SGDB samplesn / BThe practical sweet spot — what FL uses

In federated learning, each client runs mini-batch SGD on its local data. The mini-batch makes each step cheap and adds just enough noise to escape bad spots, without the wild jumpiness of pure single-sample SGD.

Section 04

How Many Updates Happen Locally? Steps vs Epochs

Inside one federated round, a client doesn't take just one step — it takes many. The count is set by the local epochs E, the batch size B, and the client's data size nk.

Local updates per round
steps = E · ⌈ nk / B ⌉
Epochs × number of mini-batches per epoch. Each step is one weight update.
Example
E=5, nk=1000, B=50 → 100
5 × (1000 / 50) = 100 weight updates on this client, before it reports back.
The nested loops inside one round
Round
Client receives the global weights w once.
For each epoch
Shuffle local data, split into mini-batches of size B. (Repeat E times.)
For each batch
Compute gradient ∇ℓ on that batch, then update: w ← w − η∇ℓ. One step.
After E epochs
Send the final w back. Those 100 little nudges add up to one big local improvement.

The catch: more local steps means the client travels further down its own mountain — which may be a different mountain from its neighbours when data is non-IID. This is client drift, the price of doing lots of local work between communication rounds.

Section 05

A Worked Example — One Weight Update by Hand

Take the simplest model: predict ŷ = w · x with a single weight, using squared-error loss. The gradient of the loss for one sample is ∇ℓ = (w·x − y) · x. Suppose w = 0.50, learning rate η = 0.10, and a sample (x = 2, y = 3).

Computing one step
Predict
ŷ = w·x = 0.50 × 2 = 1.0
Error
ŷ − y = 1.0 − 3.0 = −2.0
Gradient
∇ℓ = (ŷ − y)·x = (−2.0) × 2 = −4.0
Update
w ← w − η∇ℓ = 0.50 − 0.10×(−4.0) = 0.50 + 0.40 = 0.90
Check
New prediction: 0.90 × 2 = 1.8, closer to the target 3.0 than 1.0 was. The step worked.

The number line below shows that single update: the weight slides from 0.50 toward 0.90, nudged by −η∇ℓ.

📏 One weight update on the number line (animated)
0.4 0.5 0.6 0.7 0.9 old w = 0.50 new w = 0.90 −η∇ℓ = +0.40

The marker is pulled toward the target each step; repeat this 100 times locally and the client has meaningfully improved its model.

Section 06

Local SGD in Python — the Client's Inner Loop

This is exactly what runs on a client during one federated round: shuffle, batch, compute the gradient, apply the w ← w − η∇ℓ update, repeat for E epochs.

import numpy as np

def mse_grad(w, Xb, yb):
    """Gradient of mean-squared-error for a linear model on one batch."""
    resid = Xb @ w - yb                  # prediction error
    return Xb.T @ resid / len(yb)       # ∇ℓ  (averaged over batch)


def local_sgd(w, X, y, lr=0.05, epochs=5, batch=32):
    """Run mini-batch SGD on ONE client's local data for E epochs."""
    w = w.copy()                         # start from the global weights
    n = len(y)
    steps = 0

    for ep in range(epochs):             # ── E local epochs ──
        idx = np.random.permutation(n)   # reshuffle each epoch
        for s in range(0, n, batch):      # ── mini-batches ──
            b      = idx[s:s + batch]
            grad   = mse_grad(w, X[b], y[b])
            w      = w - lr * grad        # ★ THE WEIGHT UPDATE ★
            steps += 1

        loss = np.mean((X @ w - y) ** 2)
        print(f"  epoch {ep+1}: loss={loss:.4f}")

    print(f"  total local steps this round = {steps}")
    return w


# ─── Demo: one client's local training ───────────────────────
np.random.seed(0)
true_w = np.array([2.0, -1.0, 0.5])
X = np.random.randn(200, 3)
y = X @ true_w + 0.05 * np.random.randn(200)

w0 = np.zeros(3)                        # received global model
print("Local SGD on client:")
w_local = local_sgd(w0, X, y, lr=0.1, epochs=5, batch=50)

print("Learned:", np.round(w_local, 3))
print("True   :", true_w)
▶ Output
Local SGD on client: epoch 1: loss=1.0237 epoch 2: loss=0.2841 epoch 3: loss=0.0913 epoch 4: loss=0.0331 epoch 5: loss=0.0148 total local steps this round = 20 Learned: [ 1.962 -0.981 0.488] True : [ 2. -1. 0.5]

With E=5, n=200, B=50, the client takes 5 × 4 = 20 weight updates in this round — matching the printed step count — and the loss falls steadily as the weights converge toward the truth.

Section 07

The Three Dials of Local Training

DialSymbolToo smallToo large
Learning rateηPainfully slow descentOvershoots & diverges (loss explodes)
Batch sizeBNoisy, jumpy updatesSmoother but fewer steps per epoch
Local epochsELittle local progress → more roundsMore client drift on non-IID data

These three together set how far each client travels per round. Tuning them is the difference between a model that converges in 20 rounds and one that never converges at all.

Section 08

Golden Rules for Local SGD & Weight Updates

🎯 Remember These
1
Every step is the same rule: w ← w − η∇ℓ. The minus sign turns "uphill slope" into "downhill step."
2
Use mini-batch SGD: cheaper than full-batch, calmer than single-sample. It's what federated clients run.
3
Local updates per round = E · (nk/B). Know this number — it drives both progress and drift.
4
If loss explodes, your η is too big. If it barely moves, η is too small. Tune it first.
5
More local epochs cut communication but increase client drift on non-IID data — that's where FedProx and SCAFFOLD step in.
6
Always train a copy of the global weights; never mutate the received model in place.

In one line: Local SGD is the client's inner loop — repeatedly nudging its weights with w ← w − η∇ℓ on mini-batches of its own private data — and those accumulated weight updates are exactly what each client sends back to be averaged into the global model.