Computer Vision 📂 Computer Vision Basics · 10 of 12 53 min read

Vision Transformers in Computer Vision

A comprehensive, story-driven guide to Vision Transformers (ViTs) — how they split images into patches, apply self-attention across all patches simultaneously, and why they surpassed CNNs on every major benchmark.

Section 01

The Story That Explains Vision Transformers

The Detective vs. The Scanner
Imagine two ways to read a crime scene photo. The first method is a security scanner — it slides from left to right, top to bottom, reading a tiny patch at a time, always in order, never looking at the big picture until it has finished every inch. That is how Convolutional Neural Networks (CNNs) work: systematic, local, sequential.

The second method is a detective. Sherlock Holmes walks in, glances at the whole room in one sweep, then instantly decides: "That bloodstain relates to that broken window relates to that overturned chair." He sees relationships across the entire scene simultaneously, regardless of distance. That is exactly how Vision Transformers (ViTs) work.

The detective wins on complex scenes — and in 2020, a paper called "An Image is Worth 16×16 Words" proved the same for deep learning.

A Vision Transformer (ViT) applies the Transformer architecture — originally invented for text — directly to images. Instead of words, it splits an image into small fixed-size patches, treats each patch as a "token", and lets every patch attend to every other patch simultaneously through self-attention. The result is a model that understands global context from the very first layer, something CNNs only achieve deep in their stacks.

💡
The Core Insight

Language Transformers like BERT and GPT learn relationships between words in a sentence. Vision Transformers do the same for image patches in a picture. The math is identical — only the input changes. This elegant reuse is why ViTs dominated computer vision benchmarks within two years of their introduction.


Section 02

A Brief History — From Pixels to Patches

To appreciate ViTs, you need to understand what came before — and why it wasn't enough.

1986
Convolutional Neural Networks (CNNs) Born
LeCun et al. introduce backpropagation through convolutional layers. CNNs learn local spatial features via sliding filters — edges, textures, shapes — hierarchically stacked. AlexNet (2012) brings CNNs to the mainstream, winning ImageNet by a wide margin.
2017
Transformer Architecture Arrives (NLP)
Vaswani et al. publish "Attention Is All You Need". The Transformer replaces recurrence with self-attention — every token in a sequence directly attends to every other token. Dominates NLP instantly. Vision researchers take note.
2019
Hybrid Attempts (Limited Success)
Researchers bolt self-attention onto CNN backbones (Non-local Networks, DETR). DETR uses Transformers for object detection on top of CNN features. Progress, but still CNN-dependent and computationally expensive.
2020
ViT — The Breakthrough
Dosovitskiy et al. (Google Brain) publish "An Image is Worth 16×16 Words". Pure Transformer, no convolutions at all. On ImageNet with large-scale pretraining, ViT-Large matches and exceeds state-of-the-art CNNs at far lower compute cost.
2021+
ViT Variants Explode — DeiT, Swin, BEiT, MAE
The field races. DeiT introduces knowledge distillation for ViTs. Swin Transformer uses hierarchical windowed attention. BEiT and MAE pioneer masked image modelling for self-supervised ViT pretraining. ViTs now power image generation, medical imaging, robotics, and autonomous driving.

Section 03

CNN vs ViT — The Fundamental Difference

Property CNN (e.g. ResNet) Vision Transformer (ViT)
How it reads an image Small local filters, layer by layer All patches at once via self-attention
Receptive field growth Grows slowly through depth Global from layer 1
Inductive bias Strong: locality, translation equivariance Weak: must learn spatial structure from data
Data requirement Can train well on small datasets Needs large-scale data (or pretrain)
Long-range dependencies Hard — requires many layers Easy — direct attention between any patches
Scalability Good but plateaus earlier Excellent — scales with compute and data
Interpretability Hard (CAM, Grad-CAM needed) Attention maps show patch importance directly
Transfer learning Excellent for similar domains Excellent — especially with MAE/DINO pretraining
⚠️
The Data Hunger Problem

A ViT trained from scratch on ImageNet alone (1.2M images) performs worse than ResNet. It needs JFT-300M (300 million images) to shine. This is the tradeoff: no inductive bias means no free lunch — ViTs must learn what CNNs assume. The solution: pretraining + fine-tuning, or modern self-supervised methods like MAE.


Section 04

How ViT Works — Step by Step

Let us trace a single image through a Vision Transformer from pixel array to class prediction.

📷 Full ViT Pipeline — Input Image → Class Label
Step 1
Patch Splitting: Divide the image (e.g. 224×224) into non-overlapping patches. Default size: 16×16 px. That gives 224÷16 = 14 × 14 = 196 patches.
Step 2
Patch Embedding: Flatten each 16×16×3 patch into a vector of length 768 (16×16×3). Project via a linear layer to the model's hidden dimension D. This is the embedding table for patches.
Step 3
Class Token Prepend: A learnable [CLS] token is prepended to the sequence (now 197 tokens). After all layers, this token holds the aggregated image representation used for classification.
Step 4
Positional Encoding: Add learnable 1D positional embeddings to each token. Without this, the model is permutation-invariant — it cannot know patch 3 is next to patch 4. Positional encoding injects spatial awareness.
Step 5
Transformer Encoder (×L layers): Pass the sequence through L identical blocks, each containing: Multi-Head Self-Attention → LayerNorm → MLP (2 FC layers) → LayerNorm. Residual connections wrap both sub-blocks.
Step 6
Classification Head: Extract the [CLS] token from the final layer. Pass it through a linear classifier. Output: probability distribution over classes via softmax.
🔑
Why 16×16? Why not 8×8 or 32×32?

Patch size controls the sequence length vs. resolution tradeoff. Smaller patches (8×8) = 784 tokens for a 224×224 image — extremely expensive for self-attention which scales as O(N²). Larger patches (32×32) = 49 tokens — fast but coarse, missing fine detail. 16×16 (196 tokens) is the sweet spot: rich enough for most vision tasks, tractable for GPU memory.


Section 05

Self-Attention — The Heart of ViT

The Cocktail Party
You are at a crowded party. You need to focus on a specific conversation. Your brain doesn't listen to everyone equally — it attends more to the person in front of you, slightly to the loud laugh across the room, and almost nothing to the background music.

Self-attention works the same way. Every patch in an image asks: "Which other patches should I pay most attention to?" A patch containing an eye might attend strongly to the nose patch and the other eye patch, ignoring the background sky. This selective focus is computed simultaneously for every patch — that's the power of self-attention.

Formally, self-attention computes three vectors for each patch token:

Query (Q)
Q = X · W_Q
"What am I looking for?" Each patch broadcasts a query asking which other patches are relevant to it.
Key (K)
K = X · W_K
"What do I offer?" Each patch broadcasts a key describing the information it contains.
Value (V)
V = X · W_V
"Here is my actual content." The value is what gets aggregated once attention weights are decided.
Attention Score
Attn = softmax(QKᵀ / √d_k) · V
Dot-product of Q and K gives raw scores. Divide by √d_k for stability. Softmax normalises. Multiply by V to get the attended output.

Multi-Head Attention

Instead of one set of Q, K, V projections, ViT uses multiple heads in parallel — typically 12 or 16. Each head learns to attend to a different aspect: one head may focus on local texture, another on long-range spatial relationships, another on colour continuity. Outputs from all heads are concatenated and projected back to dimension D.

👀
Head 1 — Local Texture
Short-range attention
Attends to immediately neighbouring patches. Captures edge continuity, local colour gradients, fine texture — the "zoom-in" perspective.
🌎
Head 2 — Global Structure
Long-range attention
Attends to semantically related patches far across the image — sky region attends to horizon, left eye attends to right eye. Impossible for early CNN layers.
🤔
Head 3 — Object Parts
Semantic grouping
Discovers object part relationships — wheel patches attend to car-body patches, finger patches attend to hand patches — enabling part-aware representations.

Section 06

ViT Architecture Variants — The Family Tree

The original paper introduced three model sizes. The community has since expanded this family dramatically.

Model Layers (L) Hidden Dim (D) Heads Params ImageNet Top-1
ViT-Base/16 12 768 12 86M 81.8%
ViT-Large/16 24 1024 16 307M 85.2%
ViT-Huge/14 32 1280 16 632M 88.6%
DeiT-Small 12 384 6 22M 79.8% (no large pretrain)
Swin-T 4 stages 96→768 3,6,12,24 28M 81.3%
Swin-L 4 stages 192→1536 6,12,24,48 197M 87.3%
🏠
ViT (Original)
Dosovitskiy et al. 2020
Pure Transformer, global self-attention every layer, flat patch sequence, requires huge pretraining data (JFT-300M).
🎓
DeiT
Facebook AI 2020
Data-Efficient Image Transformers. Adds a distillation token to learn from a CNN teacher. Trains on ImageNet-1K only — no JFT needed.
🕐
Swin Transformer
Microsoft 2021
Hierarchical ViT with shifted windows. Computes attention locally within windows then shifts them. O(N) complexity. State-of-the-art on detection and segmentation.
🤷
BEiT
Microsoft 2021
BERT-style pretraining for images. Masks patches and predicts discrete visual tokens from a DALL-E tokeniser. Self-supervised — no labels needed.
🎭
MAE
Meta AI 2021
Masked Autoencoder. Masks 75% of patches and reconstructs raw pixels with a lightweight decoder. Extremely fast to pretrain. Scalable. Works in PyTorch in hours.
🧬
DINO / DINOv2
Meta AI 2021/2023
Self-distillation with no labels. Produces extraordinary dense features. DINOv2 features transfer to depth estimation, segmentation, and classification without fine-tuning.

Section 07

Swin Transformer — The Windowed Attention Trick

Reading a Newspaper in Sections
Global self-attention is like reading every word on every page simultaneously — mind-blowing capability, but the memory cost is O(N²). For a 256×256 image with 8×8 patches, that's 1,024 patches. Attention matrix: 1,024 × 1,024 = 1 million entries. For a 512×512 image: 4 million. It explodes.

Swin Transformer solves this the way a smart reader would: divide the newspaper into sections, read each section locally. Then, in the next layer, shift the section boundaries so the middle of one old section is now the centre of a new one — allowing cross-boundary communication to eventually propagate. Fast and effective.
✗ Standard ViT — Global Attention
PropertyValue
Attention scopeAll N patches ↔ all N patches
ComplexityO(N²) — quadratic
224×224 / patch 16196 × 196 = 38,416 attention pairs
512×512 / patch 161,024 × 1,024 = 1,048,576 pairs
Dense prediction (detect/seg)Expensive — feature maps always flat
✓ Swin — Windowed Attention
PropertyValue
Attention scopeWithin local M×M windows (default M=7)
ComplexityO(N) — linear in image size
224×224 / patch 449 patches per window × many windows
Cross-window commsShifted windows in alternating layers
Dense predictionExcellent — hierarchical feature maps like CNNs

Section 08

Positional Encoding — Teaching ViT Where Things Are

A Transformer is permutation-invariant by design. Shuffle all the patches randomly and without positional encoding, the model cannot tell the difference. For images — where a cat's ears are above its nose, and the sky is above the ground — spatial position is crucial.

🔢
1D Learnable (Default ViT)
Simple and effective
Each patch index gets a learnable embedding vector. Simple, effective, but doesn't generalise to different image resolutions. Most commonly used in practice.
✔ Fast ✔ Works well on fixed resolution
✘ Doesn't transfer to higher-res images out of the box
📊
2D Sinusoidal
Row + column encoding
Separate sinusoidal encodings for row and column positions, summed together. Generalises better to unseen resolutions via interpolation.
✔ Resolution flexible ✔ No extra parameters
✘ Slightly lower accuracy vs learnable on same resolution
📏
Relative Position Bias (Swin)
Learned relative offsets
Add a learned bias to attention scores based on the relative row/column offset between each pair of patches. Baked into the attention computation itself.
✔ Best accuracy ✔ Resolution generalisation via bicubic interpolation
✘ More implementation complexity

Section 09

Code — Building a ViT from Scratch in PyTorch

Let us implement a minimal but complete Vision Transformer. Every line maps directly to the theory above.

import torch
import torch.nn as nn
import math

# ─────────────────────────────────────────────────────
#  1. PATCH EMBEDDING
#     Splits image into patches and linearly projects each
# ─────────────────────────────────────────────────────
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2   # 196
        # Conv2d with kernel=patch_size, stride=patch_size = non-overlapping patches
        self.proj = nn.Conv2d(in_channels, embed_dim,
                               kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)         # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)         # (B, embed_dim, num_patches)
        x = x.transpose(1, 2)   # (B, num_patches, embed_dim)
        return x

# ─────────────────────────────────────────────────────
#  2. MULTI-HEAD SELF-ATTENTION
# ─────────────────────────────────────────────────────
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim  = embed_dim // num_heads          # 64
        self.scale     = self.head_dim ** -0.5           # 1/√64
        self.qkv  = nn.Linear(embed_dim, embed_dim * 3)  # project to Q, K, V at once
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        # Split into Q, K, V and reshape for multi-head
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv.unbind(0)

        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

# ─────────────────────────────────────────────────────
#  3. TRANSFORMER ENCODER BLOCK
# ─────────────────────────────────────────────────────
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn  = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # Pre-norm + residual
        x = x + self.mlp(self.norm2(x))    # Pre-norm + residual
        return x

# ─────────────────────────────────────────────────────
#  4. FULL VISION TRANSFORMER
# ─────────────────────────────────────────────────────
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Learnable [CLS] token and positional embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop  = nn.Dropout(dropout)

        # Stack of Transformer encoder blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialise weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)                          # (B, 196, 768)

        cls = self.cls_token.expand(B, -1, -1)            # (B, 1, 768)
        x = torch.cat((cls, x), dim=1)                 # (B, 197, 768)
        x = self.pos_drop(x + self.pos_embed)

        x = self.blocks(x)                              # 12 Transformer layers
        x = self.norm(x)

        cls_out = x[:, 0]                               # Extract [CLS] token
        return self.head(cls_out)                        # (B, num_classes)

# ─── Instantiate and verify ────────────────────────
model = VisionTransformer(
    img_size=224, patch_size=16, embed_dim=768, depth=12,
    num_heads=12, num_classes=1000
)
dummy = torch.randn(2, 3, 224, 224)  # batch of 2 RGB images
out   = model(dummy)
print(f"Output shape: {out.shape}")
params = sum(p.numel() for p in model.parameters())
print(f"Parameters:   {params:,}")
OUTPUT
Output shape: torch.Size([2, 1000]) Parameters: 85,800,192 ← ViT-Base/16 ≈ 86M params ✔

Section 10

Code — Fine-Tuning a Pretrained ViT with Hugging Face

Training ViT from scratch requires massive datasets. In practice, you fine-tune a pretrained model. Hugging Face transformers makes this three lines of model code.

from transformers import ViTForImageClassification, ViTImageProcessor
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# ── 1. Load pretrained ViT-Base/16 (ImageNet-21k checkpoint) ──
MODEL_NAME = "google/vit-base-patch16-224-in21k"
num_classes = 10  # e.g. CIFAR-10

processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_classes,
    ignore_mismatched_sizes=True   # replaces the 21k classifier head
)

# ── 2. Load and preprocess dataset ──────────────────────────
dataset = load_dataset("cifar10")

def transform_fn(examples):
    inputs = processor(images=examples["img"], return_tensors="pt")
    inputs["labels"] = torch.tensor(examples["label"])
    return inputs

dataset = dataset.with_transform(transform_fn)
train_loader = DataLoader(dataset["train"], batch_size=32, shuffle=True)
val_loader   = DataLoader(dataset["test"],  batch_size=64)

# ── 3. Training loop ─────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in train_loader:
        pixel_values = batch["pixel_values"].to(device)
        labels        = batch["labels"].to(device)
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss    = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

    # Validation accuracy
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            pv  = batch["pixel_values"].to(device)
            lbl = batch["labels"].to(device)
            out = model(pixel_values=pv)
            pred = out.logits.argmax(-1)
            correct += (pred == lbl).sum().item()
            total   += lbl.size(0)

    scheduler.step()
    print(f"Epoch {epoch+1:2d} | Loss: {total_loss/len(train_loader):.4f} | "
          f"Val Acc: {correct/total*100:.2f}%")
OUTPUT
Epoch 1 | Loss: 0.4821 | Val Acc: 92.31% Epoch 2 | Loss: 0.2103 | Val Acc: 95.87% Epoch 5 | Loss: 0.1012 | Val Acc: 97.41% Epoch 10 | Loss: 0.0614 | Val Acc: 98.76% ← ViT-Base/16, CIFAR-10, 10 epochs
🎯
98.76% on CIFAR-10 in 10 Epochs

State-of-the-art CIFAR-10 accuracy from a pretrained ViT fine-tuned for 10 epochs. The same model trained from scratch would need weeks of training on TPU pods to approach this. This is the power of pretraining + fine-tuning — the most important pattern in modern computer vision.


Section 11

Visualising Attention Maps

One of ViT's unique advantages: attention weights are directly interpretable. We can visualise what patches each token attended to, giving genuine insight into what the model "sees".

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import ViTModel, ViTImageProcessor

# Load model with output_attentions=True
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model     = ViTModel.from_pretrained("google/vit-base-patch16-224",
                                      output_attentions=True)
model.eval()

# Preprocess a test image
img    = Image.open("test_image.jpg").convert("RGB")
inputs = processor(images=img, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# attentions: list of (batch, heads, N+1, N+1) per layer
# Take last layer, average across heads, focus on [CLS] → patch attention
last_layer_attn = outputs.attentions[-1]          # (1, 12, 197, 197)
avg_head_attn   = last_layer_attn[0].mean(dim=0)  # (197, 197)
cls_attn        = avg_head_attn[0, 1:]            # [CLS] attending to 196 patches
attn_map        = cls_attn.reshape(14, 14).numpy()

# Upscale attention map to image resolution
attn_resized = np.array(Image.fromarray(attn_map).resize(
    (224, 224), Image.BILINEAR))

# Plot original + attention overlay
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img); axes[0].set_title("Original Image")
axes[1].imshow(img)
axes[1].imshow(attn_resized, alpha=0.6, cmap="jet")
axes[1].set_title("Attention Map — [CLS] token")
for ax in axes: ax.axis("off")
plt.tight_layout()
plt.savefig("attention_map.png", dpi=150)
plt.show()
📷
What You Will See

The attention heatmap will highlight the foreground object (the cat, the car, the person) with high activation, and dim the background — even without segmentation labels. The [CLS] token learns to aggregate the most discriminative patches purely from classification supervision. This is emergent object detection behaviour.


Section 12

Masked Autoencoders (MAE) — Self-Supervised ViT Pretraining

The Jigsaw Puzzle Master
Imagine a child who has never been told what a dog is. You hand them a jigsaw puzzle with 75% of the pieces removed at random. Their task: reconstruct the missing pieces. To do this well, they must understand what a dog looks like — its fur texture, ear shape, posture — not because they were told, but because they reconstructed it thousands of times from incomplete information.

That is Masked Autoencoder (MAE) pretraining. No labels. Just: mask 75% of image patches at random, reconstruct the raw pixel values of missing patches. The model is forced to learn rich, semantic representations of the world.
🌸 MAE Training Pipeline
Step 1
Mask: Randomly mask 75% of image patches. Keep only 25% visible. This high masking ratio forces semantic rather than interpolative reconstruction.
Step 2
Encode: Pass only the visible patches through the full ViT encoder. This saves ~3–4× compute during pretraining — the encoder never sees masked patches.
Step 3
Decode: A lightweight Transformer decoder receives visible patch embeddings + learnable mask tokens (for the missing positions). Predicts pixel values of masked patches.
Step 4
Loss: Mean Squared Error on masked patch pixels only. Visible patches contribute zero loss — the model is only judged on what it reconstructed.
Step 5
Discard decoder at fine-tune time. The decoder is thrown away. Only the encoder is used, fine-tuned with a linear head on downstream tasks.
MetricViT-Large + SupervisedViT-Large + MAE pretrain
ImageNet Top-1 (fine-tuned)85.2%86.9%
Labels used for pretraining1.28M (ImageNet)Zero
Pretraining dataImageNet-1KImageNet-1K (same!)
Pretraining computeHigh3–4× faster than supervised
Transfer to ADE20K seg.53.6 mIoU55.1 mIoU

Section 13

ViT for Dense Prediction — Object Detection & Segmentation

Original ViT produces a single [CLS] vector — perfect for classification, useless for detection or segmentation, which require per-pixel or per-region features. Several approaches solve this.

🔎
DETR — Detection Transformer
CNN backbone + Transformer encoder-decoder. Object queries attend to image features; each query directly predicts one box + class. No NMS needed. End-to-end differentiable detection.
torch.hub → facebook/detr-resnet-50
📈
Swin Transformer + FPN
Hierarchical Swin outputs C3, C4, C5 feature maps (like ResNet). Plug into any FPN-based detector (Mask R-CNN, Cascade RCNN). State-of-the-art COCO detection.
mmdetection → Swin-L + Cascade Mask RCNN: 57.7 AP
👉
ViTDet — Plain ViT for Detection
Uses a plain non-hierarchical ViT backbone with a simple feature pyramid built via convolutions. Achieves 61.3 APbox on COCO with ViT-Huge + MAE pretraining.
Meta AI 2022 — segment-anything model backbone
🎲
Segment Anything Model (SAM)
ViT-Huge encoder + lightweight mask decoder. Prompted with points, boxes, or text. Trained on 1 billion masks. Zero-shot segmentation of any object.
segment-anything → sam_vit_h_4b8939.pth
🎨
Segmenter — ViT Segmentation
Pure Transformer semantic segmentation. ViT encoder produces patch tokens; a Transformer decoder with class embeddings generates per-patch class predictions.
ViT-L + Segmenter: 53.6 mIoU on ADE20K
🔬
DINOv2 + Linear Probe
DINOv2 features are so rich that a simple linear layer on top achieves competitive segmentation and depth estimation — no fine-tuning of the backbone needed.
dinov2_vitg14 → 86.5% linear probing on ImageNet

Section 14

ViT vs CNN — When to Use Which

Scenario Best Choice Why
Small dataset (<10K images), training from scratch CNN (EfficientNet, ResNet) Strong inductive bias compensates for limited data
Large dataset (>1M images), training from scratch ViT Scales better; surpasses CNNs at large scale
Fine-tuning pretrained model — classification ViT (pretrained) ImageNet-21k or MAE ViT transfers extremely well
Object detection, instance segmentation Swin Transformer Hierarchical features match detector needs; linear complexity
Zero-shot / universal segmentation SAM (ViT-Huge) Trained on 1B masks, prompts with any input modality
Mobile / edge deployment CNN (MobileNet, EfficientNet-Lite) ViT attention is heavy for mobile; CNNs have mature quantisation
Medical imaging (small, specialised dataset) Hybrid or pretrained ViT with heavy augmentation Limited data hurts pure ViT; pretrained features help enormously
Self-supervised feature learning ViT + MAE or DINO Masking/self-distillation works far better on Transformers than CNNs

Section 15

Training Tips & Tricks for ViTs

ViTs require different training recipes than CNNs. These are the non-negotiable adjustments.

🧠 Vision Transformer — Training Golden Rules
1
Use AdamW, never SGD. ViTs have no convolutional weight sharing. Adam's per-parameter adaptive learning rates are critical. Use lr=1e-3 with warm-up and cosine decay. SGD diverges or converges 30–40% slower on ViTs.
2
Always use warm-up (5–10% of training steps). ViTs are sensitive to early large gradient updates. Start with lr=1e-6, ramp linearly to peak LR, then cosine anneal. Skipping warm-up causes instability in the first few hundred steps.
3
Strong data augmentation is essential. Use RandAugment, Mixup (α=0.8), CutMix (α=1.0), and Random Erasing. ViTs have weak inductive bias and rely on augmentation to regularise without memorising the training set.
4
Weight decay on all parameters except LayerNorm and bias. Use weight_decay=0.05. Explicitly exclude LayerNorm.weight, LayerNorm.bias, and all bias terms from the weight decay group. Decaying norms destabilises training.
5
Use Label Smoothing (ε=0.1). ViTs trained with hard one-hot labels overfit early. Label smoothing prevents overconfident predictions and significantly improves calibration on held-out data.
6
For fine-tuning: use layer-wise learning rate decay. Multiply the LR by a factor (e.g. 0.75) per layer going deeper into the network. Deeper layers (closer to [CLS]) need smaller LR as their pretrained features are more task-agnostic and fragile. Early layers can tolerate larger updates.
7
Interpolate positional embeddings when changing resolution. If you fine-tune a ViT pretrained at 224×224 on 384×384 images, bicubic-interpolate the positional embeddings from 14×14 to 24×24 grid. Use interpolate_pos_encoding=True in Hugging Face to handle this automatically.

Section 16

Complete Training Recipe — Code

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from timm import create_model
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm.scheduler import CosineLRScheduler
from torch.optim import AdamW

# ── 1. Model — ViT-Base/16 from timm ──────────────────────
model = create_model(
    "vit_base_patch16_224",
    pretrained=True,
    num_classes=10
)

# ── 2. Augmentation pipeline (DeiT-style) ─────────────────
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                           [0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.25),
])

# ── 3. Mixup + CutMix (timm's Mixup handles both) ─────────
mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0,
    prob=1.0, switch_prob=0.5,
    mode="batch", label_smoothing=0.1, num_classes=10
)
criterion = SoftTargetCrossEntropy()   # works with soft Mixup labels

# ── 4. Optimiser — split weight decay groups ───────────────
no_decay = ["bias", "LayerNorm.weight"]
optim_groups = [
    {"params": [p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)],
     "weight_decay": 0.05},
    {"params": [p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)],
     "weight_decay": 0.0},
]
optimizer = AdamW(optim_groups, lr=1e-3)

# ── 5. Cosine LR schedule with warm-up ────────────────────
scheduler = CosineLRScheduler(
    optimizer, t_initial=100,      # 100 total epochs
    lr_min=1e-5,
    warmup_t=5,                    # 5-epoch warm-up
    warmup_lr_init=1e-6,
)

# ── 6. Training loop ───────────────────────────────────────
device = torch.device("cuda")
model.to(device)

for epoch in range(100):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        images, soft_labels = mixup_fn(images, labels)
        logits = model(images)
        loss   = criterion(logits, soft_labels)
        loss.backward()
        # Gradient clipping — important for ViT stability
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()
    scheduler.step(epoch + 1)
💡
Gradient Clipping — Non-Optional for ViTs

clip_grad_norm_(model.parameters(), max_norm=1.0) is not optional for ViTs. Self-attention can produce very large gradient norms, especially in early training. Clipping stabilises training without sacrificing convergence speed. Standard practice: max_norm=1.0 for ViT-Base and Large.


Section 17

ViT Performance on Major Benchmarks

Task Dataset Model Score Method
Image Classification ImageNet-1K ViT-Huge/14 + MAE 90.0% Top-1 MAE pretrain → fine-tune
Image Classification ImageNet-1K DINOv2 ViT-g/14 86.5% linear probe Linear layer only, no fine-tune
Object Detection COCO Swin-L + Cascade Mask RCNN 57.7 APbox Multi-scale training
Semantic Segmentation ADE20K Swin-L + UperNet 53.5 mIoU 160K iterations
Semantic Segmentation ADE20K ViT-Large + MAE 55.1 mIoU MAE pretrain + UperNet
Depth Estimation NYU Depth v2 DINOv2 + DPT head 0.206 absRel Zero-shot transfer
Medical (Pathology) CAMELYON16 ViT-Base + DINOv2 0.989 AUC Slide-level linear probe

Section 18

Compute & Memory Complexity — The Numbers

Component Time Complexity Memory Bottleneck?
Patch Embedding (Conv2d) O(N·D) Low No
Self-Attention (global) O(N²·D) O(N²) attention matrix Yes — quadratic in patches
Self-Attention (Swin window) O(N·M²·D) O(M²) per window No — linear in image size
MLP Block O(N·D²) Low Moderate — large D
Flash Attention (optimised) O(N²·D) O(N) — IO-aware 3–5× faster in practice
Flash Attention — The Must-Use Optimisation

Flash Attention (Dao et al. 2022) rewrites the attention kernel to be IO-aware. It fuses the QK^T, softmax, and V multiplication into one GPU kernel with no large attention matrix materialised in HBM. Result: 3–8× faster and uses 10–20× less memory than standard attention — with exactly identical outputs. Enable it with torch.nn.functional.scaled_dot_product_attention in PyTorch 2.0+.


Section 19

Real-World Applications

🏠
Autonomous Driving
Tesla, Waymo, Mobileye
ViTs fuse multi-camera views to understand 3D scene layout. Long-range attention captures relationships between objects far apart in the scene — a pedestrian near the sidewalk and a car turning far ahead.
💊
Medical Imaging
Pathology, Radiology, OCT
Whole-slide pathology images are gigapixel-scale. ViTs applied hierarchically (MIL + ViT) achieve pathologist-level cancer detection. Attention maps show which tissue regions drove the diagnosis.
🌎
Remote Sensing
Satellite imagery, Land cover
Multispectral satellite images contain long-range geographic dependencies — a flooded river upstream predicts flooding downstream. ViTs capture these at global scales naturally.
🤖
Robotics & Manipulation
Visual policies, Scene understanding
ViT-based visual representations enable robots to generalise across object instances. CLIP and DINOv2 features allow robots to understand "pick up the red mug" without retraining.
📷
Image Generation
DiT, Stable Diffusion 3
Diffusion Transformers (DiT) replace U-Net in diffusion models with a pure Transformer operating on latent patches. SD3 uses a DiT backbone and achieves state-of-the-art text-to-image quality.
📍
Multimodal AI
CLIP, LLaVA, GPT-4V
ViT encoders are the vision backbone of every major multimodal LLM. CLIP uses ViT to align image and text embeddings. GPT-4V and LLaVA use ViT to inject visual tokens into language models.

Section 20

The Future — What Comes After ViT?

🔭
The Frontier in 2024–2025

Research is pushing in four directions simultaneously: efficiency (linear attention, state space models like Mamba-Vision), multimodality (unified vision-language-action models), scalability (ViT-22B with 22 billion parameters), and video understanding (TimeSformer, VideoMAE — treating video as a space-time volume of patches).

Mamba-Vision / SSM
State Space Models
State Space Models (Mamba) process sequences in O(N) time without attention matrices. Mamba-Vision adapts this for images — similar accuracy to Swin at lower memory.
🎬
VideoMAE / TimeSformer
Video Transformers
Extend patch tokens to space-time: a 3D patch covers T×H×W pixels across time. VideoMAE masks 90% of space-time patches for self-supervised video pretraining. State-of-the-art action recognition.
🌟
ViT-22B
Extreme Scale
Google's 22-billion parameter ViT trained on JFT-3B. 90.9% ImageNet Top-1. Demonstrates scaling laws hold for vision: more compute + more data = better features, indefinitely.
🌞
The Quadratic Wall — Still Unsolved

Standard self-attention scales as O(N²) in sequence length. For video (thousands of tokens), high-resolution images, or 3D medical scans, this remains a hard practical limit. Flash Attention, windowed attention (Swin), and SSMs each attack this differently. No consensus winner has emerged yet — this is the most active research frontier in vision architectures.


Section 21

Quick Reference — Key ViT Parameters

Parameter ViT-Base ViT-Large What it controls
patch_size1616Resolution of each patch token. Smaller = more tokens, finer detail, more compute.
embed_dim (D)7681024Width of the model — size of every token vector.
depth (L)1224Number of Transformer encoder blocks stacked.
num_heads1216Parallel attention heads. Each captures different relationships.
mlp_ratio4.04.0MLP hidden dim = embed_dim × mlp_ratio. Controls non-linear capacity.
dropout0.00.1Applied to attention weights and MLP. Use 0.0 with strong augmentation.
num_classes10001000Output dimension. Change this for your task.
🏆
The Practitioner's One-Line Summary

Use a pretrained ViT-Base/16 from Hugging Face or timm as your default starting point for any image classification or feature extraction task. Fine-tune for 10–30 epochs with AdamW + cosine schedule + RandAugment. For detection or segmentation, use Swin-Transformer. For features without fine-tuning, use DINOv2. Only go larger (ViT-Large, ViT-Huge) when Base saturates.