Why Do We Need to Explain LLMs?
The attending physician stares at the screen. The model is right 94% of the time on test data. But why did it flag this patient? Was it the "chest pain" token? The combination of all three symptoms? Or did it latch onto something unrelated — perhaps the patient's age mentioned earlier in the note?
Without an explanation, the physician can't verify the reasoning, can't catch a failure mode, and can't defend the decision in court. The model is a black box. This is precisely the problem that Explainable AI (XAI) for LLMs tries to solve.
Large Language Models (LLMs) like GPT-4, Claude, and LLaMA achieve remarkable performance across tasks — yet their internal computations involve billions of parameters interacting through attention mechanisms, feed-forward networks, and residual streams that no human can read directly. XAI for LLMs is the discipline of building tools and methods that make these opaque systems interpretable, transparent, and auditable.
Interpretability asks: what does the model actually do internally? Explainability asks: can I generate a human-readable justification for an output? Interpretability is a property of the model; explainability is a property of the explanation. XAI typically pursues explainability using interpretability tools as its engine.
Token Attributions — Tracing Credit Back to Input
When an LLM generates a token, every input token contributed some amount — positive or negative — to that prediction. Token attribution methods assign a numerical score to each input token reflecting how much it influenced the output.
The Main Methods
Visualising Token Attribution — Sentiment Classification
Below is an animated diagram showing how Integrated Gradients attribute credit for a sentiment prediction. Green tokens pushed the prediction toward Positive; red tokens pushed it toward Negative. Token opacity encodes magnitude.
Many practitioners use attention weights as attributions because they are easily extracted. But Jain & Wallace (2019) showed that attention weights are often uncorrelated with gradient-based attributions — you can shuffle attention and not change the output. Attention tells you where the model looks, not what changes the prediction. Always prefer gradient-based or Shapley methods when faithfulness matters.
Integrated Gradients — Python Code
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
# ── Load a pre-trained sentiment model ──────────────────────
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
text = "The film was absolutely brilliant and deeply moving."
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
# ── Integrated Gradients (simplified scalar version) ────────
def integrated_gradients(input_ids, target_class=1, steps=50):
embeddings = model.get_input_embeddings()(input_ids) # [1, seq, dim]
baseline = torch.zeros_like(embeddings)
scaled_inputs = [baseline + (i / steps) * (embeddings - baseline)
for i in range(steps + 1)]
grads = []
for inp in scaled_inputs:
inp.requires_grad_(True)
out = model(inputs_embeds=inp).logits[0, target_class]
out.backward()
grads.append(inp.grad.detach().clone())
model.zero_grad()
avg_grads = torch.stack(grads).mean(dim=0) # [1, seq, dim]
ig_attrs = ((embeddings - baseline) * avg_grads).sum(dim=-1) # [1, seq]
return ig_attrs[0].detach().numpy()
attrs = integrated_gradients(input_ids)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(f"{'Token':<15} {'Attribution':>12}")
print("-" * 28)
for tok, score in zip(tokens, attrs):
bar = "█" * int(abs(score) * 20)
sign = "+" if score > 0 else "-"
print(f"{tok:<15} {sign}{abs(score):.4f} {bar}")
SHAP for LLMs — Token-Level Shapley Values
SHAP (SHapley Additive exPlanations) originates from cooperative game theory. The Shapley value distributes the total prediction fairly among all features (tokens), considering every possible coalition they could form. For language, each token is a "player" and the output logit is the "prize to share."
SHAP for Text Classification — Practical Code
import shap
import transformers
import torch
# ── Pipeline wrapper ──────────────────────────────────────────
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
pipe = transformers.pipeline(
"text-classification",
model=model_name,
return_all_scores=True,
device=0 if torch.cuda.is_available() else -1
)
# ── SHAP Text Explainer ───────────────────────────────────────
explainer = shap.Explainer(pipe)
samples = [
"The movie was absolutely brilliant and deeply moving.",
"Terrible plot, bad acting — a complete waste of time.",
"Despite some flaws, it was not terrible at all.",
]
shap_values = explainer(samples)
# ── Visualise (in Jupyter / HTML export) ─────────────────────
shap.plots.text(shap_values[0]) # coloured token heatmap
shap.plots.bar(shap_values[0, :, 1]) # bar chart for POSITIVE class
# ── Programmatic access ───────────────────────────────────────
for i, sample in enumerate(samples):
values = shap_values[i, :, 1].values # POSITIVE class scores
tokens = shap_values[i].data
print(f"\nSample {i+1}:")
for tok, val in sorted(zip(tokens, values), key=lambda x: -abs(x[1])):
if tok not in ['[CLS]', '[SEP]']:
direction = "▲ POS" if val > 0 else "▼ NEG"
print(f" {tok:15s} {val:+.4f} {direction}")
Notice in Sample 3: "not" has a positive attribution and "terrible" has a negative one — the model correctly learned that "not terrible" is a positive signal. SHAP faithfully reflects this because it evaluates all token coalitions, including "not" appearing with and without "terrible."
How Tokens Flow Through a Transformer — Visual Walkthrough
Before we can attribute meaning to tokens, we need to understand their journey through the transformer. Each token embedding is updated at every layer, influenced by all other tokens via multi-head attention.
Each token has a residual stream — a vector that accumulates information from every layer. Attribution methods like IG and SHAP attribute to this final accumulated representation. Newer mechanistic interpretability work (e.g., Anthropic's superposition research) tries to read the residual stream directly, decomposing it into interpretable features.
Chain-of-Thought Prompting as Explanation
Chain-of-Thought (CoT) prompting forces LLMs to "show their work" before answering. Instead of jumping directly to an answer, the model produces a sequence of intermediate reasoning steps. This makes the model's reasoning visible, checkable, and correctable — a form of self-explanation.
Zero-Shot CoT vs. Few-Shot CoT
| Role | Content |
|---|---|
| User | If a train travels at 60 mph for 2.5 hours, then slows to 40 mph for 1 hour, what is the total distance? |
| Model | 190 miles |
| XAI Verdict | No reasoning visible. Correct answer but no verifiability. |
| Role | Content |
|---|---|
| User | …Let's think step by step. |
| Model | Step 1: Distance₁ = 60 × 2.5 = 150 miles. Step 2: Distance₂ = 40 × 1 = 40 miles. Step 3: Total = 150 + 40 = 190 miles. |
| XAI Verdict | Each step auditable. Errors localizable. |
Faithful vs. Unfaithful CoT — A Critical Distinction
A critical question in XAI is: does the chain-of-thought actually cause the answer, or is it a post-hoc rationalisation? Research by Turpin et al. (2023) showed that CoT can be unfaithful — the model produces a persuasive-looking reasoning chain, yet the actual computation driving the answer is different.
Testing CoT Faithfulness — Intervention Method
import openai
import re
client = openai.OpenAI() # or use Anthropic / local LLM
def cot_with_intervention(question, wrong_intermediate):
"""
Tests CoT faithfulness by injecting a wrong intermediate step
and checking if the model corrects itself or blindly follows.
"""
# Faithful model: re-computes from injected wrong step → wrong answer
# Unfaithful model: ignores the step → returns its pre-computed answer
prompt_faithful = f"""Solve step by step.
{question}
Let me start: Step 1: {wrong_intermediate}
Continue from Step 2 onward:"""
resp = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt_faithful}],
temperature=0
)
continuation = resp.choices[0].message.content
# Extract final answer
nums = re.findall(r'\b\d+(?:\.\d+)?\b', continuation)
final = nums[-1] if nums else "?"
return continuation, final
question = "Train travels 60 mph for 2.5h, then 40 mph for 1h. Total distance?"
wrong_step = "Distance₁ = 60 × 2.5 = 120 miles (incorrect injected value)"
chain, answer = cot_with_intervention(question, wrong_step)
print("Chain continuation:\n", chain)
print("\nFinal answer extracted:", answer)
print("Expected faithful answer: 160 (follows injected wrong step)")
print("If model says 190 → UNFAITHFUL (ignored your step)")
CoT provides a textual explanation at the prompt level. Token attribution methods like IG and SHAP operate at the model-internal level. Both are needed: CoT tells you what reasoning the model claims to do; attribution methods tell you what the model actually does internally. In high-stakes settings, verify CoT faithfulness with causal intervention.
Attention Maps — Visualising What the Model "Looks At"
Every transformer layer has multiple attention heads, each computing a weighted average over all token positions. Visualising these weights as a heatmap is one of the most popular (and most abused) interpretability tools.
XAI Taxonomy for LLMs — Which Method for Which Goal?
No single XAI method dominates all settings. The choice depends on your access to model internals, the granularity of explanation needed, and whether you prioritise faithfulness or human readability.
| Method | Type | Access Required | Faithfulness | Human Readability | Speed | Best For |
|---|---|---|---|---|---|---|
| Integrated Gradients | Gradient | White-box | High | Medium | Fast | Open-source LLMs, token-level attribution |
| SHAP (Text) | Shapley | Black-box / API | Very High | High | Slow | Production APIs, stakeholder reports |
| LIME | Perturbation | Black-box / API | Medium | High | Medium | Quick local explanations, non-technical audiences |
| Attention Attribution | Attention | White-box | Low–Medium | Very High | Very Fast | Debugging, visualisation dashboards |
| Chain-of-Thought | Generative | Prompt-level | Variable | Very High | Fast | End-user explanations, reasoning traces |
| Mechanistic Interp. | Circuit | Full internals | Very High | Low | Very Slow | Research, safety auditing, capability probing |
| Probing Classifiers | Linear Probe | Activations | Medium | Medium | Fast | Checking what concepts are represented in layers |
Mechanistic Interpretability — Reading the Circuit
Mechanistic interpretability goes beyond attributing tokens — it tries to reverse-engineer the actual computational circuits (sub-graphs of attention heads and MLP neurons) that implement a specific behaviour.
Activation Patching — Python Sketch
from transformer_lens import HookedTransformer, patching
import torch
# Load a small interpretable model (TransformerLens wraps HuggingFace)
model = HookedTransformer.from_pretrained("gpt2-small")
# Two prompts: clean (IOI task) and corrupted (name swapped)
clean_prompt = "When Mary and John went to the store, John gave a bottle to"
corrupted_prompt = "When Mary and John went to the store, Mary gave a bottle to"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
# Get logits and cache for both
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupt_cache = model.run_with_cache(corrupted_tokens)
mary_idx = model.to_single_token(" Mary")
john_idx = model.to_single_token(" John")
# Metric: logit difference (John - Mary at last position)
def ioi_metric(logits):
last = logits[0, -1]
return (last[john_idx] - last[mary_idx]).item()
clean_metric = ioi_metric(clean_logits)
corrupted_metric = ioi_metric(corrupted_logits)
print(f"Clean metric (John-Mary logit diff): {clean_metric:.3f}")
print(f"Corrupted metric: {corrupted_metric:.3f}")
# Patch residual stream at each position and layer — find causal heads
results = patching.get_act_patch_resid_pre(
model, corrupted_tokens, clean_cache,
lambda logits: ioi_metric(logits) - corrupted_metric
)
print("Patching results shape:", results.shape(), "[layers × positions]")
print("Max causal impact layer:", results.argmax(dim=0))
Probing Classifiers — What Does Each Layer Know?
A probing classifier is a simple linear (or shallow) model trained on the internal representations (activations) of an LLM layer to predict some property — POS tags, sentiment, syntactic role, factual attributes. If the probe achieves high accuracy, the LLM's representations encode that property at that layer.
Layer-by-Layer Probe Accuracy — Animated Bar Chart
A high probing accuracy means the information is linearly decodable from the representation — not that the model uses it for its output. A representation can encode POS information that is completely ignored by downstream layers. Combine probing with activation patching to test causal relevance.
Self-Consistency & CoT Faithfulness Testing
Self-Consistency (Wang et al. 2022) improves CoT reliability by sampling multiple reasoning paths and taking a majority vote on the final answer. From an XAI perspective, disagreeing chains highlight exactly which reasoning steps are uncertain.
import openai
from collections import Counter
import re, json
client = openai.OpenAI()
def self_consistent_cot(question, n_samples=5, temperature=0.7):
"""Generate n CoT chains and return majority answer + diversity stats."""
prompt = f"""{question}
Think step by step. At the end write: ANSWER: [your final answer]"""
chains, answers = [], []
for _ in range(n_samples):
resp = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=temperature, max_tokens=400
)
text = resp.choices[0].message.content
chains.append(text)
match = re.search(r"ANSWER:\s*(.+)", text)
answers.append(match.group(1).strip() if match else "?")
vote = Counter(answers)
majority, count = vote.most_common(1)[0]
confidence = count / n_samples
print(f"\n{'='*50}")
print(f"Majority answer : {majority}")
print(f"Confidence : {confidence:.0%} ({count}/{n_samples} chains agree)")
print(f"Answer spread : {dict(vote)}")
# XAI insight: if confidence < 60%, flag for human review
if confidence < 0.6:
print("⚠️ LOW CONFIDENCE — reasoning chains disagree significantly.")
print(" Recommend: human review or additional context.")
return majority, chains
ans, chains = self_consistent_cot(
"A snail travels 3 metres per hour uphill and 6 metres per hour downhill. "
"It travels 12 km uphill and 12 km downhill. What is its average speed?"
)
INSEQ — Unified Attribution for Generative LLMs
INSEQ (Sarti et al. 2023) is a Python library that wraps Hugging Face causal and encoder-decoder models and provides 15+ attribution methods through a single unified API. It is the closest thing to a "standard toolkit" for LLM interpretability.
# pip install inseq
import inseq
# ── Load model with attribution method ───────────────────────
model = inseq.load_model(
"gpt2",
attribution_method="integrated_gradients"
)
# ── Attribute a generation step ──────────────────────────────
out = model.attribute(
input_texts="The capital of France is",
n_steps=50, # IG integration steps
show_progress=False
)
# ── Inspect attributions ────────────────────────────────────
out.show() # HTML heatmap in Jupyter
# Programmatic access
step = out.sequence_attributions[0]
print("Generated token:", step.target)
print("\nInput attributions:")
for token, attr in zip(step.source_tokens, step.source_attributions[0]):
print(f" {token:15s} {attr.item():+.4f}")
# ── Compare methods side by side ────────────────────────────
methods = ["integrated_gradients", "input_x_gradient", "attention"]
for method in methods:
m = inseq.load_model("gpt2", attribution_method=method)
o = m.attribute("The capital of France is", show_progress=False)
top = sorted(
zip(o.sequence_attributions[0].source_tokens,
o.sequence_attributions[0].source_attributions[0]),
key=lambda x: -abs(x[1].item())
)[:2]
print(f"{method:25s} top tokens: {[t for t,_ in top]}")
The token "France" has the highest attribution for generating "Paris" — sensible, as France directly specifies the country. "capital" is second — it signals the relationship type. Notice how attention puts "is" high, while gradient methods do not. This illustrates why gradient methods are preferred for faithful attribution.
Golden Rules — XAI for LLMs in Production
[MASK] token.
For causal models, use a pad token or blank-text embedding.
The completeness axiom guarantees attributions sum to the
output difference from your chosen baseline.
XAI Methods at a Glance — Full Reference Table
| Method | Family | Key Axioms Met | Pros | Cons | Python Library |
|---|---|---|---|---|---|
| Integrated Gradients | Gradient | Completeness, Sensitivity | Fast, principled baseline | Requires white-box access | inseq, captum |
| Gradient × Input | Gradient | Sensitivity | Fastest gradient method | No completeness guarantee | inseq |
| SmoothGrad | Gradient | Sensitivity | Reduces gradient noise | Slower, many forward passes | captum |
| SHAP (Partition) | Shapley | Efficiency, Symmetry, Dummy | All axioms satisfied | Exponential exact complexity | shap |
| LIME | Surrogate | Local fidelity | Black-box, human-friendly | Unstable, no axioms | lime |
| Attention Rollout | Attention | None formally | No gradient needed | Unfaithful to prediction | bertviz |
| CoT Prompting | Generative | None formal | Natural language, auditable | May be post-hoc rationalisation | Any LLM API |
| Activation Patching | Causal | Causal sufficiency, necessity | Truly causal — not correlational | Very slow, needs clean/corrupt pairs | transformer_lens |
| Probing Classifiers | Representational | None formal | Layer-wise concept tracking | Linear decodability ≠ causal use | sklearn + HF |