ViT(Vision Transformer): 아키텍처 및 실제 애플리케이션
2020년 Google 연구 논문은 컴퓨터 비전을 근본적으로 변화시켰습니다. "이미지는 가치가 있습니다. 16x16 단어". 직관은 단순하지만 혁신적이었습니다. Transformer 아키텍처를 적용하고, NLP에서 지배적이며 i를 처리하여 이미지에 직접 적용 패치 토큰으로 시각적입니다. 결과 그것은 비전 트랜스포머(ViT)몇 년 만에 CNN을 추월했습니다. ImageNet 및 기타 수십 개의 벤치마크에서 최고 수준을 기록하며 새로운 시대의 길을 열었습니다. 시각적 모델의 생성.
ViT의 약속은 단지 정확성이 아닙니다. 다재. 동일한 백본 텍스트에 사용되는 변환기를 이미지와 공유하여 템플릿을 활성화할 수 있습니다. CLIP, DALL-E 및 GPT-4V와 같은 다중 모드. ViT는 데이터와 컴퓨팅 측면에서 CNN보다 확장성이 뛰어납니다. 추가 및 변형 스윈 트랜스포머 e 디잇 그들이 만들었어 이러한 모델은 수백 개의 사전 교육 없이도 중간 규모의 데이터 세트에서도 효율적입니다. 수백만 개의 이미지.
이 가이드에서는 PyTorch에서 처음부터 ViT를 구축하고 아키텍처 변형을 탐색합니다. 가장 중요하며 특정 생산 작업에 맞게 미세 조정하는 방법을 보여줍니다.
무엇을 배울 것인가
- ViT 아키텍처: 패치 임베딩, 위치 인코딩, 시각적 주의
- PyTorch를 사용하여 처음부터 완전한 구현
- ViT-B/16, ViT-L/32, DeiT, Swin Transformer의 차이점
- 맞춤형 데이터 세트에 대해 사전 훈련된 ViT의 미세 조정
- ViT용 데이터 증대 기술(MixUp, CutMix, RandAugment)
- 주의 맵의 주의 롤아웃 및 해석 가능성
- 최적화된 배포: ONNX, TorchScript, 에지 장치
- 실제 데이터 세트에 대한 ViT 및 CNN 벤치마크
ViT 아키텍처: 작동 방식
Vision Transformer는 이미지를 입력으로 받아 이를 겹치지 않는 패치로 나눕니다.
고정 크기(일반적으로 16x16 또는 32x32 픽셀). 패치가 나올 때마다 단조롭게 하는
(평면화) 및 차원 벡터에 선형으로 투영됩니다. d_model (임베딩).
이러한 임베딩은 패치 임베딩, Transformer 토큰이 됩니다.
특별한 토큰 [CLS] (클래스 토큰)은 마찬가지로 시퀀스 앞에 붙습니다.
NLP에서 BERT로. 인코딩이 완료되면 CLS 토큰 표현이 다음으로 전달됩니다.
최종 예측을 생성하는 분류 헤드. 위치 인코딩 — 형태
사인 대신 학습됨 — 정보를 보존하기 위해 임베딩에 추가됨
그것 없이는 잃어버릴 공간.
# Diagramma architettura ViT
#
# Input Image (224x224x3)
# |
# v
# Patch Extraction: divide in 196 patch di 16x16
# (224/16 = 14 patch per lato -> 14*14 = 196 patch)
# |
# v
# Patch Embedding: ogni patch [768] via Linear projection
# + [CLS] token -> sequenza di 197 token
# |
# v
# + Positional Embedding (learnable, 197x768)
# |
# v
# Transformer Encoder (L strati):
# - LayerNorm
# - Multi-Head Self-Attention (h heads)
# - Residual connection
# - LayerNorm
# - MLP (d_model -> 4*d_model -> d_model)
# - Residual connection
# |
# v
# [CLS] token representation
# |
# v
# MLP Head -> num_classes output
# Varianti standard:
# ViT-B/16: d_model=768, L=12, h=12 | ~86M param
# ViT-L/16: d_model=1024, L=24, h=16 | ~307M param
# ViT-H/14: d_model=1280, L=32, h=16 | ~632M param
처음부터 ViT 구현
PyTorch에서 완전한 ViT를 구축해 보겠습니다. 기본 구성요소부터 시작해 보겠습니다. 패치 임베딩.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
# ============================================================
# 1. PATCH EMBEDDING
# ============================================================
class PatchEmbedding(nn.Module):
"""
Converte un'immagine in una sequenza di patch embedding.
Metodo 1: Convolution (efficiente, equivalente a patch+linear)
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Equivalente a: flatten ogni patch + proiezione lineare
# Ma implementato come Conv2d per efficienza
self.projection = nn.Sequential(
# Divide in patch e proietta
nn.Conv2d(in_channels, d_model,
kernel_size=patch_size, stride=patch_size),
# [B, d_model, H/P, W/P] -> [B, n_patches, d_model]
Rearrange('b d h w -> b (h w) d')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W]
return self.projection(x) # [B, n_patches, d_model]
# ============================================================
# 2. MULTI-HEAD SELF ATTENTION per ViT
# ============================================================
class ViTAttention(nn.Module):
"""Multi-head self-attention con dropout."""
def __init__(self, d_model=768, n_heads=12, attn_dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim ** -0.5
# QKV projection
self.qkv = nn.Linear(d_model, d_model * 3, bias=True)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
B, N, C = x.shape
# Calcola Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Ognuno: [B, heads, N, head_dim]
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
attn_weights = attn # Salva per attention rollout
attn = self.attn_drop(attn)
# Weighted sum + proiezione output
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, attn_weights
# ============================================================
# 3. TRANSFORMER ENCODER BLOCK
# ============================================================
class ViTBlock(nn.Module):
"""Singolo blocco Transformer per ViT."""
def __init__(self, d_model=768, n_heads=12, mlp_ratio=4.0,
dropout=0.0, attn_dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = ViTAttention(d_model, n_heads, attn_dropout)
self.norm2 = nn.LayerNorm(d_model)
mlp_dim = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor):
# Pre-norm + residual connection
attn_out, attn_weights = self.attn(self.norm1(x))
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x, attn_weights
# ============================================================
# 4. VISION TRANSFORMER COMPLETO
# ============================================================
class VisionTransformer(nn.Module):
"""
Vision Transformer (ViT) completo.
Paper: "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
num_classes: int = 1000,
d_model: int = 768,
depth: int = 12,
n_heads: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
attn_dropout: float = 0.0,
representation_size: int = None # Pre-logit layer (opzionale)
):
super().__init__()
self.num_classes = num_classes
self.d_model = d_model
# Patch + Position Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, d_model)
n_patches = self.patch_embed.n_patches
# Token CLS e positional embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
self.pos_drop = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
ViTBlock(d_model, n_heads, mlp_ratio, dropout, attn_dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(d_model)
# Classification head
if representation_size is not None:
self.pre_logits = nn.Sequential(
nn.Linear(d_model, representation_size),
nn.Tanh()
)
else:
self.pre_logits = nn.Identity()
self.head = nn.Linear(
representation_size if representation_size else d_model,
num_classes
)
# Inizializzazione pesi
self._init_weights()
def _init_weights(self):
"""Inizializzazione standard per ViT."""
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor, return_attn: bool = False):
B = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x) # [B, n_patches, d_model]
# 2. Prepend CLS token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
x = torch.cat([cls_tokens, x], dim=1) # [B, n_patches+1, d_model]
# 3. Add positional embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# 4. Transformer blocks
attn_weights_list = []
for block in self.blocks:
x, attn_weights = block(x)
attn_weights_list.append(attn_weights)
# 5. Layer norm finale
x = self.norm(x)
# 6. Usa solo il CLS token per classificazione
cls_output = x[:, 0]
cls_output = self.pre_logits(cls_output)
logits = self.head(cls_output)
if return_attn:
return logits, attn_weights_list
return logits
# ============================================================
# 5. CREAZIONE VARIANTI STANDARD
# ============================================================
def vit_base_16(num_classes=1000, **kwargs):
"""ViT-B/16: 86M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=768,
depth=12, n_heads=12, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_large_16(num_classes=1000, **kwargs):
"""ViT-L/16: 307M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=1024,
depth=24, n_heads=16, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_tiny_16(num_classes=1000, **kwargs):
"""ViT-Ti/16: ~6M parametri, per edge/mobile."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=192,
depth=12, n_heads=3, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
# Test
model = vit_base_16(num_classes=100)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}") # [2, 100]
print(f"Parametri: {sum(p.numel() for p in model.parameters()):,}")
# Parametri: 85,880,164
아키텍처 변형: DeiT, Swin 및 BEiT
원래 ViT는 막대한 양의 데이터(JFT-300M, 3억 개의 이미지)가 필요했습니다. CNN을 능가합니다. 이러한 제한으로 인해 보다 데이터 효율적인 변형이 개발되었습니다.
| 모델 | 년도 | 주요 혁신 | ImageNet 상위 1위 | 매개변수 |
|---|---|---|---|---|
| ViT-B/16 | 2020 | 첫 번째 ViT, JFT-300M 필요 | 81.8% | 86M |
| 데이티비(DeiT-B) | 2021 | CNN 교사의 증류, ImageNet에만 해당 | 83.1% | 87M |
| 스윈-B | 2021 | Shifted Window Attention, 계층적 | 85.2% | 88M |
| BEIT-L | 2022년 | 마스크된 이미지 모델링(비전용 BERT) | 87.4% | 307M |
| 데이트 III-H | 2022년 | 고급 훈련 레시피 | 87.7% | 632M |
| ViT-G (EVA) | 2023년 | 1B 매개변수로 확장, CLIP 사전 훈련 | 89.6% | 1.0B |
DeiT(데이터 효율적인 이미지 변환기) Facebook AI 및 아마도 변형 연습에 가장 중요한 것: 증류 토큰 배울 수 있게 해주는 것 RegNet 또는 ConvNext와 같은 CNN 교사로부터 ImageNet-1K만으로 탁월한 성능을 얻습니다.
스윈 트랜스포머 주의력의 2차 복잡도 문제를 해결합니다. 소개하는 ShiftedWindows: 관심은 로컬 창 내에서 계산됩니다. 전체 이미지에 대한 것이 아니라 이미지에 대한 선형 계산 비용이 발생합니다. 스윈 CNN과 같은 계층적 표현을 생성하며 탐지를 위해 선호되는 백본입니다. 그리고 세분화.
ViT 사전 훈련 미세 조정
프로덕션에서 ViT를 사용하는 가장 실용적인 방법은 사전 훈련된 모델에서 시작하는 것입니다. ImageNet-21K를 사용하고 데이터세트를 미세 조정하세요. Hugging Face Transformers는 다음과 같은 모든 기능을 제공합니다. 통일된 API를 갖춘 핵심 ViT 모델.
# pip install transformers timm torch torchvision datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
ViTForImageClassification, ViTImageProcessor,
AutoImageProcessor
)
from torchvision import datasets, transforms
import os
# ============================================================
# FINE-TUNING ViT-B/16 su Dataset Custom
# ============================================================
class ViTFineTuner(nn.Module):
"""
ViT pre-addestrato con classification head custom.
Supporta fine-tuning parziale o completo.
"""
def __init__(self, num_classes: int, model_name: str = "google/vit-base-patch16-224",
freeze_backbone: bool = False):
super().__init__()
# Carica ViT pre-addestrato da HuggingFace
self.vit = ViTForImageClassification.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True # Permette cambio num_classes
)
if freeze_backbone:
# Congela tutto tranne il classification head
for param in self.vit.vit.parameters():
param.requires_grad = False
# Solo il classifier rimane trainable
print(f"Parametri trainabili: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
def forward(self, x):
outputs = self.vit(pixel_values=x)
return outputs.logits
# ============================================================
# DATA AUGMENTATION per ViT
# ============================================================
def get_vit_transforms(img_size: int = 224, mode: str = "train"):
"""
Augmentation pipeline ottimizzata per ViT.
ViT beneficia molto da augmentation aggressiva.
"""
if mode == "train":
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9), # RandAugment
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=0.25) # CutOut/Erasing
])
else:
# Resize + center crop per validation/test
return transforms.Compose([
transforms.Resize(int(img_size * 1.143)), # 256 per 224
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ============================================================
# TRAINING LOOP CON WARMUP + COSINE DECAY
# ============================================================
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
"""LR schedule: linear warmup + cosine decay (standard per ViT)."""
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
def train_vit(
model, train_loader, val_loader,
num_epochs=30, base_lr=3e-5, weight_decay=0.05,
device="cuda", label_smoothing=0.1
):
model = model.to(device)
# AdamW con weight decay (standard per ViT)
# Escludi bias e LayerNorm dal weight decay
no_decay_params = []
decay_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'bias' in name or 'norm' in name or 'cls_token' in name or 'pos_embed' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = torch.optim.AdamW([
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=base_lr)
total_steps = len(train_loader) * num_epochs
warmup_steps = len(train_loader) * 5 # 5 epoch di warmup
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
# Label smoothing loss
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
best_acc = 0.0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, labels)
loss.backward()
# Gradient clipping (importante per ViT)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_loss += loss.item()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_loss = train_loss / len(train_loader)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {avg_loss:.4f} | "
f"Val Acc: {val_acc:.4f} | "
f"LR: {current_lr:.2e}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_vit.pth")
print(f" -> Nuovo best: {best_acc:.4f}")
return best_acc
MixUp 및 CutMix: ViT를 위한 고급 강화
ViT는 특히 기술의 이점을 얻습니다. 혼합 확대. 믹스업 이미지와 해당 레이블 쌍을 선형적으로 결합합니다. CutMix는 일부를 대체합니다. 한 이미지의 직사각형 부분과 다른 이미지의 해당 부분. 두 기술 모두 모델의 일반화 및 보정을 개선합니다.
import numpy as np
class MixUpCutMix:
"""
Combinazione di MixUp e CutMix come in DeiT e timm.
Applica randomicamente uno dei due metodi per ogni batch.
"""
def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0,
prob=0.5, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.prob = prob
self.num_classes = num_classes
def one_hot(self, labels: torch.Tensor) -> torch.Tensor:
return F.one_hot(labels, self.num_classes).float()
def mixup(self, imgs, labels_oh):
"""MixUp: interpolazione lineare."""
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
B = imgs.size(0)
idx = torch.randperm(B)
mixed_imgs = lam * imgs + (1 - lam) * imgs[idx]
mixed_labels = lam * labels_oh + (1 - lam) * labels_oh[idx]
return mixed_imgs, mixed_labels
def cutmix(self, imgs, labels_oh):
"""CutMix: ritaglia e incolla patch."""
lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
B, C, H, W = imgs.shape
idx = torch.randperm(B)
# Calcola dimensioni bounding box
cut_ratio = math.sqrt(1.0 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
# Centro casuale
cy = np.random.randint(H)
cx = np.random.randint(W)
y1 = max(0, cy - cut_h // 2)
y2 = min(H, cy + cut_h // 2)
x1 = max(0, cx - cut_w // 2)
x2 = min(W, cx + cut_w // 2)
# Applica CutMix
mixed_imgs = imgs.clone()
mixed_imgs[:, :, y1:y2, x1:x2] = imgs[idx, :, y1:y2, x1:x2]
# Ricalcola lambda effettivo
lam_actual = 1.0 - (y2 - y1) * (x2 - x1) / (H * W)
mixed_labels = lam_actual * labels_oh + (1 - lam_actual) * labels_oh[idx]
return mixed_imgs, mixed_labels
def __call__(self, imgs, labels):
labels_oh = self.one_hot(labels).to(imgs.device)
if np.random.random() < self.prob:
if np.random.random() < 0.5:
return self.mixup(imgs, labels_oh)
else:
return self.cutmix(imgs, labels_oh)
return imgs, labels_oh
# Uso nel training loop
mixup_cutmix = MixUpCutMix(num_classes=100)
# Nel training loop:
# imgs, soft_labels = mixup_cutmix(imgs, labels)
# loss = F.cross_entropy(logits, soft_labels) # Soft labels
주의 롤아웃: ViT가 보는 것을 시각화합니다.
ViT의 가장 흥미로운 기능 중 하나는 주의 지도 모델이 고려하는 이미지의 영역을 이해하기 위해 관련. 기술 주의 롤아웃 모든 계층에 주의를 전파합니다. 글로벌 관련성 지도를 얻기 위해.
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def compute_attention_rollout(attn_weights_list: list,
discard_ratio: float = 0.9) -> np.ndarray:
"""
Attention Rollout (Abnar & Zuidema, 2020).
Propaga le attention maps attraverso tutti i layer.
attn_weights_list: lista di tensori [B, heads, N, N]
discard_ratio: percentuale di attention da azzerare (focus sui top)
"""
n_layers = len(attn_weights_list)
# Media su tutte le teste
rollout = None
for attn in attn_weights_list:
# attn: [B, heads, N, N] -> media teste -> [B, N, N]
attn_mean = attn.mean(dim=1) # [B, N, N]
# Aggiunge identità (residual connection)
eye = torch.eye(attn_mean.size(-1), device=attn_mean.device)
attn_mean = attn_mean + eye
attn_mean = attn_mean / attn_mean.sum(dim=-1, keepdim=True)
if rollout is None:
rollout = attn_mean
else:
rollout = torch.bmm(attn_mean, rollout)
return rollout
def visualize_vit_attention(model, image_tensor: torch.Tensor,
patch_size: int = 16):
"""
Visualizza l'attention del CLS token sull'immagine.
"""
model.eval()
with torch.no_grad():
_, attn_list = model(image_tensor.unsqueeze(0), return_attn=True)
# Calcola rollout
rollout = compute_attention_rollout(attn_list) # [1, N+1, N+1]
# Attenzione del CLS verso tutti i patch
cls_attn = rollout[0, 0, 1:] # Escludi il CLS token stesso
# Ridimensiona alla griglia dei patch
H = W = int(math.sqrt(cls_attn.size(0)))
attn_map = cls_attn.reshape(H, W).cpu().numpy()
# Normalizza e upscale alla dimensione immagine
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
attn_map_upscaled = np.kron(attn_map, np.ones((patch_size, patch_size)))
return attn_map_upscaled
# Esempio di utilizzo e visualizzazione
# model = vit_base_16(num_classes=1000)
# img_tensor = get_vit_transforms(mode="val")(Image.open("dog.jpg"))
# attn_map = visualize_vit_attention(model, img_tensor)
#
# plt.figure(figsize=(12, 4))
# plt.subplot(1, 2, 1)
# plt.imshow(img_tensor.permute(1,2,0).numpy())
# plt.title("Immagine originale")
# plt.subplot(1, 2, 2)
# plt.imshow(attn_map, cmap='inferno')
# plt.title("Attention Rollout (CLS token)")
# plt.colorbar()
# plt.tight_layout()
# plt.savefig("vit_attention.png", dpi=150)
Swin Transformer: 계층적 창에 대한 주의
Il 스윈 트랜스포머 표준 ViT의 두 가지 기본 제한 사항을 해결합니다. 주의의 2차 복잡성(처리 가능한 해상도를 제한함)과 부재 계층적 표현(탐지 및 분할에 필요)
Swin은 이미지를 겹치지 않는 창으로 나누고 내부에서만 주의를 계산합니다. 각 창의 (선형 복잡도) 한 층과 다른 층 사이에 창문이 옵니다. 옮기다 인접한 창 사이의 통신을 허용합니다. 계층 구조 점차적으로 공간 해상도를 줄여 다음과 같은 4단계 특징 맵을 생성합니다. 전통적인 CNN.
# Uso di Swin Transformer tramite timm (più semplice che implementare da zero)
# pip install timm
import timm
import torch
# Crea Swin-T (Tiny): 28M param, 81.3% ImageNet Top-1
swin_tiny = timm.create_model(
'swin_tiny_patch4_window7_224',
pretrained=True,
num_classes=0 # 0 = rimuovi classifier (backbone solo)
)
# Swin-B (Base): 88M param, 85.2% ImageNet Top-1
swin_base = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
num_classes=100 # Custom classifier
)
# Swin-V2-L per alta risoluzione (resolution scaling)
swin_v2 = timm.create_model(
'swinv2_large_window12to16_192to256_22kft1k',
pretrained=True,
num_classes=10
)
# Verifica feature maps gerarchiche (per detection/segmentation)
swin_backbone = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
features_only=True, # Restituisce feature a 4 scale
out_indices=(0, 1, 2, 3)
)
x = torch.randn(2, 3, 224, 224)
features = swin_backbone(x)
for i, feat in enumerate(features):
print(f"Stage {i}: {feat.shape}")
# Stage 0: torch.Size([2, 192, 56, 56])
# Stage 1: torch.Size([2, 384, 28, 28])
# Stage 2: torch.Size([2, 768, 14, 14])
# Stage 3: torch.Size([2, 1536, 7, 7])
# Fine-tuning completo con timm
from timm.loss import SoftTargetCrossEntropy
from timm.data.mixup import Mixup
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
# Parametri ottimali per fine-tuning Swin
model = timm.create_model('swin_base_patch4_window7_224',
pretrained=True, num_classes=10)
# Optimizer con parametri specifici per Swin
optimizer = create_optimizer_v2(
model,
opt='adamw',
lr=5e-5,
weight_decay=0.05,
layer_decay=0.9 # Layer-wise LR decay: layer più profondi = LR più bassa
)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Swin output: {out.shape}") # [2, 10]
최적화된 배포: ONNX 및 TorchScript
프로덕션 배포의 경우 모델을 최적화된 형식으로 내보내는 것이 중요합니다. ONNX 프레임워크와 하드웨어별 최적화 간의 상호 운용성을 허용합니다. 토치스크립트 추론을 위한 Python 오버헤드를 제거합니다.
import torch
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np
import timm
# Modello ViT/Swin pre-addestrato e fine-tuned
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
model.load_state_dict(torch.load('best_vit.pth'))
model.eval()
# ============================================================
# EXPORT ONNX
# ============================================================
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"vit_model.onnx",
export_params=True,
opset_version=17, # ONNX opset 17 per operatori recenti
do_constant_folding=True, # Ottimizzazione grafo
input_names=['pixel_values'],
output_names=['logits'],
dynamic_axes={
'pixel_values': {0: 'batch_size'}, # Batch size dinamico
'logits': {0: 'batch_size'}
}
)
# Verifica modello ONNX
onnx_model = onnx.load("vit_model.onnx")
onnx.checker.check_model(onnx_model)
print("Modello ONNX valido!")
# Inferenza con ONNX Runtime (CPU o GPU)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
ort_session = ort.InferenceSession("vit_model.onnx", providers=providers)
# Test inferenza ONNX
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = ort_session.run(None, {'pixel_values': test_input})
print(f"ONNX output shape: {outputs[0].shape}")
# ============================================================
# TORCHSCRIPT (alternativa per deployment PyTorch)
# ============================================================
model_scripted = torch.jit.script(model)
model_scripted.save("vit_scripted.pt")
# Ricarica e usa
loaded = torch.jit.load("vit_scripted.pt")
with torch.no_grad():
out = loaded(dummy_input)
print(f"TorchScript output: {out.shape}")
# ============================================================
# BENCHMARK ONNX vs PyTorch
# ============================================================
import time
def benchmark(fn, n_runs=50, warmup=10):
for _ in range(warmup):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(n_runs):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = (time.perf_counter() - t0) / n_runs * 1000
return elapsed
# PyTorch
def pt_inference():
with torch.no_grad():
model(dummy_input)
# ONNX Runtime
def onnx_inference():
ort_session.run(None, {'pixel_values': test_input})
pt_ms = benchmark(pt_inference)
onnx_ms = benchmark(onnx_inference)
print(f"PyTorch: {pt_ms:.1f} ms/inference")
print(f"ONNX RT: {onnx_ms:.1f} ms/inference")
print(f"Speedup ONNX: {pt_ms/onnx_ms:.2f}x")
전문 작업을 위한 ViT: 의료, 위성 및 복합 모드
ViT는 다른 도메인과 매우 다른 도메인에서 탁월한 전송 용량을 보여주었습니다. ImageNet. 세 가지 특히 중요한 응용 분야는 다음과 같습니다. 컴퓨터 의료 시력 (방사선과, 수지병리과, 피부과), 원격 감지 (위성 이미지, 드론 이미지) 및 모델 다중 모드 (CLIP, SigLIP, LLaVA).
import timm
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
# ============================================================
# ViT PER IMAGING MEDICO (classificazione CXR)
# ============================================================
# Chest X-Ray classification con DeiT fine-tuned
class MedicalViT(nn.Module):
"""
ViT per classificazione immagini mediche.
Usa un backbone pre-addestrato su ImageNet + fine-tuning su CXR.
Considera: le immagini mediche sono spesso grayscale (convertite a 3ch)
e richiedono risoluzione maggiore (384px).
"""
def __init__(self, n_classes: int, model_name: str = "deit3_base_patch16_384",
dropout: float = 0.2):
super().__init__()
# DeiT3 a 384px: più accurato per dettagli fini nelle immagini mediche
self.backbone = timm.create_model(
model_name,
pretrained=True,
num_classes=0, # Rimuovi head originale
img_size=384
)
d_model = self.backbone.embed_dim
# Head medica con dropout aggressivo (evita overfit su dataset piccoli)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Dropout(dropout),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout / 2),
nn.Linear(d_model // 2, n_classes)
)
# Congela i primi 6 layer (feature basiche = ImageNet features)
# Fine-tuna solo i layer superiori (feature ad alto livello)
total_blocks = len(self.backbone.blocks)
freeze_until = total_blocks // 2
for i, block in enumerate(self.backbone.blocks):
if i < freeze_until:
for p in block.parameters():
p.requires_grad = False
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Parametri trainabili: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x) # [B, d_model] - CLS token
return self.head(features)
# Uso per NIH Chest X-Ray Dataset (14 classi, multi-label)
medical_vit = MedicalViT(n_classes=14, dropout=0.3)
x = torch.randn(4, 3, 384, 384) # 384px per CXR
out = medical_vit(x)
print(f"CXR prediction shape: {out.shape}") # [4, 14]
# ============================================================
# CLIP: VISION-LANGUAGE PRETRAINING
# ============================================================
# CLIP usa un ViT come encoder visuale accoppiato a un Transformer testuale.
# L'addestramento contrasto allinea rappresentazioni visive e testuali.
def clip_zero_shot_classification(
images: torch.Tensor,
class_descriptions: list, # ["a photo of a cat", "a photo of a dog", ...]
model_name: str = "openai/clip-vit-base-patch32"
):
"""
Zero-shot image classification con CLIP.
Non richiede esempi di training: usa descrizioni testuali delle classi.
"""
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()
# Codifica testi e immagini nello stesso spazio embedding
with torch.no_grad():
# Text embeddings
text_inputs = processor(text=class_descriptions, return_tensors="pt",
padding=True, truncation=True)
text_features = model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Image embeddings (usa ViT internamente)
image_inputs = processor(images=images, return_tensors="pt")
image_features = model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Similarità coseno: matrice [n_images, n_classes]
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
return similarity
# Esempio: classificazione zero-shot senza training
class_names = [
"a chest X-ray showing pneumonia",
"a normal chest X-ray",
"a chest X-ray showing cardiomegaly",
"a chest X-ray with pleural effusion"
]
# similarity = clip_zero_shot_classification(images, class_names)
# print(f"Predicted class: {class_names[similarity.argmax()]}")
print("ViT multimodale (CLIP) pronto per zero-shot classification")
엣지 장치를 위한 ViT 최적화
엣지 하드웨어에 ViT를 배포하려면 구체적인 전략이 필요합니다. 표준 ViT(86M+ 매개변수) Raspberry Pi나 마이크로컨트롤러에는 너무 무겁습니다. 다음과 같은 더 가벼운 변형 ViT-Ti/16 (6M 매개변수) e 모바일비트 (5M 매개변수)는 이 사용 사례를 위해 설계되었으며 주의력의 표현력과 컨볼루션의 효율성.
import timm
import torch
import torch.onnx
import time
import numpy as np
# ============================================================
# VARIANTI ViT LEGGERE PER EDGE
# ============================================================
edge_models = {
"vit_tiny_patch16_224": "ViT-Ti (6M, ~4ms GPU)",
"deit_tiny_patch16_224": "DeiT-Ti (5.7M, ~3ms GPU)",
"mobilevit_s": "MobileViT-S (5.6M, 4ms, ottimo CPU)",
"efficientvit_m0": "EfficientViT-M0 (2.4M, ultra-light)",
"fastvit_t8": "FastViT-T8 (4M, 3x più veloce di DeiT)",
}
def benchmark_edge_models(input_size=(1, 3, 224, 224), device="cpu", n_runs=50):
"""
Benchmark dei modelli ViT leggeri su CPU (simula edge device).
CPU benchmark e più rappresentativo di deployment su RPi/Jetson Nano.
"""
results = []
x = torch.randn(*input_size).to(device)
for model_name, description in edge_models.items():
try:
model = timm.create_model(model_name, pretrained=False, num_classes=10)
model = model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
model_size_mb = n_params * 4 / (1024**2)
# Warmup
with torch.no_grad():
for _ in range(5):
model(x)
# Benchmark
t0 = time.perf_counter()
with torch.no_grad():
for _ in range(n_runs):
model(x)
latency_ms = (time.perf_counter() - t0) / n_runs * 1000
results.append({
"model": model_name,
"desc": description,
"params_M": n_params / 1e6,
"size_mb": model_size_mb,
"latency_ms": latency_ms
})
print(f"{model_name:<35} {n_params/1e6:>5.1f}M "
f"{model_size_mb:>6.1f}MB {latency_ms:>8.1f}ms")
except Exception as e:
print(f"{model_name}: Errore - {e}")
return results
# ============================================================
# EXPORT OTTIMIZZATO PER EDGE
# ============================================================
def export_vit_for_edge(model_name: str = "vit_tiny_patch16_224",
n_classes: int = 10):
"""
Pipeline completa: carica ViT-Ti, quantizza e esporta per edge.
"""
model = timm.create_model(model_name, pretrained=False, num_classes=n_classes)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# 1. Export ONNX con opset 17
torch.onnx.export(
model, dummy_input, f"{model_name}_edge.onnx",
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}}
)
# 2. Quantizzazione dinamica INT8 (per CPU edge)
import torch.quantization
model_quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.MultiheadAttention}, dtype=torch.qint8
)
# Confronto dimensioni
original_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
print(f"Modello originale FP32: {original_size:.1f} MB")
# Salva versione quantizzata
torch.save(model_quantized.state_dict(), f"{model_name}_int8.pt")
# Test latenza quantizzata su CPU
with torch.no_grad():
for _ in range(5): model_quantized(dummy_input) # warmup
t0 = time.perf_counter()
for _ in range(50): model_quantized(dummy_input)
lat_quant = (time.perf_counter() - t0) / 50 * 1000
print(f"Latenza INT8 CPU: {lat_quant:.1f}ms")
return model_quantized
print("ViT edge export pipeline pronto")
일반 작업에 대한 ViT 및 CNN 벤치마크(2025)
| 모델 | ImageNet 상위 1위 | 지연 시간(밀리초) | 처리량(img/초) | 매개변수 |
|---|---|---|---|---|
| ResNet-50 | 76.1% | 4.1ms | 1,240 | 25M |
| ConvNeXt-T | 82.1% | 5.5ms | 960 | 29M |
| 데이티비(DeiT-B) | 83.1% | 9.2ms | 570 | 87M |
| 스윈-T | 81.3% | 6.8ms | 740 | 28M |
| ViT-B/16 (timm) | 85.5% | 11.4ms | 460 | 86M |
| EfficientNet-B4 | 83.0% | 7.3ms | 690 | 19M |
RTX 4090, 배치 크기 32, FP16에서 측정되었습니다. 지연 시간 = 단일 이미지, 배치=1.
경고: ViT가 항상 최선의 선택은 아닙니다.
- 소규모 데이터세트(이미지 10,000개 미만): CNN 또는 EfficientNet은 대규모 사전 훈련 없이도 더 나은 성능을 발휘합니다. ViT가 올바르게 수렴하려면 많은 데이터가 필요합니다.
- 엣지에서의 실시간 작업: ViT-Ti/16은 GPU에서는 ~4ms의 지연 시간을 갖지만 CPU에서는 100ms를 초과합니다. CPU 배포에는 MobileNet 또는 EfficientNet-Lite가 선호됩니다.
- CPU에서 개체 감지: 스윈과 훌륭한 백본이지만 무겁습니다. 경량 백본을 갖춘 YOLO는 지연 시간에서 CPU의 Swin을 능가합니다.
- 극단적인 도메인 이동 데이터 세트를 사용한 미세 조정: 사전 훈련된 CNN은 대상 데이터 세트가 ImageNet과 매우 다른 경우 더 잘 일반화할 수 있습니다.
프로덕션에서 ViT를 위한 모범 사례
ViT 배포를 위한 체크리스트
- 올바른 변형을 선택하십시오: 제한된 자원을 위한 ViT-Ti/S, 표준 품질을 위한 ViT-B, 탐지/분할을 위한 Swin-T/S, ImageNet 규모에서 처음부터 훈련을 위한 DeiT-B.
- ImageNet-21K 사전 훈련: 항상 ImageNet-1K가 아닌 ImageNet-21K 가중치에서 시작됩니다. 특히 작은 데이터 세트의 경우 정확도가 크게 향상됩니다.
- 미세 조정을 위한 낮은 학습률: ViT-B에는 기본 LR 3e-5를 사용하고, 최소 5세대의 워밍업을 사용합니다. LR이 너무 높으면 사전 훈련된 표현이 파괴됩니다.
- 입력 해상도: 224px로 사전 훈련된 ViT가 가장 잘 작동합니다. 224px 입력으로. 384px로 미세 조정하면 정확도가 향상되지만 메모리 비용이 2.3배 늘어납니다.
- 배치 크기 및 기울기 누적: ViT는 대규모 배치 크기로 인한 이점을 얻습니다. (256-2048). VRAM이 충분하지 않은 경우 그라데이션 누적을 사용하십시오.
-
혼합 정밀도(BF16/FP16): 항상 활성화
torch.autocast. ViT는 정확도 손실 없이 2배의 속도 향상을 얻습니다. -
플래시 주의: 미국
torch.nn.functional.scaled_dot_product_attention(PyTorch 2.0+) 또는flash-attn주의 기억력을 40% 감소시킵니다.
결론
Vision Transformers는 컴퓨터 비전 환경을 재정의했습니다. 2026년에는 이분법이 ViT 대 CNN 및 대부분 구식: 하이브리드 아키텍처(ConvNeXt, CoAtNet, FastViT) 결합 EVA 및 SigLIP과 같은 순수 ViT가 대규모 벤치마크를 지배하는 반면 두 세계의 최고입니다.
실습을 위한 최적의 명확한 워크플로우: 대규모 데이터세트에 대해 사전 훈련된 백본 선택 (ImageNet-21K, LAION), 공격적인 강화로 미세 조정(MixUp, CutMix, RandAugment) LR 워밍업을 수행한 다음 최적화된 배포를 위해 ONNX로 내보냅니다. CNN과의 차이점 이는 단지 정량적인 것이 아닙니다. 글로벌 관심 기능을 통해 ViT는 복잡한 시각적 이해 작업에 중요한 이미지의 장거리 관계.
시리즈의 다음 단계는 신경망 아키텍처 검색(NAS): 어떻게 주어진 작업 및 계산 예산에 대한 최적의 아키텍처 선택을 자동화합니다. ViT, CNN 및 하이브리드 변형 간의 수동 선택을 넘어선 것입니다.
다음 단계
- 다음 기사: 신경망 아키텍처 검색: 딥 러닝을 위한 AutoML
- 관련된: LoRA 및 QLoRA를 통한 미세 조정
- 컴퓨터 비전 시리즈: Swin Transformer를 사용한 객체 감지
- MLOps 시리즈: 프로덕션에서 비전 모델 제공







