Neural Network Pruning: Reducing Model Complexity
A ResNet-50 model has over 25 million parameters. GPT-3 has 175 billion. Yet systematic research shows that the majority of these parameters are redundant: trained neural networks can lose over 90% of their weights without significant accuracy degradation. Pruning — the technique of systematically removing superfluous parameters — is one of the most powerful tools for reducing the computational complexity of deep learning models.
Unlike quantization, which reduces the numerical precision of parameters, pruning eliminates them entirely. The result can be a smaller, faster, and cheaper model to run — especially when applying structured pruning, which removes entire neurons, filters, or attention heads, producing real speedups on hardware without requiring sparsity support.
In this guide we explore pruning in depth: from the theory of the Lottery Ticket Hypothesis to practical implementations with PyTorch, from magnitude pruning to movement pruning for Transformers, through iterative workflows and combinations with quantization.
What You Will Learn
- Difference between structured and unstructured pruning, and when to use each
- Magnitude pruning: the simplest and most effective method
- Movement pruning for modern Transformers and LLMs
- Lottery Ticket Hypothesis: the theory explaining why pruning works
- PyTorch pruning API with complete examples
- Iterative pruning workflow with retraining
- Torch-Pruning for advanced structured pruning
- Combining pruning and quantization for maximum compression
- Real benchmarks on accuracy, memory, and speed
- Best practices and common anti-patterns
Why Pruning? The Redundancy Problem
Modern neural networks are notoriously over-parameterized. This redundancy is partly intentional: larger networks train more easily and generalize better, but at deployment time they carry unnecessary computational weight. Three fundamental empirical observations motivate pruning:
- Weight redundancy: Studies on trained networks show that weight distributions are heavily concentrated around zero. Removing weights with the smallest magnitude has minimal impact on predictions.
- Lottery Ticket Hypothesis (Frankle & Carlin, 2019): Every trained dense neural network contains a "winning ticket" subnetwork that, when re-initialized with its original values and trained alone, reaches comparable performance to the full network.
- Over-parameterization as a tool: Extra parameters serve the training process (smoother loss landscape, easier escape from local minima) but are not needed for inference.
Pruning Impact: Real Data
Research on ResNet and BERT shows that models can lose 70-90% of parameters with less than 1-2% accuracy loss. Structured pruning of BERT-base Transformers at 50% sparsity produces a 2x FLOPs reduction and a 1.5x inference speedup while maintaining over 99% of original accuracy. In the LLM context, block pruning techniques for Transformers have shown speedups of up to 2.4x on SQuAD with only 1% F1 loss.
Structured vs Unstructured Pruning
The fundamental distinction in pruning is between structured and unstructured approaches. The choice depends on the target hardware and deployment goals:
| Aspect | Unstructured | Structured |
|---|---|---|
| What it removes | Individual weights (arbitrary) | Neurons, filters, channels, attention heads, layers |
| Resulting sparsity | Irregular (sparse matrix) | Regular (reduced dimensions) |
| Real speedup on standard CPU/GPU | None (without sparse ops) | Yes, immediate with dense ops |
| Speedup on sparse hardware | Yes | Yes |
| Memory reduction | Only with explicit sparse format | Always (reduced dimensions) |
| Accuracy at equal sparsity | Better | Slightly lower |
| Implementation complexity | Simple | More complex (dependency recalculation) |
Unstructured pruning is more flexible: it can remove any weight regardless of position. The problem is that the resulting matrix remains dense in memory (explicit zeros), and modern hardware gains no benefit from irregular sparsity without specific support. NVIDIA introduced 2:4 sparsity support with Ampere GPUs, but this requires specific patterns. Structured pruning, by removing complete structures, produces measurably smaller models: a Linear layer with 512 neurons pruned to 256 simply becomes a Linear(in, 256) run with standard dense operations.
Magnitude Pruning: The Fundamental Method
Magnitude pruning is the simplest and surprisingly effective approach: remove weights whose absolute value is below a threshold. The intuition is that small weights contribute little to the signal transmitted by the network. Despite its simplicity, when combined with iterative retraining it produces results competitive with far more sophisticated methods.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np
# ===================================================================
# MAGNITUDE PRUNING WITH PYTORCH NATIVE API
# ===================================================================
class ConvNet(nn.Module):
"""Simple CNN model to demonstrate pruning."""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(4)
)
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
model = ConvNet()
# --- L1 unstructured pruning (magnitude-based) ---
# Removes 30% of weights with smallest absolute value
prune.l1_unstructured(
model.features[0], # Layer to prune
name='weight', # Parameter to prune
amount=0.30 # Percentage to remove (30%)
)
# --- Random pruning (comparison baseline) ---
prune.random_unstructured(
model.features[2],
name='weight',
amount=0.30
)
# --- Sparsity analysis ---
def compute_sparsity(module):
"""Calculates the effective sparsity of a module."""
total = 0
zeros = 0
for param in module.parameters():
total += param.numel()
zeros += (param == 0).sum().item()
return zeros / total if total > 0 else 0.0
print("Conv1 sparsity:", f"{compute_sparsity(model.features[0]):.1%}")
print("Conv2 sparsity:", f"{compute_sparsity(model.features[2]):.1%}")
# --- Internal pruning structure ---
# PyTorch creates weight_orig (original) + weight_mask (0/1)
print("\nmodel.features[0] parameters after pruning:")
for name, param in model.features[0].named_parameters():
print(f" {name}: shape={param.shape}")
for name, buf in model.features[0].named_buffers():
print(f" buffer {name}: shape={buf.shape}")
# --- Remove mask (make permanent) ---
# After retraining, consolidate: model reverts to using 'weight'
prune.remove(model.features[0], 'weight')
# --- Global Pruning: across the entire model ---
# More effective than per-layer: uses a global threshold
parameters_to_prune = (
(model.features[0], 'weight'),
(model.features[2], 'weight'),
(model.classifier[0], 'weight'),
(model.classifier[2], 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.40, # Removes 40% globally (not per layer)
)
# Final sparsity per layer
for module_name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
if hasattr(module, 'weight_mask'):
sparsity = (module.weight_mask == 0).float().mean().item()
print(f"{module_name}: sparsity {sparsity:.1%}")
Important: PyTorch Native Pruning Does Not Speed Up Inference
The torch.nn.utils.prune API applies a binary mask on weights, zeroing
selected ones but maintaining the original dense structure. The resulting model occupies the
same memory and takes the same time for the forward pass. For real speedups you need:
structured pruning (with physical removal of structures), or specific libraries for sparse
operations. PyTorch native pruning is excellent for experimentation and for QAT with sparsity,
but not for direct deployment speedups.
Structured Pruning with Torch-Pruning
The Torch-Pruning library (Fang et al., CVPR 2023) solves the real structured pruning problem: removing a filter from a Conv2D layer requires updating the next layer too (which expects N input channels, not N-k). Torch-Pruning automatically handles these dependencies through a dependency graph (DepGraph), supporting complex architectures including ViT, LLMs, YOLO, and models with skip connections.
# pip install torch-pruning
import torch
import torch.nn as nn
import torch_pruning as tp
# ===================================================================
# STRUCTURED PRUNING WITH TORCH-PRUNING
# ===================================================================
class ResidualBlock(nn.Module):
"""Residual block: Torch-Pruning handles skip connections automatically."""
def __init__(self, channels=64):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return self.relu(out + residual) # Skip connection
class SimpleResNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.layer1 = ResidualBlock(64)
self.layer2 = ResidualBlock(64)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x).view(x.size(0), -1)
return self.fc(x)
model = SimpleResNet()
model.eval()
# Example input to trace dependencies
example_input = torch.randn(1, 3, 32, 32)
# --- Build dependency graph ---
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input)
# --- Model analysis BEFORE pruning ---
macs_before, params_before = tp.utils.count_ops_and_params(model, example_input)
print(f"Parameters BEFORE: {params_before / 1e6:.2f}M")
print(f"MACs BEFORE: {macs_before / 1e9:.3f}G")
# --- Define pruning strategy ---
# L1 magnitude pruning of filters (L2 available with tp.strategy.L2Strategy)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=example_input,
importance=tp.importance.MagnitudeImportance(p=1), # L1 norm
iterative_steps=5, # Iterative pruning in 5 steps
ch_sparsity=0.5, # Remove 50% of channels
ignored_layers=[model.fc], # Don't prune the final classifier
)
# --- Execute pruning (single step) ---
pruner.step()
# --- Model analysis AFTER pruning ---
macs_after, params_after = tp.utils.count_ops_and_params(model, example_input)
print(f"\nParameters AFTER: {params_after / 1e6:.2f}M")
print(f"MACs AFTER: {macs_after / 1e9:.3f}G")
print(f"Parameter reduction: {(1 - params_after/params_before):.1%}")
print(f"MACs reduction: {(1 - macs_after/macs_before):.1%}")
# --- Post-pruning architecture check ---
print("\nPost-pruning architecture:")
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
print(f" {name}: Conv2d({module.in_channels}, {module.out_channels}, ...)")
# Typical output:
# Parameters BEFORE: 0.15M | MACs BEFORE: 0.009G
# Parameters AFTER: 0.04M | MACs AFTER: 0.003G
# Parameter reduction: 75% | MACs reduction: 72%
# layer1.conv1: Conv2d(32, 32, ...) <- from 64 down to 32 channels
Movement Pruning for Transformers
Magnitude pruning works well for CNNs, but Transformers present a different challenge: attention weights may have small magnitudes but be critical to the model's behavior. Movement pruning (Sanh et al., 2020) addresses this with a radically different approach: instead of removing small weights, it removes those that are moving toward zero during fine-tuning. In other words, the criterion is the weight gradient with respect to the pruning objective, not the current weight value.
Movement pruning has shown significant advantages for pruning BERT models: at high sparsity (80-97%), movement pruning outperforms magnitude pruning by 10-20 percentage points on NLP benchmarks like MNLI and SQuAD.
import torch
import torch.nn as nn
# ===================================================================
# MOVEMENT PRUNING (conceptual implementation)
# ===================================================================
class MovementPruningLinear(nn.Module):
"""
Linear layer with movement pruning.
Maintains a score for each weight: the score is optimized
during training. Weights with low scores are pruned.
"""
def __init__(self, in_features, out_features, pruning_ratio=0.5):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
self.bias = nn.Parameter(torch.zeros(out_features))
# Scores initialized to zero: rise during training for important weights
self.scores = nn.Parameter(torch.zeros_like(self.weight))
self.pruning_ratio = pruning_ratio
self.mask = None
def update_mask(self):
"""Updates mask based on current scores."""
k = int(self.scores.numel() * (1 - self.pruning_ratio))
# Top-k scores: keep weights with highest score
threshold = torch.kthvalue(self.scores.flatten(), self.scores.numel() - k).values
self.mask = (self.scores >= threshold).float().detach()
def forward(self, x):
# Apply mask during forward pass
if self.mask is None:
self.update_mask()
masked_weight = self.weight * self.mask
return nn.functional.linear(x, masked_weight, self.bias)
# ===================================================================
# ATTENTION HEAD IMPORTANCE SCORING
# ===================================================================
def compute_head_importance(model, dataloader, device):
"""
Computes the importance of each attention head using Taylor expansion.
A head is important if removing it greatly increases loss.
"""
model.eval()
head_importance = torch.zeros(
model.config.num_hidden_layers,
model.config.num_attention_heads
).to(device)
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch, output_attentions=True)
loss = outputs.loss
loss.backward()
# Accumulate gradients to estimate importance
for layer_idx, layer in enumerate(model.bert.encoder.layer):
attn = layer.attention.self
grad = attn.value.weight.grad
weight = attn.value.weight
if grad is not None:
importance = (grad * weight).abs().view(
model.config.num_attention_heads, -1
).sum(dim=-1)
head_importance[layer_idx] += importance
return head_importance
print("Movement pruning and head importance pruning: schema implemented.")
print("Typical results on BERT-base with 40% attention pruning:")
print(" - Inference speedup: 1.3-1.5x")
print(" - Model size: -35%")
print(" - GLUE accuracy: -0.5 to -1.5 points")
Lottery Ticket Hypothesis: The Theory of the Winning Subnetwork
The Lottery Ticket Hypothesis (LTH, Frankle & Carlin, NeurIPS 2019) is one of the most influential theoretical discoveries in pruning: every dense trained neural network contains one or more sparse subnetworks ("winning tickets") that, if extracted and re-initialized with their original initial values, can be trained alone to achieve performance comparable or superior to the full network, in training times less than or equal.
The LTH has important practical implications: it suggests that the large model serves primarily to find the right structure, not for the intrinsic capabilities of its parameters. The standard process for finding a winning ticket is Iterative Magnitude Pruning (IMP).
import torch
import torch.nn as nn
import copy
from typing import Dict, List
# ===================================================================
# ITERATIVE MAGNITUDE PRUNING (Lottery Ticket Hypothesis)
# ===================================================================
def save_initial_weights(model: nn.Module) -> Dict[str, torch.Tensor]:
"""Saves the model's initial weights (before training)."""
return {
name: param.data.clone()
for name, param in model.named_parameters()
if 'weight' in name
}
def apply_mask_and_reinit(
model: nn.Module,
initial_weights: Dict[str, torch.Tensor],
masks: Dict[str, torch.Tensor]
) -> nn.Module:
"""
Resets weights to initial values with pruning masks applied.
This is the critical LTH step: reinitialize (not random, but to original values).
"""
with torch.no_grad():
for name, param in model.named_parameters():
if name in initial_weights and name in masks:
param.data = initial_weights[name] * masks[name]
return model
def compute_pruning_masks(
model: nn.Module,
pruning_ratio: float
) -> Dict[str, torch.Tensor]:
"""Computes pruning masks for magnitude (L1)."""
masks = {}
for name, param in model.named_parameters():
if 'weight' in name and param.dim() > 1:
threshold = torch.quantile(param.abs(), pruning_ratio)
masks[name] = (param.abs() >= threshold).float()
return masks
def iterative_magnitude_pruning(
model: nn.Module,
train_fn,
eval_fn,
n_rounds: int = 5,
prune_per_round: float = 0.20,
epochs_per_round: int = 10
):
"""
Implementation of Iterative Magnitude Pruning (LTH).
Algorithm:
1. Save initial weights (w0)
2. Train for N epochs
3. Prune P% of weights with smallest magnitude
4. Reinitialize surviving weights to w0
5. Repeat from step 2
"""
initial_weights = save_initial_weights(model)
masks = {name: torch.ones_like(param)
for name, param in model.named_parameters()
if 'weight' in name}
results = []
for round_idx in range(n_rounds):
print(f"\n--- IMP Round {round_idx + 1}/{n_rounds} ---")
# Train with current masks applied
train_fn(model, epochs=epochs_per_round, masks=masks)
# Compute new pruning masks
effective_prune = 1 - (1 - prune_per_round) ** (round_idx + 1)
new_masks = compute_pruning_masks(model, effective_prune)
# Reinitialize with initial weights and new masks
model = apply_mask_and_reinit(model, initial_weights, new_masks)
masks = new_masks
# Evaluate
accuracy = eval_fn(model)
total_sparsity = sum(
(m == 0).float().mean().item()
for m in masks.values()
) / len(masks)
results.append({
'round': round_idx + 1,
'accuracy': accuracy,
'sparsity': total_sparsity
})
print(f"Accuracy: {accuracy:.2%} | Sparsity: {total_sparsity:.1%}")
return model, results
# Typical IMP results on ResNet-20 / CIFAR-10:
# Round 1 (20% pruned): 91.8% accuracy (baseline: 91.9%)
# Round 2 (36% pruned): 91.7% accuracy
# Round 3 (49% pruned): 91.5% accuracy
# Round 5 (67% pruned): 90.8% accuracy <- "winning ticket"
# Round 8 (83% pruned): 89.1% accuracy <- accuracy starts degrading
# Round 10 (89% pruned): 87.3% accuracy <- typical end of usefulness
LTH in Practice: Limitations
- Computational cost: IMP requires many train-prune-reinit cycles, making it expensive for large models. For LLMs, more efficient variants like GMP (Gradual Magnitude Pruning) are used, which don't require reinitialization.
- Scalability: The original LTH works well on small models. For BERT and GPT, reinitialization to initial weights does not produce clear benefits; pruning + fine-tuning on current weights is used instead.
- Transfer learning: Research from 2020 (Chen et al.) shows that "winning tickets" from pre-trained models like BERT are transferable to downstream tasks, opening interesting applications.
Iterative Pruning Workflow with Retraining
The most effective production workflow is not one-shot pruning (remove 50% of weights at once) but iterative pruning with retraining: prune gradually, giving the network time to "recover" at each step. This produces significantly more accurate models at the same target sparsity.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from torch.optim.lr_scheduler import CosineAnnealingLR
# ===================================================================
# COMPLETE ITERATIVE PRUNING WORKFLOW
# ===================================================================
def get_global_sparsity(model: nn.Module) -> float:
"""Computes global model sparsity."""
total_params = 0
zero_params = 0
for name, param in model.named_parameters():
if 'weight' in name:
total_params += param.numel()
zero_params += (param == 0).sum().item()
return zero_params / total_params if total_params > 0 else 0.0
def iterative_pruning_with_finetuning(
model: nn.Module,
train_loader,
val_loader,
target_sparsity: float = 0.70,
n_pruning_steps: int = 7,
finetune_epochs_per_step: int = 3,
lr: float = 1e-4,
device: str = 'cuda'
):
"""
Iterative pruning with post-pruning fine-tuning.
Strategy: increase sparsity gradually using a cubic schedule
(more aggressive at the start, more conservative at the end).
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
history = []
# Cubic sparsity schedule
sparsity_schedule = [
1 - (1 - target_sparsity * (step / n_pruning_steps) ** 3)
for step in range(1, n_pruning_steps + 1)
]
print(f"Sparsity schedule: {[f'{s:.1%}' for s in sparsity_schedule]}")
for step_idx, target_sparsity_step in enumerate(sparsity_schedule):
print(f"\n=== Step {step_idx + 1}/{n_pruning_steps} | Target sparsity: {target_sparsity_step:.1%} ===")
parameters_to_prune = [
(module, 'weight')
for name, module in model.named_modules()
if isinstance(module, (nn.Linear, nn.Conv2d))
]
# Global L1 pruning
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=target_sparsity_step
)
actual_sparsity = get_global_sparsity(model)
print(f"Actual sparsity: {actual_sparsity:.1%}")
# Post-pruning fine-tuning
scheduler = CosineAnnealingLR(optimizer, T_max=finetune_epochs_per_step)
for epoch in range(finetune_epochs_per_step):
model.train()
for batch_x, batch_y in train_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
loss = criterion(model(batch_x), batch_y)
loss.backward()
optimizer.step()
scheduler.step()
# Evaluation
model.eval()
correct = total = 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
pred = model(batch_x).argmax(dim=1)
correct += (pred == batch_y).sum().item()
total += batch_y.size(0)
val_acc = correct / total
history.append({'step': step_idx+1, 'sparsity': actual_sparsity, 'val_acc': val_acc})
print(f"Val accuracy: {val_acc:.2%}")
# Consolidate masks (make pruning permanent)
for module, param_name in parameters_to_prune:
try:
prune.remove(module, param_name)
except ValueError:
pass
return model, history
Pruning + Quantization: Maximum Compression
Pruning and quantization are complementary techniques that combine effectively. Pruning reduces the number of parameters; quantization reduces the precision of each remaining parameter. Applied together, they produce extremely compact models. This combination is known as "sparse quantization" or "quantized sparse models".
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# ===================================================================
# COMBINING PRUNING + QUANTIZATION
# ===================================================================
def prune_then_quantize(model: nn.Module, prune_ratio: float = 0.30):
"""
Pipeline: structured pruning -> INT8 dynamic quantization.
"""
print(f"Parameters before: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
# Step 1: Global L1 unstructured pruning
parameters_to_prune = [
(module, 'weight')
for name, module in model.named_modules()
if isinstance(module, nn.Linear) and 'classifier' not in name
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=prune_ratio
)
# Consolidate masks
for module, param_name in parameters_to_prune:
prune.remove(module, param_name)
zero_params = sum(
(param == 0).sum().item()
for name, param in model.named_parameters()
if 'weight' in name
)
total_params = sum(
param.numel()
for name, param in model.named_parameters()
if 'weight' in name
)
print(f"Sparsity after pruning: {zero_params/total_params:.1%}")
# Step 2: Dynamic INT8 quantization on pruned model
model_quantized = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
return model_quantized
# --- Compression benchmark ---
compression_results = [
{"method": "FP32 (baseline)", "sparsity": "0%", "precision": "FP32", "size_mb": 440},
{"method": "Pruning 50%", "sparsity": "50%", "precision": "FP32", "size_mb": 220},
{"method": "INT8 Quantization", "sparsity": "0%", "precision": "INT8", "size_mb": 110},
{"method": "Pruning 50% + INT8", "sparsity": "50%", "precision": "INT8", "size_mb": 55},
{"method": "Pruning 70% + INT4", "sparsity": "70%", "precision": "INT4", "size_mb": 33},
]
print(f"\n{'Method':<28} {'Sparsity':>10} {'Precision':>12} {'Size':>10}")
print("-" * 65)
for r in compression_results:
print(f"{r['method']:<28} {r['sparsity']:>10} {r['precision']:>12} {r['size_mb']:>8} MB")
# Output (BERT-base ~440MB in FP32):
# Method Sparsity Precision Size
# FP32 (baseline) 0% FP32 440 MB
# Pruning 50% 50% FP32 220 MB
# INT8 Quantization 0% INT8 110 MB
# Pruning 50% + INT8 50% INT8 55 MB
# Pruning 70% + INT4 70% INT4 33 MB
Benchmarks: Accuracy, Speedup, and Memory
Pruning results vary significantly based on the model, task, and method. The following table reports indicative benchmarks for BERT-base and ResNet-50, based on literature results and practical experiments:
| Model | Method | Sparsity | Accuracy | Speedup | Memory |
|---|---|---|---|---|---|
| BERT-base (MNLI) | Baseline FP16 | 0% | 84.6% | 1.0x | 440 MB |
| BERT-base (MNLI) | Magnitude unstr. | 50% | 84.1% | 1.0x* | 440 MB* |
| BERT-base (MNLI) | Movement pruning | 70% | 83.5% | 1.0x* | 440 MB* |
| BERT-base (MNLI) | Head pruning 30% | 30% heads | 84.0% | 1.3x | 310 MB |
| BERT-base (SQuAD) | Block pruning str. | 50% | F1 -1% | 2.4x | 220 MB |
| ResNet-50 (ImageNet) | L1 filter pruning | 40% | Top-1 -0.5% | 1.5x | -40% |
| ResNet-50 (ImageNet) | Iterative pruning | 70% | Top-1 -1.2% | 2.1x | -65% |
* Unstructured pruning: no speedup on standard hardware without dedicated sparse ops.
Recommendations by Target Hardware
- Standard NVIDIA GPU: Prefer structured pruning (Torch-Pruning, head pruning). Unstructured pruning brings no benefit without sparse support, unless using NVIDIA Ampere's 2:4 sparsity format (50% sparsity in specific 2 non-zero per 4 patterns).
- CPU (inference deployment): Unstructured pruning at high sparsity (>80%) can bring speedups with libraries like Intel oneDNN or by converting to CSR/CSC format. But structured pruning remains more predictable.
- Edge devices (Jetson, Raspberry Pi): Structured pruning + INT8 quantization or GGUF. Model size reduction is critical: even 2x fewer parameters can make the difference between runnable and not runnable.
- Mobile (ARM): Use libraries like XNNPACK or CoreML with INT8 quantization and structured pruning for real hardware acceleration.
Best Practices and Anti-Patterns
Best Practices for Pruning
- Use iterative pruning, not one-shot: Prune 10-20% per step with intermediate retraining. A single aggressive removal of 70% almost always irreversibly degrades accuracy.
- Apply retraining after each step: Even 1-3 fine-tuning epochs after each pruning round recovers most of the lost accuracy. The learning rate should be low (10-100x lower than original training).
- Choose method based on target hardware: Structured pruning for real speedups on standard hardware; unstructured only if you have access to sparse-capable hardware.
- Do not prune critical layers: The first and last layers (embedding, classifier) are most sensitive. Exclude or significantly reduce pruning on these layers.
- Monitor weight distribution during pruning: If too many weights from the same layer get pruned (>80%), the layer may collapse. Set a minimum threshold per layer.
- Evaluate on task metrics, not just loss: Training loss may not capture degradation on edge cases. Use domain-specific metrics (F1, BLEU, accuracy on test set).
Anti-Patterns to Avoid
-
Do not expect speedups from unstructured pruning on standard GPUs:
The
torch.nn.utils.pruneAPI zeroes weights but doesn't physically remove them. Inference time does not decrease without dedicated sparse ops. -
Do not mix masks and weights without consolidating: Before exporting or
distributing the model, always call
prune.remove(module, 'weight')to consolidate the mask into the parameter. Otherwise the model has memory overhead and non-portable dependencies. - Do not use too small a validation dataset: Aggressive pruning can cause overfitting on the validation set used to monitor accuracy. Use a held-out test set for final evaluation.
- Do not ignore normalization layers: BatchNorm and LayerNorm maintain statistics tied to the dimensions of preceding layers. After structured pruning, normalization statistics must be recalibrated (re-run on calibration dataset).
- Do not apply pruning to non-converged models: Pruning works best on well-trained models. Applying it to a model that has not yet converged produces unpredictable results.
Structured Pruning for Vision Transformers
Vision Transformers introduce new pruning opportunities not available in CNNs: attention head pruning, patch token pruning (removing uninformative image patches mid-inference), and layer dropping. Research from 2023-2025 shows that ViTs are surprisingly amenable to structured pruning — up to 50% of attention heads can be removed from ViT-B/16 with less than 1% accuracy drop on ImageNet, because many heads learn redundant patterns.
import torch
import torch.nn as nn
import timm
import numpy as np
from typing import Dict, List
# ===================================================================
# ATTENTION HEAD PRUNING FOR VIT
# ===================================================================
class ViTPruner:
"""
Structured pruner for Vision Transformers.
Supports attention head pruning and FFN neuron pruning.
Uses Taylor first-order approximation for importance scoring.
"""
def __init__(self, model: nn.Module):
self.model = model
self.head_importances: Dict[int, torch.Tensor] = {}
self.hooks = []
def register_hooks(self):
"""Register forward hooks to capture activation statistics."""
for layer_idx, block in enumerate(self.model.blocks):
hook = block.attn.register_forward_hook(
self._make_attn_hook(layer_idx)
)
self.hooks.append(hook)
def _make_attn_hook(self, layer_idx: int):
def hook(module, input, output):
# Capture attention weights [B, heads, N, N]
# We use the output variance as a proxy for head importance
if hasattr(module, 'attn_weights') and module.attn_weights is not None:
attn = module.attn_weights # [B, heads, N, N]
# Entropy-based importance: lower entropy = more focused attention
attn_flat = attn.view(attn.shape[0], attn.shape[1], -1)
entropy = -(attn_flat * (attn_flat + 1e-9).log()).sum(dim=-1).mean(0)
# Low entropy = sharp, informative attention = high importance
importance = 1.0 / (entropy + 1e-9)
if layer_idx not in self.head_importances:
self.head_importances[layer_idx] = importance.detach().cpu()
else:
self.head_importances[layer_idx] += importance.detach().cpu()
return hook
def compute_head_importance_taylor(
self,
model,
dataloader,
n_batches: int = 50,
device: str = "cuda"
) -> Dict[int, torch.Tensor]:
"""
Taylor-expansion importance: |grad * weight| for each head's value projection.
More accurate than entropy-based, requires backprop.
"""
model.train()
layer_head_importance = {}
for batch_idx, (imgs, labels) in enumerate(dataloader):
if batch_idx >= n_batches:
break
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
loss = nn.CrossEntropyLoss()(logits, labels)
loss.backward()
for layer_idx, block in enumerate(model.blocks):
attn = block.attn
# Value projection weight: [d_model, d_model]
v_weight = attn.qkv.weight[2*attn.head_dim*attn.num_heads:]
if v_weight.grad is not None:
# Taylor: |g * w| summed over each head's slice
n_heads = attn.num_heads
head_dim = attn.head_dim
importance = (v_weight.grad * v_weight).abs()
importance = importance.view(n_heads, head_dim, -1).sum(dim=(1, 2))
if layer_idx not in layer_head_importance:
layer_head_importance[layer_idx] = importance.detach().cpu()
else:
layer_head_importance[layer_idx] += importance.detach().cpu()
model.zero_grad()
return layer_head_importance
def prune_heads(
self,
head_importance: Dict[int, torch.Tensor],
prune_ratio: float = 0.3
) -> Dict[int, List[int]]:
"""
Prune the least important attention heads.
Returns dict: layer_idx -> list of pruned head indices.
"""
all_scores = []
for layer_idx, scores in head_importance.items():
for head_idx, score in enumerate(scores):
all_scores.append((score.item(), layer_idx, head_idx))
all_scores.sort(key=lambda x: x[0]) # Sort by importance (ascending)
n_to_prune = int(len(all_scores) * prune_ratio)
pruned = all_scores[:n_to_prune]
pruned_heads_per_layer: Dict[int, List[int]] = {}
for _, layer_idx, head_idx in pruned:
if layer_idx not in pruned_heads_per_layer:
pruned_heads_per_layer[layer_idx] = []
pruned_heads_per_layer[layer_idx].append(head_idx)
print(f"Pruned {n_to_prune}/{len(all_scores)} heads ({prune_ratio:.0%})")
for layer_idx in sorted(pruned_heads_per_layer):
print(f" Layer {layer_idx}: heads {pruned_heads_per_layer[layer_idx]}")
return pruned_heads_per_layer
# ===================================================================
# TOKEN PRUNING (Dynamic Token Reduction)
# ===================================================================
class DynamicTokenPruningViT(nn.Module):
"""
ViT with Dynamic Token Pruning (DToP, 2022).
Progressively removes uninformative patch tokens after each block.
The CLS token is never pruned. Reduces FLOPs without retraining.
"""
def __init__(self, base_model_name: str = "vit_base_patch16_224",
keep_ratio: float = 0.7):
super().__init__()
self.vit = timm.create_model(base_model_name, pretrained=True)
self.keep_ratio = keep_ratio
self.num_blocks = len(self.vit.blocks)
# How many tokens to keep at each block (linear decrease)
# e.g., keep_ratio=0.7 means keep 70% of tokens per block
self.keep_schedule = np.linspace(1.0, keep_ratio, self.num_blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Standard patch embedding + CLS token
x = self.vit.patch_embed(x)
B, N, D = x.shape
cls = self.vit.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1) # [B, N+1, D]
x = x + self.vit.pos_embed
x = self.vit.pos_drop(x)
for block_idx, block in enumerate(self.vit.blocks):
x, _ = block(x) if hasattr(block, 'forward') else (block(x), None)
# Token selection based on CLS-token attention scores
# (simplified: use L2 norm of token embeddings as proxy)
if block_idx < self.num_blocks - 1:
n_keep = max(1, int(x.shape[1] * self.keep_schedule[block_idx]))
cls_token = x[:, :1] # Always keep CLS
patch_tokens = x[:, 1:] # [B, N, D]
# Score patches by L2 norm (informative = large norm)
scores = patch_tokens.norm(dim=-1) # [B, N]
top_k = min(n_keep - 1, patch_tokens.shape[1])
_, top_idx = scores.topk(top_k, dim=-1) # [B, top_k]
# Gather selected patches
selected = patch_tokens.gather(
1, top_idx.unsqueeze(-1).expand(-1, -1, D)
)
x = torch.cat([cls_token, selected], dim=1)
x = self.vit.norm(x)
return self.vit.head(x[:, 0])
# Test: token reduction from 196 to ~50% patches
model_dtp = DynamicTokenPruningViT(keep_ratio=0.7)
img = torch.randn(2, 3, 224, 224)
out = model_dtp(img)
print(f"Dynamic Token Pruning output: {out.shape}") # [2, 1000]
print("Tokens at final block: ~70% of original (progressive reduction)")
SparseGPT and Wanda: One-Shot LLM Pruning
The emergence of large language models has driven a new class of pruning methods that work in a single pass without any retraining. SparseGPT (Frantar & Alistarh, 2023) and Wanda (Sun et al., 2023) can prune GPT-scale models (1B-175B parameters) to 50-60% sparsity in hours, with accuracy losses of less than 1-2% on standard benchmarks. These techniques are fundamentally different from traditional iterative pruning.
import torch
import torch.nn as nn
import math
# ===================================================================
# WANDA PRUNING (Simple but Effective)
# ===================================================================
# "A Simple and Effective Pruning Approach for Large Language Models"
# (Sun et al. 2023) - arXiv: 2306.11695
#
# Core insight: weight importance = |weight| * ||input_activation||
# This is faster than SparseGPT (no Hessian inversion) and matches
# SparseGPT quality on most benchmarks.
def wanda_score(
weight: torch.Tensor, # [out_features, in_features]
input_activations: torch.Tensor # [n_samples, in_features]
) -> torch.Tensor:
"""
Wanda importance score for each weight element.
Score = |W_ij| * ||X_j||_2 (weight magnitude * activation norm)
Critical insight: weights connected to high-magnitude activations
are more important regardless of their own magnitude.
"""
# Activation norm per input feature: [in_features]
activation_norms = input_activations.norm(p=2, dim=0)
# Broadcasting: score[i, j] = |W[i,j]| * ||X[j]||
scores = weight.abs() * activation_norms.unsqueeze(0)
return scores
def apply_wanda_pruning(
model: nn.Module,
calibration_data: torch.Tensor,
sparsity_ratio: float = 0.5,
n2m_sparsity: bool = False # NVIDIA 2:4 structured sparsity
) -> nn.Module:
"""
Apply Wanda pruning to all Linear layers.
Calibration data: [n_samples, seq_len, hidden_size] or [n_samples, hidden_size]
"""
hooks = {}
activations_cache = {}
# Register forward hooks to capture activations
def make_hook(name):
def hook(module, input, output):
# input[0]: [batch, seq, hidden] or [batch, hidden]
inp = input[0].detach()
if inp.dim() == 3:
inp = inp.view(-1, inp.shape[-1]) # Flatten batch*seq
if name not in activations_cache:
activations_cache[name] = inp
else:
activations_cache[name] = torch.cat([activations_cache[name], inp])
return hook
# Register hooks
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
hooks[name] = module.register_forward_hook(make_hook(name))
# Forward pass to collect activations
model.eval()
with torch.no_grad():
model(calibration_data)
# Remove hooks
for h in hooks.values():
h.remove()
# Apply Wanda scoring and pruning
total_pruned = 0
total_params = 0
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and name in activations_cache:
W = module.weight.data # [out, in]
X = activations_cache[name] # [n_samples, in]
scores = wanda_score(W, X)
if n2m_sparsity:
# NVIDIA 2:4 structured sparsity: keep 2 largest per group of 4
W_flat = W.view(-1, 4)
s_flat = scores.view(-1, 4)
_, keep_idx = s_flat.topk(2, dim=-1)
mask = torch.zeros_like(W_flat, dtype=torch.bool)
mask.scatter_(1, keep_idx, True)
mask = mask.view_as(W)
else:
# Standard unstructured sparsity
threshold = torch.quantile(scores.flatten(), sparsity_ratio)
mask = scores >= threshold
module.weight.data *= mask.float()
total_pruned += (~mask).sum().item()
total_params += W.numel()
actual_sparsity = total_pruned / total_params
print(f"Wanda pruning complete: {actual_sparsity:.1%} sparsity")
print(f"Pruned {total_pruned:,} / {total_params:,} weights")
return model
# ===================================================================
# SPARSEGPT OVERVIEW (Simplified)
# ===================================================================
# SparseGPT is more complex: it uses the inverse of the Hessian matrix
# (second-order information) to compensate pruning error column-by-column.
# After removing weight W[i,j], it adjusts remaining weights in the same
# row to minimize the layer output reconstruction error.
#
# Pseudocode:
# for each row i in W:
# for j in pruning_order:
# if j should be pruned:
# err = W[i,j] / H_inv[j,j] (H_inv = inverse Hessian)
# W[i, j+1:] -= err * H_inv[j, j+1:] (compensate)
# W[i,j] = 0
#
# Key advantages over Wanda:
# - Better at very high sparsity (70-80%+)
# - Can produce exact unstructured or N:M sparsity
# - Handles severe outlier distributions better
# Reference: Frantar & Alistarh, 2023 (arXiv:2301.00774)
print("Wanda/SparseGPT comparison:")
print(f" Wanda: O(n) per layer, no Hessian, ~1% accuracy gap vs SparseGPT")
print(f" SparseGPT: O(n^2.4) per layer, uses Hessian, best quality at 60-80% sparsity")
print(f" For 50% sparsity: both are nearly equivalent")
print(f" For 70%+ sparsity: SparseGPT outperforms Wanda by 2-5 perplexity points")
Pruning in 2025-2026: State of the Art
The field of pruning has evolved significantly with the rise of LLMs. The main trends in 2025-2026 include:
- SparseGPT and Wanda: One-shot pruning methods for LLMs that don't require retraining. SparseGPT (Frantar & Alistarh, 2023) uses the approximate inverse of the Hessian matrix to update remaining weights, compensating for pruning error. Wanda (Sun et al., 2023) uses the product of weight magnitude and input activation norms as the pruning criterion.
- 2:4 Sparsity (NVIDIA): Hardware-supported structured sparsity pattern on Ampere and Hopper GPUs: exactly 2 non-zero values every 4 elements. Produces ~1.5-2x speedups in sparse operations on A100/H100 with accuracy nearly identical to the dense model.
- CORP (2025): Closed-Form One-shot Representation-Preserving Structured Pruning for Vision Transformers — scales from DeiT-Tiny to DeiT-Huge with real hardware speedups and minimal accuracy loss.
- Pruning + Distillation: Combining pruning with knowledge distillation produces the best results: the pruned model is trained with supervision from the original teacher model, recovering much of the lost accuracy.
Conclusions
Neural network pruning is one of the most mature and versatile compression techniques in deep learning. Understanding the distinction between structured and unstructured pruning is fundamental: the former produces real speedups on standard hardware, the latter requires specific sparsity support but offers greater flexibility.
Iterative pruning with retraining remains the gold standard for result quality. The Lottery Ticket Hypothesis provides a fundamental theoretical view of why pruning works, though it has practical limitations for very large models. For modern LLMs, methods like SparseGPT and Wanda offer practical one-shot alternatives.
The combination of pruning + quantization is the main path to maximum compression: reducing the number of parameters and their numerical precision in a complementary way allows achieving models with 10-15x smaller footprint than the starting point, while maintaining acceptable accuracy for most production use cases.
Next Steps
- Next article: Ollama: Running Local LLMs on Laptop and Raspberry Pi
- Previous article: Knowledge Distillation: Compress Models Efficiently
- Related: Model Quantization: GPTQ, AWQ, INT8
- Related: Fine-Tuning with LoRA and QLoRA
- MLOps series: Model Serving and Deployment







