The Story That Explains Attention
The old approach: memorise every single word the diplomat says, compress it all into one mental note, then translate from that single note. The longer the speech, the harder it becomes to hold everything — details at the start start to blur.
The attention approach: as you translate each English word, you glance back at the French words that are most relevant to the English word you're currently producing. Translating "bank"? You glance at the French word for context — was it "rive" (riverbank) or "banque" (financial bank)? You pay more attention to the words that matter right now.
That is the entire idea behind the Attention Mechanism.
Before attention, sequence models like RNNs were forced to compress an entire input sentence into a single fixed-length vector — a bottleneck that made them struggle with long sentences. Attention broke that constraint by allowing the model to look at all input tokens simultaneously and dynamically focus on the relevant ones at each decoding step.
Attention replaces the single bottleneck vector with a dynamic context vector built fresh at every decoding step — a weighted sum of all encoder hidden states, where weights reflect relevance. This lets models "look back" freely, solving the long-range dependency problem that crippled RNNs.
The Foundation — Why RNNs Needed Help
To understand why attention was revolutionary, you first need to feel the pain it solved. The classic encoder-decoder RNN architecture compressed an entire source sentence into a single vector — the final hidden state — before decoding.
RNNs suffer two problems simultaneously: vanishing gradients (early tokens receive near-zero gradient signal during backpropagation) and the information bottleneck (the entire input is squeezed into one vector). LSTMs partially solved the gradient problem but the bottleneck remained — until Bahdanau et al. introduced attention in 2014.
Bahdanau Attention — The Original Breakthrough
Bahdanau attention works exactly like this. At each decoding step, it assigns a score (relevance weight) to every encoder hidden state, then builds a focused summary weighted by those scores. The decoder always has access to all encoder states — it just learns which ones matter most right now.
The Three-Step Mechanism
Score formula: e_{ti} = a(s_{t-1}, h_i)
Weight formula: α_{ti} = softmax(e_{ti})
Context: c_t = Σ α_{ti} · h_i — this changes every step!
The alignment matrix (attention weights visualised over source × target positions) was a revelation. For French-to-English translation, the diagonal structure showed the model had learned near-monotone alignment — without ever being told that "le chat" maps to "the cat". The model discovered linguistic structure purely from data.
Python Implementation — Bahdanau Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
class BahdanauAttention(nn.Module):
"""
Bahdanau (Additive) Attention — Bahdanau et al., 2014
'Neural Machine Translation by Jointly Learning to Align and Translate'
"""
def __init__(self, hidden_dim):
super().__init__()
# W_a: projects decoder state
self.W_a = nn.Linear(hidden_dim, hidden_dim, bias=False)
# U_a: projects encoder states
self.U_a = nn.Linear(hidden_dim, hidden_dim, bias=False)
# v_a: collapses to a scalar score
self.v_a = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, decoder_state, encoder_outputs):
# decoder_state: (batch, hidden_dim)
# encoder_outputs: (batch, seq_len, hidden_dim)
# Step 1: Score — combine decoder state with each encoder state
dec_proj = self.W_a(decoder_state).unsqueeze(1) # (batch,1,hidden)
enc_proj = self.U_a(encoder_outputs) # (batch,seq,hidden)
energy = self.v_a(torch.tanh(dec_proj + enc_proj)).squeeze(-1)
# energy: (batch, seq_len) — raw alignment scores
# Step 2: Weights — softmax over source positions
attn_weights = F.softmax(energy, dim=-1) # (batch, seq_len)
# Step 3: Context — weighted sum of encoder states
context = torch.bmm(
attn_weights.unsqueeze(1), # (batch, 1, seq_len)
encoder_outputs # (batch, seq_len, hidden)
).squeeze(1) # → (batch, hidden)
return context, attn_weights
# ── Quick test ──────────────────────────────────────────────
batch, seq_len, hidden = 4, 12, 256
attn = BahdanauAttention(hidden)
decoder_state = torch.randn(batch, hidden)
encoder_outputs = torch.randn(batch, seq_len, hidden)
context, weights = attn(decoder_state, encoder_outputs)
print(f"Context shape : {context.shape}") # (4, 256)
print(f"Weights shape : {weights.shape}") # (4, 12)
print(f"Weights sum : {weights.sum(-1)}") # tensor([1., 1., 1., 1.])
Luong Attention — The Streamlined Successor
Think of two wine sommeliers comparing a glass to a reference. The first (Bahdanau) chemically analyses both wines, blends them, then scores the combination. The second (Luong) simply measures how similar the two wines taste by multiplying their taste profiles together — a dot product. Faster, and often just as accurate. Luong proposed three scoring variants and showed they match or surpass Bahdanau on many tasks.
Luong vs Bahdanau — Key Differences
| Property | Bahdanau (Additive) | Luong (Multiplicative) |
|---|---|---|
| Score function | v·tanh(Ws + Uh) — MLP | dot, general, or concat |
| Context timing | Used to compute decoder step t | Used after decoder step t is computed |
| Decoder state used | s_{t-1} (previous state) | h_t (current state) |
| Trainable params | More — W_a, U_a, v_a | Fewer — just W for "general" variant |
| Speed | Slightly slower | Faster — simpler computation |
| Typical performance | Excellent on long sentences | Competitive; "general" often best |
Three Luong Scoring Functions
Python Implementation — Luong Attention
class LuongAttention(nn.Module):
"""
Luong (Multiplicative) Attention — Luong et al., 2015
'Effective Approaches to Attention-based Neural Machine Translation'
Supports 'dot', 'general', and 'concat' scoring.
"""
def __init__(self, hidden_dim, method="general"):
super().__init__()
self.method = method
if method == "general":
self.W_a = nn.Linear(hidden_dim, hidden_dim, bias=False)
elif method == "concat":
self.W_a = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
self.v_a = nn.Linear(hidden_dim, 1, bias=False)
def score(self, decoder_h, encoder_outputs):
# decoder_h: (batch, hidden)
# encoder_outputs: (batch, seq_len, hidden)
if self.method == "dot":
# Simple dot product
return torch.bmm(
encoder_outputs, # (batch, seq, hidden)
decoder_h.unsqueeze(-1) # (batch, hidden, 1)
).squeeze(-1) # → (batch, seq)
elif self.method == "general":
# Learned transformation of encoder states
transformed = self.W_a(encoder_outputs) # (batch, seq, hidden)
return torch.bmm(
transformed,
decoder_h.unsqueeze(-1)
).squeeze(-1)
elif self.method == "concat":
# Concatenate and project through MLP
dec_exp = decoder_h.unsqueeze(1).expand_as(encoder_outputs)
combined = torch.cat([dec_exp, encoder_outputs], dim=-1)
return self.v_a(torch.tanh(self.W_a(combined))).squeeze(-1)
def forward(self, decoder_h, encoder_outputs):
# Step 1 + 2: Score → Softmax → Weights
energy = self.score(decoder_h, encoder_outputs) # (batch, seq)
attn_weights = F.softmax(energy, dim=-1) # (batch, seq)
# Step 3: Context vector
context = torch.bmm(
attn_weights.unsqueeze(1), # (batch, 1, seq)
encoder_outputs # (batch, seq, hidden)
).squeeze(1) # → (batch, hidden)
return context, attn_weights
# ── Compare all three methods ───────────────────────────────
for method in ["dot", "general", "concat"]:
attn = LuongAttention(256, method=method)
ctx, w = attn(decoder_state, encoder_outputs)
print(f"[{method:8s}] context: {ctx.shape}, weights sum≈1: {w.sum(-1).allclose(torch.ones(4))}")
Self-Attention — When a Sentence Reads Itself
What does it refer to? The trophy or the suitcase?
Answer: the trophy. But how do you know? Because "big" plus "didn't fit" points to the trophy being the obstacle.
A human resolves this by letting every word in the sentence attend to every other word — "it" attends strongly to "trophy", weakly to "suitcase". This is precisely self-attention: each token in a sequence computing relevance scores against all other tokens in the same sequence — including itself. No encoder/decoder split. The sequence attends to itself.
Self-attention (also called intra-attention) was the key architectural innovation in the Transformer (Vaswani et al., 2017). Unlike Bahdanau/Luong attention which connected encoder states to decoder states, self-attention lets every position in a sequence build a representation informed by every other position.
The Query–Key–Value Framework
Self-attention is built around three learned linear projections of every token embedding: Query (Q), Key (K), and Value (V).
Imagine a search engine: you type a Query ("best Italian restaurant"), the engine matches it against document Keys ("Italian cuisine near you"), and returns the Values (actual restaurant info). Self-attention does this simultaneously for every token, in parallel, using learnable linear projections for Q, K, V.
The Scaled Dot-Product Attention Formula
If d_k = 64, the dot products can have variance ~64. Large magnitudes push softmax into regions where gradients are near zero — the model barely learns. Dividing by √64 = 8 brings variance to ~1, keeping gradients healthy. This single line of math was crucial for making Transformers trainable.
Python Implementation — Scaled Dot-Product Self-Attention
import math
class SelfAttention(nn.Module):
"""
Single-head Scaled Dot-Product Self-Attention.
Vaswani et al. 2017 — 'Attention Is All You Need'
"""
def __init__(self, d_model, d_k=None):
super().__init__()
self.d_k = d_k if d_k else d_model
# Three learned projection matrices
self.W_Q = nn.Linear(d_model, self.d_k, bias=False)
self.W_K = nn.Linear(d_model, self.d_k, bias=False)
self.W_V = nn.Linear(d_model, self.d_k, bias=False)
def forward(self, x, mask=None):
# x: (batch, seq_len, d_model)
# Step 1: Project to Q, K, V spaces
Q = self.W_Q(x) # (batch, seq, d_k)
K = self.W_K(x) # (batch, seq, d_k)
V = self.W_V(x) # (batch, seq, d_k)
# Step 2: Compute scaled dot-product scores
# scores[i,j] = how much token i attends to token j
scale = math.sqrt(self.d_k)
scores = torch.bmm(Q, K.transpose(1, 2)) / scale # (batch, seq, seq)
# Step 3: Optional mask (causal / padding)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 4: Softmax → attention weights
attn_weights = F.softmax(scores, dim=-1) # (batch, seq, seq)
# Step 5: Weighted sum of values
output = torch.bmm(attn_weights, V) # (batch, seq, d_k)
return output, attn_weights
# ── Demo ─────────────────────────────────────────────────────
batch, seq, d_model = 2, 8, 64
sa = SelfAttention(d_model)
x_in = torch.randn(batch, seq, d_model)
out, weights = sa(x_in)
print(f"Input shape: {x_in.shape}") # (2, 8, 64)
print(f"Output shape: {out.shape}") # (2, 8, 64)
print(f"Attn shape: {weights.shape}") # (2, 8, 8)
print(f"Row sums ≈ 1: {weights[0].sum(-1)}")
Attention Map — Visualising "it" Resolution
Multi-Head Attention — Parallel Perspectives
Multi-head attention runs h independent attention mechanisms in parallel, each looking at the same sentence but from a different learned representational subspace. One head might specialise in syntactic dependencies ("subject-verb agreement"). Another in coreference ("it" → "trophy"). Another in positional proximity. The final output concatenates all heads, giving the model multiple simultaneous perspectives on the same input.
Why Multiple Heads?
The Full Multi-Head Attention Formula
Parameter Count — Where Does the Memory Go?
| Component | Shape | Parameters (d_model=512, h=8) |
|---|---|---|
| W_i^Q per head | (d_model, d_k) | 512 × 64 = 32,768 |
| W_i^K per head | (d_model, d_k) | 512 × 64 = 32,768 |
| W_i^V per head | (d_model, d_v) | 512 × 64 = 32,768 |
| All 8 heads (Q+K+V) | 3 × 8 matrices | 3 × 8 × 32,768 = 786,432 |
| W^O output proj | (d_model, d_model) | 512 × 512 = 262,144 |
| Total per layer | — | ≈ 1,048,576 (≈ 1M params) |
Python Implementation — Multi-Head Attention
class MultiHeadAttention(nn.Module):
"""
Multi-Head Scaled Dot-Product Attention.
Vaswani et al. 2017 — 'Attention Is All You Need'
"""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # per-head dimension
# Unified projection matrices — project all heads at once then split
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False) # output projection
def split_heads(self, x):
# x: (batch, seq, d_model) → (batch, num_heads, seq, d_k)
batch, seq, _ = x.size()
x = x.view(batch, seq, self.num_heads, self.d_k)
return x.transpose(1, 2) # → (batch, heads, seq, d_k)
def scaled_dot_product(self, Q, K, V, mask=None):
scale = math.sqrt(self.d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V), weights
def forward(self, query, key, value, mask=None):
# For self-attention: query = key = value = x
# For cross-attention: query from decoder, key/value from encoder
batch = query.size(0)
# Step 1: Project and split into h heads
Q = self.split_heads(self.W_Q(query)) # (batch, heads, seq_q, d_k)
K = self.split_heads(self.W_K(key)) # (batch, heads, seq_k, d_k)
V = self.split_heads(self.W_V(value)) # (batch, heads, seq_v, d_k)
# Step 2: Per-head attention (runs in parallel on all heads)
x, attn_weights = self.scaled_dot_product(Q, K, V, mask)
# x: (batch, heads, seq, d_k)
# Step 3: Concatenate heads
x = x.transpose(1, 2).contiguous() # (batch, seq, heads, d_k)
x = x.view(batch, -1, self.d_model) # (batch, seq, d_model)
# Step 4: Final output projection
output = self.W_O(x) # (batch, seq, d_model)
return output, attn_weights
# ── Demo: BERT-base style (12 heads, d_model=768) ───────────
d_model, num_heads = 768, 12
mha = MultiHeadAttention(d_model, num_heads)
x_bert = torch.randn(2, 16, d_model) # batch=2, seq=16
out, w = mha(x_bert, x_bert, x_bert) # self-attention
total_params = sum(p.numel() for p in mha.parameters())
print(f"Output shape : {out.shape}") # (2, 16, 768)
print(f"Attn map shape : {w.shape}") # (2, 12, 16, 16)
print(f"Total params : {total_params:,}") # 2,359,296
print(f"d_k per head : {d_model // num_heads}") # 64
Interpretability research (Clark et al., 2019 — "What Does BERT Look At?") found that different attention heads in BERT specialize remarkably: some heads track direct object relations, others track coreferent mentions, one head attends almost exclusively to the preceding token, and several heads in the final layers attend strongly to the [CLS] token for classification. Heads are not redundant — they're specialists.
Three Types of Attention in a Transformer
The full Transformer architecture uses multi-head attention in three distinct roles, each with a different configuration of Q, K, V sources.
| Attention Type | Query Source | Key/Value Source | Masking? | Example Model |
|---|---|---|---|---|
| Encoder Self-Attention | Encoder layer input | Same | None | BERT, RoBERTa |
| Decoder Self-Attention | Decoder layer input | Same | Causal mask | GPT-2, GPT-4, LLaMA |
| Cross-Attention | Decoder states | Encoder outputs | None | T5, BART, mT5 |
Attention Complexity — The Elephant in the Room
Multi-head attention is powerful but expensive. Understanding its complexity is essential for production systems and for understanding why models like Longformer and FlashAttention exist.
| Resource | Standard Attention | Implication |
|---|---|---|
| Time complexity | O(n² · d) | Quadratic in sequence length n — doubling the sequence length quadruples compute time |
| Memory complexity | O(n²) | The attention matrix (n × n) must fit in GPU memory — bottleneck for long documents |
| Sequence length 512 | 262,144 entries | Manageable on modern hardware |
| Sequence length 4,096 | 16,777,216 entries | Requires efficient attention (FlashAttention, etc.) |
| Sequence length 32,768 | Over 1 billion entries | Impossible with standard attention without hardware tricks |
FlashAttention (Dao et al., 2022) reorders the attention computation to be IO-aware — it tiles the Q, K, V matrices and avoids materialising the full n×n attention matrix in slow HBM memory. This achieves the same mathematical result as standard attention but with 2–4× less memory and 2–4× faster wall-clock time. GPT-4 and modern LLMs all use FlashAttention internally.
Positional Encoding — Giving Attention a Sense of Order
The Transformer solves this by adding a positional encoding to each token embedding before feeding it into the attention layers. Each position in the sequence gets a unique signal — like stamping each telegram with a timestamp — so the model can distinguish "word at position 3" from "word at position 7".
import numpy as np
def positional_encoding(max_seq, d_model):
"""
Compute sinusoidal positional encodings.
Returns tensor of shape (max_seq, d_model).
"""
PE = np.zeros((max_seq, d_model))
pos = np.arange(max_seq)[:, np.newaxis] # (max_seq, 1)
i = np.arange(0, d_model, 2)[np.newaxis, :] # (1, d_model/2)
div_term = np.power(10000, (2 * i) / d_model)
PE[:, 0::2] = np.sin(pos / div_term) # even dims → sine
PE[:, 1::2] = np.cos(pos / div_term) # odd dims → cosine
return PE
PE = positional_encoding(100, 512)
print(f"PE shape: {PE.shape}") # (100, 512)
print(f"Values range: [{PE.min():.3f}, {PE.max():.3f}]") # [-1.000, 1.000]
Full Transformer Block — Putting It All Together
class TransformerEncoderBlock(nn.Module):
"""
A single Transformer encoder block:
Multi-Head Self-Attention → Add&Norm → FFN → Add&Norm
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# ── Multi-head self-attention ───────────────────────────
self.mha = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.drop1 = nn.Dropout(dropout)
# ── Feed-forward network ────────────────────────────────
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.drop2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 1. Multi-head self-attention + residual
attn_out, _ = self.mha(x, x, x, mask)
x = self.norm1(x + self.drop1(attn_out)) # Add & Norm
# 2. Position-wise feed-forward + residual
ffn_out = self.ffn(x)
x = self.norm2(x + self.drop2(ffn_out)) # Add & Norm
return x
# ── Stack 6 blocks (original Transformer encoder) ───────────
encoder = nn.Sequential(
*[TransformerEncoderBlock(d_model=512, num_heads=8, d_ff=2048)
for _ in range(6)]
)
src = torch.randn(2, 32, 512) # batch=2, seq=32, d_model=512
out = encoder(src)
total = sum(p.numel() for p in encoder.parameters())
print(f"Encoder output : {out.shape}") # (2, 32, 512)
print(f"Encoder params : {total:,}") # ~25M
Comparison — Bahdanau vs Luong vs Self-Attention
| Property | Bahdanau (2014) | Luong (2015) | Self-Attention / MHA (2017) |
|---|---|---|---|
| Architecture role | Encoder → Decoder bridge | Encoder → Decoder bridge | Within-sequence or cross-sequence |
| Score function | Additive MLP (v·tanh(Ws+Uh)) | Dot / General / Concat | Scaled dot-product (QKᵀ/√d_k) |
| Parallelisable? | No — sequential decoding | No — sequential decoding | Yes — O(1) sequential steps |
| Multiple perspectives? | Single head | Single head | h parallel heads |
| Seq complexity | O(n) sequential steps | O(n) sequential steps | O(n²) attention matrix |
| State-of-art in 2025? | Superseded | Superseded | Foundation of all modern LLMs |
| Still studied? | Historically important | Historically important | Active research frontier |
Golden Rules — Attention in Practice
torch.nn.functional.scaled_dot_product_attention
(PyTorch 2.0+) which automatically dispatches to FlashAttention when available.
Never implement naive O(n²) attention for production systems.
BertViz or a simple heatmap with matplotlib.
Modern Variants and Where Attention Is Going
Use full multi-head attention for sequences under 4K tokens and fine-tuning pretrained models. Use FlashAttention (via PyTorch 2.0 SDPA) for all GPU training. Use GQA or MQA when deploying with long contexts and memory constraints. For research on >32K context, explore sliding window + global token hybrids. Bahdanau/Luong are useful for teaching and for custom seq2seq tasks where you want explicit control over the alignment model.