지식 증류: 복잡한 모델 압축
GPT-4는 노트북에서 실행하기에는 너무 큽니다. ResNet-152는 너무 느립니다. 모바일 장치. 하지만 이 거대한 모델에서 얻은 지식은 GPU 클러스터에 대한 몇 주간의 교육을 통해 이전됨 모델에게 훨씬 더 작고 빠르며 놀라울 정도로 정확도가 떨어집니다. 이것이 약속이다 지식 증류.
2015년 Hinton, Vinyals 및 Dean이 제안한 증류는 기술 중 하나가 되었습니다. 최신 딥 러닝 툴킷 중 가장 강력하고 다재다능합니다. 원칙은 우아합니다. 대형 모델(선생님)는 더 작은 모델의 훈련을 안내합니다(학생) 단순히 하드 라벨(0/1)이 아닌 연성 확률 교사의 — 수업 간 유사성에 대한 정보가 풍부한 확률 분포입니다. 이 암묵적인 정보, 즉 "어두운 지식"이 증류를 만드는 것입니다. 매우 효과적입니다.
이 가이드에서는 원래 이론부터 변형까지 증류에 대해 심도 있게 탐구합니다. 현대적(특성 증류, 주의 이전, 자가 증류, LLM 증류) PyTorch의 완벽한 구현, 프로덕션 모범 사례 및 실제 사례 연구입니다.
무엇을 배울 것인가
- 증류 이론: 소프트 라벨, 온도 및 어두운 지식
- PyTorch를 사용하여 표준 증류 구현
- 기능 증류: 중간 표현 전송
- 주의 전달: Transformers의 주의 지도 증류하기
- 자기 증류 및 Born Again Networks
- 오프라인 vs 온라인 vs 자체 증류
- LLM을 위한 증류: 대형 모델에서 엣지 모델까지
- 증류와 양자화 및 가지치기 결합
- 모범 사례, 일반적인 오류 및 평가 지표
- 사례 연구: DitilBERT 단계별
증류의 원리: 어두운 지식
하드 라벨(원-핫)에 대해 훈련된 표준 분류자는 매우 적은 정보를 사용합니다. 올바른 클래스의 확률은 1이고 나머지는 모두 0입니다. 하지만 잘 훈련된 모델은 많은 것을 알고 있습니다. 더 보기: 그는 고양이가 자동차보다 개와 더 유사하다는 것을 알고 있습니다. 이 정보는 정답인 경우에도 교사의 확률 분포에 포함됨 확률은 거의 1입니다.
증류의 비결은 하나를 사용하는 것입니다 온도 T "부드럽게 하다" 교사의 확률은 가능성이 낮지만 유익한 수업 간의 차이를 증폭시킵니다. T=1이면 원래 분포를 갖습니다. T가 높으면(예: T=4) 확률이 더 커집니다. 균일하고 암묵적인 유사 관계를 드러냅니다. 이 메커니즘을 어두운 지식 — 모델의 로짓에 숨겨진 지식 단순 바이너리 라벨은 캡처할 수 없습니다.
# Visualizzazione effetto della temperatura sulla dark knowledge
import torch
import torch.nn.functional as F
import numpy as np
# Supponiamo che il teacher produca questi logits per un campione
# con classe vera = 0 (gatto)
teacher_logits = torch.tensor([8.2, 2.1, 1.8, 0.5, 0.3, -0.2, -0.5, -0.8, -1.1, -1.5])
# Classi ipotetiche: [gatto, cane, felino, leone, volpe, auto, aereo, nave, treno, barca]
classi = ["gatto", "cane", "felino", "leone", "volpe", "auto", "aereo", "nave", "treno", "barca"]
print("Effetto della temperatura sulle soft probabilities:")
print("-" * 75)
for T in [1, 2, 4, 8, 20]:
probs = F.softmax(teacher_logits / T, dim=0)
entropia = -(probs * probs.log()).sum().item()
print(f"T={T:2d}: p(gatto)={probs[0]:.4f}, "
f"p(cane)={probs[1]:.4f}, p(felino)={probs[2]:.4f}, "
f"entropia={entropia:.3f}")
# Output:
# T= 1: p(gatto)=0.9833, p(cane)=0.0106, p(felino)=0.0079, entropia=0.123
# T= 2: p(gatto)=0.9175, p(cane)=0.0440, p(felino)=0.0297, entropia=0.424
# T= 4: p(gatto)=0.7562, p(cane)=0.1168, p(felino)=0.0913, entropia=0.895
# T= 8: p(gatto)=0.5756, p(cane)=0.1572, p(felino)=0.1343, entropia=1.387
# T=20: p(gatto)=0.3520, p(cane)=0.1668, p(felino)=0.1600, entropia=1.944
#
# Con T alta, il teacher rivela che "gatto" e molto simile a "cane" e "felino"
# ma niente a che fare con "auto" o "aereo".
# Questa e la DARK KNOWLEDGE che lo student impara!
print("\nAnalisi della dark knowledge:")
probs_t1 = F.softmax(teacher_logits / 1, dim=0)
probs_t4 = F.softmax(teacher_logits / 4, dim=0)
for i, classe in enumerate(classi):
print(f" {classe:8s}: T=1: {probs_t1[i]:.4f}, T=4: {probs_t4[i]:.4f}")
증류의 수학
증류 손실은 두 가지 항을 알파 균형 하이퍼파라미터와 결합합니다.
- L_증류: 온도 T에서 학생과 교사의 소프트 확률 간 KL 발산에 T²를 곱하여 기울기 크기 감소를 보상합니다.
- L_학생: 학생 예측과 하드 라벨 간의 표준 교차 엔트로피
최종 공식은 다음과 같습니다. L = 알파 * T² * KL(student_soft || Teacher_soft) + (1-alpha) * CE(학생, 라벨)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
# ============================================================
# DISTILLATION LOSS - Implementazione completa
# ============================================================
class DistillationLoss(nn.Module):
"""
Loss per Knowledge Distillation (Hinton et al., 2015).
L = alpha * T^2 * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
Il fattore T^2 e fondamentale: quando si usa T > 1,
i gradienti si riducono di 1/T^2. Moltiplicando per T^2
si compensa questo effetto, mantenendo scale dei gradienti
coerenti tra la KD loss e la CE loss.
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
"""
temperature: scala le probabilità soft (tipico: 2-8)
alpha: peso della distillation loss (tipico: 0.5-0.9)
"""
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor) -> dict:
# Soft probabilities a temperatura T
# NB: log_softmax per student (richiesto da KLDivLoss)
student_soft = F.log_softmax(student_logits / self.T, dim=1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=1)
# KL divergence * T^2 per compensare la riduzione del gradiente
loss_distill = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
# Cross-entropy standard con hard labels
loss_student = self.ce_loss(student_logits, labels)
# Combinazione pesata
total_loss = self.alpha * loss_distill + (1 - self.alpha) * loss_student
return {
'total': total_loss,
'distill': loss_distill.detach(),
'student': loss_student.detach()
}
# ============================================================
# MODELLI TEACHER e STUDENT
# ============================================================
def create_teacher_student(n_classes: int = 100):
"""
Teacher: ResNet-50 (~25M parametri) - pre-trained su ImageNet
Student: MobileNetV3-Small (~2.5M parametri) - 10x più piccolo
"""
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, n_classes)
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, n_classes
)
total_teacher = sum(p.numel() for p in teacher.parameters())
total_student = sum(p.numel() for p in student.parameters())
flops_teacher = 4.1e9 # Approx FLOPs ResNet-50
flops_student = 0.056e9 # Approx FLOPs MobileNetV3-S
print(f"Teacher (ResNet-50): {total_teacher:,} param, {flops_teacher/1e9:.1f}G FLOPs")
print(f"Student (MobileNetV3): {total_student:,} param, {flops_student*1000:.0f}M FLOPs")
print(f"Fattore compressione: {total_teacher/total_student:.1f}x param, "
f"{flops_teacher/flops_student:.0f}x FLOPs")
return teacher, student
# ============================================================
# TRAINING LOOP CON DISTILLAZIONE
# ============================================================
def train_with_distillation(
teacher: nn.Module,
student: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 50,
temperature: float = 4.0,
alpha: float = 0.7,
lr: float = 1e-3,
device: str = "cuda"
):
teacher = teacher.to(device).eval() # Teacher: SOLO inference, no backprop!
student = student.to(device)
criterion = DistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
best_acc = 0.0
history = {'train_loss': [], 'val_acc': [], 'distill_loss': [], 'student_loss': []}
for epoch in range(n_epochs):
student.train()
total_loss = distill_loss_sum = student_loss_sum = 0.0
n_batches = 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass teacher (CRITICO: no gradients, risparmia memoria!)
with torch.no_grad():
teacher_logits = teacher(imgs)
# Forward pass student (con gradients)
student_logits = student(imgs)
# Loss combinata
losses = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
optimizer.step()
total_loss += losses['total'].item()
distill_loss_sum += losses['distill'].item()
student_loss_sum += losses['student'].item()
n_batches += 1
scheduler.step()
# Validation
student.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = student(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_total = total_loss / n_batches
history['train_loss'].append(avg_total)
history['val_acc'].append(val_acc)
history['distill_loss'].append(distill_loss_sum / n_batches)
history['student_loss'].append(student_loss_sum / n_batches)
if val_acc > best_acc:
best_acc = val_acc
torch.save(student.state_dict(), "best_student.pth")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {avg_total:.4f} "
f"| Val Acc: {val_acc:.4f} | Best: {best_acc:.4f}")
print(f"\nMiglior accuracy student: {best_acc:.4f}")
return history, best_acc
# Risultati tipici CIFAR-100:
# ResNet-50 teacher: 78.2% Top-1
# MobileNetV3-S senza KD: 67.1% Top-1
# MobileNetV3-S con KD: 71.4% Top-1 (+4.3%)
# MobileNetV3-S con KD+feat: 73.2% Top-1 (+6.1%)
# Compression: 10x param, 73x FLOPs
기능 추출: 내부 표현 전송
소프트 라벨의 증류는 교사의 최종 출력만 전송합니다. 거기 특징 증류 더 나아가서 학생이 다음을 복제하도록 강요합니다. 중간 표현 교사의 기능은 네트워크의 다양한 수준에 매핑됩니다. 이는 교사와 학생의 아키텍처가 매우 다를 때 특히 효과적입니다. (예: CNN 교사, ViT 학생) 및 작업에 풍부한 공간 기능이 필요한 경우.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# ============================================================
# FEATURE EXTRACTOR via Forward Hooks
# ============================================================
class FeatureExtractor:
"""Cattura le feature di layer specifici tramite forward hooks."""
def __init__(self, model: nn.Module, layer_names: list):
self.features = {}
self.hooks = []
for name, module in model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, n=name: self.features.update({n: out})
)
self.hooks.append(hook)
def get_features(self) -> list:
return list(self.features.values())
def clear(self):
self.features.clear()
def remove(self):
"""Rimuovi hooks per evitare memory leaks."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
# ============================================================
# FEATURE DISTILLATION LOSS
# ============================================================
class FeatureDistillationLoss(nn.Module):
"""
Loss che combina:
1. KD loss standard (soft labels output)
2. Feature Matching Loss (MSE tra feature intermedie normalizzate)
3. Relation-Based Loss (distanze relative tra sample nel batch)
"""
def __init__(self, student_channels: list, teacher_channels: list,
temperature: float = 4.0, alpha: float = 0.4,
beta: float = 0.4, gamma: float = 0.2):
"""
alpha: peso KD loss
beta: peso feature matching loss
gamma: peso CE loss standard
(alpha + beta + gamma deve essere 1.0)
"""
super().__init__()
assert abs(alpha + beta + gamma - 1.0) < 1e-6, "Pesi devono sommare a 1"
self.T = temperature
self.alpha = alpha
self.beta = beta
self.gamma = gamma
# Adattatori per allineare dimensioni teacher->student
# Esempio: teacher ha 2048 canali, student 96 -> adapter 1x1 conv
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(t_ch, s_ch, 1, bias=False),
nn.BatchNorm2d(s_ch),
nn.ReLU(inplace=True)
)
for t_ch, s_ch in zip(teacher_channels, student_channels)
])
def forward(self, student_logits, teacher_logits, labels,
student_features: list, teacher_features: list):
# 1. KD Loss (soft labels)
kl = nn.KLDivLoss(reduction='batchmean')
loss_kd = kl(
F.log_softmax(student_logits / self.T, dim=1),
F.softmax(teacher_logits / self.T, dim=1)
) * self.T ** 2
# 2. CE Loss standard
loss_ce = F.cross_entropy(student_logits, labels)
# 3. Feature Matching Loss
loss_feat = torch.tensor(0.0, device=student_logits.device)
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Adatta canali del teacher allo student
t_adapted = self.adapters[i](t_feat.detach())
# Allinea risoluzione spaziale se necessario
if s_feat.shape[2:] != t_adapted.shape[2:]:
t_adapted = F.interpolate(
t_adapted, size=s_feat.shape[2:],
mode='bilinear', align_corners=False
)
# Normalizza le feature (cosine similarity invece di MSE)
s_norm = F.normalize(s_feat.view(s_feat.size(0), -1), dim=1)
t_norm = F.normalize(t_adapted.view(t_adapted.size(0), -1), dim=1)
# MSE tra feature normalizzate
loss_feat = loss_feat + F.mse_loss(s_norm, t_norm)
loss_feat = loss_feat / max(len(student_features), 1)
total = self.alpha * loss_kd + self.beta * loss_feat + self.gamma * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'feat': loss_feat.detach()
}
# Configurazione per ResNet-50 teacher -> MobileNetV3-S student
# Layer teacher: [layer2, layer3, layer4] -> Canali: [512, 1024, 2048]
# Layer student: [features.4, features.9, features.12] -> Canali: [40, 96, 576]
teacher = models.resnet50(pretrained=True)
student = models.mobilenet_v3_small(pretrained=False)
teacher_layers = ['layer2', 'layer3', 'layer4']
student_layers = ['features.4', 'features.9', 'features.12']
teacher_channels = [512, 1024, 2048]
student_channels = [40, 96, 576]
teacher_extractor = FeatureExtractor(teacher, teacher_layers)
student_extractor = FeatureExtractor(student, student_layers)
feat_criterion = FeatureDistillationLoss(
student_channels=student_channels,
teacher_channels=teacher_channels,
temperature=4.0, alpha=0.4, beta=0.4, gamma=0.2
)
print("Feature Distillation setup completato!")
print(f"Teacher layers: {teacher_layers}")
print(f"Student layers: {student_layers}")
Transformer 및 Vision Transformer의 주의 전달
Vision Transformers는 직접적으로 확인할 수 있는 명시적인 주의 지도를 생성합니다. 증류. 디잇 (Data-efficient Image Transformer)는 이 접근 방식을 사용합니다. 특별한 증류 토큰으로. 그만큼'주의 전달 (자고루이코 & Komodakis, 2017) 이 개념을 CNN에도 확장하여 주의 지도를 구축합니다. 컨볼루션 레이어의 활성화로 인해 발생합니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# ============================================================
# ATTENTION TRANSFER (AT) per CNN
# ============================================================
class AttentionTransferLoss(nn.Module):
"""
Attention Transfer (Zagoruyko & Komodakis, 2017).
Forza lo student a replicare le attention maps del teacher.
Efficace per transfer tra architetture diverse (CNN <-> ViT).
"""
def __init__(self, beta: float = 1000.0):
super().__init__()
self.beta = beta
def attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""
Calcola mappa di attention come norma L2 quadrata delle attivazioni.
Input features: [B, C, H, W]
Output: [B, H*W] normalizzato (attention map piatta)
"""
# Somma sui canali -> [B, H, W]
attention = features.pow(2).sum(dim=1)
# Appiattisci -> [B, H*W]
attention = attention.view(attention.size(0), -1)
# Normalizza L2 per ogni sample nel batch
return F.normalize(attention, p=2, dim=1)
def forward(self, student_features: list, teacher_features: list) -> torch.Tensor:
"""Calcola AT loss su più livelli."""
total_loss = torch.tensor(0.0)
for s_feat, t_feat in zip(student_features, teacher_features):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat).detach()
# Allinea dimensioni spaziali se necessario
if s_attn.shape != t_attn.shape:
s_h = int(s_feat.shape[2] * s_feat.shape[3])
t_h = int(t_feat.shape[2] * t_feat.shape[3])
# Usa interpolazione sull'attention map 2D
s_2d = s_feat.pow(2).mean(1, keepdim=True)
t_2d = t_feat.pow(2).mean(1, keepdim=True)
t_2d = F.interpolate(t_2d, size=s_feat.shape[2:], mode='bilinear')
s_attn = F.normalize(s_2d.view(s_2d.size(0), -1), p=2, dim=1)
t_attn = F.normalize(t_2d.view(t_2d.size(0), -1), p=2, dim=1).detach()
total_loss = total_loss + (s_attn - t_attn).pow(2).mean()
return self.beta * total_loss / max(len(student_features), 1)
# ============================================================
# DeiT-STYLE: Distillation Token per Vision Transformer
# ============================================================
class ViTWithDistillationToken(nn.Module):
"""
Aggiunge un distillation token a un ViT standard.
Come in DeiT: il token impara a replicare le predizioni
di un teacher CNN (es. RegNet, ResNet).
Durante inference: media CLS token + dist token.
Durante training: loss su entrambi i token.
"""
def __init__(self, vit_model: nn.Module, n_classes: int, d_model: int = 384):
super().__init__()
self.vit = vit_model
# Token di distillazione learnable
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.dist_token, std=0.02)
# Distillation head (separato dal CLS head)
self.dist_head = nn.Linear(d_model, n_classes)
def forward(self, x: torch.Tensor, return_dist: bool = False):
# Ottieni feature dal ViT
features = self.vit.forward_features(x)
# CLS prediction (predizione principale)
cls_pred = self.vit.head(features[:, 0])
# Dist token prediction (guida del teacher)
dist_pred = self.dist_head(features[:, 1]) # Assumendo dist_token al pos 1
if self.training:
return cls_pred, dist_pred
else:
# Inference: media delle due predizioni
return (cls_pred + dist_pred) / 2.0
def deit_distillation_loss(cls_pred, dist_pred, teacher_pred, labels,
alpha: float = 0.5, temperature: float = 3.0):
"""
Loss DeiT: combina CE hard labels + KD dal teacher CNN.
"""
# Hard label loss sul CLS token
loss_cls = F.cross_entropy(cls_pred, labels)
# Soft label loss sul distillation token
loss_dist = F.kl_div(
F.log_softmax(dist_pred / temperature, dim=1),
F.softmax(teacher_pred / temperature, dim=1),
reduction='batchmean'
) * temperature ** 2
return alpha * loss_dist + (1 - alpha) * loss_cls
LLM을 위한 증류: 대형 모델부터 소형 모델까지
LLM의 증류는 동일한 원칙을 따르지만 몇 가지 중요한 특징이 있습니다. 어휘는 엄청납니다(32K-128K 토큰). 교사와 학생 모델은 다음과 같아야 합니다. 토크나이저 수준에서 호환되며 손실은 시퀀스의 각 토큰 수준에서 작동합니다. DistilBERT, DistilGPT2 및 Microsoft의 Phi 제품군이 성공적인 예입니다.
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
AutoModelForSequenceClassification
)
import torch
import torch.nn.functional as F
# ============================================================
# DISTILLAZIONE LLM: causal language modeling
# ============================================================
def distill_llm_batch(
teacher_model,
student_model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 2.0,
alpha: float = 0.7
) -> dict:
"""
Distillazione LLM per next-token prediction.
Funziona per GPT-style (causal) e BERT-style (masked).
teacher_model: modello grande (es. Llama-3-8B)
student_model: modello piccolo (es. Llama-3-1B)
alpha: peso KD loss (1-alpha = peso CE loss standard)
"""
device = next(student_model.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Teacher inference (no gradients, può essere su device diverso)
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids, attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits # [B, seq_len, vocab_size]
# Student inference (con gradients)
student_outputs = student_model(
input_ids, attention_mask=attention_mask
)
student_logits = student_outputs.logits
# Shift per next-token prediction (esclude l'ultimo token come input)
shift_student = student_logits[:, :-1, :].contiguous()
shift_teacher = teacher_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# Ridimensiona per calcolo per-token loss
B, S, V = shift_student.shape
shift_student_flat = shift_student.view(B * S, V)
shift_teacher_flat = shift_teacher.view(B * S, V)
shift_labels_flat = shift_labels.view(B * S)
# 1. KD Loss: KL divergence per ogni token
student_log_probs = F.log_softmax(shift_student_flat / temperature, dim=-1)
teacher_probs = F.softmax(shift_teacher_flat / temperature, dim=-1)
loss_kd = F.kl_div(student_log_probs, teacher_probs,
reduction='batchmean') * temperature ** 2
# 2. CE Loss standard (con label -100 per token da ignorare)
loss_ce = F.cross_entropy(
shift_student_flat, shift_labels_flat,
ignore_index=-100 # Padding tokens
)
total = alpha * loss_kd + (1 - alpha) * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'perplexity': torch.exp(loss_ce).detach()
}
# ============================================================
# PIPELINE DISTILLAZIONE LLM COMPLETA
# ============================================================
def setup_llm_distillation(
teacher_name: str = "meta-llama/Llama-3.1-8B",
student_name: str = "meta-llama/Llama-3.2-1B",
device: str = "cuda"
):
"""
Setup per distillare un LLM grande in uno più piccolo.
IMPORTANTE: teacher e student devono condividere il tokenizer
per avere distribuzioni compatibili sullo stesso vocabolario.
"""
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Teacher: carica in FP16 per risparmiare memoria
teacher = AutoModelForCausalLM.from_pretrained(
teacher_name,
torch_dtype=torch.float16,
device_map="auto" # Distribuisce su più GPU se disponibili
)
teacher.eval()
# Student: carica in FP32 per training stabile
student = AutoModelForCausalLM.from_pretrained(
student_name,
torch_dtype=torch.float32
).to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {teacher_params/1e9:.1f}B parametri")
print(f"Student: {student_params/1e9:.1f}B parametri")
print(f"Compressione: {teacher_params/student_params:.1f}x")
return teacher, student, tokenizer
print("Setup distillazione LLM completato!")
자기 증류 및 Born Again Networks
La 자가 증류 그리고 놀라운 변형: 모델이 교사 역할을 합니다. 자신에게. ~ 안에 다시 태어난 네트워크 (BAN, Furlanello et al. 2018), 그들은 훈련합니다 동일한 아키텍처를 사용하는 연속적인 세대의 모델: 각 세대는 전직 교사. 그 결과 체계적 개선(+1-2% Top-1)이 이루어졌습니다. 모델의 크기를 늘립니다.
import torch
import torch.nn as nn
import copy
# ============================================================
# BORN AGAIN NETWORKS (BANs)
# ============================================================
def born_again_training(model_factory, train_loader, val_loader,
n_generations: int = 3,
temperature: float = 4.0,
n_epochs: int = 30,
device: str = "cuda"):
"""
Allena N generazioni con la stessa architettura.
Gen 1: training standard con CE loss.
Gen 2+: distillazione dalla generazione precedente.
Risultati tipici CIFAR-100:
Gen 1: 67.1% (baseline)
Gen 2: 70.8% (+3.7%)
Gen 3: 72.1% (+5.0%)
Ensemble 1+2+3: 74.8% (+7.7%)
"""
criterion_kd = DistillationLoss(temperature=temperature, alpha=0.7)
criterion_ce = nn.CrossEntropyLoss()
all_models = []
results = []
# === Generazione 1: training standard ===
print("Gen 1: training standard...")
gen1 = model_factory().to(device)
opt1 = torch.optim.SGD(gen1.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch1 = torch.optim.lr_scheduler.MultiStepLR(opt1, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
gen1.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
opt1.zero_grad()
criterion_ce(gen1(imgs), labels).backward()
opt1.step()
sch1.step()
acc1 = _evaluate(gen1, val_loader, device)
results.append(acc1)
all_models.append(gen1)
print(f" Gen 1 val acc: {acc1:.4f}")
teacher = copy.deepcopy(gen1).eval()
# === Generazioni successive con distillazione ===
for gen_idx in range(2, n_generations + 1):
print(f"Gen {gen_idx}: KD da gen {gen_idx-1}...")
student = model_factory().to(device)
opt = torch.optim.SGD(student.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
student.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
with torch.no_grad():
t_logits = teacher(imgs)
s_logits = student(imgs)
losses = criterion_kd(s_logits, t_logits, labels)
opt.zero_grad()
losses['total'].backward()
opt.step()
sch.step()
acc = _evaluate(student, val_loader, device)
results.append(acc)
all_models.append(student)
print(f" Gen {gen_idx} val acc: {acc:.4f}")
teacher = copy.deepcopy(student).eval()
# Ensemble di tutti i modelli (upper bound)
ensemble_acc = _ensemble_evaluate(all_models, val_loader, device)
print(f"\nEnsemble {n_generations} gen: {ensemble_acc:.4f}")
return results, all_models
def _evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
correct += (model(imgs).argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
def _ensemble_evaluate(models, loader, device):
"""Ensemble averaging delle predizioni."""
for m in models:
m.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
logits_sum = torch.stack([m(imgs) for m in models]).mean(0)
correct += (logits_sum.argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
일반적인 증류 결과(벤치마크 2024-2025)
| 작업 | 선생님 | 학생 | KD 없이 | KD와 함께 | 선생님 | 압축 |
|---|---|---|---|---|---|---|
| CIFAR-100 | ResNet-50 | MobileNetV3-S | 67.1% | 73.2% | 78.2% | 10x 매개변수 |
| 이미지넷 | ViT-L/16 | 데이티에스(DeiT-S) | 79.8% | 83.1% | 87.1% | 5x 매개변수 |
| 글루(NLP) | BERT-대형 | 디스틸버트 | 83.2% | 86.4% | 89.2% | 2x 매개변수, 2x 속도 |
| 스쿼드(QA) | 로버타-L | 디스틸로버타 | 82.1% | 85.8% | 90.4% | 2x 매개변수 |
| LLM (곤란함) | 라마 3.1 8B | 라마 3.2 1B | 8.24 PPL | 7.81 PPL | 6.12 PPL | 8x 매개변수 |
KD는 일반적으로 2~10배 더 적은 매개변수를 사용하여 학생과 교사 간의 격차의 70~85%를 복구합니다.
생산 파이프라인: 증류 + 정량
엣지 배포를 위한 가장 강력한 워크플로우는 증류와 양자화를 결합합니다. 순서: 먼저 KD로 학생을 생성하고(높은 정확도 유지) 양자화합니다. 학생(크기를 줄이고 속도를 높입니다). 조합으로 모델을 줄일 수 있습니다. 원래 교사에 비해 100배에서 40배까지 향상되었으며 정확도는 5~10%만 손실되었습니다.
import torch
import torch.nn as nn
from torchvision import models
# ============================================================
# PIPELINE COMPLETA: Distillazione -> Quantizzazione -> ONNX
# ============================================================
def full_compression_pipeline(teacher_path: str, output_dir: str = "./compressed"):
"""
Pipeline completa per comprimere un modello per edge deployment.
Step 1: Carica teacher pre-trainato
Step 2: Distilla in student più piccolo
Step 3: Quantizza lo student (PTQ INT8)
Step 4: Esporta in ONNX per deployment cross-platform
"""
import os
os.makedirs(output_dir, exist_ok=True)
# STEP 1: Teacher
print("Step 1: Carico teacher...")
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(2048, 10) # 10 classi
# In produzione: teacher.load_state_dict(torch.load(teacher_path))
teacher.eval()
teacher_size_mb = sum(p.numel() * p.element_size()
for p in teacher.parameters()) / (1024**2)
print(f" Teacher: {teacher_size_mb:.1f} MB, "
f"{sum(p.numel() for p in teacher.parameters())/1e6:.1f}M param")
# STEP 2: Student (dopo distillazione)
print("Step 2: Student con distillazione (simulato con MobileNetV3)...")
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, 10
)
# In produzione: train_with_distillation(teacher, student, ...)
# student.load_state_dict(torch.load("best_student.pth"))
student_size_mb = sum(p.numel() * p.element_size()
for p in student.parameters()) / (1024**2)
print(f" Student: {student_size_mb:.1f} MB, "
f"{sum(p.numel() for p in student.parameters())/1e6:.1f}M param")
print(f" Riduzione rispetto teacher: {teacher_size_mb/student_size_mb:.1f}x")
# STEP 3: Quantizzazione INT8 (PTQ)
print("Step 3: Quantizzazione INT8...")
student.eval()
# Quantizzazione dinamica (più semplice, leggermente meno efficiente)
student_quant = torch.quantization.quantize_dynamic(
student,
{nn.Linear}, # Quantizza solo Linear (Conv2d richiede calibrazione)
dtype=torch.qint8
)
# Per quantizzazione statica (più efficiente):
# student.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# torch.quantization.prepare(student, inplace=True)
# calibrate(student, calib_loader) # Forward pass su dati di calibrazione
# torch.quantization.convert(student, inplace=True)
quant_size_mb = sum(p.numel() * p.element_size()
for p in student_quant.parameters()) / (1024**2)
print(f" Student INT8: ~{student_size_mb/4:.1f} MB (stima)")
print(f" Riduzione totale: ~{teacher_size_mb/(student_size_mb/4):.0f}x vs teacher")
# STEP 4: Export ONNX
print("Step 4: Export ONNX...")
dummy = torch.randn(1, 3, 224, 224)
onnx_path = f"{output_dir}/student_compressed.onnx"
torch.onnx.export(
student, # Usa FP32 per ONNX (compatibilità più ampia)
dummy,
onnx_path,
opset_version=13,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}},
export_params=True
)
print(f"\n=== RIEPILOGO PIPELINE ===")
print(f"Teacher (ResNet-50): {teacher_size_mb:.1f} MB")
print(f"Student KD (MobNetV3): {student_size_mb:.1f} MB ({teacher_size_mb/student_size_mb:.1f}x riduzione)")
print(f"Student INT8 (stimato): {student_size_mb/4:.1f} MB ({teacher_size_mb/(student_size_mb/4):.0f}x riduzione)")
print(f"ONNX salvato: {onnx_path}")
return student_quant, onnx_path
full_compression_pipeline("teacher_weights.pth")
증류의 안티 패턴: 일반적인 실수
- 온도가 너무 높거나 너무 낮음: T=1은 하드 라벨과 같습니다. T가 너무 높으면(>20) 소프트 라벨이 거의 균일해져서 신호가 손실됩니다. 특정 데이터 세트에 대해 항상 T ∈ {2, 4, 6, 8}을 사용하여 절제 연구를 수행하십시오.
- 교사와 학생의 능력이 너무 다릅니다. 격차가 크다면 (GPT-4 ~ 7B), 직접 증류는 효과적이지 않습니다. 폭포 증류를 사용하십시오: GPT-4 -> 13B -> 7B -> 3B. 각 단계는 4~5배 감소를 초과해서는 안 됩니다.
- 증류 데이터 세트의 품질을 무시합니다. 품질 증류를 수행하는 데이터 세트는 큰 영향을 미칩니다. 다양한 데이터를 활용하고 목표 작업의 대표자. 배포되지 않은 데이터로 인해 전송이 손상되었습니다.
- 잘못 보정된 알파: alpha=1(KD만 해당)이면 학생은 다음을 무시합니다. 교사가 실수를 하면 지상 진실 라벨을 붙이고 불안정한 예측을 생성할 수 있습니다. alpha=0이면 표준 훈련으로 축소됩니다. 일반적으로 0.5-0.8 값이 최적입니다.
- 선생님을 얼리지 마세요: 교사는 eval() 모드에 있어야 합니다. 학생 훈련 중. 교사가 계속 바뀌는 경우(예: BatchNorm이 있음) 훈련 모드에서) 증류 목표는 일관되지 않으며 훈련은 다양합니다.
LLM을 위한 증류: 추측적 디코딩 및 응답 증류
대규모 언어 모델의 맥락에서 증류는 새롭고 강력한 형태를 취합니다. 2025~2026년에 특히 관련성이 높은 두 가지 기술은 반응 증류 (Llama-3.2 및 Microsoft의 Phi 모델을 훈련하는 데 사용됨) 및 lo 추측적 디코딩, 추론 속도를 높이기 위해 작은 모델을 활용합니다. 품질 손실 없이 대형 모델을 제작할 수 있습니다.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
# ============================================================
# SPECULATIVE DECODING: Draft Model + Target Model
# ============================================================
# Principio: un modello piccolo (draft) genera K token in anticipo.
# Il modello grande (target) verifica tutti i K token in un solo forward pass.
# Se il draft ha ragione, si risparmiano K-1 forward pass del modello grande.
# Speedup tipico: 2-4x senza perdita di qualità.
class SpeculativeDecoder:
"""
Implementazione base di speculative decoding.
Draft model: modello piccolo (es. Llama-3.2-1B)
Target model: modello grande (es. Llama-3.1-8B)
Riferimento: "Fast Inference from Transformers via Speculative Decoding"
(Leviathan et al., 2022) - il paper originale di Google.
"""
def __init__(
self,
draft_model_name: str,
target_model_name: str,
device: str = "cuda",
lookahead_k: int = 5 # Token generati dal draft per ogni step
):
self.device = device
self.lookahead_k = lookahead_k
print(f"Carico draft model: {draft_model_name}...")
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=torch.float16
).to(device).eval()
print(f"Carico target model: {target_model_name}...")
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=torch.float16
).to(device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
def draft_generate(self, input_ids: torch.Tensor) -> tuple:
"""
Il draft model genera K token e restituisce
le distribuzioni di probabilità per acceptance/rejection.
"""
draft_ids = input_ids.clone()
draft_probs = []
with torch.no_grad():
for _ in range(self.lookahead_k):
out = self.draft_model(draft_ids)
next_logits = out.logits[:, -1, :]
next_probs = F.softmax(next_logits, dim=-1)
# Campiona dal draft
next_token = torch.multinomial(next_probs, num_samples=1)
draft_probs.append(next_probs)
draft_ids = torch.cat([draft_ids, next_token], dim=1)
# draft_ids ora include K token aggiuntivi
draft_tokens = draft_ids[:, input_ids.shape[1]:]
return draft_tokens, draft_probs
def speculative_generate(
self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.7
) -> str:
"""Genera testo con speculative decoding."""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated = input_ids.clone()
total_accepted = 0
total_draft = 0
with torch.no_grad():
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# 1. Draft genera K token
draft_tokens, draft_probs = self.draft_generate(generated)
total_draft += self.lookahead_k
# 2. Target verifica tutti i K+1 token in un forward pass
full_seq = torch.cat([generated, draft_tokens], dim=1)
target_out = self.target_model(full_seq)
target_logits = target_out.logits[:, generated.shape[1]-1:-1, :]
# 3. Acceptance-rejection sampling
n_accepted = 0
for i in range(draft_tokens.shape[1]):
draft_tok = draft_tokens[0, i].item()
target_probs = F.softmax(target_logits[0, i] / temperature, dim=-1)
draft_p = draft_probs[i][0, draft_tok].item()
target_p = target_probs[draft_tok].item()
# Accetta se target d'accordo con draft
r = torch.rand(1).item()
if r < min(1.0, target_p / (draft_p + 1e-10)):
n_accepted += 1
else:
# Rifiuta: campiona dal target corretto
corrected = torch.multinomial(
F.relu(target_probs - draft_probs[i][0]),
num_samples=1
)
generated = torch.cat([
generated,
draft_tokens[:, :i],
corrected.unsqueeze(0)
], dim=1)
break
else:
# Tutti accettati: aggiungi bonus token dal target
generated = torch.cat([generated, draft_tokens], dim=1)
bonus_logits = target_out.logits[:, -1, :]
bonus_tok = torch.multinomial(
F.softmax(bonus_logits / temperature, dim=-1), 1
)
generated = torch.cat([generated, bonus_tok], dim=1)
total_accepted += n_accepted
acceptance_rate = total_accepted / max(total_draft, 1)
print(f"Acceptance rate: {acceptance_rate:.1%} (atteso 60-80% con draft simile)")
new_tokens = generated[0, input_ids.shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# ============================================================
# RESPONSE DISTILLATION per LLM (semplificata)
# ============================================================
# Tecnica usata da Llama-3.2, Phi-3, Mistral-7B-Instruct:
# 1. Teacher LLM grande (es. GPT-4, Llama-3.1-70B) genera risposte
# 2. Student LLM piccolo impara a imitare quelle risposte
# Diverso dalla distillazione classica: distilla risposte (output testo),
# non distribuzioni di probabilità interne.
def response_distillation_dataset(
teacher_model_name: str,
prompts: list,
output_file: str = "distillation_dataset.jsonl"
) -> list:
"""
Genera dataset di distillazione con risposte del teacher.
In produzione: usa GPT-4 API, Llama-3.1-70B, o Claude.
"""
import json
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
dataset = []
with open(output_file, 'w') as f:
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(teacher.device)
with torch.no_grad():
outputs = teacher.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response_ids = outputs[0, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
entry = {"prompt": prompt, "response": response}
dataset.append(entry)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"Dataset generato: {len(dataset)} esempi -> {output_file}")
return dataset
# Note: in pratica, usa l'API di un servizio commerciale (OpenAI, Anthropic)
# per generare le risposte del "teacher", poi addestra lo student su di esse.
# Questa e la tecnica dietro la maggior parte dei modelli instruction-following
# come Alpaca, Vicuna, Orca, e i modelli Phi di Microsoft.
print("LLM distillation patterns pronti")
증류: 변형 간 기술 비교(2024-2025)
| 변종 | 그것이 증류하는 것 | 찬성 | 에 맞서 | 일반적인 사용 |
|---|---|---|---|---|
| 소프트 라벨(Hinton 2015) | 확률 분포 | 풍부한 표준 정보 | 교사의 로짓에 대한 액세스가 필요합니다. | 비전, 분류 |
| 특징 증류 | 중간 표현 | 심층 기능 전송 | 교사와 학생은 호환 가능한 아키텍처를 가지고 있어야 합니다. | 탐지, 세분화 |
| 반응 증류 | 교사의 텍스트 출력 | 내부 액세스가 필요하지 않습니다. | 불확실성에 대한 정보를 잃음 | LLM 지시 따르기 |
| 다시 태어난 네트워크 | 반복적인 자기 증류 | 별도의 선생님이 필요하지 않습니다. | 제한된 이득, 높은 계산 비용 | 앙상블, 개선 |
| 추측적 디코딩 | 증류는 아니지만 드래프트를 사용한다. | 손실 없이 2~4배 속도 향상 | 메모리에 두 가지 모델이 필요합니다. | LLM 추론 가속 |
결론
Knowledge Distillation은 가장 강력하고 다양한 압축 기술 중 하나입니다. 2026년 출시. 양자화 및 가지치기와 자연스럽게 결합: 먼저 증류 최적의 학생을 생성한 다음 에지 배포를 위해 학생을 양자화합니다. 결과 종종 모델은 교사보다 10-100배 더 작고 정확도는 5-15%만 손실됩니다.
LLM의 경우 증류를 통해 전체 "Distil*" 모델 제품군인 DistilBERT, DistilGPT2 및 Microsoft의 Phi 모델(7B 모델 품질의 2.7B) 2026년의 트렌드 — SLM(Small Language Model)은 다음과 같은 사용 빈도에서 LLM보다 성능이 뛰어납니다. Gartner는 다음과 같은 지식을 전달하는 증류를 통해 정확하게 가능해졌습니다. Raspberry Pi와 스마트폰에서 실행되는 모델까지 다양합니다.
다음 문서에서는 이러한 압축 모델을 배포하는 방법을 보여줍니다. 엣지 디바이스: 필요한 모든 최적화 기능을 갖춘 Raspberry Pi, NVIDIA Jetson 및 임베디드 하드웨어 실제 제작 환경을 위해
다음 단계
- 다음 기사: 엣지 장치의 딥 러닝: 클라우드에서 엣지까지
- 관련된: 모델 양자화: GPTQ, AWQ, INT8
- 관련된: 신경망 가지치기: 매개변수 줄이기
- 관련된: Vision Transformer: DeiT를 이용한 증류
- MLOps 시리즈: 프로덕션에서 압축 모델 제공







