The Story That Makes Grad-CAM Click
You hand them a highlighter pen and say: "Circle the part of the image that drove your decision." If they highlight the right lung region, you trust the model. If they highlight the corner watermark, you have a shortcut learner — accurate on training data, catastrophic in the real world.
Grad-CAM is that highlighter pen for neural networks. It asks the network: "Which spatial regions of this image were most responsible for your classification?" — and paints a heatmap answer directly onto the image.
Gradient-weighted Class Activation Mapping (Grad-CAM) is an explainability technique introduced by Selvaraju et al. (2017) that produces a coarse localisation heatmap highlighting the important regions in an image used by a CNN to make a classification decision. It works with any CNN architecture — no retraining, no architectural changes, no special layers required.
Deep learning models are often called black boxes. Grad-CAM cracks the box open — not by explaining every weight, but by answering the one question that matters most in practice: Where was the model looking? This turns an opaque prediction into a verifiable, auditable decision.
Before Grad-CAM — The Problem with Older Methods
Explainability for CNNs didn't begin with Grad-CAM. Several earlier approaches each had significant weaknesses that Grad-CAM was designed to solve.
| Method | Core Idea | Key Weakness | Requires Retraining? |
|---|---|---|---|
| CAM (Zhou et al., 2016) | Global Average Pooling + linear layer weights | Only works on specific architectures (GAP before softmax) | Yes — architecture must be modified |
| Vanilla Backprop / Saliency | Gradient of output w.r.t. input pixels | Noisy, hard to interpret, not class-discriminative | No |
| Guided Backprop | Backprop with ReLU masking | Not class-discriminative — same map for any class | No |
| Occlusion Sensitivity | Slide a patch, measure output change | Computationally expensive: O(N²) forward passes | No |
| Grad-CAM (Selvaraju, 2017) | Gradient-weighted feature map combination | Coarse spatial resolution (feature map size) | No — works with any CNN |
Original CAM required a Global Average Pooling layer immediately before the classifier — limiting it to models built with that specific design. Grad-CAM generalises this by using gradients instead of architectural constraints. It can explain any CNN — VGG, ResNet, DenseNet, EfficientNet — with a single backward pass. That universality is its superpower.
The Intuition — What Are We Actually Computing?
When the network sees an image of a golden retriever and predicts "dog", some sketch pads light up intensely (the ones detecting fur, ears, snout shapes) while others stay blank. The question Grad-CAM answers is: which sketch pads mattered most for this specific prediction, and where on each pad was the relevant drawing?
It answers this by computing: "If I increased this sketch pad's activation by one unit, how much would the 'dog' score increase?" The more the score increases, the more important that sketch pad is. Then it combines all the important sketch pads into one final heatmap.
Mathematically, the last convolutional layer produces a set of feature maps
Ak of spatial dimensions H × W, where k
indexes the channels. The network then flattens and classifies these maps to
produce a class score yc for class c.
The ReLU is critical. Without it, negative activations (regions that suppress the class score) would also appear in the heatmap, creating confusing visualisations that highlight anti-evidence regions. By zeroing negatives, Grad-CAM shows only the regions that positively contributed to the specific predicted class — making the map directly interpretable as "evidence in favour of class c".
Step-by-Step: The Grad-CAM Pipeline
SVG Diagram — The Grad-CAM Architecture
ⓘ Blue = forward data flow | Red dashed = backward gradient flow | Amber dashed = weight computation
Implementation from Scratch in PyTorch
The cleanest way to implement Grad-CAM is with PyTorch hooks — no third-party
library needed. The following complete implementation works with any
torchvision model.
import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
# ──────────────────────────────────────────────
# GradCAM Class — works with any CNN
# ──────────────────────────────────────────────
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self._register_hooks()
def _register_hooks(self):
# Forward hook: captures feature maps A^k
def save_activation(module, inp, output):
self.activations = output.detach()
# Backward hook: captures gradients ∂y^c/∂A^k
def save_gradient(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
self.target_layer.register_forward_hook(save_activation)
self.target_layer.register_backward_hook(save_gradient)
def generate(self, input_tensor, class_idx=None):
self.model.eval()
# Step 1: Forward pass
logits = self.model(input_tensor)
# Step 2: Select class score
if class_idx is None:
class_idx = logits.argmax(dim=1).item()
score = logits[0, class_idx]
# Step 3: Backward pass — zero grads first!
self.model.zero_grad()
score.backward(retain_graph=True)
# Step 4: Global Average Pool the gradients → α^c_k
alpha = self.gradients.mean(dim=[2, 3], keepdim=True) # (1, C, 1, 1)
# Step 5: Weighted sum of feature maps
cam = (self.activations * alpha).sum(dim=1, keepdim=True) # (1, 1, H, W)
# Step 6: ReLU — keep only positive influence
cam = F.relu(cam)
# Step 7: Normalise to [0, 1]
cam -= cam.min()
cam /= (cam.max() + 1e-8)
# Step 8: Upsample to input size
cam = F.interpolate(
cam, size=(input_tensor.shape[2], input_tensor.shape[3]),
mode='bilinear', align_corners=False
)
return cam.squeeze().cpu().numpy(), class_idx
# ──────────────────────────────────────────────
# Overlay helper — blends heatmap onto image
# ──────────────────────────────────────────────
def overlay_heatmap(img_rgb, cam, alpha=0.5, colormap=cv2.COLORMAP_JET):
# cam: numpy array [0,1] float32
heatmap = cv2.applyColorMap(
(255 * cam).astype(np.uint8), colormap
)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Alpha blend
img_uint8 = (img_rgb * 255).astype(np.uint8)
overlay = (alpha * heatmap + (1 - alpha) * img_uint8).astype(np.uint8)
return overlay
# ──────────────────────────────────────────────
# Full pipeline — load model, run, visualise
# ──────────────────────────────────────────────
# 1. Load pretrained ResNet-50
model = models.resnet50(pretrained=True)
model.eval()
# 2. Target = last conv layer (layer4[-1].conv3 in ResNet-50)
target_layer = model.layer4[-1].conv3
gradcam = GradCAM(model, target_layer)
# 3. ImageNet preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 4. Load image and run
image = Image.open('dog.jpg').convert('RGB')
img_array = np.array(image.resize((224, 224))) / 255.0
tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224)
cam, predicted_class = gradcam.generate(tensor)
result = overlay_heatmap(img_array, cam, alpha=0.5)
# 5. Visualise
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_array); axes[0].set_title('Original')
axes[1].imshow(cam, cmap='jet'); axes[1].set_title('Grad-CAM Heatmap')
axes[2].imshow(result); axes[2].set_title(f'Overlay (class {predicted_class})')
for ax in axes: ax.axis('off')
plt.tight_layout(); plt.show()
Pass any class_idx to generate() to explain
any class — not just the predicted one. For example, pass the index
for "cat" when the model predicted "dog" to visualise what cat-like regions
exist in the image. This is a powerful debugging tool for understanding
model confusion.
Grad-CAM vs Guided Grad-CAM vs Grad-CAM++
The original Grad-CAM paper also introduced Guided Grad-CAM by combining Grad-CAM with Guided Backpropagation. The research community has since produced several variants:
| Property | Grad-CAM | Guided Grad-CAM | Grad-CAM++ | Score-CAM |
|---|---|---|---|---|
| Resolution | Coarse (7×7) | Fine (224×224) | Coarse (7×7) | Coarse (7×7) |
| Class-discriminative | Yes | Yes | Yes | Yes |
| Multiple objects | Partial | Partial | Better | Good |
| Gradient-free | No | No | No | Yes |
| Speed (relative) | Fast (1 backward) | Medium (2 backward) | Medium | Slow (N×forward) |
| Best use case | Quick debugging | Fine-grained analysis | Multi-instance images | Gradient-unstable models |
The Weighted Sum — SVG Diagram
Each channel's feature map is weighted by its importance score α, summed, passed through ReLU, then upsampled to produce the final localisation heatmap.
Real-World Applications
Using the pytorch-grad-cam Library
For production use, the pytorch-grad-cam library (Jacob Gildenblat)
provides battle-tested implementations of Grad-CAM and all major variants with
a single, consistent API.
# Install
# pip install grad-cam
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
# 1. Model
model = models.resnet50(pretrained=True)
model.eval()
# 2. Target layer — a list (supports multiple layers)
target_layers = [model.layer4[-1]]
# 3. Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
img = Image.open('cat.jpg').convert('RGB')
img_np = np.array(img.resize((224, 224))) / 255.0
tensor = transform(img).unsqueeze(0)
# 4. Run Grad-CAM — context manager handles hook cleanup
with GradCAM(model=model, target_layers=target_layers) as cam:
# Explain class 281 (tabby cat in ImageNet)
targets = [ClassifierOutputTarget(281)]
grayscale_cam = cam(input_tensor=tensor, targets=targets)
grayscale_cam = grayscale_cam[0] # (H, W)
# 5. Overlay and display
visualisation = show_cam_on_image(img_np.astype(np.float32), grayscale_cam)
# 6. Switch to Grad-CAM++ — identical API
with GradCAMPlusPlus(model=model, target_layers=target_layers) as cam:
pp_cam = cam(input_tensor=tensor, targets=targets)[0]
Applying Grad-CAM to Non-ResNet Models
Grad-CAM is architecture-agnostic. The only thing that changes between models is which layer you target. Here is a reference guide:
| Model | Target Layer (PyTorch) | Feature Map Size | Notes |
|---|---|---|---|
| ResNet-50 / 101 / 152 | model.layer4[-1] | 7 × 7 | Best default choice — deepest semantic features |
| VGG-16 / VGG-19 | model.features[-1] | 14 × 14 | Last MaxPool reduces to 7×7; target conv before it for 14×14 |
| EfficientNet-B0 | model.features[-1] | 7 × 7 | Works out of the box with pytorch-grad-cam |
| MobileNet-V3 | model.features[-1] | 7 × 7 | Lightweight models still produce informative maps |
| DenseNet-121 | model.features.denseblock4 | 7 × 7 | Dense skip connections make gradients rich |
| Vision Transformer (ViT) | Attention rollout or last block | 14 × 14 patches | Needs special handling — no conv layers. Use GradCAM on last attention block. |
| Custom CNN | Last conv layer before classifier head | Varies | Target deeper = more semantic; shallower = more spatial |
Deeper layers → more class-semantic but coarser spatial heatmaps. Shallower layers → finer spatial resolution but less class-discriminative. For most explainability tasks, target the last convolutional block before the classifier head — it provides the best trade-off. For spatial precision (e.g. lesion localisation), try middle-depth layers and compare.
Grad-CAM with Transfer Learning — Fine-Tuning Workflow
Grad-CAM is especially powerful during fine-tuning, where you need to verify that the model is learning domain-specific features rather than relying on ImageNet priors.
import torch
import torch.nn as nn
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np, matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
# ── 1. Fine-tune ResNet on custom dataset (e.g. chest X-rays) ──
model = models.resnet50(pretrained=True)
num_classes = 2 # e.g. Normal vs Pneumonia
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.cuda()
# Standard fine-tuning loop (abbreviated)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# ... training loop ...
# ── 2. After training: compare pre-train vs post-train Grad-CAM ──
model.eval()
target_layers = [model.layer4[-1]]
transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def inspect_batch(model, batch_tensors, batch_imgs_np, class_idx, n=4):
"""Visualise Grad-CAM for first n images in a batch."""
with GradCAM(model=model, target_layers=target_layers) as cam:
targets = [ClassifierOutputTarget(class_idx)] * n
cams = cam(input_tensor=batch_tensors[:n], targets=targets)
fig, axes = plt.subplots(2, n, figsize=(4 * n, 8))
for i in range(n):
axes[0, i].imshow(batch_imgs_np[i])
axes[0, i].set_title(f'Original {i+1}')
axes[0, i].axis('off')
overlay = show_cam_on_image(batch_imgs_np[i].astype(np.float32), cams[i])
axes[1, i].imshow(overlay)
axes[1, i].set_title(f'Grad-CAM — class {class_idx}')
axes[1, i].axis('off')
plt.tight_layout()
plt.show()
Run Grad-CAM on your validation set at three checkpoints: (1) before fine-tuning (random classifier head, ImageNet features), (2) after 1 epoch, and (3) after full convergence. Watching the heatmap shift from background noise → generic textures → correct domain-specific regions is one of the most satisfying validation steps in applied deep learning. If the map never moves to the right region, your fine-tuning strategy or data has a problem.
Limitations and Pitfalls
Grad-CAM shows where the model attends — not necessarily where the correct evidence is. A map that looks reasonable can still represent a spurious correlation. Always validate with quantitative metrics (pointing game accuracy, insertion/deletion AUC) in addition to visual inspection.
Quantitative Evaluation — Is Your Heatmap Good?
Visual inspection is insufficient for rigorous evaluation. The community uses three standard quantitative protocols:
| Metric | How It Works | Higher = Better? | Best For |
|---|---|---|---|
| Pointing Game | Does the max-activation pixel fall inside the ground-truth bounding box? | Yes (accuracy %) | Localisation sanity check |
| Insertion AUC | Progressively reveal pixels by importance; measure AUC of confidence curve | Yes | Sufficiency of highlighted region |
| Deletion AUC | Progressively mask pixels by importance; measure AUC of drop curve | No (lower = better explanation) | Necessity of highlighted region |
| IoU with GT mask | Threshold heatmap to binary mask; compute IoU with ground-truth segmentation | Yes | Pixel-accurate localisation tasks |
# Insertion AUC — partial implementation
def insertion_auc(model, img_tensor, cam_map, class_idx, steps=50):
"""
Progressively reveal pixels in order of Grad-CAM importance.
Measures how quickly model confidence recovers as we reveal pixels.
"""
flat_cam = cam_map.flatten()
sorted_idx = np.argsort(flat_cam)[::-1] # most important first
n_pixels = len(flat_cam)
step_size = n_pixels // steps
scores = []
# Start with fully blurred image (baseline)
baseline = torch.zeros_like(img_tensor)
for i in range(steps):
revealed = sorted_idx[:step_size * (i + 1)]
masked = baseline.clone()
# Reveal the top-k pixels from the original
h_idx = revealed // cam_map.shape[1]
w_idx = revealed % cam_map.shape[1]
masked[0, :, h_idx, w_idx] = img_tensor[0, :, h_idx, w_idx]
with torch.no_grad():
score = torch.softmax(model(masked), dim=1)[0, class_idx].item()
scores.append(score)
return np.trapz(scores) / steps # AUC under insertion curve
Golden Rules
model.zero_grad() before the backward
pass. Accumulated gradients from previous batches will silently contaminate
your heatmap — making it reflect a mix of samples, not the one you're explaining.
model.eval() before generating heatmaps.
BatchNorm layers behave differently in training mode — using running statistics
vs batch statistics — which changes feature map values and therefore heatmaps.
handle = layer.register_hook(...) and always call
handle.remove(), or use the context manager pattern.
ClassifierOutputTarget(class_idx). Letting
the model pick the top-1 class may hide important evidence regions for
secondary classes that the model partially detects.