Natural Language Processing (NLP) 📂 Attention & Transformers · 1 of 2 47 min read

Attention Mechanisms in NLP

A comprehensive, code-first tutorial covering the three foundational attention mechanisms in NLP. Starting from the pain point that RNNs solved poorly — long-range dependencies — it walks through Bahdanau additive attention and its alignment matrix breakthrough, Luong's three multiplicative scoring variants, the revolutionary Query-Key-Value self-attention framework, and multi-head attention with its parallel representational heads.

Section 01

The Story That Explains Attention

The Translator at a Noisy Conference
Imagine you're a live translator at an international conference. A diplomat speaks a long, complex sentence in French. Your task is to translate it into English — word by word — as it is being spoken.

The old approach: memorise every single word the diplomat says, compress it all into one mental note, then translate from that single note. The longer the speech, the harder it becomes to hold everything — details at the start start to blur.

The attention approach: as you translate each English word, you glance back at the French words that are most relevant to the English word you're currently producing. Translating "bank"? You glance at the French word for context — was it "rive" (riverbank) or "banque" (financial bank)? You pay more attention to the words that matter right now.

That is the entire idea behind the Attention Mechanism.

Before attention, sequence models like RNNs were forced to compress an entire input sentence into a single fixed-length vector — a bottleneck that made them struggle with long sentences. Attention broke that constraint by allowing the model to look at all input tokens simultaneously and dynamically focus on the relevant ones at each decoding step.

💡
The Core Insight

Attention replaces the single bottleneck vector with a dynamic context vector built fresh at every decoding step — a weighted sum of all encoder hidden states, where weights reflect relevance. This lets models "look back" freely, solving the long-range dependency problem that crippled RNNs.


Section 02

The Foundation — Why RNNs Needed Help

To understand why attention was revolutionary, you first need to feel the pain it solved. The classic encoder-decoder RNN architecture compressed an entire source sentence into a single vector — the final hidden state — before decoding.

⚠️ The Bottleneck Problem — Machine Translation
Input
"The cat that the dog that the child played with chased sat on the mat." — 18 words compressed into one vector
Encode
RNN reads left-to-right, updating hidden state at each step — early words are gradually overwritten
Bottleneck
Final hidden state h_n must encode the entire sentence — impossible for long inputs
Decode
Every output word is generated from the same fixed vector — h_n regardless of which input word is most relevant
Result
Translation quality degrades sharply beyond ~10–15 words — verified empirically by Cho et al., 2014
⚠️
The Vanishing Gradient + Bottleneck Double Curse

RNNs suffer two problems simultaneously: vanishing gradients (early tokens receive near-zero gradient signal during backpropagation) and the information bottleneck (the entire input is squeezed into one vector). LSTMs partially solved the gradient problem but the bottleneck remained — until Bahdanau et al. introduced attention in 2014.


Section 03

Bahdanau Attention — The Original Breakthrough

The Detective and the Witness Board
Imagine a detective reconstructing a crime. On a board hang photographs of every witness statement. When writing the conclusion about the getaway car, the detective doesn't re-read every statement equally — they focus most intensely on the three witnesses who mentioned the car, glance briefly at those who mentioned the street, and ignore those talking about the victim's clothes.

Bahdanau attention works exactly like this. At each decoding step, it assigns a score (relevance weight) to every encoder hidden state, then builds a focused summary weighted by those scores. The decoder always has access to all encoder states — it just learns which ones matter most right now.

The Three-Step Mechanism

01
Score Every Encoder State
For the current decoder hidden state s_{t-1} and each encoder hidden state h_i, compute a scalar alignment score using a small feed-forward neural network (the "alignment model"). This network is trained jointly with the whole system.
Score formula: e_{ti} = a(s_{t-1}, h_i)
02
Convert Scores to Weights via Softmax
Apply softmax across all encoder positions to get attention weights that sum to 1. A weight close to 1 means "focus here". A weight close to 0 means "ignore this".
Weight formula: α_{ti} = softmax(e_{ti})
03
Build the Context Vector
Compute a weighted sum of all encoder hidden states — a context vector that summarises the most relevant parts of the input for this specific decoding step.
Context: c_t = Σ α_{ti} · h_i — this changes every step!
Alignment Score
e_{ti} = v_a · tanh(W_a s_{t-1} + U_a h_i)
A 1-layer MLP with learned weights W_a, U_a and a scaling vector v_a. Produces a scalar relevance score.
Attention Weights
α_{ti} = exp(e_{ti}) / Σ exp(e_{tj})
Standard softmax normalisation across all source positions. Ensures weights are non-negative and sum to 1.
Context Vector
c_t = Σ_{i=1}^{T_x} α_{ti} · h_i
Weighted sum of all encoder states. Dynamically reconstructed at every decoding timestep t.
Decoder Update
s_t = f(s_{t-1}, y_{t-1}, c_t)
The new decoder state depends on the previous state, the last output word, and the fresh context vector.
🌟
What Bahdanau Changed Forever

The alignment matrix (attention weights visualised over source × target positions) was a revelation. For French-to-English translation, the diagonal structure showed the model had learned near-monotone alignment — without ever being told that "le chat" maps to "the cat". The model discovered linguistic structure purely from data.

Python Implementation — Bahdanau Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

class BahdanauAttention(nn.Module):
    """
    Bahdanau (Additive) Attention — Bahdanau et al., 2014
    'Neural Machine Translation by Jointly Learning to Align and Translate'
    """

    def __init__(self, hidden_dim):
        super().__init__()
        # W_a: projects decoder state
        self.W_a = nn.Linear(hidden_dim, hidden_dim, bias=False)
        # U_a: projects encoder states
        self.U_a = nn.Linear(hidden_dim, hidden_dim, bias=False)
        # v_a: collapses to a scalar score
        self.v_a = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, decoder_state, encoder_outputs):
        # decoder_state:   (batch, hidden_dim)
        # encoder_outputs: (batch, seq_len, hidden_dim)

        # Step 1: Score — combine decoder state with each encoder state
        dec_proj = self.W_a(decoder_state).unsqueeze(1)  # (batch,1,hidden)
        enc_proj = self.U_a(encoder_outputs)               # (batch,seq,hidden)
        energy   = self.v_a(torch.tanh(dec_proj + enc_proj)).squeeze(-1)
        # energy: (batch, seq_len)  — raw alignment scores

        # Step 2: Weights — softmax over source positions
        attn_weights = F.softmax(energy, dim=-1)  # (batch, seq_len)

        # Step 3: Context — weighted sum of encoder states
        context = torch.bmm(
            attn_weights.unsqueeze(1),   # (batch, 1, seq_len)
            encoder_outputs               # (batch, seq_len, hidden)
        ).squeeze(1)                      # → (batch, hidden)

        return context, attn_weights

# ── Quick test ──────────────────────────────────────────────
batch, seq_len, hidden = 4, 12, 256
attn = BahdanauAttention(hidden)

decoder_state    = torch.randn(batch, hidden)
encoder_outputs  = torch.randn(batch, seq_len, hidden)

context, weights = attn(decoder_state, encoder_outputs)
print(f"Context shape : {context.shape}")   # (4, 256)
print(f"Weights shape : {weights.shape}")   # (4, 12)
print(f"Weights sum   : {weights.sum(-1)}") # tensor([1., 1., 1., 1.])
OUTPUT
Context shape : torch.Size([4, 256]) Weights shape : torch.Size([4, 12]) Weights sum : tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

Section 04

Luong Attention — The Streamlined Successor

The Sommelier's Two Techniques
Bahdanau attention uses a small neural network to compare decoder and encoder states — elegant but relatively expensive. Minh-Thang Luong (2015) asked: can we get equal quality with simpler math?

Think of two wine sommeliers comparing a glass to a reference. The first (Bahdanau) chemically analyses both wines, blends them, then scores the combination. The second (Luong) simply measures how similar the two wines taste by multiplying their taste profiles together — a dot product. Faster, and often just as accurate. Luong proposed three scoring variants and showed they match or surpass Bahdanau on many tasks.

Luong vs Bahdanau — Key Differences

Property Bahdanau (Additive) Luong (Multiplicative)
Score functionv·tanh(Ws + Uh) — MLPdot, general, or concat
Context timingUsed to compute decoder step tUsed after decoder step t is computed
Decoder state useds_{t-1} (previous state)h_t (current state)
Trainable paramsMore — W_a, U_a, v_aFewer — just W for "general" variant
SpeedSlightly slowerFaster — simpler computation
Typical performanceExcellent on long sentencesCompetitive; "general" often best

Three Luong Scoring Functions

🤔
Dot Score
score(h_t, h̄_s) = h_t · h̄_s
Simple inner product of decoder and encoder state. Requires both to have the same dimensionality. Fast — no learned parameters. Surprisingly effective for many tasks.
✓ No extra params, very fast
✗ Requires same dim for encoder/decoder
🔧
General Score
score = h_t · W_a · h̄_s
Introduces a learned weight matrix W_a between the two states, allowing different dimensionalities and learning which dimensions of encoder/decoder to align. Usually the best-performing Luong variant.
✓ Flexible dims; learns alignment bias
✗ O(d²) parameters in W_a
📚
Concat Score
score = v·tanh(W·[h_t; h̄_s])
Concatenate both states and apply an MLP — very similar to Bahdanau's formulation. Luong included this for comparison. Slightly slower but captures complex interactions.
✓ Captures nonlinear interactions
✗ Closer to Bahdanau; slower

Python Implementation — Luong Attention

class LuongAttention(nn.Module):
    """
    Luong (Multiplicative) Attention — Luong et al., 2015
    'Effective Approaches to Attention-based Neural Machine Translation'
    Supports 'dot', 'general', and 'concat' scoring.
    """

    def __init__(self, hidden_dim, method="general"):
        super().__init__()
        self.method = method

        if method == "general":
            self.W_a = nn.Linear(hidden_dim, hidden_dim, bias=False)

        elif method == "concat":
            self.W_a = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
            self.v_a = nn.Linear(hidden_dim, 1, bias=False)

    def score(self, decoder_h, encoder_outputs):
        # decoder_h:       (batch, hidden)
        # encoder_outputs: (batch, seq_len, hidden)

        if self.method == "dot":
            # Simple dot product
            return torch.bmm(
                encoder_outputs,                  # (batch, seq, hidden)
                decoder_h.unsqueeze(-1)           # (batch, hidden, 1)
            ).squeeze(-1)                         # → (batch, seq)

        elif self.method == "general":
            # Learned transformation of encoder states
            transformed = self.W_a(encoder_outputs)  # (batch, seq, hidden)
            return torch.bmm(
                transformed,
                decoder_h.unsqueeze(-1)
            ).squeeze(-1)

        elif self.method == "concat":
            # Concatenate and project through MLP
            dec_exp = decoder_h.unsqueeze(1).expand_as(encoder_outputs)
            combined = torch.cat([dec_exp, encoder_outputs], dim=-1)
            return self.v_a(torch.tanh(self.W_a(combined))).squeeze(-1)

    def forward(self, decoder_h, encoder_outputs):
        # Step 1 + 2: Score → Softmax → Weights
        energy      = self.score(decoder_h, encoder_outputs) # (batch, seq)
        attn_weights = F.softmax(energy, dim=-1)             # (batch, seq)

        # Step 3: Context vector
        context = torch.bmm(
            attn_weights.unsqueeze(1),    # (batch, 1, seq)
            encoder_outputs                # (batch, seq, hidden)
        ).squeeze(1)                       # → (batch, hidden)

        return context, attn_weights

# ── Compare all three methods ───────────────────────────────
for method in ["dot", "general", "concat"]:
    attn   = LuongAttention(256, method=method)
    ctx, w = attn(decoder_state, encoder_outputs)
    print(f"[{method:8s}] context: {ctx.shape}, weights sum≈1: {w.sum(-1).allclose(torch.ones(4))}")
OUTPUT
[dot ] context: torch.Size([4, 256]), weights sum≈1: True [general ] context: torch.Size([4, 256]), weights sum≈1: True [concat ] context: torch.Size([4, 256]), weights sum≈1: True

Section 05

Self-Attention — When a Sentence Reads Itself

The Pronoun Resolution Game
Consider this sentence: "The trophy didn't fit in the suitcase because it was too big."

What does it refer to? The trophy or the suitcase?
Answer: the trophy. But how do you know? Because "big" plus "didn't fit" points to the trophy being the obstacle.

A human resolves this by letting every word in the sentence attend to every other word — "it" attends strongly to "trophy", weakly to "suitcase". This is precisely self-attention: each token in a sequence computing relevance scores against all other tokens in the same sequence — including itself. No encoder/decoder split. The sequence attends to itself.

Self-attention (also called intra-attention) was the key architectural innovation in the Transformer (Vaswani et al., 2017). Unlike Bahdanau/Luong attention which connected encoder states to decoder states, self-attention lets every position in a sequence build a representation informed by every other position.

The Query–Key–Value Framework

Self-attention is built around three learned linear projections of every token embedding: Query (Q), Key (K), and Value (V).

Query (Q)
W_Q · x_i → q_i
The question a token is asking. "What information do I need to understand my role in this sentence?" Each token's query is used to score against all keys. Think of it as the search query in a database lookup.
🔑
Key (K)
W_K · x_j → k_j
The label a token advertises about itself. "Here's what kind of information I contain." Keys are matched against queries to compute relevance. Think of them as database index keys — they determine who gets retrieved.
📝
Value (V)
W_V · x_j → v_j
The actual content a token contributes when attended to. After scoring queries against keys, the context vector is built from a weighted sum of values. Think of the actual database record content that gets returned.
🔑
The Search Engine Analogy

Imagine a search engine: you type a Query ("best Italian restaurant"), the engine matches it against document Keys ("Italian cuisine near you"), and returns the Values (actual restaurant info). Self-attention does this simultaneously for every token, in parallel, using learnable linear projections for Q, K, V.

The Scaled Dot-Product Attention Formula

Raw Score
QK^T
Matrix multiply queries by transposed keys. Element (i,j) = how much token i should attend to token j.
Scaling
QK^T / √d_k
Divide by √d_k (key dimension). Without this, dot products grow large as d_k increases, pushing softmax into near-zero gradient regions.
Attention Weights
softmax(QK^T / √d_k)
Softmax over the score matrix rows. Each row sums to 1. Row i = how token i distributes attention across all positions.
Output
Attention(Q,K,V) = softmax(QK^T / √d_k) · V
Weight the value matrix by the attention weights. Each token's new representation is a weighted blend of all values.
Why the √d_k Scaling Matters

If d_k = 64, the dot products can have variance ~64. Large magnitudes push softmax into regions where gradients are near zero — the model barely learns. Dividing by √64 = 8 brings variance to ~1, keeping gradients healthy. This single line of math was crucial for making Transformers trainable.

Python Implementation — Scaled Dot-Product Self-Attention

import math

class SelfAttention(nn.Module):
    """
    Single-head Scaled Dot-Product Self-Attention.
    Vaswani et al. 2017 — 'Attention Is All You Need'
    """

    def __init__(self, d_model, d_k=None):
        super().__init__()
        self.d_k = d_k if d_k else d_model
        # Three learned projection matrices
        self.W_Q = nn.Linear(d_model, self.d_k, bias=False)
        self.W_K = nn.Linear(d_model, self.d_k, bias=False)
        self.W_V = nn.Linear(d_model, self.d_k, bias=False)

    def forward(self, x, mask=None):
        # x: (batch, seq_len, d_model)

        # Step 1: Project to Q, K, V spaces
        Q = self.W_Q(x)  # (batch, seq, d_k)
        K = self.W_K(x)  # (batch, seq, d_k)
        V = self.W_V(x)  # (batch, seq, d_k)

        # Step 2: Compute scaled dot-product scores
        # scores[i,j] = how much token i attends to token j
        scale  = math.sqrt(self.d_k)
        scores = torch.bmm(Q, K.transpose(1, 2)) / scale  # (batch, seq, seq)

        # Step 3: Optional mask (causal / padding)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Step 4: Softmax → attention weights
        attn_weights = F.softmax(scores, dim=-1)  # (batch, seq, seq)

        # Step 5: Weighted sum of values
        output = torch.bmm(attn_weights, V)  # (batch, seq, d_k)

        return output, attn_weights

# ── Demo ─────────────────────────────────────────────────────
batch, seq, d_model = 2, 8, 64
sa = SelfAttention(d_model)

x_in = torch.randn(batch, seq, d_model)
out, weights = sa(x_in)

print(f"Input  shape: {x_in.shape}")      # (2, 8, 64)
print(f"Output shape: {out.shape}")       # (2, 8, 64)
print(f"Attn   shape: {weights.shape}")   # (2, 8, 8)
print(f"Row sums ≈ 1: {weights[0].sum(-1)}")
OUTPUT
Input shape: torch.Size([2, 8, 64]) Output shape: torch.Size([2, 8, 64]) Attn shape: torch.Size([2, 8, 8]) Row sums ≈ 1: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

Attention Map — Visualising "it" Resolution

🔎 Visualising Self-Attention — Sentence: "The trophy didn't fit in the suitcase because it was too big"
High
"it" → "trophy": weight ≈ 0.71 — the model correctly identifies the referent
Med
"it" → "fit": weight ≈ 0.15 — syntactic verb dependency
Low
"it" → "suitcase": weight ≈ 0.08 — plausible referent but lower confidence
Near 0
"it" → "the", "in", "because": functional words receive near-zero attention

Section 06

Multi-Head Attention — Parallel Perspectives

The Team of Detectives
One detective studying a crime can only focus on one thing at a time. But send eight detectives simultaneously, each with a different speciality — one tracking financial motives, one tracking physical evidence, one tracking timelines — and their combined report is far richer than any single investigator's findings.

Multi-head attention runs h independent attention mechanisms in parallel, each looking at the same sentence but from a different learned representational subspace. One head might specialise in syntactic dependencies ("subject-verb agreement"). Another in coreference ("it" → "trophy"). Another in positional proximity. The final output concatenates all heads, giving the model multiple simultaneous perspectives on the same input.

Why Multiple Heads?

👥
Representational Diversity
Each head projects Q, K, V into a different subspace (d_k = d_model / h). They learn to encode different aspects of language without interfering with each other.
→ syntactic, semantic, positional heads emerge
📈
Parallel Computation
All h heads run simultaneously on the GPU. Unlike an RNN that processes tokens sequentially, multi-head attention has O(1) sequential steps regardless of sentence length.
→ massive speed-up vs. recurrent models
🔴
Redundancy as Stability
If one head fails to learn a useful pattern, others compensate. The model is robust to head dropout — up to 20–30% of heads can be pruned post-training with minimal performance loss.
→ head pruning literature (Michel et al., 2019)

The Full Multi-Head Attention Formula

Per-Head Projections
Q_i = X·W_i^Q, K_i = X·W_i^K, V_i = X·W_i^V
Each head i has its own learned projection matrices W^Q, W^K, W^V of size (d_model × d_k) where d_k = d_model / h.
Per-Head Output
head_i = Attention(Q_i, K_i, V_i)
Standard scaled dot-product attention applied independently within each head's subspace.
Concatenate
MultiHead = Concat(head_1, …, head_h)
Stack all h head outputs along the last dimension. Result is (batch, seq, h × d_v) = (batch, seq, d_model).
Final Projection
MultiHead(Q,K,V) = Concat(heads) · W^O
A final learned output projection W^O (d_model × d_model) mixes information across heads.

Parameter Count — Where Does the Memory Go?

Component Shape Parameters (d_model=512, h=8)
W_i^Q per head(d_model, d_k)512 × 64 = 32,768
W_i^K per head(d_model, d_k)512 × 64 = 32,768
W_i^V per head(d_model, d_v)512 × 64 = 32,768
All 8 heads (Q+K+V)3 × 8 matrices3 × 8 × 32,768 = 786,432
W^O output proj(d_model, d_model)512 × 512 = 262,144
Total per layer≈ 1,048,576 (≈ 1M params)

Python Implementation — Multi-Head Attention

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Scaled Dot-Product Attention.
    Vaswani et al. 2017 — 'Attention Is All You Need'
    """

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model    = d_model
        self.num_heads  = num_heads
        self.d_k        = d_model // num_heads  # per-head dimension

        # Unified projection matrices — project all heads at once then split
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)  # output projection

    def split_heads(self, x):
        # x: (batch, seq, d_model) → (batch, num_heads, seq, d_k)
        batch, seq, _ = x.size()
        x = x.view(batch, seq, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # → (batch, heads, seq, d_k)

    def scaled_dot_product(self, Q, K, V, mask=None):
        scale  = math.sqrt(self.d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        weights = F.softmax(scores, dim=-1)
        return torch.matmul(weights, V), weights

    def forward(self, query, key, value, mask=None):
        # For self-attention: query = key = value = x
        # For cross-attention: query from decoder, key/value from encoder

        batch = query.size(0)

        # Step 1: Project and split into h heads
        Q = self.split_heads(self.W_Q(query))  # (batch, heads, seq_q, d_k)
        K = self.split_heads(self.W_K(key))    # (batch, heads, seq_k, d_k)
        V = self.split_heads(self.W_V(value))  # (batch, heads, seq_v, d_k)

        # Step 2: Per-head attention (runs in parallel on all heads)
        x, attn_weights = self.scaled_dot_product(Q, K, V, mask)
        # x: (batch, heads, seq, d_k)

        # Step 3: Concatenate heads
        x = x.transpose(1, 2).contiguous()  # (batch, seq, heads, d_k)
        x = x.view(batch, -1, self.d_model)   # (batch, seq, d_model)

        # Step 4: Final output projection
        output = self.W_O(x)  # (batch, seq, d_model)

        return output, attn_weights

# ── Demo: BERT-base style (12 heads, d_model=768) ───────────
d_model, num_heads = 768, 12
mha = MultiHeadAttention(d_model, num_heads)

x_bert = torch.randn(2, 16, d_model)  # batch=2, seq=16
out, w  = mha(x_bert, x_bert, x_bert)  # self-attention

total_params = sum(p.numel() for p in mha.parameters())
print(f"Output shape   : {out.shape}")         # (2, 16, 768)
print(f"Attn map shape : {w.shape}")           # (2, 12, 16, 16)
print(f"Total params   : {total_params:,}")    # 2,359,296
print(f"d_k per head   : {d_model // num_heads}")  # 64
OUTPUT
Output shape : torch.Size([2, 16, 768]) Attn map shape : torch.Size([2, 12, 16, 16]) Total params : 2,359,296 d_k per head : 64
🌟
What Each Head Actually Learns

Interpretability research (Clark et al., 2019 — "What Does BERT Look At?") found that different attention heads in BERT specialize remarkably: some heads track direct object relations, others track coreferent mentions, one head attends almost exclusively to the preceding token, and several heads in the final layers attend strongly to the [CLS] token for classification. Heads are not redundant — they're specialists.


Section 07

Three Types of Attention in a Transformer

The full Transformer architecture uses multi-head attention in three distinct roles, each with a different configuration of Q, K, V sources.

📄
Encoder Self-Attention
Q = K = V = encoder input. Every source token attends to every other source token. No masking. Builds rich contextualised representations of the input sequence. Used in BERT-like models for understanding tasks.
→ bidirectional: all positions see all positions
🔒
Masked Decoder Self-Attention
Q = K = V = decoder input, but a causal mask prevents position t from attending to positions > t. Enforces autoregressive generation — the model can't "cheat" by looking at future tokens. Used in GPT-like models.
→ causal: position t sees only positions ≤ t
🔗
Encoder–Decoder Cross-Attention
Q = decoder states, K = V = encoder outputs. Each decoder position queries the full encoder output. This is the modern equivalent of Bahdanau attention — the bridge connecting source understanding to target generation.
→ Q from decoder, K/V from encoder
Attention Type Query Source Key/Value Source Masking? Example Model
Encoder Self-Attention Encoder layer input Same None BERT, RoBERTa
Decoder Self-Attention Decoder layer input Same Causal mask GPT-2, GPT-4, LLaMA
Cross-Attention Decoder states Encoder outputs None T5, BART, mT5

Section 08

Attention Complexity — The Elephant in the Room

Multi-head attention is powerful but expensive. Understanding its complexity is essential for production systems and for understanding why models like Longformer and FlashAttention exist.

Resource Standard Attention Implication
Time complexity O(n² · d) Quadratic in sequence length n — doubling the sequence length quadruples compute time
Memory complexity O(n²) The attention matrix (n × n) must fit in GPU memory — bottleneck for long documents
Sequence length 512 262,144 entries Manageable on modern hardware
Sequence length 4,096 16,777,216 entries Requires efficient attention (FlashAttention, etc.)
Sequence length 32,768 Over 1 billion entries Impossible with standard attention without hardware tricks
🚀
FlashAttention — The Modern Solution

FlashAttention (Dao et al., 2022) reorders the attention computation to be IO-aware — it tiles the Q, K, V matrices and avoids materialising the full n×n attention matrix in slow HBM memory. This achieves the same mathematical result as standard attention but with 2–4× less memory and 2–4× faster wall-clock time. GPT-4 and modern LLMs all use FlashAttention internally.


Section 09

Positional Encoding — Giving Attention a Sense of Order

The Scrambled Telegram
Self-attention treats a sentence like a bag of words — it computes relationships between all pairs of tokens but has no notion of word order. "Dog bites man" and "Man bites dog" would produce identical attention scores — a disaster for understanding.

The Transformer solves this by adding a positional encoding to each token embedding before feeding it into the attention layers. Each position in the sequence gets a unique signal — like stamping each telegram with a timestamp — so the model can distinguish "word at position 3" from "word at position 7".
Sine Encoding (even dims)
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
Even-indexed dimensions receive a sine wave with frequency determined by the dimension index i.
Cosine Encoding (odd dims)
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Odd-indexed dimensions receive a cosine wave. Together, sine+cosine create a unique fingerprint for each position.
import numpy as np

def positional_encoding(max_seq, d_model):
    """
    Compute sinusoidal positional encodings.
    Returns tensor of shape (max_seq, d_model).
    """
    PE = np.zeros((max_seq, d_model))
    pos = np.arange(max_seq)[:, np.newaxis]           # (max_seq, 1)
    i   = np.arange(0, d_model, 2)[np.newaxis, :]      # (1, d_model/2)

    div_term = np.power(10000, (2 * i) / d_model)

    PE[:, 0::2] = np.sin(pos / div_term)   # even dims → sine
    PE[:, 1::2] = np.cos(pos / div_term)   # odd dims  → cosine

    return PE

PE = positional_encoding(100, 512)
print(f"PE shape: {PE.shape}")   # (100, 512)
print(f"Values range: [{PE.min():.3f}, {PE.max():.3f}]")   # [-1.000, 1.000]
OUTPUT
PE shape: (100, 512) Values range: [-1.000, 1.000]

Section 10

Full Transformer Block — Putting It All Together

class TransformerEncoderBlock(nn.Module):
    """
    A single Transformer encoder block:
    Multi-Head Self-Attention → Add&Norm → FFN → Add&Norm
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        # ── Multi-head self-attention ───────────────────────────
        self.mha    = MultiHeadAttention(d_model, num_heads)
        self.norm1  = nn.LayerNorm(d_model)
        self.drop1  = nn.Dropout(dropout)

        # ── Feed-forward network ────────────────────────────────
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 1. Multi-head self-attention + residual
        attn_out, _ = self.mha(x, x, x, mask)
        x = self.norm1(x + self.drop1(attn_out))   # Add & Norm

        # 2. Position-wise feed-forward + residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.drop2(ffn_out))    # Add & Norm

        return x

# ── Stack 6 blocks (original Transformer encoder) ───────────
encoder = nn.Sequential(
    *[TransformerEncoderBlock(d_model=512, num_heads=8, d_ff=2048)
      for _ in range(6)]
)

src = torch.randn(2, 32, 512)  # batch=2, seq=32, d_model=512
out = encoder(src)
total = sum(p.numel() for p in encoder.parameters())
print(f"Encoder output : {out.shape}")    # (2, 32, 512)
print(f"Encoder params : {total:,}")      # ~25M
OUTPUT
Encoder output : torch.Size([2, 32, 512]) Encoder params : 25,690,112

Section 11

Comparison — Bahdanau vs Luong vs Self-Attention

Property Bahdanau (2014) Luong (2015) Self-Attention / MHA (2017)
Architecture role Encoder → Decoder bridge Encoder → Decoder bridge Within-sequence or cross-sequence
Score function Additive MLP (v·tanh(Ws+Uh)) Dot / General / Concat Scaled dot-product (QKᵀ/√d_k)
Parallelisable? No — sequential decoding No — sequential decoding Yes — O(1) sequential steps
Multiple perspectives? Single head Single head h parallel heads
Seq complexity O(n) sequential steps O(n) sequential steps O(n²) attention matrix
State-of-art in 2025? Superseded Superseded Foundation of all modern LLMs
Still studied? Historically important Historically important Active research frontier

Section 12

Golden Rules — Attention in Practice

🌟 Attention Mechanisms — Non-Negotiable Rules for Practitioners
1
Always scale by √d_k. Without scaling, dot products grow as d_k increases, pushing softmax into near-zero gradient regions. Even a missing sqrt can stall training completely on deeper models.
2
Use causal masking for autoregressive generation. If position t can attend to position t+1, the model leaks future information during training — resulting in near-perfect training loss but zero generalisation. Apply the upper-triangular mask.
3
d_model must be divisible by num_heads. With d_model=768 and num_heads=12, each head gets d_k=64. If not evenly divisible, the split fails. Always assert this in your constructor.
4
For long sequences (>2,048 tokens), standard attention becomes memory-prohibitive. Use torch.nn.functional.scaled_dot_product_attention (PyTorch 2.0+) which automatically dispatches to FlashAttention when available. Never implement naive O(n²) attention for production systems.
5
Add positional encoding before the first attention layer. Pure attention is permutation-invariant — it cannot distinguish "dog bites man" from "man bites dog" without positional information. Sinusoidal (original Transformer) or learned embeddings (BERT) both work; RoPE (Rotary Position Embeddings) is the modern default.
6
Visualise attention maps during debugging. They are your X-ray machine. If all tokens attend uniformly (flat attention), your model isn't learning meaningful alignment. If a head is always attending to the [CLS] token, that head may be "no-op". Use BertViz or a simple heatmap with matplotlib.
7
For cross-attention (encoder-decoder): Q comes from the decoder, K and V come from the encoder. A common mistake is swapping this, making K/V come from the decoder — the model will "attend" to its own past outputs for context, leading to degenerate generations.

Section 13

Modern Variants and Where Attention Is Going

FlashAttention-3
IO-aware attention that avoids materialising the full n×n matrix in HBM. Achieves near-peak GPU utilisation. Used in GPT-4, LLaMA 3, Gemini. Reduces memory from O(n²) to O(n).
→ Dao et al., 2022/2023
🍃
RoPE Embeddings
Rotary Position Embeddings encode relative position information directly into the Q and K vectors via rotation matrices. Enables better length generalisation and is the default in LLaMA, Mistral, and Qwen.
→ Su et al., 2021 (RoFormer)
🗘
Grouped Query Attention
GQA shares K/V heads across groups of Q heads, reducing the KV-cache size during inference by h/groups factor. LLaMA-2 70B and Mistral 7B use GQA — critical for long-context deployment on constrained hardware.
→ Ainslie et al., 2023 (GQA)
🚀
Multi-Query Attention
Extreme form of GQA: a single K/V head is shared by all Q heads. Reduces KV-cache memory by h×. First used in PaLM and adopted in Falcon. Sacrifices some quality for significant inference speed improvements.
→ Shazeer, 2019 (Fast Transformer)
📋
Sliding Window Attention
Each token attends only to a local window of w neighbours. Used in Longformer and Mistral to achieve O(n·w) attention complexity. Combined with global tokens for document-level tasks — handles 32K+ context windows affordably.
→ Beltagy et al., 2020 (Longformer)
🧠
Linear Attention
Approximates full attention using kernel functions to factorise the attention matrix, reducing complexity to O(n). Models include Performer, Linear Transformer. Trading some accuracy for scalability to very long contexts.
→ Katharopoulos et al., 2020
🏆
The Practitioner's Guide to Choosing Attention

Use full multi-head attention for sequences under 4K tokens and fine-tuning pretrained models. Use FlashAttention (via PyTorch 2.0 SDPA) for all GPU training. Use GQA or MQA when deploying with long contexts and memory constraints. For research on >32K context, explore sliding window + global token hybrids. Bahdanau/Luong are useful for teaching and for custom seq2seq tasks where you want explicit control over the alignment model.