The Black Box Crisis
No reason. No breakdown. The model knew 47 things about her — credit history, cash flow, business age, postcode, transaction patterns — and it combined them in a way that no human could trace. Maria couldn't appeal because she didn't know which factor had hurt her. Was it the three missed payments in 2021? Was it her industry category? Was it where she lived?
This is the Black Box Problem. As ML models grew more powerful (ensemble trees, neural networks, gradient boosting), they also grew more opaque. A model that no one can explain is a model no one can fully trust, audit, or fix.
Enter SHAP — SHapley Additive exPlanations. SHAP is the gold standard method for explaining the output of any machine learning model. It gives every feature a precise, fair attribution value for every single prediction, grounded in rigorous mathematics from cooperative game theory.
Methods like PDP (Partial Dependence Plots), LIME, and feature importances existed before SHAP. But SHAP uniquely satisfies all three desired mathematical properties at once — consistency, local accuracy, and missingness — making it both theoretically sound and practically useful. It is now the most cited explainability method in academic literature and the most used in industry.
What Is SHAP? — The Core Idea
SHAP answers one simple but powerful question: "For this specific prediction, how much did each feature contribute?" Not on average across the dataset — for this exact data point.
The beautiful guarantee: the SHAP values always sum to the gap between the model's prediction and its average prediction. This property is called Efficiency (or local accuracy).
ⓘ Every SHAP value adds up exactly: $207k + $89k + $51k + $12k − $5k − $12k = $342k. This is guaranteed by the Efficiency axiom.
The Game Theory Foundation — Where SHAP Comes From
Simple split (£100k each) is unfair if Alice alone could have earned £180k, and Alice+Bob could earn £250k. Lloyd Shapley (Nobel Prize, 2012) solved this problem in 1953. His solution: for each consultant, look at every possible order in which they could join the project. Calculate how much value they added when they joined. Average those marginal contributions across all orderings.
That average is their Shapley value — the uniquely fair attribution. SHAP applies this exact idea to machine learning features. Each feature is a "player". The model's prediction is the "value". SHAP values are the Shapley values for features.
The Shapley Formula
For feature i, its SHAP value is:
Visualising the Shapley Calculation
With 3 features — MedInc (M), HouseAge (H), AveRooms (R) — there are 2³ = 8 possible subsets. To find SHAP(MedInc), we look at every subset without M and measure MedInc's marginal contribution when added:
The Shapley value for MedInc = the weighted average of highlighted values: ≈ +0.82. Every feature gets the same treatment, and the values are guaranteed to add up to the full prediction gap.
The Four Axioms — Why SHAP Is Uniquely Fair
Shapley proved that his allocation method is the only method satisfying all four fairness axioms simultaneously. This is what makes SHAP's explanations trustworthy rather than arbitrary.
φ₁ + φ₂ + ... + φₙ = f(x) − E[f(X)]
Shapley proved in 1953 that there is exactly one method that satisfies all four axioms. SHAP values are that unique solution. This is not a heuristic or an approximation of fairness — it is the mathematically provable definition of fair attribution.
SHAP Explainer Types — Choosing the Right One
The exact Shapley computation requires 2ⁿ model evaluations (exponential in the number of features). For a 50-feature model that is 2⁵⁰ evaluations — computationally impossible. Different SHAP explainers use clever shortcuts based on model structure.
shap.Explainer(model) without specifying type.
| Explainer | Speed | Exactness | Model Types | Best For |
|---|---|---|---|---|
| TreeSHAP | ⚡ Very Fast | Exact | Trees, XGBoost, LightGBM | Production tabular ML |
| LinearSHAP | ⚡ Very Fast | Exact | Linear models | Regression, logistic |
| KernelSHAP | ⏰ Slow | Approximate | Any model | SVM, custom models |
| DeepSHAP | 🕑 Medium | Approximate | Neural networks | Keras, PyTorch |
| GradientSHAP | 🕑 Medium | Approximate | Differentiable models | Images, text (large NNs) |
| PartitionSHAP | 🕑 Medium | Approximate | Any (structured) | Text, image (hierarchical) |
Installing SHAP & Setup
# Install the shap library (and common companions)
pip install shap xgboost lightgbm scikit-learn pandas matplotlib
# For deep learning support
pip install shap tensorflow torch
# Verify installation
import shap
print(shap.__version__) # e.g. 0.44.1
# Standard imports you will use in every SHAP project
import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
# Optional: suppress shap progress bars in notebooks
shap.initjs() # Only needed in Jupyter for force_plot
shap.initjs() is only needed for the interactive JavaScript-based force plot inside
Jupyter notebooks. All other plots work in any environment. For production code or scripts,
use shap.plots.* (the new API) instead of the older shap.summary_plot() etc.
Your First SHAP Explanation — TreeSHAP with XGBoost
We will use the California Housing dataset (predicting median house prices) throughout this tutorial. It has 8 features, 20,640 samples, and is built into scikit-learn.
import shap
import xgboost as xgb
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
# ── Load data ───────────────────────────────────────────────
data = fetch_california_housing(as_frame=True)
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# ── Train model ─────────────────────────────────────────────
model = xgb.XGBRegressor(
n_estimators=200,
max_depth=6,
learning_rate=0.1,
random_state=42,
verbosity=0
)
model.fit(X_train, y_train)
# ── Create TreeSHAP explainer ────────────────────────────────
explainer = shap.TreeExplainer(model)
# Compute SHAP values for the test set (returns Explanation object)
shap_values = explainer(X_test) # New API — shap.Explanation object
# ── The key numbers ─────────────────────────────────────────
print(f"Base value E[f(X)] : {explainer.expected_value:.4f}")
print(f"Prediction f(x₀) : {model.predict(X_test.iloc[[0]])[0]:.4f}")
print(f"SHAP sum + base : {shap_values[0].values.sum() + explainer.expected_value:.4f}")
print()
print("SHAP values for sample 0:")
for feat, val in zip(X_test.columns, shap_values[0].values):
direction = "▲" if val > 0 else "▼"
print(f" {feat:12s}: {val:+.4f} {direction}")
The sum of all SHAP values (0.8934 + 0.3211 − 0.0412 − 0.0521 + 0.1143 − 0.0231 − 0.0089 + 0.0118) plus the base value (2.0684) equals exactly the model's prediction (3.4821). This guarantee holds for every single sample, not just on average.
The Waterfall Plot — Explaining One Prediction
The waterfall plot is the go-to visualization for understanding a single prediction. It shows each feature's SHAP contribution as a bar extending left (negative) or right (positive) from the cumulative running total.
ⓘ Red bars push the prediction above the base value. Blue bars pull it down. The sum of all bars = prediction − base value = 3.4821 − 2.0684 = 1.4137.
# Generate the waterfall plot in Python
import shap
import matplotlib.pyplot as plt
# Explain sample index 0
shap.plots.waterfall(shap_values[0])
# Or explain a specific index
sample_idx = 42
shap.plots.waterfall(shap_values[sample_idx], max_display=10)
plt.tight_layout()
plt.savefig("waterfall_sample42.png", dpi=150, bbox_inches="tight")
The grey base value at the bottom is E[f(X)] — the model's average prediction over training data. Each bar extends from the cumulative total so far. Red bars (positive SHAP) push the prediction right toward higher values; blue bars push left toward lower values. Features are sorted by absolute SHAP value — the most impactful features appear at the top.
The Beeswarm Plot — Global Feature Importance
A single waterfall plot explains one prediction. The beeswarm plot shows SHAP values for all predictions simultaneously, giving global insight into which features matter most and how they affect the model.
ⓘ Each dot is one sample. Dots are stacked vertically (hence "bee" swarm) when they would overlap. Color = feature value (blue=low, red=high). Features sorted by mean |SHAP|.
# Beeswarm plot — global view across all test samples
shap.plots.beeswarm(shap_values)
# Bar chart of mean |SHAP| — cleaner global importance
shap.plots.bar(shap_values)
# Old-style summary plot (also works, returns matplotlib fig)
shap.summary_plot(shap_values.values, X_test, plot_type="violin")
shap.summary_plot(shap_values.values, X_test, plot_type="bar")
Each row is a feature, sorted by mean absolute SHAP value (most important at top). Each dot is a single prediction. The x-axis is the SHAP value — how much this feature shifted this prediction up or down. Color encodes the feature value (red=high, blue=low). Pattern to look for: if the red dots cluster on the right and blue on the left, high feature values increase predictions — a positive relationship.
The Force Plot — Push & Pull Visualization
The force plot shows the same information as a waterfall plot but in a horizontal format that is more intuitive for stakeholders. Features literally push the prediction left or right from the base value.
📈 House A — High income area, large rooms, old building
📉 House B — Low income area, small rooms, new building
# Force plot — single sample (interactive in Jupyter)
shap.plots.force(shap_values[0])
# Force plot saved as standalone HTML file
p = shap.force_plot(
explainer.expected_value,
shap_values[0].values,
X_test.iloc[0],
feature_names=X_test.columns.tolist()
)
shap.save_html("force_plot.html", p)
# Stacked force plot — many samples at once (time series)
p_multi = shap.force_plot(
explainer.expected_value,
shap_values[::10].values, # every 10th sample
X_test.iloc[::10]
)
shap.save_html("force_multi.html", p_multi)
The Dependence Plot — How One Feature Affects Predictions
The SHAP dependence plot shows how the SHAP value of one feature changes as its value changes, while automatically colouring points by the most interacting feature. It replaces Partial Dependence Plots (PDPs) because it shows the actual marginal effect for each individual sample, not just an average.
ⓘ Each dot = one sample. X-axis = raw feature value of MedInc. Y-axis = SHAP value for MedInc. The colour (AveRooms) reveals an interaction: small houses get less benefit from high income.
# SHAP Dependence Plot
# SHAP automatically picks the most interacting coloring feature
shap.plots.scatter(shap_values[:, "MedInc"])
# Or manually specify the colour feature
shap.plots.scatter(shap_values[:, "MedInc"], color=shap_values[:, "AveRooms"])
# Old API (still works)
shap.dependence_plot("MedInc", shap_values.values, X_test)
# Plot all features in a grid
for col in X_test.columns:
shap.plots.scatter(shap_values[:, col])
SHAP Interaction Values — When Features Combine
Standard SHAP values capture the total effect of a feature including its interactions. SHAP interaction values split this further: how much of a feature's effect is due to its relationship with every other feature?
For each pair of features (i, j), the SHAP interaction value φᵢⱼ measures how much they jointly affect the prediction beyond their individual effects. The diagonal (φᵢᵢ) contains the "pure" effect of each feature with interactions removed. This is only available for TreeSHAP and is expensive: O(T·L·D²·p) where p is feature count.
# SHAP Interaction Values (TreeSHAP only)
# Returns a [n_samples, n_features, n_features] array
shap_interaction_values = explainer.shap_interaction_values(X_test.iloc[:200])
print(f"Shape: {shap_interaction_values.shape}")
# → Shape: (200, 8, 8)
# Get interaction between MedInc and Latitude for sample 0
feat_idx = {name: i for i, name in enumerate(X_test.columns)}
i_medinc = feat_idx['MedInc']
i_lat = feat_idx['Latitude']
phi_ij = shap_interaction_values[0, i_medinc, i_lat]
print(f"MedInc × Latitude interaction (sample 0): {phi_ij:.4f}")
# Visualise as a heatmap
import seaborn as sns
mean_abs_int = np.abs(shap_interaction_values).mean(axis=0)
labels = X_test.columns.tolist()
plt.figure(figsize=(8, 6))
sns.heatmap(mean_abs_int, xticklabels=labels, yticklabels=labels,
annot=True, fmt=".3f", cmap="YlOrRd")
plt.title("Mean Absolute SHAP Interaction Values")
plt.tight_layout()
plt.show()
KernelSHAP — Explaining Any Model
KernelSHAP is model-agnostic: it only needs a function f(x) → prediction. It approximates Shapley values using a specially weighted linear regression on perturbed inputs. It's slower — often needing hundreds of model calls per sample — but mathematically it converges to the same Shapley values as TreeSHAP given enough samples.
import shap
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
# ── Example with Random Forest ──────────────────────────────
rf = RandomForestRegressor(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)
# Summarise background data (KernelSHAP needs a reference)
# Using k-means keeps it tractable. 50–200 samples is typical.
background = shap.kmeans(X_train, 50)
explainer_kernel = shap.KernelExplainer(rf.predict, background)
# Explain 50 test samples (slow — run fewer for debugging)
# nsamples controls accuracy vs speed tradeoff
shap_vals_kernel = explainer_kernel.shap_values(
X_test.iloc[:50],
nsamples=200, # 200 permutations per sample
silent=True
)
print(f"KernelSHAP shape: {shap_vals_kernel.shape}")
print(f"Sum check: {shap_vals_kernel[0].sum() + explainer_kernel.expected_value:.4f} vs pred {rf.predict(X_test.iloc[[0]])[0]:.4f}")
# KernelSHAP works with ANY callable:
# - API endpoints
# - scikit-learn pipelines
# - SVMs, KNNs, custom models
# - Spark ML models (via UDF wrapper)
KernelSHAP can be 100–1000× slower than TreeSHAP on the same dataset. For 1,000 test samples with 50 features and nsamples=200, expect minutes not seconds. Always profile with a small batch first. For production use, consider caching explanations or switching to a tree-based surrogate model that you can then explain with TreeSHAP.
DeepSHAP — Explaining Neural Networks
import shap
import torch
import torch.nn as nn
import numpy as np
# ── Simple neural network example ──────────────────────────
class HousePriceNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(8, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x):
return self.net(x)
model_nn = HousePriceNet()
# Convert data to tensors
X_train_t = torch.FloatTensor(X_train.values)
X_test_t = torch.FloatTensor(X_test.values)
# ── Create DeepSHAP explainer ───────────────────────────────
# Background: a random subset of training data
background = X_train_t[:100]
explainer_deep = shap.DeepExplainer(model_nn, background)
# Explain test samples
shap_vals_deep = explainer_deep.shap_values(X_test_t[:50])
print(f"DeepSHAP values shape: {shap_vals_deep[0].shape}")
# Works the same way for Keras/TensorFlow:
# explainer_tf = shap.DeepExplainer(keras_model, x_background)
# shap_vals_tf = explainer_tf.shap_values(x_test[:50])
# ── SHAP for NLP — Text classification ─────────────────────
import shap
from transformers import pipeline
# Wrap a HuggingFace pipeline as a SHAP-compatible function
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
def predict_proba(texts):
results = classifier(texts.tolist(), truncation=True)
return np.array([[r['score'] if r['label']=='POSITIVE' else 1-r['score'] for r in results]])
explainer_nlp = shap.Explainer(predict_proba, masker=shap.maskers.Text())
text = "The product quality is excellent but shipping was terrible"
shap_vals_nlp = explainer_nlp([text])
# Visualise token-level attributions
shap.plots.text(shap_vals_nlp)
SHAP vs LIME — Complete Comparison
LIME (Local Interpretable Model-agnostic Explanations) and SHAP are often mentioned together. Both explain individual predictions of black-box models. Here is a comprehensive comparison:
| Property | SHAP | LIME |
|---|---|---|
| Foundation | Cooperative game theory (Shapley, 1953). Mathematically proven unique. | Local linear regression with kernel weighting. No uniqueness guarantee. |
| Consistency | Consistent — if model changes to make feature more important, its SHAP value never decreases. | Not guaranteed — can assign lower importance to a feature that became more impactful. |
| Local Accuracy | Guaranteed — sum of SHAP values always = prediction − base value. | Approximate — surrogate model is fitted locally, residual can be non-zero. |
| Stability | Deterministic for TreeSHAP. Stable across runs. | Stochastic — results change every run due to random perturbations. |
| Global view | Beeswarm, bar, dependence plots aggregate local to global naturally. | Designed for local explanations only. Global view is awkward. |
| Speed (Tree models) | ⚡ Very fast (TreeSHAP polynomial) | ⏰ Slower (many model calls) |
| Model-agnostic | Yes (KernelSHAP) | Yes (always) |
| Handles correlations | Handles well for independent features; can be misleading with high correlations. | Poorly — correlated features can receive arbitrary splits of joint effect. |
| Text / Image support | Yes — PartitionSHAP, TextMasker, ImageMasker | Yes — superpixels for images, word removal for text |
| Production readiness | Widely used, well-maintained, many integrations | Less maintained, less tooling |
| Best when | You want the gold standard. Regulatory compliance. Research. | Quick prototype. Very custom feature encoding needed. |
In most cases, prefer SHAP. It is more principled, more stable, faster for tree models, and provides both local and global views. LIME made sense in 2016 when SHAP did not yet exist. Today, LIME is mainly used for quick debugging of text classifiers where custom tokenisation matters.
Real-World Applications — Where SHAP Runs in Production
"Your application was declined primarily because of: (1) high existing debt-to-income ratio [+2.3 risk score], (2) recent payment history [+1.8 risk score], (3) short credit history [+1.1 risk score]."
The bank passed the EU AI Act audit. The model team could also now spot when the model was using features it shouldn't (e.g., postcode as a proxy for ethnicity) — and fix it.
| Industry | Use Case | What SHAP Provides | Explainer |
|---|---|---|---|
| 🏠 Finance | Credit scoring, fraud detection | Regulatory explanations (GDPR Art.22, EU AI Act) | TreeSHAP |
| 🏥 Healthcare | Readmission risk, diagnosis support | Doctor-facing feature attributions. Model auditing. | TreeSHAP / DeepSHAP |
| 💲 Insurance | Premium pricing, claim assessment | Why this premium? Fairness auditing across demographics. | TreeSHAP |
| 🎓 EdTech | Student dropout prediction | Early intervention — which student needs help and why. | KernelSHAP |
| 🛒 E-commerce | Recommendation engines | "Recommended because you liked X and bought Y" | KernelSHAP / PartitionSHAP |
| 🏭 Manufacturing | Predictive maintenance | Which sensor reading is driving the failure alert? | TreeSHAP |
| 📋 HR / Recruiting | Attrition prediction | Why this employee is flagged + bias detection. | TreeSHAP |
SHAP in a Production Pipeline
# ── Complete production-ready SHAP pipeline ─────────────────
import shap
import pandas as pd
import numpy as np
import json, pickle
from pathlib import Path
class SHAPProductionExplainer:
"""Wraps a trained model + TreeExplainer for production serving."""
def __init__(self, model, feature_names, top_n=3):
self.model = model
self.feature_names = feature_names
self.top_n = top_n
self.explainer = shap.TreeExplainer(model)
def explain_single(self, x: pd.DataFrame) -> dict:
"""Returns prediction + top N SHAP attributions for one row."""
prediction = self.model.predict(x)[0]
shap_vals = self.explainer(x)
base_value = self.explainer.expected_value
# Sort by absolute SHAP value
pairs = sorted(
zip(self.feature_names, shap_vals[0].values),
key=lambda x: abs(x[1]),
reverse=True
)
return {
"prediction": round(float(prediction), 4),
"base_value": round(float(base_value), 4),
"top_reasons": [
{"feature": feat, "shap_value": round(float(val), 4),
"direction": "increases" if val > 0 else "decreases"}
for feat, val in pairs[:self.top_n]
]
}
# ── Usage in a FastAPI endpoint ──────────────────────────────
explainer_service = SHAPProductionExplainer(
model = model,
feature_names = X_test.columns.tolist(),
top_n = 3
)
result = explainer_service.explain_single(X_test.iloc[[0]])
print(json.dumps(result, indent=2))
Common Pitfalls
shap.KernelExplainer
on an XGBoost model "just to be safe" gives slower, approximate values when TreeExplainer
gives exact values in a fraction of the time. Always use the structure-aware explainer when available.
shap.KernelExplainer.shap_values(x) may return slightly different values.
For reproducible results, set random_state or fix numpy seed before each call.
shap_values.values) vs new Explanation object inconsistently.
shap.TreeExplainer(model)(X) returns an Explanation object.
shap.TreeExplainer(model).shap_values(X) returns a plain ndarray.
Mixing old and new API calls on the same data leads to shape mismatches. Pick one.
Golden Rules
assert abs((shap_values[i].values.sum() + base_value) - prediction) < 1e-4.
If it fails, you have the wrong explainer or API mismatch. Fix it before proceeding.
shap.TreeExplainer for any tree-based model. It is exact,
fast, and handles interaction values. For any other model, use the appropriate specialist:
LinearSHAP for linear models, DeepSHAP for neural nets, KernelSHAP for everything else.
SHAP Cheat Sheet
| Task | Code | Notes |
|---|---|---|
| Explain XGBoost / LightGBM | shap.TreeExplainer(model)(X) |
Fastest. Exact. |
| Explain any sklearn model | shap.KernelExplainer(model.predict, bg) |
Slow. Use <200 bg samples. |
| Explain linear model | shap.LinearExplainer(model, X_train) |
Fast. Handles correlations. |
| Explain PyTorch / Keras | shap.DeepExplainer(model, background) |
Approximate. Fast for NNs. |
| Single prediction plot | shap.plots.waterfall(sv[i]) |
Shows each feature's contribution. |
| Global importance | shap.plots.beeswarm(sv) |
All samples, all features. |
| Push-pull view | shap.plots.force(sv[i]) |
Great for stakeholder demos. |
| Feature interaction | shap.plots.scatter(sv[:,"feat"]) |
Auto-detects colour feature. |
| Interaction values | explainer.shap_interaction_values(X) |
TreeSHAP only. Returns [n,p,p] array. |
| Summarise background | shap.kmeans(X_train, 50) |
For KernelSHAP/DeepSHAP speed. |
| Save HTML report | shap.save_html("out.html", plot) |
For force_plot interactive widget. |
You understand where SHAP comes from (Shapley's 1953 game theory), why it is uniquely fair (four axioms), which explainer to use for which model, all five major plot types, how to deploy SHAP in production, and the mistakes to avoid. The next step is to open a notebook and run SHAP on a model you already have. The moment you see your first beeswarm plot, you will never want to ship a model without SHAP again.