知識の蒸留: 複雑なモデルの圧縮
GPT-4 はラップトップで実行するには大きすぎます。 ResNet-152 はあなたのものには遅すぎます モバイルデバイス。しかし、これらの巨大なモデルから得られた知識は GPU クラスターでの数週間のトレーニングを通じて、 転送された モデルに はるかに小さくて高速ですが、精度の低下は驚くほどわずかです。 これが約束です 知識の蒸留.
2015年にヒントン、ビニャルズ、ディーンによって提案され、蒸留は技術の1つとなった 最新の深層学習ツールキットの中で最も強力で多用途です。原則はエレガントです: 大きなモデル(教師) より小さなモデル (学生) 単にハードラベル (0/1) を使用するのではなく、 ソフト確率 教師の確率分布 - クラス間の類似性に関する情報が豊富な確率分布。 この暗黙の情報、つまり「闇の知識」が蒸留の原因となります。 とても効果的です。
このガイドでは、元の理論からバリエーションまで、蒸留について詳しく説明します。 モダン (特徴蒸留、注意喚起、自己蒸留、LLM 蒸留)、 PyTorch での完全な実装、本番環境のベスト プラクティス、および実際のケース スタディ。
何を学ぶか
- 蒸留理論: ソフトラベル、温度、そして闇の知識
- PyTorch を使用した標準蒸留の実装
- 特徴の抽出: 中間表現の転送
- 注意転移: トランスフォーマーの注意マップを抽出する
- 自己蒸留とボーン・アゲイン・ネットワーク
- オフライン vs オンライン vs 自己蒸留
- LLM の蒸留: 大型モデルからエッジ モデルまで
- 蒸留と量子化および枝刈りを組み合わせる
- ベスト プラクティス、一般的なエラー、評価指標
- ケーススタディ: DitilBERT のステップバイステップ
蒸留の原理: 闇の知識
ハード ラベル (ワンホット) でトレーニングされた標準の分類子は、ほとんど情報を使用しません。 正しいクラスの確率は 1、他のクラスはすべて 0 です。しかし、よく訓練されたモデルは多くのことを知っています。 詳細: 彼は猫が車よりも犬に似ていることを知っています。この情報は、 たとえ答えが正しい場合でも、教師の確率分布に含まれる 確率はほぼ1です。
蒸留のコツは、次のものを使用することです 温度T 「柔らかくする」 教師の確率によって、可能性は低いが有益なクラス間の差異が増幅されます。 T=1 の場合、元の分布が得られます。 T が高い (例: T=4) と、確率が大きくなります。 均一であり、暗黙の類似関係を明らかにします。このメカニズムはと呼ばれます 闇の知識 — モデルのロジットに隠された知識。 単純なバイナリラベルはキャプチャできません。
# Visualizzazione effetto della temperatura sulla dark knowledge
import torch
import torch.nn.functional as F
import numpy as np
# Supponiamo che il teacher produca questi logits per un campione
# con classe vera = 0 (gatto)
teacher_logits = torch.tensor([8.2, 2.1, 1.8, 0.5, 0.3, -0.2, -0.5, -0.8, -1.1, -1.5])
# Classi ipotetiche: [gatto, cane, felino, leone, volpe, auto, aereo, nave, treno, barca]
classi = ["gatto", "cane", "felino", "leone", "volpe", "auto", "aereo", "nave", "treno", "barca"]
print("Effetto della temperatura sulle soft probabilities:")
print("-" * 75)
for T in [1, 2, 4, 8, 20]:
probs = F.softmax(teacher_logits / T, dim=0)
entropia = -(probs * probs.log()).sum().item()
print(f"T={T:2d}: p(gatto)={probs[0]:.4f}, "
f"p(cane)={probs[1]:.4f}, p(felino)={probs[2]:.4f}, "
f"entropia={entropia:.3f}")
# Output:
# T= 1: p(gatto)=0.9833, p(cane)=0.0106, p(felino)=0.0079, entropia=0.123
# T= 2: p(gatto)=0.9175, p(cane)=0.0440, p(felino)=0.0297, entropia=0.424
# T= 4: p(gatto)=0.7562, p(cane)=0.1168, p(felino)=0.0913, entropia=0.895
# T= 8: p(gatto)=0.5756, p(cane)=0.1572, p(felino)=0.1343, entropia=1.387
# T=20: p(gatto)=0.3520, p(cane)=0.1668, p(felino)=0.1600, entropia=1.944
#
# Con T alta, il teacher rivela che "gatto" e molto simile a "cane" e "felino"
# ma niente a che fare con "auto" o "aereo".
# Questa e la DARK KNOWLEDGE che lo student impara!
print("\nAnalisi della dark knowledge:")
probs_t1 = F.softmax(teacher_logits / 1, dim=0)
probs_t4 = F.softmax(teacher_logits / 4, dim=0)
for i, classe in enumerate(classi):
print(f" {classe:8s}: T=1: {probs_t1[i]:.4f}, T=4: {probs_t4[i]:.4f}")
蒸留の数学
蒸留損失は、アルファ バランシング ハイパーパラメーターを使用して 2 つの項を組み合わせます。
- L_蒸留: 生徒と教師のソフト確率 (温度 T) 間の KL 発散に、勾配サイズの減少を補償するために T² を乗じたもの
- L_学生: 学生の予測とハードラベルの間の標準クロスエントロピー
最終的な式は次のとおりです。 L = アルファ * T² * KL(学生ソフト || 教師ソフト) + (1-アルファ) * CE(学生、ラベル)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
# ============================================================
# DISTILLATION LOSS - Implementazione completa
# ============================================================
class DistillationLoss(nn.Module):
"""
Loss per Knowledge Distillation (Hinton et al., 2015).
L = alpha * T^2 * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
Il fattore T^2 e fondamentale: quando si usa T > 1,
i gradienti si riducono di 1/T^2. Moltiplicando per T^2
si compensa questo effetto, mantenendo scale dei gradienti
coerenti tra la KD loss e la CE loss.
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
"""
temperature: scala le probabilità soft (tipico: 2-8)
alpha: peso della distillation loss (tipico: 0.5-0.9)
"""
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor) -> dict:
# Soft probabilities a temperatura T
# NB: log_softmax per student (richiesto da KLDivLoss)
student_soft = F.log_softmax(student_logits / self.T, dim=1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=1)
# KL divergence * T^2 per compensare la riduzione del gradiente
loss_distill = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
# Cross-entropy standard con hard labels
loss_student = self.ce_loss(student_logits, labels)
# Combinazione pesata
total_loss = self.alpha * loss_distill + (1 - self.alpha) * loss_student
return {
'total': total_loss,
'distill': loss_distill.detach(),
'student': loss_student.detach()
}
# ============================================================
# MODELLI TEACHER e STUDENT
# ============================================================
def create_teacher_student(n_classes: int = 100):
"""
Teacher: ResNet-50 (~25M parametri) - pre-trained su ImageNet
Student: MobileNetV3-Small (~2.5M parametri) - 10x più piccolo
"""
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, n_classes)
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, n_classes
)
total_teacher = sum(p.numel() for p in teacher.parameters())
total_student = sum(p.numel() for p in student.parameters())
flops_teacher = 4.1e9 # Approx FLOPs ResNet-50
flops_student = 0.056e9 # Approx FLOPs MobileNetV3-S
print(f"Teacher (ResNet-50): {total_teacher:,} param, {flops_teacher/1e9:.1f}G FLOPs")
print(f"Student (MobileNetV3): {total_student:,} param, {flops_student*1000:.0f}M FLOPs")
print(f"Fattore compressione: {total_teacher/total_student:.1f}x param, "
f"{flops_teacher/flops_student:.0f}x FLOPs")
return teacher, student
# ============================================================
# TRAINING LOOP CON DISTILLAZIONE
# ============================================================
def train_with_distillation(
teacher: nn.Module,
student: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 50,
temperature: float = 4.0,
alpha: float = 0.7,
lr: float = 1e-3,
device: str = "cuda"
):
teacher = teacher.to(device).eval() # Teacher: SOLO inference, no backprop!
student = student.to(device)
criterion = DistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
best_acc = 0.0
history = {'train_loss': [], 'val_acc': [], 'distill_loss': [], 'student_loss': []}
for epoch in range(n_epochs):
student.train()
total_loss = distill_loss_sum = student_loss_sum = 0.0
n_batches = 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass teacher (CRITICO: no gradients, risparmia memoria!)
with torch.no_grad():
teacher_logits = teacher(imgs)
# Forward pass student (con gradients)
student_logits = student(imgs)
# Loss combinata
losses = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
optimizer.step()
total_loss += losses['total'].item()
distill_loss_sum += losses['distill'].item()
student_loss_sum += losses['student'].item()
n_batches += 1
scheduler.step()
# Validation
student.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = student(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_total = total_loss / n_batches
history['train_loss'].append(avg_total)
history['val_acc'].append(val_acc)
history['distill_loss'].append(distill_loss_sum / n_batches)
history['student_loss'].append(student_loss_sum / n_batches)
if val_acc > best_acc:
best_acc = val_acc
torch.save(student.state_dict(), "best_student.pth")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {avg_total:.4f} "
f"| Val Acc: {val_acc:.4f} | Best: {best_acc:.4f}")
print(f"\nMiglior accuracy student: {best_acc:.4f}")
return history, best_acc
# Risultati tipici CIFAR-100:
# ResNet-50 teacher: 78.2% Top-1
# MobileNetV3-S senza KD: 67.1% Top-1
# MobileNetV3-S con KD: 71.4% Top-1 (+4.3%)
# MobileNetV3-S con KD+feat: 73.2% Top-1 (+6.1%)
# Compression: 10x param, 73x FLOPs
機能の抽出: 内部表現の転送
ソフトラベルでの蒸留では、教師の最終出力のみが転送されます。そこには 特徴の蒸留 さらに言えば、学生にも同じことを繰り返すよう強制します。 中間表現 教師の特徴 — ネットワークのさまざまなレベルでの特徴マップ。 これは、教師と生徒のアーキテクチャが大きく異なる場合に特に効果的です。 (例: CNN 教師、ViT 学生)、タスクに豊富な空間特徴が必要な場合。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# ============================================================
# FEATURE EXTRACTOR via Forward Hooks
# ============================================================
class FeatureExtractor:
"""Cattura le feature di layer specifici tramite forward hooks."""
def __init__(self, model: nn.Module, layer_names: list):
self.features = {}
self.hooks = []
for name, module in model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, n=name: self.features.update({n: out})
)
self.hooks.append(hook)
def get_features(self) -> list:
return list(self.features.values())
def clear(self):
self.features.clear()
def remove(self):
"""Rimuovi hooks per evitare memory leaks."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
# ============================================================
# FEATURE DISTILLATION LOSS
# ============================================================
class FeatureDistillationLoss(nn.Module):
"""
Loss che combina:
1. KD loss standard (soft labels output)
2. Feature Matching Loss (MSE tra feature intermedie normalizzate)
3. Relation-Based Loss (distanze relative tra sample nel batch)
"""
def __init__(self, student_channels: list, teacher_channels: list,
temperature: float = 4.0, alpha: float = 0.4,
beta: float = 0.4, gamma: float = 0.2):
"""
alpha: peso KD loss
beta: peso feature matching loss
gamma: peso CE loss standard
(alpha + beta + gamma deve essere 1.0)
"""
super().__init__()
assert abs(alpha + beta + gamma - 1.0) < 1e-6, "Pesi devono sommare a 1"
self.T = temperature
self.alpha = alpha
self.beta = beta
self.gamma = gamma
# Adattatori per allineare dimensioni teacher->student
# Esempio: teacher ha 2048 canali, student 96 -> adapter 1x1 conv
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(t_ch, s_ch, 1, bias=False),
nn.BatchNorm2d(s_ch),
nn.ReLU(inplace=True)
)
for t_ch, s_ch in zip(teacher_channels, student_channels)
])
def forward(self, student_logits, teacher_logits, labels,
student_features: list, teacher_features: list):
# 1. KD Loss (soft labels)
kl = nn.KLDivLoss(reduction='batchmean')
loss_kd = kl(
F.log_softmax(student_logits / self.T, dim=1),
F.softmax(teacher_logits / self.T, dim=1)
) * self.T ** 2
# 2. CE Loss standard
loss_ce = F.cross_entropy(student_logits, labels)
# 3. Feature Matching Loss
loss_feat = torch.tensor(0.0, device=student_logits.device)
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Adatta canali del teacher allo student
t_adapted = self.adapters[i](t_feat.detach())
# Allinea risoluzione spaziale se necessario
if s_feat.shape[2:] != t_adapted.shape[2:]:
t_adapted = F.interpolate(
t_adapted, size=s_feat.shape[2:],
mode='bilinear', align_corners=False
)
# Normalizza le feature (cosine similarity invece di MSE)
s_norm = F.normalize(s_feat.view(s_feat.size(0), -1), dim=1)
t_norm = F.normalize(t_adapted.view(t_adapted.size(0), -1), dim=1)
# MSE tra feature normalizzate
loss_feat = loss_feat + F.mse_loss(s_norm, t_norm)
loss_feat = loss_feat / max(len(student_features), 1)
total = self.alpha * loss_kd + self.beta * loss_feat + self.gamma * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'feat': loss_feat.detach()
}
# Configurazione per ResNet-50 teacher -> MobileNetV3-S student
# Layer teacher: [layer2, layer3, layer4] -> Canali: [512, 1024, 2048]
# Layer student: [features.4, features.9, features.12] -> Canali: [40, 96, 576]
teacher = models.resnet50(pretrained=True)
student = models.mobilenet_v3_small(pretrained=False)
teacher_layers = ['layer2', 'layer3', 'layer4']
student_layers = ['features.4', 'features.9', 'features.12']
teacher_channels = [512, 1024, 2048]
student_channels = [40, 96, 576]
teacher_extractor = FeatureExtractor(teacher, teacher_layers)
student_extractor = FeatureExtractor(student, student_layers)
feat_criterion = FeatureDistillationLoss(
student_channels=student_channels,
teacher_channels=teacher_channels,
temperature=4.0, alpha=0.4, beta=0.4, gamma=0.2
)
print("Feature Distillation setup completato!")
print(f"Teacher layers: {teacher_layers}")
print(f"Student layers: {student_layers}")
Transformer および Vision Transformer の注意転送
ビジョン トランスフォーマーは、明示的なアテンション マップを生成します。 蒸留した。 ディート (Data-efficient Image Transformer) はこのアプローチを使用します 特別な蒸留トークン付き。ザ」注意力の伝達 (ザゴルイコ & Komodakis、2017) この概念を CNN にも拡張し、アテンション マップを構築します 畳み込み層の活性化から。
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# ============================================================
# ATTENTION TRANSFER (AT) per CNN
# ============================================================
class AttentionTransferLoss(nn.Module):
"""
Attention Transfer (Zagoruyko & Komodakis, 2017).
Forza lo student a replicare le attention maps del teacher.
Efficace per transfer tra architetture diverse (CNN <-> ViT).
"""
def __init__(self, beta: float = 1000.0):
super().__init__()
self.beta = beta
def attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""
Calcola mappa di attention come norma L2 quadrata delle attivazioni.
Input features: [B, C, H, W]
Output: [B, H*W] normalizzato (attention map piatta)
"""
# Somma sui canali -> [B, H, W]
attention = features.pow(2).sum(dim=1)
# Appiattisci -> [B, H*W]
attention = attention.view(attention.size(0), -1)
# Normalizza L2 per ogni sample nel batch
return F.normalize(attention, p=2, dim=1)
def forward(self, student_features: list, teacher_features: list) -> torch.Tensor:
"""Calcola AT loss su più livelli."""
total_loss = torch.tensor(0.0)
for s_feat, t_feat in zip(student_features, teacher_features):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat).detach()
# Allinea dimensioni spaziali se necessario
if s_attn.shape != t_attn.shape:
s_h = int(s_feat.shape[2] * s_feat.shape[3])
t_h = int(t_feat.shape[2] * t_feat.shape[3])
# Usa interpolazione sull'attention map 2D
s_2d = s_feat.pow(2).mean(1, keepdim=True)
t_2d = t_feat.pow(2).mean(1, keepdim=True)
t_2d = F.interpolate(t_2d, size=s_feat.shape[2:], mode='bilinear')
s_attn = F.normalize(s_2d.view(s_2d.size(0), -1), p=2, dim=1)
t_attn = F.normalize(t_2d.view(t_2d.size(0), -1), p=2, dim=1).detach()
total_loss = total_loss + (s_attn - t_attn).pow(2).mean()
return self.beta * total_loss / max(len(student_features), 1)
# ============================================================
# DeiT-STYLE: Distillation Token per Vision Transformer
# ============================================================
class ViTWithDistillationToken(nn.Module):
"""
Aggiunge un distillation token a un ViT standard.
Come in DeiT: il token impara a replicare le predizioni
di un teacher CNN (es. RegNet, ResNet).
Durante inference: media CLS token + dist token.
Durante training: loss su entrambi i token.
"""
def __init__(self, vit_model: nn.Module, n_classes: int, d_model: int = 384):
super().__init__()
self.vit = vit_model
# Token di distillazione learnable
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.dist_token, std=0.02)
# Distillation head (separato dal CLS head)
self.dist_head = nn.Linear(d_model, n_classes)
def forward(self, x: torch.Tensor, return_dist: bool = False):
# Ottieni feature dal ViT
features = self.vit.forward_features(x)
# CLS prediction (predizione principale)
cls_pred = self.vit.head(features[:, 0])
# Dist token prediction (guida del teacher)
dist_pred = self.dist_head(features[:, 1]) # Assumendo dist_token al pos 1
if self.training:
return cls_pred, dist_pred
else:
# Inference: media delle due predizioni
return (cls_pred + dist_pred) / 2.0
def deit_distillation_loss(cls_pred, dist_pred, teacher_pred, labels,
alpha: float = 0.5, temperature: float = 3.0):
"""
Loss DeiT: combina CE hard labels + KD dal teacher CNN.
"""
# Hard label loss sul CLS token
loss_cls = F.cross_entropy(cls_pred, labels)
# Soft label loss sul distillation token
loss_dist = F.kl_div(
F.log_softmax(dist_pred / temperature, dim=1),
F.softmax(teacher_pred / temperature, dim=1),
reduction='batchmean'
) * temperature ** 2
return alpha * loss_dist + (1 - alpha) * loss_cls
LLMの蒸留:大型モデルから小型モデルまで
LLM の蒸留は同じ原理に従いますが、いくつかの重要な特徴があります。 語彙は膨大 (32,000 ~ 128,000 トークン) であり、教師と生徒のモデルは次のとおりである必要があります。 トークナイザー レベルで互換性があり、損失はシーケンス内の各トークンのレベルで発生します。 DistilBERT、DistilGPT2、および Microsoft の Phi ファミリは成功例です。
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
AutoModelForSequenceClassification
)
import torch
import torch.nn.functional as F
# ============================================================
# DISTILLAZIONE LLM: causal language modeling
# ============================================================
def distill_llm_batch(
teacher_model,
student_model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 2.0,
alpha: float = 0.7
) -> dict:
"""
Distillazione LLM per next-token prediction.
Funziona per GPT-style (causal) e BERT-style (masked).
teacher_model: modello grande (es. Llama-3-8B)
student_model: modello piccolo (es. Llama-3-1B)
alpha: peso KD loss (1-alpha = peso CE loss standard)
"""
device = next(student_model.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Teacher inference (no gradients, può essere su device diverso)
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids, attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits # [B, seq_len, vocab_size]
# Student inference (con gradients)
student_outputs = student_model(
input_ids, attention_mask=attention_mask
)
student_logits = student_outputs.logits
# Shift per next-token prediction (esclude l'ultimo token come input)
shift_student = student_logits[:, :-1, :].contiguous()
shift_teacher = teacher_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# Ridimensiona per calcolo per-token loss
B, S, V = shift_student.shape
shift_student_flat = shift_student.view(B * S, V)
shift_teacher_flat = shift_teacher.view(B * S, V)
shift_labels_flat = shift_labels.view(B * S)
# 1. KD Loss: KL divergence per ogni token
student_log_probs = F.log_softmax(shift_student_flat / temperature, dim=-1)
teacher_probs = F.softmax(shift_teacher_flat / temperature, dim=-1)
loss_kd = F.kl_div(student_log_probs, teacher_probs,
reduction='batchmean') * temperature ** 2
# 2. CE Loss standard (con label -100 per token da ignorare)
loss_ce = F.cross_entropy(
shift_student_flat, shift_labels_flat,
ignore_index=-100 # Padding tokens
)
total = alpha * loss_kd + (1 - alpha) * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'perplexity': torch.exp(loss_ce).detach()
}
# ============================================================
# PIPELINE DISTILLAZIONE LLM COMPLETA
# ============================================================
def setup_llm_distillation(
teacher_name: str = "meta-llama/Llama-3.1-8B",
student_name: str = "meta-llama/Llama-3.2-1B",
device: str = "cuda"
):
"""
Setup per distillare un LLM grande in uno più piccolo.
IMPORTANTE: teacher e student devono condividere il tokenizer
per avere distribuzioni compatibili sullo stesso vocabolario.
"""
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Teacher: carica in FP16 per risparmiare memoria
teacher = AutoModelForCausalLM.from_pretrained(
teacher_name,
torch_dtype=torch.float16,
device_map="auto" # Distribuisce su più GPU se disponibili
)
teacher.eval()
# Student: carica in FP32 per training stabile
student = AutoModelForCausalLM.from_pretrained(
student_name,
torch_dtype=torch.float32
).to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {teacher_params/1e9:.1f}B parametri")
print(f"Student: {student_params/1e9:.1f}B parametri")
print(f"Compressione: {teacher_params/student_params:.1f}x")
return teacher, student, tokenizer
print("Setup distillazione LLM completato!")
自己蒸留とボーン・アゲイン・ネットワーク
La 自己蒸留 そして驚くべきバリエーション: モデルが教師として機能する 自分自身に。で ボーン・アゲイン・ネットワークス (BANs、Furlanello et al. 2018)、彼らはトレーニングします 同じアーキテクチャを持つモデルの連続世代: 各世代では、 以前は教師として。その結果、体系的な改善 (+1-2% トップ 1) が実現しました。 モデルのサイズを大きくします。
import torch
import torch.nn as nn
import copy
# ============================================================
# BORN AGAIN NETWORKS (BANs)
# ============================================================
def born_again_training(model_factory, train_loader, val_loader,
n_generations: int = 3,
temperature: float = 4.0,
n_epochs: int = 30,
device: str = "cuda"):
"""
Allena N generazioni con la stessa architettura.
Gen 1: training standard con CE loss.
Gen 2+: distillazione dalla generazione precedente.
Risultati tipici CIFAR-100:
Gen 1: 67.1% (baseline)
Gen 2: 70.8% (+3.7%)
Gen 3: 72.1% (+5.0%)
Ensemble 1+2+3: 74.8% (+7.7%)
"""
criterion_kd = DistillationLoss(temperature=temperature, alpha=0.7)
criterion_ce = nn.CrossEntropyLoss()
all_models = []
results = []
# === Generazione 1: training standard ===
print("Gen 1: training standard...")
gen1 = model_factory().to(device)
opt1 = torch.optim.SGD(gen1.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch1 = torch.optim.lr_scheduler.MultiStepLR(opt1, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
gen1.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
opt1.zero_grad()
criterion_ce(gen1(imgs), labels).backward()
opt1.step()
sch1.step()
acc1 = _evaluate(gen1, val_loader, device)
results.append(acc1)
all_models.append(gen1)
print(f" Gen 1 val acc: {acc1:.4f}")
teacher = copy.deepcopy(gen1).eval()
# === Generazioni successive con distillazione ===
for gen_idx in range(2, n_generations + 1):
print(f"Gen {gen_idx}: KD da gen {gen_idx-1}...")
student = model_factory().to(device)
opt = torch.optim.SGD(student.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
student.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
with torch.no_grad():
t_logits = teacher(imgs)
s_logits = student(imgs)
losses = criterion_kd(s_logits, t_logits, labels)
opt.zero_grad()
losses['total'].backward()
opt.step()
sch.step()
acc = _evaluate(student, val_loader, device)
results.append(acc)
all_models.append(student)
print(f" Gen {gen_idx} val acc: {acc:.4f}")
teacher = copy.deepcopy(student).eval()
# Ensemble di tutti i modelli (upper bound)
ensemble_acc = _ensemble_evaluate(all_models, val_loader, device)
print(f"\nEnsemble {n_generations} gen: {ensemble_acc:.4f}")
return results, all_models
def _evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
correct += (model(imgs).argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
def _ensemble_evaluate(models, loader, device):
"""Ensemble averaging delle predizioni."""
for m in models:
m.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
logits_sum = torch.stack([m(imgs) for m in models]).mean(0)
correct += (logits_sum.argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
典型的な蒸留結果 (ベンチマーク 2024 ~ 2025)
| タスク | 教師 | 学生 | KDなし | KDあり | 教師 | 圧縮 |
|---|---|---|---|---|---|---|
| CIFAR-100 | レスネット-50 | モバイルネットV3-S | 67.1% | 73.2% | 78.2% | 10xパラメータ |
| イメージネット | ヴィット-L/16 | DeiT-S | 79.8% | 83.1% | 87.1% | 5xパラメータ |
| 接着剤 (NLP) | BERT-ラージ | 蒸留BERT | 83.2% | 86.4% | 89.2% | 2倍のパラメータ、2倍の速度 |
| 分隊 (QA) | ロベルタ-L | ディスティルロベルタ | 82.1% | 85.8% | 90.4% | 2xパラメータ |
| LLM (複雑さ) | ラマ 3.1 8B | ラマ 3.2 1B | 8.24PPL | 7.81PPL | 6.12 PPL | 8xパラメータ |
KD は通常、2 ~ 10 倍少ないパラメータで生徒と教師の間のギャップの 70 ~ 85% を回復します。
生産パイプライン: 蒸留 + 定量
エッジ展開のための最も強力なワークフローは、蒸留と量子化を組み合わせたものです。 シーケンス: 最初に KD を使用してスチューデントを作成し (高精度を維持)、次に量子化します。 学生(サイズを縮小し、速度を向上させます)。組み合わせによりモデルを削減できる 元の教師と比較して 100 倍から 40 倍に向上し、精度の低下はわずか 5 ~ 10% です。
import torch
import torch.nn as nn
from torchvision import models
# ============================================================
# PIPELINE COMPLETA: Distillazione -> Quantizzazione -> ONNX
# ============================================================
def full_compression_pipeline(teacher_path: str, output_dir: str = "./compressed"):
"""
Pipeline completa per comprimere un modello per edge deployment.
Step 1: Carica teacher pre-trainato
Step 2: Distilla in student più piccolo
Step 3: Quantizza lo student (PTQ INT8)
Step 4: Esporta in ONNX per deployment cross-platform
"""
import os
os.makedirs(output_dir, exist_ok=True)
# STEP 1: Teacher
print("Step 1: Carico teacher...")
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(2048, 10) # 10 classi
# In produzione: teacher.load_state_dict(torch.load(teacher_path))
teacher.eval()
teacher_size_mb = sum(p.numel() * p.element_size()
for p in teacher.parameters()) / (1024**2)
print(f" Teacher: {teacher_size_mb:.1f} MB, "
f"{sum(p.numel() for p in teacher.parameters())/1e6:.1f}M param")
# STEP 2: Student (dopo distillazione)
print("Step 2: Student con distillazione (simulato con MobileNetV3)...")
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, 10
)
# In produzione: train_with_distillation(teacher, student, ...)
# student.load_state_dict(torch.load("best_student.pth"))
student_size_mb = sum(p.numel() * p.element_size()
for p in student.parameters()) / (1024**2)
print(f" Student: {student_size_mb:.1f} MB, "
f"{sum(p.numel() for p in student.parameters())/1e6:.1f}M param")
print(f" Riduzione rispetto teacher: {teacher_size_mb/student_size_mb:.1f}x")
# STEP 3: Quantizzazione INT8 (PTQ)
print("Step 3: Quantizzazione INT8...")
student.eval()
# Quantizzazione dinamica (più semplice, leggermente meno efficiente)
student_quant = torch.quantization.quantize_dynamic(
student,
{nn.Linear}, # Quantizza solo Linear (Conv2d richiede calibrazione)
dtype=torch.qint8
)
# Per quantizzazione statica (più efficiente):
# student.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# torch.quantization.prepare(student, inplace=True)
# calibrate(student, calib_loader) # Forward pass su dati di calibrazione
# torch.quantization.convert(student, inplace=True)
quant_size_mb = sum(p.numel() * p.element_size()
for p in student_quant.parameters()) / (1024**2)
print(f" Student INT8: ~{student_size_mb/4:.1f} MB (stima)")
print(f" Riduzione totale: ~{teacher_size_mb/(student_size_mb/4):.0f}x vs teacher")
# STEP 4: Export ONNX
print("Step 4: Export ONNX...")
dummy = torch.randn(1, 3, 224, 224)
onnx_path = f"{output_dir}/student_compressed.onnx"
torch.onnx.export(
student, # Usa FP32 per ONNX (compatibilità più ampia)
dummy,
onnx_path,
opset_version=13,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}},
export_params=True
)
print(f"\n=== RIEPILOGO PIPELINE ===")
print(f"Teacher (ResNet-50): {teacher_size_mb:.1f} MB")
print(f"Student KD (MobNetV3): {student_size_mb:.1f} MB ({teacher_size_mb/student_size_mb:.1f}x riduzione)")
print(f"Student INT8 (stimato): {student_size_mb/4:.1f} MB ({teacher_size_mb/(student_size_mb/4):.0f}x riduzione)")
print(f"ONNX salvato: {onnx_path}")
return student_quant, onnx_path
full_compression_pipeline("teacher_weights.pth")
蒸留におけるアンチパターン: よくある間違い
- 温度が高すぎる、または低すぎる: T=1 はハードラベルと同等です。 T が高すぎる (>20) と、ソフトラベルがほぼ均一になり、信号が失われます。 特定のデータセットに対して、T ∈ {2, 4, 6, 8} を使用してアブレーション スタディを常に実行します。
- 教師と生徒の能力が違いすぎる: ギャップが大きい場合 (GPT-4~7B)、直接蒸留は効果がありません。ウォーターフォール蒸留を使用します。 GPT-4 -> 13B -> 7B -> 3B。各ステップは 4 ~ 5 倍の縮小を超えてはなりません。
- 蒸留データセットの品質を無視する場合: の品質 蒸留を実行するデータセットは大きな影響を与えます。多様なデータを活用し、 対象タスクの代表者。配布外のデータにより転送が破損します。
- アルファのキャリブレーションが不十分: alpha=1 (KD のみ) の場合、学生は無視します。 真実のラベルが付けられ、教師が間違いを犯した場合には不安定な予測が生成される可能性があります。 alpha=0 では、標準的なトレーニングになります。通常、値 0.5 ~ 0.8 が最適です。
- 先生を凍らせないでください。 教師は eval() モードでなければなりません 学生の研修中。教師が変わり続ける場合 (例: BatchNorm がある場合) トレイン モードでは)、蒸留ターゲットに一貫性がなく、トレーニングが分岐します。
LLM の蒸留: 投機的デコードと応答蒸留
大規模言語モデルのコンテキストでは、蒸留は新しく強力な形式をとります。 2025 年から 2026 年に特に関連する 2 つの手法は次のとおりです。 応答蒸留 (Llama-3.2 および Microsoft の Phi モデルのトレーニングに使用) 投機的デコード、小規模なモデルを活用して推論を高速化します。 品質を損なうことなく大型モデルを再現できます。
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
# ============================================================
# SPECULATIVE DECODING: Draft Model + Target Model
# ============================================================
# Principio: un modello piccolo (draft) genera K token in anticipo.
# Il modello grande (target) verifica tutti i K token in un solo forward pass.
# Se il draft ha ragione, si risparmiano K-1 forward pass del modello grande.
# Speedup tipico: 2-4x senza perdita di qualità.
class SpeculativeDecoder:
"""
Implementazione base di speculative decoding.
Draft model: modello piccolo (es. Llama-3.2-1B)
Target model: modello grande (es. Llama-3.1-8B)
Riferimento: "Fast Inference from Transformers via Speculative Decoding"
(Leviathan et al., 2022) - il paper originale di Google.
"""
def __init__(
self,
draft_model_name: str,
target_model_name: str,
device: str = "cuda",
lookahead_k: int = 5 # Token generati dal draft per ogni step
):
self.device = device
self.lookahead_k = lookahead_k
print(f"Carico draft model: {draft_model_name}...")
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=torch.float16
).to(device).eval()
print(f"Carico target model: {target_model_name}...")
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=torch.float16
).to(device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
def draft_generate(self, input_ids: torch.Tensor) -> tuple:
"""
Il draft model genera K token e restituisce
le distribuzioni di probabilità per acceptance/rejection.
"""
draft_ids = input_ids.clone()
draft_probs = []
with torch.no_grad():
for _ in range(self.lookahead_k):
out = self.draft_model(draft_ids)
next_logits = out.logits[:, -1, :]
next_probs = F.softmax(next_logits, dim=-1)
# Campiona dal draft
next_token = torch.multinomial(next_probs, num_samples=1)
draft_probs.append(next_probs)
draft_ids = torch.cat([draft_ids, next_token], dim=1)
# draft_ids ora include K token aggiuntivi
draft_tokens = draft_ids[:, input_ids.shape[1]:]
return draft_tokens, draft_probs
def speculative_generate(
self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.7
) -> str:
"""Genera testo con speculative decoding."""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated = input_ids.clone()
total_accepted = 0
total_draft = 0
with torch.no_grad():
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# 1. Draft genera K token
draft_tokens, draft_probs = self.draft_generate(generated)
total_draft += self.lookahead_k
# 2. Target verifica tutti i K+1 token in un forward pass
full_seq = torch.cat([generated, draft_tokens], dim=1)
target_out = self.target_model(full_seq)
target_logits = target_out.logits[:, generated.shape[1]-1:-1, :]
# 3. Acceptance-rejection sampling
n_accepted = 0
for i in range(draft_tokens.shape[1]):
draft_tok = draft_tokens[0, i].item()
target_probs = F.softmax(target_logits[0, i] / temperature, dim=-1)
draft_p = draft_probs[i][0, draft_tok].item()
target_p = target_probs[draft_tok].item()
# Accetta se target d'accordo con draft
r = torch.rand(1).item()
if r < min(1.0, target_p / (draft_p + 1e-10)):
n_accepted += 1
else:
# Rifiuta: campiona dal target corretto
corrected = torch.multinomial(
F.relu(target_probs - draft_probs[i][0]),
num_samples=1
)
generated = torch.cat([
generated,
draft_tokens[:, :i],
corrected.unsqueeze(0)
], dim=1)
break
else:
# Tutti accettati: aggiungi bonus token dal target
generated = torch.cat([generated, draft_tokens], dim=1)
bonus_logits = target_out.logits[:, -1, :]
bonus_tok = torch.multinomial(
F.softmax(bonus_logits / temperature, dim=-1), 1
)
generated = torch.cat([generated, bonus_tok], dim=1)
total_accepted += n_accepted
acceptance_rate = total_accepted / max(total_draft, 1)
print(f"Acceptance rate: {acceptance_rate:.1%} (atteso 60-80% con draft simile)")
new_tokens = generated[0, input_ids.shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# ============================================================
# RESPONSE DISTILLATION per LLM (semplificata)
# ============================================================
# Tecnica usata da Llama-3.2, Phi-3, Mistral-7B-Instruct:
# 1. Teacher LLM grande (es. GPT-4, Llama-3.1-70B) genera risposte
# 2. Student LLM piccolo impara a imitare quelle risposte
# Diverso dalla distillazione classica: distilla risposte (output testo),
# non distribuzioni di probabilità interne.
def response_distillation_dataset(
teacher_model_name: str,
prompts: list,
output_file: str = "distillation_dataset.jsonl"
) -> list:
"""
Genera dataset di distillazione con risposte del teacher.
In produzione: usa GPT-4 API, Llama-3.1-70B, o Claude.
"""
import json
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
dataset = []
with open(output_file, 'w') as f:
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(teacher.device)
with torch.no_grad():
outputs = teacher.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response_ids = outputs[0, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
entry = {"prompt": prompt, "response": response}
dataset.append(entry)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"Dataset generato: {len(dataset)} esempi -> {output_file}")
return dataset
# Note: in pratica, usa l'API di un servizio commerciale (OpenAI, Anthropic)
# per generare le risposte del "teacher", poi addestra lo student su di esse.
# Questa e la tecnica dietro la maggior parte dei modelli instruction-following
# come Alpaca, Vicuna, Orca, e i modelli Phi di Microsoft.
print("LLM distillation patterns pronti")
蒸留: バリアント間の技術的比較 (2024-2025)
| 変異体 | 蒸留するもの | プロ | に対して | 一般的な使用方法 |
|---|---|---|---|---|
| ソフトラベル (ヒントン 2015) | 確率分布 | 豊富な標準情報 | 教師のロジットへのアクセスが必要です | ビジョン、分類 |
| 特徴の蒸留 | 中間表現 | 深い機能の転送 | 教師と生徒は互換性のあるアーキテクチャを持っている必要があります | 検出、セグメンテーション |
| 応答蒸留 | 教師のテキスト出力 | 内部アクセスは必要ありません | 不確実性に関する情報が失われる | LLM 命令追従 |
| 生まれ変わったネットワーク | 反復自己蒸留 | 個別の教師は必要ありません | ゲインが限られており、計算コストが高い | アンサンブル、改良 |
| 投機的デコード | 蒸留ではなくドラフトを使用します | 損失なしで 2 ~ 4 倍のスピードアップ | メモリ内に 2 つのモデルが必要 | LLM 推論の高速化 |
結論
知識の蒸留は、最も強力で汎用性の高い圧縮技術の 1 つです 2026 年に利用可能になります。量子化と枝刈りを自然に組み合わせる: 最初に蒸留 最適なスチューデントを作成し、エッジ展開用にスチューデントを量子化します。結果 そして多くの場合、モデルは教師よりも 10 ~ 100 分の 1 小さく、精度は 5 ~ 15% 低下するだけです。
LLM の場合、蒸留により「Distil*」モデルのファミリー全体が可能になりました: DistilBERT、 DistilGPT2、および Microsoft の Phi モデル (2.7B、7B モデル品質)。 2026年のトレンド — 使用頻度において LLM を上回る小型言語モデル (SLM) Gartner — 蒸留によって正確に可能となり、次の知識が伝達されます。 巨大なものから、Raspberry Pi やスマートフォンで動作するモデルまで。
次の記事では、これらの圧縮モデルを エッジデバイス: Raspberry Pi、NVIDIA Jetson、組み込みハードウェア、必要なすべての最適化が施されています 実際の本番環境向け。
次のステップ
- 次の記事: エッジデバイス上のディープラーニング: クラウドからエッジまで
- 関連している: モデル量子化: GPTQ、AWQ、INT8
- 関連している: ニューラル ネットワークの枝刈り: パラメーターの削減
- 関連している: Vision Transformer: DeiT による蒸留
- MLOps シリーズ: 本番環境での圧縮モデルの提供







