The Story That Explains Local SGD
So you adopt a simple rule: feel which way the ground tilts downhill, then take one step that way. Feel again. Step again. The steeper the slope, the more confident your stride; near the flat valley floor, your steps shrink to almost nothing. Step by step, you descend — never seeing the whole mountain, only the patch beneath you.
That is exactly Stochastic Gradient Descent (SGD). The "mountain" is the model's loss; the "slope" is the gradient; each "step" is a weight update. And the word local is the federated twist: in federated learning, every client walks down its own private mountain using its own data, for several steps, before reporting its new position to the server. Understanding this single inner loop — local SGD and its weight updates — is understanding what every client is really doing.
The Heart of It — the Weight Update Rule
Every step of local training is one application of the same tiny formula. Master this and you understand SGD.
The minus sign is the whole trick: the gradient ∇ℓ points in the direction of steepest increase of the loss, so stepping the opposite way reduces it. The animation shows a marker descending a loss curve, one step at a time.
Steps are large where the slope is steep and shrink near the flat minimum — that's the gradient doing the steering.
Three Flavours: Full-Batch, Stochastic, Mini-Batch
They share the same update rule; they differ only in how much data they look at before each step.
| Flavour | Data per step | Steps per epoch | Character |
|---|---|---|---|
| Full-batch GD | All n samples | 1 | Smooth but slow & memory-hungry |
| Stochastic (SGD) | 1 sample | n | Fast, very noisy, jumpy path |
| Mini-batch SGD | B samples | n / B | The practical sweet spot — what FL uses |
In federated learning, each client runs mini-batch SGD on its local data. The mini-batch makes each step cheap and adds just enough noise to escape bad spots, without the wild jumpiness of pure single-sample SGD.
How Many Updates Happen Locally? Steps vs Epochs
Inside one federated round, a client doesn't take just one step — it takes many. The count is set by the local epochs E, the batch size B, and the client's data size nk.
The catch: more local steps means the client travels further down its own mountain — which may be a different mountain from its neighbours when data is non-IID. This is client drift, the price of doing lots of local work between communication rounds.
A Worked Example — One Weight Update by Hand
Take the simplest model: predict ŷ = w · x with a single weight, using squared-error loss. The gradient of the loss for one sample is ∇ℓ = (w·x − y) · x. Suppose w = 0.50, learning rate η = 0.10, and a sample (x = 2, y = 3).
The number line below shows that single update: the weight slides from 0.50 toward 0.90, nudged by −η∇ℓ.
The marker is pulled toward the target each step; repeat this 100 times locally and the client has meaningfully improved its model.
Local SGD in Python — the Client's Inner Loop
This is exactly what runs on a client during one federated round: shuffle, batch, compute the gradient, apply the w ← w − η∇ℓ update, repeat for E epochs.
import numpy as np
def mse_grad(w, Xb, yb):
"""Gradient of mean-squared-error for a linear model on one batch."""
resid = Xb @ w - yb # prediction error
return Xb.T @ resid / len(yb) # ∇ℓ (averaged over batch)
def local_sgd(w, X, y, lr=0.05, epochs=5, batch=32):
"""Run mini-batch SGD on ONE client's local data for E epochs."""
w = w.copy() # start from the global weights
n = len(y)
steps = 0
for ep in range(epochs): # ── E local epochs ──
idx = np.random.permutation(n) # reshuffle each epoch
for s in range(0, n, batch): # ── mini-batches ──
b = idx[s:s + batch]
grad = mse_grad(w, X[b], y[b])
w = w - lr * grad # ★ THE WEIGHT UPDATE ★
steps += 1
loss = np.mean((X @ w - y) ** 2)
print(f" epoch {ep+1}: loss={loss:.4f}")
print(f" total local steps this round = {steps}")
return w
# ─── Demo: one client's local training ───────────────────────
np.random.seed(0)
true_w = np.array([2.0, -1.0, 0.5])
X = np.random.randn(200, 3)
y = X @ true_w + 0.05 * np.random.randn(200)
w0 = np.zeros(3) # received global model
print("Local SGD on client:")
w_local = local_sgd(w0, X, y, lr=0.1, epochs=5, batch=50)
print("Learned:", np.round(w_local, 3))
print("True :", true_w)
With E=5, n=200, B=50, the client takes 5 × 4 = 20 weight updates in this round — matching the printed step count — and the loss falls steadily as the weights converge toward the truth.
The Three Dials of Local Training
| Dial | Symbol | Too small | Too large |
|---|---|---|---|
| Learning rate | η | Painfully slow descent | Overshoots & diverges (loss explodes) |
| Batch size | B | Noisy, jumpy updates | Smoother but fewer steps per epoch |
| Local epochs | E | Little local progress → more rounds | More client drift on non-IID data |
These three together set how far each client travels per round. Tuning them is the difference between a model that converges in 20 rounds and one that never converges at all.
Golden Rules for Local SGD & Weight Updates
In one line: Local SGD is the client's inner loop — repeatedly nudging its weights with w ← w − η∇ℓ on mini-batches of its own private data — and those accumulated weight updates are exactly what each client sends back to be averaged into the global model.