The Story That Explains Vision Transformers
The second method is a detective. Sherlock Holmes walks in, glances at the whole room in one sweep, then instantly decides: "That bloodstain relates to that broken window relates to that overturned chair." He sees relationships across the entire scene simultaneously, regardless of distance. That is exactly how Vision Transformers (ViTs) work.
The detective wins on complex scenes — and in 2020, a paper called "An Image is Worth 16×16 Words" proved the same for deep learning.
A Vision Transformer (ViT) applies the Transformer architecture — originally invented for text — directly to images. Instead of words, it splits an image into small fixed-size patches, treats each patch as a "token", and lets every patch attend to every other patch simultaneously through self-attention. The result is a model that understands global context from the very first layer, something CNNs only achieve deep in their stacks.
Language Transformers like BERT and GPT learn relationships between words in a sentence. Vision Transformers do the same for image patches in a picture. The math is identical — only the input changes. This elegant reuse is why ViTs dominated computer vision benchmarks within two years of their introduction.
A Brief History — From Pixels to Patches
To appreciate ViTs, you need to understand what came before — and why it wasn't enough.
CNN vs ViT — The Fundamental Difference
| Property | CNN (e.g. ResNet) | Vision Transformer (ViT) |
|---|---|---|
| How it reads an image | Small local filters, layer by layer | All patches at once via self-attention |
| Receptive field growth | Grows slowly through depth | Global from layer 1 |
| Inductive bias | Strong: locality, translation equivariance | Weak: must learn spatial structure from data |
| Data requirement | Can train well on small datasets | Needs large-scale data (or pretrain) |
| Long-range dependencies | Hard — requires many layers | Easy — direct attention between any patches |
| Scalability | Good but plateaus earlier | Excellent — scales with compute and data |
| Interpretability | Hard (CAM, Grad-CAM needed) | Attention maps show patch importance directly |
| Transfer learning | Excellent for similar domains | Excellent — especially with MAE/DINO pretraining |
A ViT trained from scratch on ImageNet alone (1.2M images) performs worse than ResNet. It needs JFT-300M (300 million images) to shine. This is the tradeoff: no inductive bias means no free lunch — ViTs must learn what CNNs assume. The solution: pretraining + fine-tuning, or modern self-supervised methods like MAE.
How ViT Works — Step by Step
Let us trace a single image through a Vision Transformer from pixel array to class prediction.
Patch size controls the sequence length vs. resolution tradeoff. Smaller patches (8×8) = 784 tokens for a 224×224 image — extremely expensive for self-attention which scales as O(N²). Larger patches (32×32) = 49 tokens — fast but coarse, missing fine detail. 16×16 (196 tokens) is the sweet spot: rich enough for most vision tasks, tractable for GPU memory.
Self-Attention — The Heart of ViT
Self-attention works the same way. Every patch in an image asks: "Which other patches should I pay most attention to?" A patch containing an eye might attend strongly to the nose patch and the other eye patch, ignoring the background sky. This selective focus is computed simultaneously for every patch — that's the power of self-attention.
Formally, self-attention computes three vectors for each patch token:
Multi-Head Attention
Instead of one set of Q, K, V projections, ViT uses multiple heads in parallel — typically 12 or 16. Each head learns to attend to a different aspect: one head may focus on local texture, another on long-range spatial relationships, another on colour continuity. Outputs from all heads are concatenated and projected back to dimension D.
ViT Architecture Variants — The Family Tree
The original paper introduced three model sizes. The community has since expanded this family dramatically.
| Model | Layers (L) | Hidden Dim (D) | Heads | Params | ImageNet Top-1 |
|---|---|---|---|---|---|
| ViT-Base/16 | 12 | 768 | 12 | 86M | 81.8% |
| ViT-Large/16 | 24 | 1024 | 16 | 307M | 85.2% |
| ViT-Huge/14 | 32 | 1280 | 16 | 632M | 88.6% |
| DeiT-Small | 12 | 384 | 6 | 22M | 79.8% (no large pretrain) |
| Swin-T | 4 stages | 96→768 | 3,6,12,24 | 28M | 81.3% |
| Swin-L | 4 stages | 192→1536 | 6,12,24,48 | 197M | 87.3% |
Swin Transformer — The Windowed Attention Trick
Swin Transformer solves this the way a smart reader would: divide the newspaper into sections, read each section locally. Then, in the next layer, shift the section boundaries so the middle of one old section is now the centre of a new one — allowing cross-boundary communication to eventually propagate. Fast and effective.
| Property | Value |
|---|---|
| Attention scope | All N patches ↔ all N patches |
| Complexity | O(N²) — quadratic |
| 224×224 / patch 16 | 196 × 196 = 38,416 attention pairs |
| 512×512 / patch 16 | 1,024 × 1,024 = 1,048,576 pairs |
| Dense prediction (detect/seg) | Expensive — feature maps always flat |
| Property | Value |
|---|---|
| Attention scope | Within local M×M windows (default M=7) |
| Complexity | O(N) — linear in image size |
| 224×224 / patch 4 | 49 patches per window × many windows |
| Cross-window comms | Shifted windows in alternating layers |
| Dense prediction | Excellent — hierarchical feature maps like CNNs |
Positional Encoding — Teaching ViT Where Things Are
A Transformer is permutation-invariant by design. Shuffle all the patches randomly and without positional encoding, the model cannot tell the difference. For images — where a cat's ears are above its nose, and the sky is above the ground — spatial position is crucial.
Code — Building a ViT from Scratch in PyTorch
Let us implement a minimal but complete Vision Transformer. Every line maps directly to the theory above.
import torch
import torch.nn as nn
import math
# ─────────────────────────────────────────────────────
# 1. PATCH EMBEDDING
# Splits image into patches and linearly projects each
# ─────────────────────────────────────────────────────
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2 # 196
# Conv2d with kernel=patch_size, stride=patch_size = non-overlapping patches
self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2) # (B, embed_dim, num_patches)
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
return x
# ─────────────────────────────────────────────────────
# 2. MULTI-HEAD SELF-ATTENTION
# ─────────────────────────────────────────────────────
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # 64
self.scale = self.head_dim ** -0.5 # 1/√64
self.qkv = nn.Linear(embed_dim, embed_dim * 3) # project to Q, K, V at once
self.proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# Split into Q, K, V and reshape for multi-head
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv.unbind(0)
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.proj(x)
# ─────────────────────────────────────────────────────
# 3. TRANSFORMER ENCODER BLOCK
# ─────────────────────────────────────────────────────
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
x = x + self.attn(self.norm1(x)) # Pre-norm + residual
x = x + self.mlp(self.norm2(x)) # Pre-norm + residual
return x
# ─────────────────────────────────────────────────────
# 4. FULL VISION TRANSFORMER
# ─────────────────────────────────────────────────────
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.num_patches
# Learnable [CLS] token and positional embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(dropout)
# Stack of Transformer encoder blocks
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# Initialise weights
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.size(0)
x = self.patch_embed(x) # (B, 196, 768)
cls = self.cls_token.expand(B, -1, -1) # (B, 1, 768)
x = torch.cat((cls, x), dim=1) # (B, 197, 768)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x) # 12 Transformer layers
x = self.norm(x)
cls_out = x[:, 0] # Extract [CLS] token
return self.head(cls_out) # (B, num_classes)
# ─── Instantiate and verify ────────────────────────
model = VisionTransformer(
img_size=224, patch_size=16, embed_dim=768, depth=12,
num_heads=12, num_classes=1000
)
dummy = torch.randn(2, 3, 224, 224) # batch of 2 RGB images
out = model(dummy)
print(f"Output shape: {out.shape}")
params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {params:,}")
Code — Fine-Tuning a Pretrained ViT with Hugging Face
Training ViT from scratch requires massive datasets. In practice, you fine-tune a pretrained model.
Hugging Face transformers makes this three lines of model code.
from transformers import ViTForImageClassification, ViTImageProcessor
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# ── 1. Load pretrained ViT-Base/16 (ImageNet-21k checkpoint) ──
MODEL_NAME = "google/vit-base-patch16-224-in21k"
num_classes = 10 # e.g. CIFAR-10
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_classes,
ignore_mismatched_sizes=True # replaces the 21k classifier head
)
# ── 2. Load and preprocess dataset ──────────────────────────
dataset = load_dataset("cifar10")
def transform_fn(examples):
inputs = processor(images=examples["img"], return_tensors="pt")
inputs["labels"] = torch.tensor(examples["label"])
return inputs
dataset = dataset.with_transform(transform_fn)
train_loader = DataLoader(dataset["train"], batch_size=32, shuffle=True)
val_loader = DataLoader(dataset["test"], batch_size=64)
# ── 3. Training loop ─────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
for epoch in range(10):
model.train()
total_loss = 0
for batch in train_loader:
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
outputs = model(pixel_values=pixel_values, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
# Validation accuracy
model.eval()
correct, total = 0, 0
with torch.no_grad():
for batch in val_loader:
pv = batch["pixel_values"].to(device)
lbl = batch["labels"].to(device)
out = model(pixel_values=pv)
pred = out.logits.argmax(-1)
correct += (pred == lbl).sum().item()
total += lbl.size(0)
scheduler.step()
print(f"Epoch {epoch+1:2d} | Loss: {total_loss/len(train_loader):.4f} | "
f"Val Acc: {correct/total*100:.2f}%")
State-of-the-art CIFAR-10 accuracy from a pretrained ViT fine-tuned for 10 epochs. The same model trained from scratch would need weeks of training on TPU pods to approach this. This is the power of pretraining + fine-tuning — the most important pattern in modern computer vision.
Visualising Attention Maps
One of ViT's unique advantages: attention weights are directly interpretable. We can visualise what patches each token attended to, giving genuine insight into what the model "sees".
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import ViTModel, ViTImageProcessor
# Load model with output_attentions=True
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTModel.from_pretrained("google/vit-base-patch16-224",
output_attentions=True)
model.eval()
# Preprocess a test image
img = Image.open("test_image.jpg").convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# attentions: list of (batch, heads, N+1, N+1) per layer
# Take last layer, average across heads, focus on [CLS] → patch attention
last_layer_attn = outputs.attentions[-1] # (1, 12, 197, 197)
avg_head_attn = last_layer_attn[0].mean(dim=0) # (197, 197)
cls_attn = avg_head_attn[0, 1:] # [CLS] attending to 196 patches
attn_map = cls_attn.reshape(14, 14).numpy()
# Upscale attention map to image resolution
attn_resized = np.array(Image.fromarray(attn_map).resize(
(224, 224), Image.BILINEAR))
# Plot original + attention overlay
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img); axes[0].set_title("Original Image")
axes[1].imshow(img)
axes[1].imshow(attn_resized, alpha=0.6, cmap="jet")
axes[1].set_title("Attention Map — [CLS] token")
for ax in axes: ax.axis("off")
plt.tight_layout()
plt.savefig("attention_map.png", dpi=150)
plt.show()
The attention heatmap will highlight the foreground object (the cat, the car, the person) with high activation, and dim the background — even without segmentation labels. The [CLS] token learns to aggregate the most discriminative patches purely from classification supervision. This is emergent object detection behaviour.
Masked Autoencoders (MAE) — Self-Supervised ViT Pretraining
That is Masked Autoencoder (MAE) pretraining. No labels. Just: mask 75% of image patches at random, reconstruct the raw pixel values of missing patches. The model is forced to learn rich, semantic representations of the world.
| Metric | ViT-Large + Supervised | ViT-Large + MAE pretrain |
|---|---|---|
| ImageNet Top-1 (fine-tuned) | 85.2% | 86.9% |
| Labels used for pretraining | 1.28M (ImageNet) | Zero |
| Pretraining data | ImageNet-1K | ImageNet-1K (same!) |
| Pretraining compute | High | 3–4× faster than supervised |
| Transfer to ADE20K seg. | 53.6 mIoU | 55.1 mIoU |
ViT for Dense Prediction — Object Detection & Segmentation
Original ViT produces a single [CLS] vector — perfect for classification, useless for detection or segmentation, which require per-pixel or per-region features. Several approaches solve this.
ViT vs CNN — When to Use Which
| Scenario | Best Choice | Why |
|---|---|---|
| Small dataset (<10K images), training from scratch | CNN (EfficientNet, ResNet) | Strong inductive bias compensates for limited data |
| Large dataset (>1M images), training from scratch | ViT | Scales better; surpasses CNNs at large scale |
| Fine-tuning pretrained model — classification | ViT (pretrained) | ImageNet-21k or MAE ViT transfers extremely well |
| Object detection, instance segmentation | Swin Transformer | Hierarchical features match detector needs; linear complexity |
| Zero-shot / universal segmentation | SAM (ViT-Huge) | Trained on 1B masks, prompts with any input modality |
| Mobile / edge deployment | CNN (MobileNet, EfficientNet-Lite) | ViT attention is heavy for mobile; CNNs have mature quantisation |
| Medical imaging (small, specialised dataset) | Hybrid or pretrained ViT with heavy augmentation | Limited data hurts pure ViT; pretrained features help enormously |
| Self-supervised feature learning | ViT + MAE or DINO | Masking/self-distillation works far better on Transformers than CNNs |
Training Tips & Tricks for ViTs
ViTs require different training recipes than CNNs. These are the non-negotiable adjustments.
lr=1e-3 with warm-up and cosine decay. SGD diverges or converges
30–40% slower on ViTs.
lr=1e-6, ramp linearly to
peak LR, then cosine anneal. Skipping warm-up causes instability in the first
few hundred steps.
weight_decay=0.05. Explicitly exclude LayerNorm.weight,
LayerNorm.bias, and all bias terms from the weight decay group.
Decaying norms destabilises training.
interpolate_pos_encoding=True in Hugging Face to handle this automatically.
Complete Training Recipe — Code
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from timm import create_model
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm.scheduler import CosineLRScheduler
from torch.optim import AdamW
# ── 1. Model — ViT-Base/16 from timm ──────────────────────
model = create_model(
"vit_base_patch16_224",
pretrained=True,
num_classes=10
)
# ── 2. Augmentation pipeline (DeiT-style) ─────────────────
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.25),
])
# ── 3. Mixup + CutMix (timm's Mixup handles both) ─────────
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0,
prob=1.0, switch_prob=0.5,
mode="batch", label_smoothing=0.1, num_classes=10
)
criterion = SoftTargetCrossEntropy() # works with soft Mixup labels
# ── 4. Optimiser — split weight decay groups ───────────────
no_decay = ["bias", "LayerNorm.weight"]
optim_groups = [
{"params": [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)],
"weight_decay": 0.05},
{"params": [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)],
"weight_decay": 0.0},
]
optimizer = AdamW(optim_groups, lr=1e-3)
# ── 5. Cosine LR schedule with warm-up ────────────────────
scheduler = CosineLRScheduler(
optimizer, t_initial=100, # 100 total epochs
lr_min=1e-5,
warmup_t=5, # 5-epoch warm-up
warmup_lr_init=1e-6,
)
# ── 6. Training loop ───────────────────────────────────────
device = torch.device("cuda")
model.to(device)
for epoch in range(100):
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
images, soft_labels = mixup_fn(images, labels)
logits = model(images)
loss = criterion(logits, soft_labels)
loss.backward()
# Gradient clipping — important for ViT stability
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
scheduler.step(epoch + 1)
clip_grad_norm_(model.parameters(), max_norm=1.0) is not optional for ViTs.
Self-attention can produce very large gradient norms, especially in early training.
Clipping stabilises training without sacrificing convergence speed.
Standard practice: max_norm=1.0 for ViT-Base and Large.
ViT Performance on Major Benchmarks
| Task | Dataset | Model | Score | Method |
|---|---|---|---|---|
| Image Classification | ImageNet-1K | ViT-Huge/14 + MAE | 90.0% Top-1 | MAE pretrain → fine-tune |
| Image Classification | ImageNet-1K | DINOv2 ViT-g/14 | 86.5% linear probe | Linear layer only, no fine-tune |
| Object Detection | COCO | Swin-L + Cascade Mask RCNN | 57.7 APbox | Multi-scale training |
| Semantic Segmentation | ADE20K | Swin-L + UperNet | 53.5 mIoU | 160K iterations |
| Semantic Segmentation | ADE20K | ViT-Large + MAE | 55.1 mIoU | MAE pretrain + UperNet |
| Depth Estimation | NYU Depth v2 | DINOv2 + DPT head | 0.206 absRel | Zero-shot transfer |
| Medical (Pathology) | CAMELYON16 | ViT-Base + DINOv2 | 0.989 AUC | Slide-level linear probe |
Compute & Memory Complexity — The Numbers
| Component | Time Complexity | Memory | Bottleneck? |
|---|---|---|---|
| Patch Embedding (Conv2d) | O(N·D) | Low | No |
| Self-Attention (global) | O(N²·D) | O(N²) attention matrix | Yes — quadratic in patches |
| Self-Attention (Swin window) | O(N·M²·D) | O(M²) per window | No — linear in image size |
| MLP Block | O(N·D²) | Low | Moderate — large D |
| Flash Attention (optimised) | O(N²·D) | O(N) — IO-aware | 3–5× faster in practice |
Flash Attention (Dao et al. 2022) rewrites the attention kernel to be IO-aware.
It fuses the QK^T, softmax, and V multiplication into one GPU kernel with no large
attention matrix materialised in HBM. Result: 3–8× faster and uses
10–20× less memory than standard attention — with exactly identical outputs.
Enable it with torch.nn.functional.scaled_dot_product_attention in PyTorch 2.0+.
Real-World Applications
The Future — What Comes After ViT?
Research is pushing in four directions simultaneously: efficiency (linear attention, state space models like Mamba-Vision), multimodality (unified vision-language-action models), scalability (ViT-22B with 22 billion parameters), and video understanding (TimeSformer, VideoMAE — treating video as a space-time volume of patches).
Standard self-attention scales as O(N²) in sequence length. For video (thousands of tokens), high-resolution images, or 3D medical scans, this remains a hard practical limit. Flash Attention, windowed attention (Swin), and SSMs each attack this differently. No consensus winner has emerged yet — this is the most active research frontier in vision architectures.
Quick Reference — Key ViT Parameters
| Parameter | ViT-Base | ViT-Large | What it controls |
|---|---|---|---|
patch_size | 16 | 16 | Resolution of each patch token. Smaller = more tokens, finer detail, more compute. |
embed_dim (D) | 768 | 1024 | Width of the model — size of every token vector. |
depth (L) | 12 | 24 | Number of Transformer encoder blocks stacked. |
num_heads | 12 | 16 | Parallel attention heads. Each captures different relationships. |
mlp_ratio | 4.0 | 4.0 | MLP hidden dim = embed_dim × mlp_ratio. Controls non-linear capacity. |
dropout | 0.0 | 0.1 | Applied to attention weights and MLP. Use 0.0 with strong augmentation. |
num_classes | 1000 | 1000 | Output dimension. Change this for your task. |
Use a pretrained ViT-Base/16 from Hugging Face or timm as your default starting point for any image classification or feature extraction task. Fine-tune for 10–30 epochs with AdamW + cosine schedule + RandAugment. For detection or segmentation, use Swin-Transformer. For features without fine-tuning, use DINOv2. Only go larger (ViT-Large, ViT-Huge) when Base saturates.