Knowledge Distillation: Compressing Complex Models
GPT-4 is too large to run on your laptop. A ResNet-152 is too slow for your mobile device. Yet the knowledge acquired by these massive models through weeks of training on GPU clusters can be transferred to much smaller and faster models with surprisingly modest accuracy loss. This is the promise of Knowledge Distillation.
Proposed by Hinton, Vinyals, and Dean in 2015, distillation has become one of the most powerful and versatile techniques in modern deep learning. The principle is elegant: a large model (teacher) guides the training of a smaller model (student) not just with hard labels (0/1) but with the teacher's soft probabilities — rich probability distributions containing information about class similarities. This implicit information — the "dark knowledge" — is what makes distillation so effective.
In this guide we explore distillation in depth: from the original theory to modern variants (feature distillation, attention transfer, self-distillation, LLM distillation), with complete PyTorch implementations, production best practices, and a real-world case study.
What You'll Learn
- Distillation theory: soft labels, temperature, and dark knowledge
- Complete standard distillation implementation with PyTorch
- Feature Distillation: transferring intermediate representations
- Attention Transfer: distilling Transformer attention maps
- Self-Distillation and Born Again Networks
- Offline vs online vs self distillation
- LLM distillation: from large models to edge-ready models
- Combining distillation with quantization and pruning
- Best practices, common mistakes, and evaluation metrics
- Case study: DistilBERT step by step
The Distillation Principle: Dark Knowledge
A standard classifier trained on hard labels (one-hot) uses very sparse information: the correct class has probability 1, all others 0. But a well-trained model knows much more: it knows that a cat is more similar to a dog than to a car. This information is contained in the teacher's probability distribution, even when the correct answer has probability close to 1.
The trick of distillation is using a temperature T to "soften" the teacher's probabilities, amplifying the differences between improbable but informative classes. With T=1 you get the original distribution; with high T (e.g., T=4) the probabilities become more uniform, revealing implicit class similarity relationships. This mechanism is called dark knowledge — the hidden knowledge in model logits that a binary label cannot capture.
# Visualization of the temperature effect on dark knowledge
import torch
import torch.nn.functional as F
import numpy as np
# Suppose the teacher produces these logits for a sample
# with true class = 0 (cat)
teacher_logits = torch.tensor([8.2, 2.1, 1.8, 0.5, 0.3, -0.2, -0.5, -0.8, -1.1, -1.5])
# Hypothetical classes: [cat, dog, feline, lion, fox, car, plane, ship, train, boat]
classes = ["cat", "dog", "feline", "lion", "fox", "car", "plane", "ship", "train", "boat"]
print("Temperature effect on soft probabilities:")
print("-" * 75)
for T in [1, 2, 4, 8, 20]:
probs = F.softmax(teacher_logits / T, dim=0)
entropy = -(probs * probs.log()).sum().item()
print(f"T={T:2d}: p(cat)={probs[0]:.4f}, "
f"p(dog)={probs[1]:.4f}, p(feline)={probs[2]:.4f}, "
f"entropy={entropy:.3f}")
# Output:
# T= 1: p(cat)=0.9833, p(dog)=0.0106, p(feline)=0.0079, entropy=0.123
# T= 2: p(cat)=0.9175, p(dog)=0.0440, p(feline)=0.0297, entropy=0.424
# T= 4: p(cat)=0.7562, p(dog)=0.1168, p(feline)=0.0913, entropy=0.895
# T= 8: p(cat)=0.5756, p(dog)=0.1572, p(feline)=0.1343, entropy=1.387
# T=20: p(cat)=0.3520, p(dog)=0.1668, p(feline)=0.1600, entropy=1.944
#
# With high T, the teacher reveals that "cat" is very similar to "dog" and "feline"
# but nothing like "car" or "plane".
# This is the DARK KNOWLEDGE the student learns!
print("\nDark knowledge analysis:")
probs_t1 = F.softmax(teacher_logits / 1, dim=0)
probs_t4 = F.softmax(teacher_logits / 4, dim=0)
for i, cls in enumerate(classes):
print(f" {cls:8s}: T=1: {probs_t1[i]:.4f}, T=4: {probs_t4[i]:.4f}")
The Mathematics of Distillation
The distillation loss combines two terms with a balancing hyperparameter alpha:
- L_distill: KL divergence between the student's and teacher's soft probabilities (at temperature T), multiplied by T² to compensate for gradient magnitude reduction
- L_student: standard cross-entropy between student predictions and hard labels
The final formula is: L = alpha * T² * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
# ============================================================
# DISTILLATION LOSS - Complete implementation
# ============================================================
class DistillationLoss(nn.Module):
"""
Loss for Knowledge Distillation (Hinton et al., 2015).
L = alpha * T^2 * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
The T^2 factor is critical: when T > 1, gradients are scaled by 1/T^2.
Multiplying by T^2 compensates this, keeping gradient scales
consistent between KD loss and CE loss.
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
"""
temperature: scales soft probabilities (typical: 2-8)
alpha: weight of distillation loss (typical: 0.5-0.9)
"""
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor) -> dict:
# Soft probabilities at temperature T
# NOTE: log_softmax for student (required by KLDivLoss)
student_soft = F.log_softmax(student_logits / self.T, dim=1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=1)
# KL divergence * T^2 to compensate gradient reduction
loss_distill = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
# Standard cross-entropy with hard labels
loss_student = self.ce_loss(student_logits, labels)
# Weighted combination
total_loss = self.alpha * loss_distill + (1 - self.alpha) * loss_student
return {
'total': total_loss,
'distill': loss_distill.detach(),
'student': loss_student.detach()
}
# ============================================================
# TEACHER AND STUDENT MODELS
# ============================================================
def create_teacher_student(n_classes: int = 100):
"""
Teacher: ResNet-50 (~25M parameters) - pre-trained on ImageNet
Student: MobileNetV3-Small (~2.5M parameters) - 10x smaller
"""
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, n_classes)
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, n_classes
)
total_teacher = sum(p.numel() for p in teacher.parameters())
total_student = sum(p.numel() for p in student.parameters())
flops_teacher = 4.1e9 # Approx FLOPs ResNet-50
flops_student = 0.056e9 # Approx FLOPs MobileNetV3-S
print(f"Teacher (ResNet-50): {total_teacher:,} params, {flops_teacher/1e9:.1f}G FLOPs")
print(f"Student (MobileNetV3): {total_student:,} params, {flops_student*1000:.0f}M FLOPs")
print(f"Compression factor: {total_teacher/total_student:.1f}x params, "
f"{flops_teacher/flops_student:.0f}x FLOPs")
return teacher, student
# ============================================================
# TRAINING LOOP WITH DISTILLATION
# ============================================================
def train_with_distillation(
teacher: nn.Module,
student: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 50,
temperature: float = 4.0,
alpha: float = 0.7,
lr: float = 1e-3,
device: str = "cuda"
):
teacher = teacher.to(device).eval() # Teacher: ONLY inference, no backprop!
student = student.to(device)
criterion = DistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
best_acc = 0.0
history = {'train_loss': [], 'val_acc': [], 'distill_loss': [], 'student_loss': []}
for epoch in range(n_epochs):
student.train()
total_loss = distill_loss_sum = student_loss_sum = 0.0
n_batches = 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
# Teacher forward pass (CRITICAL: no gradients, saves memory!)
with torch.no_grad():
teacher_logits = teacher(imgs)
# Student forward pass (with gradients)
student_logits = student(imgs)
# Combined loss
losses = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
optimizer.step()
total_loss += losses['total'].item()
distill_loss_sum += losses['distill'].item()
student_loss_sum += losses['student'].item()
n_batches += 1
scheduler.step()
# Validation
student.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = student(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_total = total_loss / n_batches
history['train_loss'].append(avg_total)
history['val_acc'].append(val_acc)
if val_acc > best_acc:
best_acc = val_acc
torch.save(student.state_dict(), "best_student.pth")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {avg_total:.4f} "
f"| Val Acc: {val_acc:.4f} | Best: {best_acc:.4f}")
print(f"\nBest student accuracy: {best_acc:.4f}")
return history, best_acc
# Typical CIFAR-100 results:
# ResNet-50 teacher: 78.2% Top-1
# MobileNetV3-S without KD: 67.1% Top-1
# MobileNetV3-S with KD: 71.4% Top-1 (+4.3%)
# MobileNetV3-S with KD+feat: 73.2% Top-1 (+6.1%)
# Compression: 10x params, 73x FLOPs
Feature Distillation: Transferring Internal Representations
Soft-label distillation transfers only the final output of the teacher. Feature Distillation goes further: it forces the student to also replicate the teacher's intermediate representations — feature maps at different network levels. This is particularly effective when teacher and student have very different architectures (e.g., CNN teacher, ViT student) and when the task requires rich spatial features.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# ============================================================
# FEATURE EXTRACTOR via Forward Hooks
# ============================================================
class FeatureExtractor:
"""Captures features from specific layers via forward hooks."""
def __init__(self, model: nn.Module, layer_names: list):
self.features = {}
self.hooks = []
for name, module in model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, n=name: self.features.update({n: out})
)
self.hooks.append(hook)
def get_features(self) -> list:
return list(self.features.values())
def clear(self):
self.features.clear()
def remove(self):
"""Remove hooks to avoid memory leaks."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
# ============================================================
# FEATURE DISTILLATION LOSS
# ============================================================
class FeatureDistillationLoss(nn.Module):
"""
Loss combining:
1. Standard KD loss (soft label output)
2. Feature Matching Loss (MSE between normalized intermediate features)
3. Relation-Based Loss (relative distances between batch samples)
"""
def __init__(self, student_channels: list, teacher_channels: list,
temperature: float = 4.0, alpha: float = 0.4,
beta: float = 0.4, gamma: float = 0.2):
"""
alpha: KD loss weight
beta: feature matching loss weight
gamma: standard CE loss weight
(alpha + beta + gamma must equal 1.0)
"""
super().__init__()
assert abs(alpha + beta + gamma - 1.0) < 1e-6, "Weights must sum to 1"
self.T = temperature
self.alpha = alpha
self.beta = beta
self.gamma = gamma
# Adapters to align teacher->student dimensions
# Example: teacher has 2048 channels, student 96 -> 1x1 conv adapter
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(t_ch, s_ch, 1, bias=False),
nn.BatchNorm2d(s_ch),
nn.ReLU(inplace=True)
)
for t_ch, s_ch in zip(teacher_channels, student_channels)
])
def forward(self, student_logits, teacher_logits, labels,
student_features: list, teacher_features: list):
# 1. KD Loss (soft labels)
kl = nn.KLDivLoss(reduction='batchmean')
loss_kd = kl(
F.log_softmax(student_logits / self.T, dim=1),
F.softmax(teacher_logits / self.T, dim=1)
) * self.T ** 2
# 2. Standard CE Loss
loss_ce = F.cross_entropy(student_logits, labels)
# 3. Feature Matching Loss
loss_feat = torch.tensor(0.0, device=student_logits.device)
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Adapt teacher channels to student channels
t_adapted = self.adapters[i](t_feat.detach())
# Align spatial resolution if necessary
if s_feat.shape[2:] != t_adapted.shape[2:]:
t_adapted = F.interpolate(
t_adapted, size=s_feat.shape[2:],
mode='bilinear', align_corners=False
)
# Normalize features (cosine similarity instead of raw MSE)
s_norm = F.normalize(s_feat.view(s_feat.size(0), -1), dim=1)
t_norm = F.normalize(t_adapted.view(t_adapted.size(0), -1), dim=1)
# MSE between normalized features
loss_feat = loss_feat + F.mse_loss(s_norm, t_norm)
loss_feat = loss_feat / max(len(student_features), 1)
total = self.alpha * loss_kd + self.beta * loss_feat + self.gamma * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'feat': loss_feat.detach()
}
# Configuration for ResNet-50 teacher -> MobileNetV3-S student
# Teacher layers: [layer2, layer3, layer4] -> Channels: [512, 1024, 2048]
# Student layers: [features.4, features.9, features.12] -> Channels: [40, 96, 576]
teacher = models.resnet50(pretrained=True)
student = models.mobilenet_v3_small(pretrained=False)
teacher_layers = ['layer2', 'layer3', 'layer4']
student_layers = ['features.4', 'features.9', 'features.12']
teacher_channels = [512, 1024, 2048]
student_channels = [40, 96, 576]
teacher_extractor = FeatureExtractor(teacher, teacher_layers)
student_extractor = FeatureExtractor(student, student_layers)
feat_criterion = FeatureDistillationLoss(
student_channels=student_channels,
teacher_channels=teacher_channels,
temperature=4.0, alpha=0.4, beta=0.4, gamma=0.2
)
print("Feature Distillation setup complete!")
print(f"Teacher layers: {teacher_layers}")
print(f"Student layers: {student_layers}")
Attention Transfer for Transformers and Vision Transformers
Vision Transformers produce explicit attention maps that can be directly distilled. DeiT (Data-efficient Image Transformer) uses this approach with a special distillation token. Attention Transfer (Zagoruyko & Komodakis, 2017) extends the concept to CNNs by constructing attention maps from convolutional layer activations.
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# ============================================================
# ATTENTION TRANSFER (AT) for CNNs
# ============================================================
class AttentionTransferLoss(nn.Module):
"""
Attention Transfer (Zagoruyko & Komodakis, 2017).
Forces the student to replicate the teacher's attention maps.
Effective for transfer between different architectures (CNN <-> ViT).
"""
def __init__(self, beta: float = 1000.0):
super().__init__()
self.beta = beta
def attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""
Computes attention map as squared L2 norm of activations.
Input features: [B, C, H, W]
Output: [B, H*W] normalized (flattened attention map)
"""
# Sum over channels -> [B, H, W]
attention = features.pow(2).sum(dim=1)
# Flatten -> [B, H*W]
attention = attention.view(attention.size(0), -1)
# L2 normalize for each sample in batch
return F.normalize(attention, p=2, dim=1)
def forward(self, student_features: list, teacher_features: list) -> torch.Tensor:
"""Compute AT loss across multiple levels."""
total_loss = torch.tensor(0.0)
for s_feat, t_feat in zip(student_features, teacher_features):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat).detach()
# Align spatial dimensions if necessary
if s_attn.shape != t_attn.shape:
s_2d = s_feat.pow(2).mean(1, keepdim=True)
t_2d = t_feat.pow(2).mean(1, keepdim=True)
t_2d = F.interpolate(t_2d, size=s_feat.shape[2:], mode='bilinear')
s_attn = F.normalize(s_2d.view(s_2d.size(0), -1), p=2, dim=1)
t_attn = F.normalize(t_2d.view(t_2d.size(0), -1), p=2, dim=1).detach()
total_loss = total_loss + (s_attn - t_attn).pow(2).mean()
return self.beta * total_loss / max(len(student_features), 1)
# ============================================================
# DeiT-STYLE: Distillation Token for Vision Transformer
# ============================================================
class ViTWithDistillationToken(nn.Module):
"""
Adds a distillation token to a standard ViT.
As in DeiT: the token learns to replicate the predictions
of a CNN teacher (e.g., RegNet, ResNet).
During inference: average CLS token + dist token.
During training: loss on both tokens.
"""
def __init__(self, vit_model: nn.Module, n_classes: int, d_model: int = 384):
super().__init__()
self.vit = vit_model
# Learnable distillation token
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.dist_token, std=0.02)
# Distillation head (separate from CLS head)
self.dist_head = nn.Linear(d_model, n_classes)
def forward(self, x: torch.Tensor, return_dist: bool = False):
# Get features from ViT
features = self.vit.forward_features(x)
# CLS prediction (main prediction)
cls_pred = self.vit.head(features[:, 0])
# Dist token prediction (teacher guidance)
dist_pred = self.dist_head(features[:, 1])
if self.training:
return cls_pred, dist_pred
else:
# Inference: average both predictions
return (cls_pred + dist_pred) / 2.0
def deit_distillation_loss(cls_pred, dist_pred, teacher_pred, labels,
alpha: float = 0.5, temperature: float = 3.0):
"""
DeiT loss: combines hard label CE + soft KD from CNN teacher.
"""
# Hard label loss on CLS token
loss_cls = F.cross_entropy(cls_pred, labels)
# Soft label loss on distillation token
loss_dist = F.kl_div(
F.log_softmax(dist_pred / temperature, dim=1),
F.softmax(teacher_pred / temperature, dim=1),
reduction='batchmean'
) * temperature ** 2
return alpha * loss_dist + (1 - alpha) * loss_cls
LLM Distillation: From Large to Small Models
LLM distillation follows the same principles but with some important specifics. The vocabulary is huge (32K-128K tokens), teacher and student must share the same tokenizer, and the loss operates at each token in the sequence. DistilBERT, DistilGPT2, and Microsoft's Phi family are successful examples.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
# ============================================================
# LLM DISTILLATION: causal language modeling
# ============================================================
def distill_llm_batch(
teacher_model,
student_model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 2.0,
alpha: float = 0.7
) -> dict:
"""
LLM distillation for next-token prediction.
Works for GPT-style (causal) and BERT-style (masked).
teacher_model: large model (e.g., Llama-3-8B)
student_model: small model (e.g., Llama-3-1B)
alpha: KD loss weight (1-alpha = standard CE loss weight)
"""
device = next(student_model.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Teacher inference (no gradients, can be on different device)
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids, attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits # [B, seq_len, vocab_size]
# Student inference (with gradients)
student_outputs = student_model(
input_ids, attention_mask=attention_mask
)
student_logits = student_outputs.logits
# Shift for next-token prediction
shift_student = student_logits[:, :-1, :].contiguous()
shift_teacher = teacher_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# Reshape for per-token loss
B, S, V = shift_student.shape
shift_student_flat = shift_student.view(B * S, V)
shift_teacher_flat = shift_teacher.view(B * S, V)
shift_labels_flat = shift_labels.view(B * S)
# 1. KD Loss: KL divergence per token
student_log_probs = F.log_softmax(shift_student_flat / temperature, dim=-1)
teacher_probs = F.softmax(shift_teacher_flat / temperature, dim=-1)
loss_kd = F.kl_div(student_log_probs, teacher_probs,
reduction='batchmean') * temperature ** 2
# 2. Standard CE Loss (with -100 for padding tokens)
loss_ce = F.cross_entropy(
shift_student_flat, shift_labels_flat,
ignore_index=-100
)
total = alpha * loss_kd + (1 - alpha) * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'perplexity': torch.exp(loss_ce).detach()
}
# ============================================================
# FULL LLM DISTILLATION PIPELINE
# ============================================================
def setup_llm_distillation(
teacher_name: str = "meta-llama/Llama-3.1-8B",
student_name: str = "meta-llama/Llama-3.2-1B",
device: str = "cuda"
):
"""
Setup for distilling a large LLM into a smaller one.
IMPORTANT: teacher and student must share the same tokenizer
to have compatible distributions over the same vocabulary.
"""
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Teacher: load in FP16 to save memory
teacher = AutoModelForCausalLM.from_pretrained(
teacher_name,
torch_dtype=torch.float16,
device_map="auto" # Distributes across multiple GPUs if available
)
teacher.eval()
# Student: load in FP32 for stable training
student = AutoModelForCausalLM.from_pretrained(
student_name,
torch_dtype=torch.float32
).to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {teacher_params/1e9:.1f}B parameters")
print(f"Student: {student_params/1e9:.1f}B parameters")
print(f"Compression: {teacher_params/student_params:.1f}x")
return teacher, student, tokenizer
print("LLM distillation setup complete!")
Self-Distillation and Born Again Networks
Self-distillation is a surprising variant: the model acts as its own teacher. In Born Again Networks (BANs, Furlanello et al. 2018), successive generations of models are trained with the same architecture: each generation uses the previous as teacher. The result is systematic improvement (+1-2% Top-1) without increasing model size.
import torch
import torch.nn as nn
import copy
# ============================================================
# BORN AGAIN NETWORKS (BANs)
# ============================================================
def born_again_training(model_factory, train_loader, val_loader,
n_generations: int = 3,
temperature: float = 4.0,
n_epochs: int = 30,
device: str = "cuda"):
"""
Trains N generations with the same architecture.
Gen 1: standard training with CE loss.
Gen 2+: distillation from the previous generation.
Typical CIFAR-100 results:
Gen 1: 67.1% (baseline)
Gen 2: 70.8% (+3.7%)
Gen 3: 72.1% (+5.0%)
Ensemble 1+2+3: 74.8% (+7.7%)
"""
criterion_kd = DistillationLoss(temperature=temperature, alpha=0.7)
criterion_ce = nn.CrossEntropyLoss()
all_models = []
results = []
# === Generation 1: standard training ===
print("Gen 1: standard training...")
gen1 = model_factory().to(device)
opt1 = torch.optim.SGD(gen1.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch1 = torch.optim.lr_scheduler.MultiStepLR(opt1, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
gen1.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
opt1.zero_grad()
criterion_ce(gen1(imgs), labels).backward()
opt1.step()
sch1.step()
acc1 = _evaluate(gen1, val_loader, device)
results.append(acc1)
all_models.append(gen1)
print(f" Gen 1 val acc: {acc1:.4f}")
teacher = copy.deepcopy(gen1).eval()
# === Subsequent generations with distillation ===
for gen_idx in range(2, n_generations + 1):
print(f"Gen {gen_idx}: KD from gen {gen_idx-1}...")
student = model_factory().to(device)
opt = torch.optim.SGD(student.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
student.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
with torch.no_grad():
t_logits = teacher(imgs)
s_logits = student(imgs)
losses = criterion_kd(s_logits, t_logits, labels)
opt.zero_grad()
losses['total'].backward()
opt.step()
sch.step()
acc = _evaluate(student, val_loader, device)
results.append(acc)
all_models.append(student)
print(f" Gen {gen_idx} val acc: {acc:.4f}")
teacher = copy.deepcopy(student).eval()
# Ensemble of all models (upper bound)
ensemble_acc = _ensemble_evaluate(all_models, val_loader, device)
print(f"\nEnsemble {n_generations} gens: {ensemble_acc:.4f}")
return results, all_models
def _evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
correct += (model(imgs).argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
def _ensemble_evaluate(models, loader, device):
"""Ensemble averaging of predictions."""
for m in models:
m.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
logits_sum = torch.stack([m(imgs) for m in models]).mean(0)
correct += (logits_sum.argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
Typical Distillation Results (2024-2025 Benchmarks)
| Task | Teacher | Student | Without KD | With KD | Teacher | Compression |
|---|---|---|---|---|---|---|
| CIFAR-100 | ResNet-50 | MobileNetV3-S | 67.1% | 73.2% | 78.2% | 10x params |
| ImageNet | ViT-L/16 | DeiT-S | 79.8% | 83.1% | 87.1% | 5x params |
| GLUE (NLP) | BERT-Large | DistilBERT | 83.2% | 86.4% | 89.2% | 2x params, 2x speed |
| SQuAD (QA) | RoBERTa-L | DistilRoBERTa | 82.1% | 85.8% | 90.4% | 2x params |
| LLM (perplexity) | Llama 3.1 8B | Llama 3.2 1B | 8.24 PPL | 7.81 PPL | 6.12 PPL | 8x params |
KD typically recovers 70-85% of the gap between student and teacher with 2-10x fewer parameters.
Production Pipeline: Distillation + Quantization
The most powerful workflow for edge deployment combines distillation and quantization sequentially: first create the student with KD (maintains high accuracy), then quantize the student (reduces size and increases speed). The combination can reduce a model from 100x compared to the original teacher with only 5-10% accuracy loss.
import torch
import torch.nn as nn
from torchvision import models
import os
# ============================================================
# FULL PIPELINE: Distillation -> Quantization -> ONNX
# ============================================================
def full_compression_pipeline(output_dir: str = "./compressed"):
"""
Complete pipeline for compressing a model for edge deployment.
Step 1: Load pre-trained teacher
Step 2: Distill into smaller student
Step 3: Quantize student (PTQ INT8)
Step 4: Export to ONNX for cross-platform deployment
"""
os.makedirs(output_dir, exist_ok=True)
# STEP 1: Teacher
print("Step 1: Loading teacher...")
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(2048, 10)
teacher.eval()
teacher_size_mb = sum(p.numel() * p.element_size()
for p in teacher.parameters()) / (1024**2)
print(f" Teacher: {teacher_size_mb:.1f} MB, "
f"{sum(p.numel() for p in teacher.parameters())/1e6:.1f}M params")
# STEP 2: Student (after distillation)
print("Step 2: Student with distillation (simulated with MobileNetV3)...")
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, 10
)
student_size_mb = sum(p.numel() * p.element_size()
for p in student.parameters()) / (1024**2)
print(f" Student: {student_size_mb:.1f} MB, "
f"{sum(p.numel() for p in student.parameters())/1e6:.1f}M params")
print(f" Reduction vs teacher: {teacher_size_mb/student_size_mb:.1f}x")
# STEP 3: INT8 Quantization (PTQ)
print("Step 3: INT8 quantization...")
student.eval()
student_quant = torch.quantization.quantize_dynamic(
student,
{nn.Linear},
dtype=torch.qint8
)
print(f" Student INT8: ~{student_size_mb/4:.1f} MB (estimate)")
print(f" Total reduction: ~{teacher_size_mb/(student_size_mb/4):.0f}x vs teacher")
# STEP 4: ONNX Export
print("Step 4: ONNX export...")
dummy = torch.randn(1, 3, 224, 224)
onnx_path = f"{output_dir}/student_compressed.onnx"
torch.onnx.export(
student,
dummy,
onnx_path,
opset_version=13,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}},
export_params=True
)
print(f"\n=== PIPELINE SUMMARY ===")
print(f"Teacher (ResNet-50): {teacher_size_mb:.1f} MB")
print(f"Student KD (MobNetV3): {student_size_mb:.1f} MB ({teacher_size_mb/student_size_mb:.1f}x reduction)")
print(f"Student INT8 (est.): {student_size_mb/4:.1f} MB ({teacher_size_mb/(student_size_mb/4):.0f}x reduction)")
print(f"ONNX saved: {onnx_path}")
return student_quant, onnx_path
full_compression_pipeline()
Anti-Patterns in Distillation: Common Mistakes
- Temperature too high or too low: T=1 is equivalent to hard labels. Too high T (>20) makes soft labels nearly uniform, losing the signal. Always run an ablation study with T ∈ {2, 4, 6, 8} on your specific dataset.
- Teacher and student too different in capacity: if the gap is enormous (GPT-4 to 7B), direct distillation is ineffective. Use cascade distillation: GPT-4 -> 13B -> 7B -> 3B. Each step should not exceed 4-5x reduction.
- Ignoring distillation dataset quality: the quality of the dataset used for distillation enormously impacts results. Use diverse, representative data for the target task. Out-of-distribution data damages the transfer.
- Poorly calibrated alpha: with alpha=1 (only KD) the student ignores ground-truth labels and can generate unstable predictions when the teacher makes mistakes. Values of 0.5-0.8 are typically optimal.
- Not freezing the teacher: the teacher must be in eval() mode during student training. If the teacher keeps changing (e.g., has BatchNorm in train mode), the distillation targets are inconsistent and training can diverge.
LLM Distillation: Response Distillation and Speculative Decoding
For Large Language Models, distillation takes new and powerful forms. Two particularly relevant techniques in 2025-2026 are Response Distillation (used to train Llama-3.2, Microsoft's Phi models, and Mistral-7B-Instruct) and Speculative Decoding, which uses a small draft model to accelerate inference of the large target model without any quality loss.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
# ============================================================
# SPECULATIVE DECODING: Draft Model + Target Model
# ============================================================
# Principle: a small draft model generates K tokens ahead.
# The large target model verifies all K tokens in one forward pass.
# If the draft is correct, we save K-1 forward passes of the large model.
# Typical speedup: 2-4x with no quality loss.
class SpeculativeDecoder:
"""
Basic speculative decoding implementation.
Draft model: small model (e.g., Llama-3.2-1B)
Target model: large model (e.g., Llama-3.1-8B)
Reference: "Fast Inference from Transformers via Speculative Decoding"
(Leviathan et al., 2022) - original Google paper.
"""
def __init__(
self,
draft_model_name: str,
target_model_name: str,
device: str = "cuda",
lookahead_k: int = 5 # Tokens generated by draft per step
):
self.device = device
self.lookahead_k = lookahead_k
print(f"Loading draft model: {draft_model_name}...")
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=torch.float16
).to(device).eval()
print(f"Loading target model: {target_model_name}...")
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=torch.float16
).to(device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
def draft_generate(self, input_ids: torch.Tensor) -> tuple:
"""
Draft model generates K tokens and returns probability
distributions for acceptance/rejection sampling.
"""
draft_ids = input_ids.clone()
draft_probs = []
with torch.no_grad():
for _ in range(self.lookahead_k):
out = self.draft_model(draft_ids)
next_logits = out.logits[:, -1, :]
next_probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(next_probs, num_samples=1)
draft_probs.append(next_probs)
draft_ids = torch.cat([draft_ids, next_token], dim=1)
draft_tokens = draft_ids[:, input_ids.shape[1]:]
return draft_tokens, draft_probs
def speculative_generate(
self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.7
) -> str:
"""Generate text with speculative decoding."""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated = input_ids.clone()
total_accepted = 0
total_draft = 0
with torch.no_grad():
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
draft_tokens, draft_probs = self.draft_generate(generated)
total_draft += self.lookahead_k
full_seq = torch.cat([generated, draft_tokens], dim=1)
target_out = self.target_model(full_seq)
target_logits = target_out.logits[:, generated.shape[1]-1:-1, :]
n_accepted = 0
for i in range(draft_tokens.shape[1]):
draft_tok = draft_tokens[0, i].item()
target_probs = F.softmax(target_logits[0, i] / temperature, dim=-1)
draft_p = draft_probs[i][0, draft_tok].item()
target_p = target_probs[draft_tok].item()
r = torch.rand(1).item()
if r < min(1.0, target_p / (draft_p + 1e-10)):
n_accepted += 1
else:
corrected = torch.multinomial(
F.relu(target_probs - draft_probs[i][0]), num_samples=1
)
generated = torch.cat([
generated, draft_tokens[:, :i], corrected.unsqueeze(0)
], dim=1)
break
else:
generated = torch.cat([generated, draft_tokens], dim=1)
bonus_logits = target_out.logits[:, -1, :]
bonus_tok = torch.multinomial(
F.softmax(bonus_logits / temperature, dim=-1), 1
)
generated = torch.cat([generated, bonus_tok], dim=1)
total_accepted += n_accepted
acceptance_rate = total_accepted / max(total_draft, 1)
print(f"Acceptance rate: {acceptance_rate:.1%} (expected 60-80% with similar draft)")
new_tokens = generated[0, input_ids.shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# ============================================================
# RESPONSE DISTILLATION for LLMs
# ============================================================
# Technique behind Llama-3.2, Phi-3, Mistral-7B-Instruct:
# 1. Large teacher LLM (GPT-4, Llama-3.1-70B) generates responses
# 2. Small student LLM learns to imitate those responses
# Different from classic distillation: distills text responses (output),
# not internal probability distributions.
def response_distillation_dataset(
teacher_model_name: str,
prompts: list,
output_file: str = "distillation_dataset.jsonl"
) -> list:
"""
Generate distillation dataset with teacher responses.
In production: use GPT-4 API, Llama-3.1-70B, or Claude.
"""
import json
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
dataset = []
with open(output_file, 'w') as f:
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(teacher.device)
with torch.no_grad():
outputs = teacher.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response_ids = outputs[0, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
entry = {"prompt": prompt, "response": response}
dataset.append(entry)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"Dataset generated: {len(dataset)} examples -> {output_file}")
return dataset
# Note: in practice, use a commercial API (OpenAI, Anthropic) to generate
# teacher responses, then train the student on them.
# This is the technique behind most instruction-following models:
# Alpaca, Vicuna, Orca, and Microsoft's Phi models.
print("LLM distillation patterns ready")
Distillation Variants: Technical Comparison (2024-2025)
| Variant | What Is Distilled | Pros | Cons | Typical Use |
|---|---|---|---|---|
| Soft Labels (Hinton 2015) | Probability distributions | Rich information, standard | Requires access to teacher logits | Vision, classification |
| Feature Distillation | Intermediate representations | Deep feature transfer | Teacher and student must have compatible architecture | Detection, segmentation |
| Response Distillation | Text output of teacher | No internal access needed | Loses uncertainty information | LLM instruction-following |
| Born-Again Networks | Iterative self-distillation | No separate teacher required | Limited gain, high compute cost | Ensemble, improvement |
| Speculative Decoding | Not distillation, uses draft | 2-4x speedup, zero quality loss | Requires two models in memory | LLM inference acceleration |
Conclusions
Knowledge Distillation is one of the most powerful and versatile compression techniques available in 2026. It combines naturally with quantization and pruning: distill first to create the optimal student, then quantize the student for edge deployment. The result is often a model 10-100x smaller than the teacher with only 5-15% accuracy loss.
For LLMs, distillation has enabled the entire "Distil*" model family: DistilBERT, DistilGPT2, and Microsoft's Phi models (2.7B with 7B model quality). The 2026 trend — Small Language Models (SLMs) surpassing cloud LLMs in usage frequency according to Gartner — is made possible precisely by distillation, which transfers knowledge from giants to models that run on Raspberry Pi and smartphones.
The next article shows how to deploy these compressed models on edge devices: Raspberry Pi, NVIDIA Jetson, and embedded hardware, with all the optimizations needed for real production environments.
Next Steps
- Next article: Deep Learning on Edge Devices: From Cloud to Edge
- Related: Model Quantization: GPTQ, AWQ, INT8
- Related: Neural Network Pruning: Reducing Parameters
- Related: Vision Transformer: distillation with DeiT
- MLOps series: Serving Compressed Models in Production







