The Story That Explains Model Aggregation
So they make a clever deal. Each hospital trains the same model on its own private patients, behind its own firewall. Then — instead of sending patients — each sends only the lessons the model learned (its updated numbers, the weights) to a neutral coordinator.
The coordinator never sees a single patient. It does one job: it blends the three sets of lessons into one improved model and sends that combined model back to everyone. Crucially, the city hospital saw 600 patients and the rural practice saw 100 — so the coordinator gives the city hospital's lessons more weight. That blending step is exactly what we call model aggregation at the server.
Federated Learning lets many devices or organizations train one shared model without ever centralizing their raw data. The clients do the learning; the server does the merging. This tutorial zooms into that merge — the single most important operation on the server side.
What Actually Happens on the Server
In one federated round, the server runs a tight loop: it broadcasts the current global model, each client trains locally on its private data, every client uploads only its model update, and the server aggregates those updates into the next global model. Watch the data flow below — notice that only model weights move, never raw data.
FedAvg — The Aggregation Formula
The standard server-side aggregation is Federated Averaging (FedAvg), introduced by McMahan et al. It is a weighted average of every client's model, where each client's weight is the fraction of total data it holds. After clients train locally, the server computes one new global model for the next round.
A Worked Numerical Example
Let's aggregate a single weight value from our three hospitals. Each trained the model locally and arrived at a different value for one parameter. (Real models have millions of these — the math is applied element-by-element.)
Compare that to a plain average: (0.80 + 0.50 + 0.20) / 3 = 0.50. FedAvg lands at 0.65 instead, pulled toward Client A's value — correctly, because Client A backed its number with six times more data than Client C. That 0.15 difference is the data-weighting at work.
Aggregation Strategies at a Glance
FedAvg is the default, but the server can blend updates in several ways depending on whether you care about data heterogeneity, privacy, or defending against malicious clients.
| Strategy | What the server combines | Core idea | Best when… |
|---|---|---|---|
| FedSGD | One gradient per client, per step | Server averages gradients every step; clients do 1 local step | You want exactness; communication is cheap |
| FedAvg default | Full local model weights after several epochs | Weighted average by sample count; far fewer rounds | Communication is expensive (the usual case) |
| FedProx | Local weights + a proximal penalty | Keeps local models from drifting too far from global | Clients have very different (non-IID) data |
| FedAvgM | Aggregated update + server momentum | Server keeps momentum across rounds for stability | Training is noisy or oscillating |
| Secure Agg. | Masked / encrypted weights | Server sees only the sum, never an individual update | Privacy of single updates must be guaranteed |
| Krum / Median | A robust statistic, not the mean | Discards outlier updates to resist poisoning | Some clients may be malicious or faulty |
Implementing Server Aggregation in Python
Here is FedAvg as the server would run it. Each client returns its trained weights plus how many samples it trained on. The server blends them into one global model.
import numpy as np
# Each client sends back: (model_weights, num_samples)
# Here we use the 3-hospital values from Section 04.
w_A, w_B, w_C = np.array([0.80]), np.array([0.50]), np.array([0.20])
client_updates = [
(w_A, 600),
(w_B, 300),
(w_C, 100),
]
def federated_average(updates):
"""Weighted average of client weights by sample count (FedAvg)."""
total_samples = sum(n for _, n in updates)
# accumulator shaped exactly like the model weights
global_w = np.zeros_like(updates[0][0])
for weights, n in updates:
alpha = n / total_samples # this client's data share
global_w += alpha * weights # weighted contribution
return global_w
new_global = federated_average(client_updates)
print("Aggregated global weight:", new_global)
Real neural networks store weights as a state_dict — a dictionary of named tensors (one per layer). The server simply applies the same weighted average key-by-key across every layer:
def aggregate_state_dicts(client_states, client_sizes):
"""FedAvg over PyTorch-style state_dicts (layer by layer)."""
total = sum(client_sizes)
avg_state = {}
# every client shares the same layer names (same architecture)
for key in client_states[0].keys():
avg_state[key] = sum(
(client_sizes[i] / total) * client_states[i][key]
for i in range(len(client_states))
)
return avg_state # load this back into the global model
Challenges the Server Must Handle
| Challenge | Why it hurts aggregation | Common fix |
|---|---|---|
| Non-IID data | Clients hold very different distributions, so their models point in conflicting directions | FedProx, FedAvgM, or personalization layers |
| Stragglers | Slow or offline clients delay the round | Client sampling + timeouts; aggregate whoever returns |
| Communication cost | Sending full models every round is expensive | More local epochs (FedAvg), gradient compression |
| Poisoning attacks | A malicious client sends a corrupt update to skew the mean | Robust aggregation (Krum, trimmed mean, median) |
| Update leakage | Raw updates can leak information about private data | Secure aggregation + differential privacy noise |