The Story That Explains Transformers
"The bank by the river was steep — but the bank refused my loan anyway."
The word "bank" appears twice. As a human, you instantly understand that the first means a riverbank and the second means a financial institution — because you relate each word to every other word in the sentence simultaneously. You don't read left-to-right and forget the beginning; you hold the whole sentence in your mind and resolve ambiguity through context.
That is exactly what a Transformer does — and it was a radical departure from every AI language model that came before it. Old models (RNNs, LSTMs) read word-by-word, like a person reading aloud. Transformers read the whole sentence at once, attending to every word in relation to every other. That single idea — self-attention — changed everything.
The Transformer architecture was introduced in the landmark 2017 paper "Attention Is All You Need" by Vaswani et al. at Google Brain. In just a few years it replaced RNNs and LSTMs as the dominant architecture for natural language processing, and has since conquered computer vision, protein folding, audio generation, and more.
RNNs process sequences one token at a time — information from step 1 must travel through every intermediate step to reach step 100. This creates a vanishing gradient and makes learning long-range dependencies practically impossible. Transformers eliminated this bottleneck: any token can attend directly to any other token, regardless of how far apart they are. Distance is irrelevant. And because steps are not sequential, the whole sequence can be processed in parallel on GPUs — training became dramatically faster.
The Bird's-Eye View — Encoder & Decoder
The original Transformer is an encoder–decoder architecture designed for sequence-to-sequence tasks like machine translation. Think of it as two specialists passing a baton:
Input Embeddings & Positional Encoding
The solution: add a positional encoding to each embedding — a unique mathematical "address" that encodes not just what a word is, but where it sits in the sentence. Position 0 gets a different sinusoidal pattern added to it than position 1, 2, or 100. Now the model knows the order without ever processing tokens sequentially.
import torch
import math
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
# Create a (max_len, d_model) matrix of positional encodings
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float()
* (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # even dims → sin
pe[:, 1::2] = torch.cos(position * div_term) # odd dims → cos
pe = pe.unsqueeze(0) # shape: (1, max_len, d_model) — batch dimension
self.register_buffer('pe', pe)
def forward(self, x):
# x shape: (batch, seq_len, d_model)
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
Using integers (0, 1, 2 …) would give the model an unbounded signal that explodes for long sequences. Sinusoids are bounded between –1 and +1 regardless of sequence length, and their mathematical properties let the model generalise to sequences longer than it was trained on. Modern models (GPT, BERT derivatives) often learn positional embeddings as parameters instead — simpler but less elegant.
The Heart of It — Scaled Dot-Product Attention
That is attention. Query × Key gives relevance. Softmax normalises it to probabilities. Those probabilities weight the Values. The result is a context-aware blend of information from every position in the sequence — for every position simultaneously.
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, heads, seq_len, d_k)
K: (batch, heads, seq_len, d_k)
V: (batch, heads, seq_len, d_v)
Returns: (batch, heads, seq_len, d_v), attention_weights
"""
d_k = Q.size(-1)
# Step 1: compute raw scores → (batch, heads, seq, seq)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: apply mask (decoder causal mask or padding mask)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: softmax over last dimension (which key each query attends to)
attn_weights = F.softmax(scores, dim=-1) # (batch, heads, seq, seq)
# Step 4: weighted sum of values
output = torch.matmul(attn_weights, V) # (batch, heads, seq, d_v)
return output, attn_weights
# Quick numerical demo — 1 head, 4 tokens, d_k=8
batch, heads, seq, d_k = 1, 1, 4, 8
Q = torch.randn(batch, heads, seq, d_k)
K = torch.randn(batch, heads, seq, d_k)
V = torch.randn(batch, heads, seq, d_k)
out, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {out.shape}") # (1, 1, 4, 8)
print(f"Weights shape: {weights.shape}") # (1, 1, 4, 4) — 4×4 attention map
print(f"Weights sum: {weights[0,0,0].sum():.4f}") # 1.0000 — softmax guarantee
Multi-Head Attention — Why One Perspective Isn't Enough
Multi-head attention does exactly this. Instead of computing attention once with d_model-dimensional Q, K, V matrices, it splits into h heads, each with its own learned projection. Each head learns to attend to a different type of relationship: syntactic dependencies, co-reference, semantic similarity, position proximity. The results are concatenated and projected back, giving the layer a multi-dimensional understanding of context.
| Component | Shape | Purpose |
|---|---|---|
| Input X | (batch, seq, d_model) | Raw embeddings + positional encoding |
| W^Q, W^K, W^V per head | (d_model, d_k) | Learned projections, one set per head |
| Each head output | (batch, seq, d_k) | Context-aware representation from one perspective |
| Concatenated heads | (batch, seq, h×d_k) = (batch, seq, d_model) | All perspectives merged |
| W^O projection | (d_model, d_model) | Mixes information across heads |
| Final output | (batch, seq, d_model) | Same shape as input — ready for next layer |
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # dimension per head
# Single weight matrices — we split them per head inside forward()
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def split_heads(self, x, batch_size):
# x: (batch, seq, d_model) → (batch, heads, seq, d_k)
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Project and split into h heads
Q = self.split_heads(self.W_q(Q), batch_size)
K = self.split_heads(self.W_k(K), batch_size)
V = self.split_heads(self.W_v(V), batch_size)
# Attention on all heads simultaneously
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = torch.nn.functional.softmax(scores, dim=-1)
context = torch.matmul(attn, V) # (batch, heads, seq, d_k)
# Merge heads: (batch, heads, seq, d_k) → (batch, seq, d_model)
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, -1, self.d_model)
return self.W_o(context) # final linear projection
# Test: batch=2, seq=10, d_model=512, 8 heads → d_k = 64 per head
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out = mha(x, x, x)
print(f"Input: {x.shape}") # torch.Size([2, 10, 512])
print(f"Output: {out.shape}") # torch.Size([2, 10, 512]) — same shape!
The Encoder Block — Full Architecture
A single encoder block has just two sub-layers, each followed by a residual connection and layer normalisation. The full encoder stacks N of these blocks (N=6 in the original paper). Each block deepens the model's understanding.
import torch
import torch.nn as nn
class EncoderBlock(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
# Sub-layer 1: Multi-Head Self-Attention
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
# Sub-layer 2: Position-wise FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Sub-layer 1: self-attention + residual + norm
attn_out = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_out))
# Sub-layer 2: FFN + residual + norm
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout2(ffn_out))
return x # shape unchanged: (batch, seq, d_model)
class Encoder(nn.Module):
def __init__(self, num_blocks: int, d_model: int, num_heads: int, d_ff: int):
super().__init__()
self.blocks = nn.ModuleList(
[EncoderBlock(d_model, num_heads, d_ff) for _ in range(num_blocks)]
)
def forward(self, x, mask=None):
for block in self.blocks:
x = block(x, mask)
return x # final encoder output → passed to decoder
# 6-layer encoder: batch=2, seq=10, d_model=512
encoder = Encoder(num_blocks=6, d_model=512, num_heads=8, d_ff=2048)
x = torch.randn(2, 10, 512)
enc_out = encoder(x)
print(f"Encoder output: {enc_out.shape}") # torch.Size([2, 10, 512])
Without residual connections, gradients vanish before they reach the early layers in a 6-block network. The formula x + Sublayer(x) means that even if the sublayer outputs zero (at initialisation), the gradient still flows through the identity path. This is borrowed from ResNet (2015) and is now standard in all deep learning architectures.
The Decoder Block — Three Sub-Layers
1. They look at what they've already written — the previous chapters — to stay consistent in voice and plot (Masked Self-Attention).
2. They consult their research notes and source material — facts about the era, historical events — to ensure accuracy (Cross-Attention over the encoder output).
3. They think deeply and creatively about how to construct the next sentence given what they know (Feed-Forward Network).
The decoder block does all three — in exactly this order — for every token it generates.
class DecoderBlock(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
# 1. Masked self-attention (decoder attends to itself)
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
# 2. Cross-attention (decoder Q, encoder K and V)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.norm2 = nn.LayerNorm(d_model)
# 3. FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, tgt, enc_output, tgt_mask=None, src_mask=None):
# tgt: decoder input (batch, tgt_seq, d_model)
# enc_output: encoder output (batch, src_seq, d_model)
# 1. Masked self-attention
sa_out = self.self_attn(tgt, tgt, tgt, tgt_mask)
tgt = self.norm1(tgt + self.dropout(sa_out))
# 2. Cross-attention: Q from decoder, K/V from encoder
ca_out = self.cross_attn(tgt, enc_output, enc_output, src_mask)
tgt = self.norm2(tgt + self.dropout(ca_out))
# 3. FFN
ffn_out = self.ffn(tgt)
tgt = self.norm3(tgt + self.dropout(ffn_out))
return tgt
Without the causal mask in the decoder's self-attention, the model at position i could directly see the token at position i+1 — the answer it's supposed to predict. Training would be trivially solved by copying the next token. The mask fills future positions with −∞ before softmax, making those attention weights exactly zero. Removing the mask is one of the most common bugs in Transformer implementations.
The Full Transformer — Putting It Together
| Component | Location | Purpose | Trainable Params (d_model=512) |
|---|---|---|---|
| Input Embedding | Encoder + Decoder | Map token IDs → dense vectors | vocab × 512 |
| Positional Encoding | Encoder + Decoder | Inject position information | 0 (sinusoidal, fixed) |
| Multi-Head Self-Attention | Encoder (×6) + Decoder (×6) | Context across all positions | 4 × (512 × 512) per block |
| Cross-Attention | Decoder only (×6) | Bridge encoder understanding to decoder | 4 × (512 × 512) per block |
| FFN | Encoder (×6) + Decoder (×6) | Deep per-position reasoning | 2 × (512 × 2048) per block |
| Layer Norm | Every sub-layer | Stabilise activations | 2 × 512 per instance |
| Output Linear + Softmax | Decoder top | Project to vocabulary probabilities | 512 × vocab |
class Transformer(nn.Module):
def __init__(self, src_vocab: int, tgt_vocab: int,
d_model=512, num_heads=8, num_blocks=6,
d_ff=2048, max_len=5000, dropout=0.1):
super().__init__()
# Shared or separate embeddings
self.src_embed = nn.Embedding(src_vocab, d_model)
self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
self.pos_enc = PositionalEncoding(d_model, max_len, dropout)
# Encoder stack
self.encoder = Encoder(num_blocks, d_model, num_heads, d_ff)
# Decoder stack
self.decoder = nn.ModuleList(
[DecoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_blocks)]
)
# Output projection: d_model → tgt_vocab
self.out_proj = nn.Linear(d_model, tgt_vocab)
def encode(self, src, src_mask):
src = self.pos_enc(self.src_embed(src) * (self.src_embed.embedding_dim ** 0.5))
return self.encoder(src, src_mask)
def decode(self, tgt, enc_out, tgt_mask, src_mask):
tgt = self.pos_enc(self.tgt_embed(tgt) * (self.tgt_embed.embedding_dim ** 0.5))
for block in self.decoder:
tgt = block(tgt, enc_out, tgt_mask, src_mask)
return tgt
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
enc_out = self.encode(src, src_mask)
dec_out = self.decode(tgt, enc_out, tgt_mask, src_mask)
return self.out_proj(dec_out) # logits: (batch, tgt_seq, tgt_vocab)
# Instantiate the original "base" model from the paper
model = Transformer(src_vocab=37000, tgt_vocab=37000,
d_model=512, num_heads=8, num_blocks=6, d_ff=2048)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
Encoder-Only vs Decoder-Only vs Encoder–Decoder
The original Transformer used both encoder and decoder. But the research community quickly discovered that different tasks benefit from different subsets of the architecture. This led to three dominant families of modern models, each a specialised evolution:
| Property | Encoder-Only (BERT) | Decoder-Only (GPT) | Encoder–Decoder (T5) |
|---|---|---|---|
| Attention direction | Bidirectional | Causal (left→right) | Enc: bi / Dec: causal |
| Can generate text? | No | Yes | Yes |
| Best for | Classification, NER, QA | Chat, completion, LLMs | Translation, summarisation |
| Training objective | Masked LM (MLM) | Causal LM (CLM) | Span corruption / seq2seq |
| Uses cross-attention? | No | No | Yes |
Training a Transformer — The Key Techniques
nn.CrossEntropyLoss(label_smoothing=0.1)
in modern PyTorch.
import torch
import torch.nn as nn
class WarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, d_model: int, warmup_steps: int):
self.d_model = d_model
self.warmup_steps = warmup_steps
super().__init__(optimizer)
def get_lr(self):
step = max(1, self._step_count)
scale = self.d_model ** (-0.5) * min(
step ** (-0.5),
step * self.warmup_steps ** (-1.5)
)
return [scale for _ in self.base_lrs]
# Training setup — original paper configuration
model = Transformer(src_vocab=37000, tgt_vocab=37000)
optimizer = torch.optim.Adam(model.parameters(),
lr=1.0, betas=(0.9, 0.98), eps=1e-9)
scheduler = WarmupScheduler(optimizer, d_model=512, warmup_steps=4000)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=0)
# Training step (simplified)
def train_step(src, tgt):
model.train()
tgt_in = tgt[:, :-1] # input: "BOS I am happy"
tgt_out = tgt[:, 1:] # target: "I am happy EOS"
logits = model(src, tgt_in) # (batch, tgt_seq, vocab)
logits = logits.view(-1, logits.size(-1)) # flatten batch × seq
loss = criterion(logits, tgt_out.view(-1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
return loss.item()
Transformer Variants — The Family Tree
The original 2017 Transformer spawned an entire dynasty of architectures. Understanding where each fits helps you choose the right tool for any task.
| Model | Year | Type | Key Innovation | Best Use |
|---|---|---|---|---|
| Transformer | 2017 | Enc–Dec | The original — attention is all you need | Translation |
| BERT | 2018 | Encoder | Bidirectional + masked language modelling pre-training | Classification, NER, QA |
| GPT-2/3/4 | 2019–23 | Decoder | Scale + RLHF alignment; emergent few-shot ability | Chat, generation, agents |
| T5 | 2020 | Enc–Dec | "Text-to-text" — every NLP task framed as seq2seq | Summarisation, translation |
| Vision Transformer (ViT) | 2020 | Encoder | Patches as tokens — Transformer conquers images | Image classification |
| LLaMA / Mistral | 2023 | Decoder | RoPE, GQA, SwiGLU — efficient open-weight LLMs | Open-source LLM apps |
Modern LLMs replace the original design in several ways: RoPE (Rotary Positional Embedding) replaces sinusoidal PE for better length generalisation. RMSNorm replaces LayerNorm for efficiency. SwiGLU/GeGLU replaces ReLU in the FFN for better gradients. Grouped Query Attention (GQA) reduces the KV cache memory by sharing key/value heads across query heads. Flash Attention rewrites the attention kernel for hardware efficiency. The core attention principle, however, has not changed.
Computational Complexity — The Quadratic Problem
The attention score matrix has shape (seq_len × seq_len). For a sequence of 1,000 tokens, that's 1,000,000 values. For 10,000 tokens — a long document — it's 100,000,000 values. Memory and compute scale quadratically with sequence length. This is why standard Transformers were limited to 512–2048 tokens for years, and why a large chunk of modern AI research is dedicated to solving this bottleneck.
| Sequence Length | Attention Matrix Size | Memory (float32) | Practical? |
|---|---|---|---|
| 512 tokens | 262,144 entries | ~1 MB | Yes — standard |
| 2,048 tokens | 4,194,304 entries | ~16 MB | Yes — GPT-3 context |
| 8,192 tokens | 67,108,864 entries | ~256 MB | Marginal — needs Flash Attention |
| 128,000 tokens | 16,384,000,000 entries | ~64 GB | Impossible naively — requires sparse/linear attention |
Using a Pre-trained Transformer — Practical Guide
In practice, you almost never train a Transformer from scratch. You use a pre-trained model and either fine-tune it for your task or use it directly via the Hugging Face ecosystem.
# pip install transformers torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# ── ENCODER-ONLY: Text Classification with BERT ──────────────
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained(
'bert-base-uncased', num_labels=2
)
texts = [
"The Transformer architecture revolutionised NLP.",
"I had a terrible experience with this product."
]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
for text, prob in zip(texts, probs):
pred = "POSITIVE" if prob.argmax() == 1 else "NEGATIVE"
print(f"{pred} ({prob.max():.2%}) — {text[:50]}")
# ── DECODER-ONLY: Text Generation with GPT-2 ─────────────────
from transformers import pipeline
gen = pipeline('text-generation', model='gpt2')
out = gen("The Transformer architecture works by",
max_new_tokens=50, num_return_sequences=1)
print(out[0]['generated_text'])
For encoder models (BERT): add a classification head and fine-tune all weights on your labelled dataset — 1,000 examples is often sufficient. For decoder models (GPT, Claude): try prompting first with 3–5 examples (few-shot) before fine-tuning. Fine-tune only if accuracy is critical and you have thousands of labelled examples. Large models (7B+ parameters) are better prompted; small models (125M–1B) benefit more from fine-tuning.
Golden Rules — Transformers in Practice
AutoTokenizer.from_pretrained('model-name').
Mismatched tokeniser is a silent, hard-to-debug failure mode.