The Story That Explains Segmentation
That is what segmentation does: it does not just find objects, it draws their exact outlines, pixel by pixel. Classification says "this image contains a cat". Detection says "the cat is in this bounding box". Segmentation says "these exact 14,308 pixels are the cat."
From autonomous cars avoiding pedestrians to satellite imagery delineating flood zones, segmentation is the discipline that turns vague detection into surgical precision.
Image segmentation is the process of partitioning a digital image into multiple segments (sets of pixels), where each segment shares some characteristic — colour, texture, object class, or instance identity. The goal is to simplify or change the representation of an image into something more meaningful and easier to analyse.
Every pixel in an image is a data point. Segmentation assigns a label to every single pixel. A 1920×1080 image has 2,073,600 pixels — segmentation is essentially a per-pixel classification problem at massive scale. This is what makes it computationally harder — and more powerful — than simple classification.
The Three Pillars — Types of Segmentation
There are three fundamental types of segmentation in computer vision, each progressively more demanding in both labelling effort and computational complexity.
| Property | Semantic | Instance | Panoptic |
|---|---|---|---|
| Assigns class to every pixel | ✔ Yes | Partial (only detected objects) | ✔ Yes |
| Distinguishes individual instances | ✘ No | ✔ Yes | ✔ Yes |
| Handles background "stuff" classes | ✔ Yes | ✘ No | ✔ Yes |
| Two people → same or different mask? | Same mask | Different masks | Different masks |
| Typical use case | Road scene parsing | Object counting, robotics | Autonomous driving (full scene) |
| Computational cost | Low | Medium–High | High |
Classical Segmentation Methods — Before Deep Learning
Classical segmentation approaches rely on handcrafted rules about pixel colour, intensity, gradient, or texture. They are fast, interpretable, and require no training data — but they fail catastrophically on complex, high-variance real-world scenes.
Classical methods have no concept of semantics. They cannot distinguish a road from a patch of grey sky if both have the same pixel intensity. They cannot recognise that two differently-lit images of the same object are related. The moment you need to understand what something is — not just where its edges are — classical methods collapse. That gap is exactly what deep learning fills.
Classical Segmentation: Python Example (Watershed)
import cv2
import numpy as np
from matplotlib import pyplot as plt
# Load image
img = cv2.imread('coins.jpg')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Step 1: Threshold → binary image
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Step 2: Remove noise with morphological opening
kernel = np.ones((3, 3), np.uint8)
opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
# Step 3: Sure background via dilation
sure_bg = cv2.dilate(opening, kernel, iterations=3)
# Step 4: Sure foreground via distance transform
dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
_, sure_fg = cv2.threshold(dist_transform, 0.7 * dist_transform.max(), 255, 0)
sure_fg = np.uint8(sure_fg)
# Step 5: Unknown region (border between fg and bg)
unknown = cv2.subtract(sure_bg, sure_fg)
# Step 6: Marker labelling for watershed
_, markers = cv2.connectedComponents(sure_fg)
markers = markers + 1
markers[unknown == 255] = 0
# Step 7: Apply watershed
markers = cv2.watershed(img, markers)
img[markers == -1] = [255, 0, 0] # Mark boundaries red
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title('Watershed Segmentation — Coins')
plt.axis('off')
plt.show()
The Architecture Revolution — Fully Convolutional Networks
The breakthrough that enabled modern segmentation was deceptively simple: replace the fully-connected layers at the end of a classification network with convolutional layers. This creates a network that accepts an image of any size and outputs a spatial map of predictions — one prediction per pixel instead of one prediction per image.
When downsampling, the network loses precise location information. Skip connections "skip" over the bottleneck and feed early, high-resolution feature maps directly to the decoder. This gives the decoder both what (semantic context from the bottleneck) and where (spatial precision from early layers). Without them, segmentation boundaries are blurry and inaccurate.
U-Net — The Architecture That Changed Medical Imaging
The trick was a symmetric encoder-decoder structure with dense skip connections — every level of the encoder directly wired to the corresponding level of the decoder. They called it U-Net because the architecture diagram looks like a U. It won the ISBI cell tracking challenge in 2015 by a wide margin and remains the dominant architecture in medical image segmentation today — a decade later.
U-Net Implementation with PyTorch
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
"""Two consecutive Conv2d → BatchNorm → ReLU blocks."""
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
class UNet(nn.Module):
def __init__(self, in_channels=1, num_classes=2, features=[64,128,256,512]):
super().__init__()
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
self.pool = nn.MaxPool2d(2, 2)
# Encoder
ch = in_channels
for f in features:
self.downs.append(DoubleConv(ch, f))
ch = f
# Bottleneck
self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
# Decoder
for f in reversed(features):
self.ups.append(nn.ConvTranspose2d(f * 2, f, 2, 2))
self.ups.append(DoubleConv(f * 2, f))
self.final = nn.Conv2d(features[0], num_classes, 1)
def forward(self, x):
skips = []
for down in self.downs:
x = down(x)
skips.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skips = skips[::-1] # reverse for decoder
for i in range(0, len(self.ups), 2):
x = self.ups[i](x)
skip = skips[i // 2]
if x.shape != skip.shape:
x = nn.functional.interpolate(x, size=skip.shape[2:])
x = torch.cat([skip, x], dim=1)
x = self.ups[i + 1](x)
return self.final(x)
# Verify architecture with a dummy forward pass
model = UNet(in_channels=1, num_classes=2)
dummy = torch.randn(2, 1, 256, 256) # batch=2, C=1, H=256, W=256
output = model(dummy)
print(f"Input : {dummy.shape}")
print(f"Output: {output.shape}")
DeepLab and Atrous Convolutions — Seeing Without Downsampling
U-Net solves the spatial precision problem by restoring resolution in the decoder. DeepLab (Google, 2015–2018) takes a different approach: never lose the resolution in the first place. It uses atrous convolution (also called dilated convolution) — a convolution where the kernel has gaps (holes) between its weights, so a 3×3 kernel effectively covers a larger area without requiring more parameters or reducing resolution.
| Property | Value |
|---|---|
| Receptive field | 3×3 pixels |
| Dilation rate | 1 (no gaps) |
| Stride | 1 or 2 |
| Spatial output | Halved if stride=2 |
| Parameters | 3×3×C_in×C_out |
| Property | Value |
|---|---|
| Receptive field | 5×5 pixels (with gaps) |
| Dilation rate | 2 (one gap between each weight) |
| Stride | 1 (resolution preserved!) |
| Spatial output | Same as input |
| Parameters | 3×3×C_in×C_out (identical!) |
DeepLabV3 introduces ASPP: running multiple atrous convolutions in parallel with different dilation rates (e.g. 6, 12, 18). Each rate captures context at a different scale — small objects need small receptive fields, large objects need large ones. ASPP collects them all and concatenates, giving the network multi-scale context without any resolution loss.
ASPP Module Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class AtrousConvBNReLU(nn.Module):
def __init__(self, in_ch, out_ch, rate):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3,
padding=rate, dilation=rate, bias=False)
self.bn = nn.BatchNorm2d(out_ch)
def forward(self, x):
return F.relu(self.bn(self.conv(x)))
class ASPP(nn.Module):
def __init__(self, in_channels=2048, out_channels=256):
super().__init__()
# 1×1 conv for global context
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU())
# Atrous convolutions at 3 dilation rates
self.atrous6 = AtrousConvBNReLU(in_channels, out_channels, rate=6)
self.atrous12 = AtrousConvBNReLU(in_channels, out_channels, rate=12)
self.atrous18 = AtrousConvBNReLU(in_channels, out_channels, rate=18)
# Global average pooling branch
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU())
# Projection after concat (5 branches × out_channels)
self.project = nn.Sequential(
nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
h, w = x.shape[2:]
gap_out = F.interpolate(self.gap(x), size=(h, w), mode='bilinear')
branches = [self.conv1(x), self.atrous6(x),
self.atrous12(x), self.atrous18(x), gap_out]
return self.project(torch.cat(branches, dim=1))
Mask R-CNN — Instance Segmentation
Facebook AI Research (He et al., 2017) extended Faster R-CNN — a powerful object detector — by adding a third head: a small fully convolutional network that predicts a binary mask (foreground / background) for each detected bounding box. The detector tells you where things are; the mask head tells you exactly which pixels inside that box belong to the object. The result: state-of-the-art instance segmentation that is surprisingly fast for what it does.
Running Mask R-CNN with Detectron2
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import cv2
# ── 1. Configure the model ───────────────────────────────────
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # confidence threshold
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
# ── 2. Build predictor ───────────────────────────────────────
predictor = DefaultPredictor(cfg)
# ── 3. Run inference ─────────────────────────────────────────
img = cv2.imread("street.jpg")
outputs = predictor(img)
# ── 4. Extract instance masks ────────────────────────────────
instances = outputs["instances"].to("cpu")
masks = instances.pred_masks # shape: [N, H, W] bool tensor
boxes = instances.pred_boxes # [N, 4]
classes = instances.pred_classes # [N] class indices
scores = instances.scores # [N] confidence
print(f"Detected {len(instances)} instances")
for i in range(len(instances)):
cls = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[classes[i]]
score = scores[i].item()
pixels = masks[i].sum().item()
print(f" {cls:12s} | conf={score:.3f} | pixels={pixels:,}")
# ── 5. Visualise ─────────────────────────────────────────────
v = Visualizer(img[:, :, ::-1],
metadata=MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
scale=1.2)
out = v.draw_instance_predictions(instances)
cv2.imwrite("output_segmented.jpg", out.get_image()[:, :, ::-1])
Loss Functions for Segmentation
Standard cross-entropy loss treats every pixel equally. In segmentation, this creates a critical problem: background pixels vastly outnumber foreground pixels. If 95% of your image is background, a model can get 95% accuracy by predicting "background" for every pixel — while being completely useless. Purpose-built segmentation losses address this.
| Loss Function | Formula (simplified) | When to Use | Handles Imbalance? |
|---|---|---|---|
| Cross-Entropy | −Σ y·log(ŷ) | Balanced classes, baselines | ✘ No |
| Dice Loss | 1 − 2·|A∩B| / (|A|+|B|) | Medical imaging, rare foreground | ✔ Yes (normalised by set size) |
| IoU / Jaccard Loss | 1 − |A∩B| / |A∪B| | When IoU metric is the target | ✔ Yes |
| Focal Loss | −(1−ŷ)^γ · y·log(ŷ) | Extreme class imbalance, small objects | ✔ Focuses on hard examples |
| Tversky Loss | 1 − TP / (TP + α·FP + β·FN) | When false negatives more costly (medical) | ✔ Customisable FP/FN penalty |
| BCE + Dice (Combined) | BCE + λ·DiceLoss | General-purpose best practice | ✔ Best of both |
Dice Loss Implementation
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""Soft Dice Loss for binary or multi-class segmentation."""
def __init__(self, smooth=1.0):
super().__init__()
self.smooth = smooth
def forward(self, preds, targets):
# preds: [B, C, H, W] after sigmoid/softmax
# targets: [B, C, H, W] one-hot or [B, H, W] class indices
preds = torch.sigmoid(preds)
preds_f = preds.contiguous().view(-1)
tgts_f = targets.contiguous().view(-1).float()
intersection = (preds_f * tgts_f).sum()
dice = (2.0 * intersection + self.smooth) \
/ (preds_f.sum() + tgts_f.sum() + self.smooth)
return 1.0 - dice
class CombinedLoss(nn.Module):
"""BCE + Dice — the standard combo for medical/binary segmentation."""
def __init__(self, dice_weight=0.5):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.w = dice_weight
def forward(self, logits, targets):
return (1 - self.w) * self.bce(logits, targets) \
+ self.w * self.dice(logits, targets)
# Usage
criterion = CombinedLoss(dice_weight=0.5)
loss = criterion(model_output, ground_truth_mask)
loss.backward()
Evaluation Metrics — How Do We Measure Good Segmentation?
Pixel accuracy ("what fraction of pixels are correctly labelled?") is almost useless in segmentation because background dominates. These metrics are the industry standard.
| Benchmark Dataset | Task | Primary Metric | State-of-Art Score |
|---|---|---|---|
| PASCAL VOC 2012 | Semantic (21 classes) | mIoU | ~90.5 mIoU |
| Cityscapes | Semantic (19 classes, urban) | mIoU | ~84.8 mIoU |
| COCO (val) | Instance segmentation | Mask AP | ~58.1 AP |
| COCO Panoptic | Panoptic segmentation | PQ | ~58.8 PQ |
| MedSeg / BraTS | Brain tumour (3 classes) | Dice | ~91.2 Dice |
| ADE20K | Semantic (150 classes) | mIoU | ~62.1 mIoU |
Computing IoU and Dice in Python
import numpy as np
def compute_iou(pred_mask, gt_mask):
"""Compute IoU for binary masks (numpy boolean arrays)."""
intersection = np.logical_and(pred_mask, gt_mask).sum()
union = np.logical_or (pred_mask, gt_mask).sum()
return intersection / (union + 1e-8)
def compute_dice(pred_mask, gt_mask):
"""Compute Dice coefficient for binary masks."""
intersection = np.logical_and(pred_mask, gt_mask).sum()
return 2 * intersection / (pred_mask.sum() + gt_mask.sum() + 1e-8)
def mean_iou(pred_labels, gt_labels, num_classes):
"""Compute mean IoU across all classes for semantic segmentation."""
ious = []
for cls in range(num_classes):
pred = (pred_labels == cls)
gt = (gt_labels == cls)
if gt.sum() == 0 and pred.sum() == 0:
continue # skip absent classes
ious.append(compute_iou(pred, gt))
return np.mean(ious)
# Example
pred = np.array([[0,0,1,1], [0,1,1,1], [0,0,1,0]])
gt = np.array([[0,0,1,1], [0,0,1,1], [0,0,0,0]])
print(f"IoU = {compute_iou(pred, gt):.4f}")
print(f"Dice = {compute_dice(pred, gt):.4f}")
Modern Architecture Comparison
| Architecture | Year | Task | Key Innovation | Speed | Accuracy |
|---|---|---|---|---|---|
| FCN | 2015 | Semantic | First end-to-end conv segmentation network | Fast | Moderate |
| U-Net | 2015 | Semantic / Medical | Symmetric encoder-decoder + dense skip connections | Fast | High (medical) |
| DeepLabV3+ | 2018 | Semantic | ASPP + decoder; atrous convolutions preserve resolution | Medium | Very High |
| Mask R-CNN | 2017 | Instance | Extends Faster R-CNN with per-instance mask head + RoI Align | Medium | High |
| Panoptic FPN | 2019 | Panoptic | Unified FPN backbone for both semantic and instance heads | Medium | High |
| SegFormer | 2021 | Semantic | Hierarchical Vision Transformer encoder; no positional encoding | Fast | SOTA |
| Segment Anything (SAM) | 2023 | Universal | Promptable segmentation; trained on 1B masks; zero-shot generalisation | Medium | SOTA (zero-shot) |
Segment Anything Model (SAM) — The Foundation Model for Segmentation
For the first time, segmentation became accessible to non-experts. Architects, archaeologists, surgeons, farmers — anyone who can click a mouse can now segment objects in images at professional quality. SAM democratised segmentation the way ChatGPT democratised NLP.
Using SAM with a Point Prompt
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor
# ── Load SAM model ───────────────────────────────────────────
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device="cuda")
predictor = SamPredictor(sam)
# ── Prepare image ────────────────────────────────────────────
image = cv2.cvtColor(cv2.imread("dog.jpg"), cv2.COLOR_BGR2RGB)
predictor.set_image(image) # Encodes once — O(sec) but cached
# ── Point prompt: click on the dog ──────────────────────────
# Format: [[x, y]] coordinates in image pixels
input_point = np.array([[500, 375]]) # click on dog's body
input_label = np.array([1]) # 1 = foreground point
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True, # returns 3 candidate masks
)
# ── Select best mask by confidence score ────────────────────
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
print(f"Mask scores: {scores}")
print(f"Best mask pixels: {best_mask.sum():,}")
# ── Overlay on image ─────────────────────────────────────────
overlay = image.copy()
overlay[best_mask] = (overlay[best_mask] * 0.5
+ np.array([0, 120, 255]) * 0.5).astype(np.uint8)
Real-World Applications
Complete Training Pipeline — Binary Segmentation with U-Net
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from pathlib import Path
# ── Dataset ───────────────────────────────────────────────────
class SegDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.imgs = sorted(Path(img_dir).glob("*.png"))
self.masks = sorted(Path(mask_dir).glob("*.png"))
self.tfm = transform
def __len__(self): return len(self.imgs)
def __getitem__(self, i):
img = cv2.cvtColor(cv2.imread(str(self.imgs[i])), cv2.COLOR_BGR2RGB)
mask = cv2.imread(str(self.masks[i]), cv2.IMREAD_GRAYSCALE)
mask = (mask > 127).astype(np.float32)
if self.tfm:
aug = self.tfm(image=img, mask=mask)
img, mask = aug["image"], aug["mask"].unsqueeze(0)
return img, mask
# ── Augmentation pipeline ─────────────────────────────────────
train_tfm = A.Compose([
A.Resize(256, 256),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.3),
A.ElasticTransform(p=0.2), # crucial for medical imaging
A.GaussNoise(p=0.2),
A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
ToTensorV2(),
])
# ── Training loop ─────────────────────────────────────────────
def train_one_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss, total_iou = 0, 0
for imgs, masks in loader:
imgs, masks = imgs.to(device), masks.to(device)
optimizer.zero_grad()
preds = model(imgs)
loss = criterion(preds, masks)
loss.backward()
optimizer.step()
# Compute batch IoU
pred_bin = (torch.sigmoid(preds) > 0.5).float()
inter = (pred_bin * masks).sum(dim=[1,2,3])
union = (pred_bin + masks - pred_bin * masks).sum(dim=[1,2,3])
total_iou += (inter / (union + 1e-8)).mean().item()
total_loss += loss.item()
n = len(loader)
return total_loss / n, total_iou / n
# ── Main training script ──────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(in_channels=3, num_classes=1).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
criterion = CombinedLoss(dice_weight=0.5)
for epoch in range(1, 51):
loss, iou = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
scheduler.step()
if epoch % 10 == 0:
print(f"Epoch {epoch:3d} | Loss: {loss:.4f} | Train IoU: {iou:.4f}")
Golden Rules — Segmentation Best Practices
mIoU or
Dice which treat each class equally.
timm or segmentation-models-pytorch to access these backbones easily.
torch.cuda.amp).
Segmentation networks process full-resolution feature maps — memory is the bottleneck,
not compute. FP16 mixed precision cuts memory use by ~40% and speeds training by
1.5–2×, with no accuracy loss when paired with a GradScaler.