ビジョン トランスフォーマー (ViT): アーキテクチャと実用的なアプリケーション
2020 年、Google Research の論文がコンピューター ビジョンを根本的に変えました。「画像には価値がある」 16x16 ワード」。直感はシンプルですが革命的でした — Transformer アーキテクチャを適用し、 NLP で主流であり、i を処理することで画像に直接反映されます。 パッチ トークンとしての視覚的。結果 それは ビジョントランスフォーマー (ViT)、数年以内に CNN を追い越しました ImageNet やその他多数のベンチマークに関する最先端の技術を活用し、新しいベンチマークへの道を切り開く ビジュアルモデルの生成。
ViT が約束するのは正確さだけではありません。 多用途性。同じバックボーン テキストに使用されるトランスフォーマーは画像と共有でき、テンプレートが可能になります CLIP、DALL-E、GPT-4V などのマルチモード。 ViT はデータとコンピューティングにより CNN よりも拡張性が優れています 追加および次のようなバリエーション スイングトランス e ディート 彼らが作った これらのモデルは、数百ものデータセットで事前トレーニングを行わなくても、中規模のデータセットでも効率的です。 何百万もの画像の中から。
このガイドでは、PyTorch で ViT をゼロから構築し、アーキテクチャのバリエーションを探索します。 最も重要であり、特定の運用タスクに合わせて微調整する方法を示します。
何を学ぶか
- ViT アーキテクチャ: パッチの埋め込み、位置エンコーディング、視覚的注意
- PyTorch を使用してゼロから実装を完了する
- ViT-B/16、ViT-L/32、DeiT、Swin Transformerの違い
- カスタム データセットで事前トレーニングされた ViT の微調整
- ViT のデータ拡張技術 (MixUp、CutMix、RandAugment)
- アテンションのロールアウトとアテンション マップの解釈可能性
- 最適化された展開: ONNX、TorchScript、エッジ デバイス
- 実際のデータセットでの ViT と CNN のベンチマーク
ViT アーキテクチャ: その仕組み
Vision Transformer は画像を入力として受け取り、それを重複しないパッチに分割します。
固定サイズ (通常は 16x16 または 32x32 ピクセル)。すべてのパッチが来る 平らになった
(平坦化)、次元ベクトルに線形投影されます。 d_model (埋め込み)。
これらの埋め込みは、 パッチの埋め込み、トランスフォーマートークンになります。
特別なトークン [CLS] (クラストークン) は同様にシーケンスの先頭に付けられます
NLP の BERT に。エンコードが完了すると、CLS トークン表現が に渡されます。
最終的な予測を生成する分類ヘッド。位置エンコーディング - 形状内
サインの代わりに学習 — 情報を保存するために埋め込みに追加されます
それがなければ失われるであろう空間。
# Diagramma architettura ViT
#
# Input Image (224x224x3)
# |
# v
# Patch Extraction: divide in 196 patch di 16x16
# (224/16 = 14 patch per lato -> 14*14 = 196 patch)
# |
# v
# Patch Embedding: ogni patch [768] via Linear projection
# + [CLS] token -> sequenza di 197 token
# |
# v
# + Positional Embedding (learnable, 197x768)
# |
# v
# Transformer Encoder (L strati):
# - LayerNorm
# - Multi-Head Self-Attention (h heads)
# - Residual connection
# - LayerNorm
# - MLP (d_model -> 4*d_model -> d_model)
# - Residual connection
# |
# v
# [CLS] token representation
# |
# v
# MLP Head -> num_classes output
# Varianti standard:
# ViT-B/16: d_model=768, L=12, h=12 | ~86M param
# ViT-L/16: d_model=1024, L=24, h=16 | ~307M param
# ViT-H/14: d_model=1280, L=32, h=16 | ~632M param
ViT をゼロから実装
PyTorch で完全な ViT を構築しましょう。基本的なコンポーネントから始めましょう。 パッチの埋め込み.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
# ============================================================
# 1. PATCH EMBEDDING
# ============================================================
class PatchEmbedding(nn.Module):
"""
Converte un'immagine in una sequenza di patch embedding.
Metodo 1: Convolution (efficiente, equivalente a patch+linear)
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Equivalente a: flatten ogni patch + proiezione lineare
# Ma implementato come Conv2d per efficienza
self.projection = nn.Sequential(
# Divide in patch e proietta
nn.Conv2d(in_channels, d_model,
kernel_size=patch_size, stride=patch_size),
# [B, d_model, H/P, W/P] -> [B, n_patches, d_model]
Rearrange('b d h w -> b (h w) d')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W]
return self.projection(x) # [B, n_patches, d_model]
# ============================================================
# 2. MULTI-HEAD SELF ATTENTION per ViT
# ============================================================
class ViTAttention(nn.Module):
"""Multi-head self-attention con dropout."""
def __init__(self, d_model=768, n_heads=12, attn_dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim ** -0.5
# QKV projection
self.qkv = nn.Linear(d_model, d_model * 3, bias=True)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
B, N, C = x.shape
# Calcola Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Ognuno: [B, heads, N, head_dim]
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
attn_weights = attn # Salva per attention rollout
attn = self.attn_drop(attn)
# Weighted sum + proiezione output
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, attn_weights
# ============================================================
# 3. TRANSFORMER ENCODER BLOCK
# ============================================================
class ViTBlock(nn.Module):
"""Singolo blocco Transformer per ViT."""
def __init__(self, d_model=768, n_heads=12, mlp_ratio=4.0,
dropout=0.0, attn_dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = ViTAttention(d_model, n_heads, attn_dropout)
self.norm2 = nn.LayerNorm(d_model)
mlp_dim = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor):
# Pre-norm + residual connection
attn_out, attn_weights = self.attn(self.norm1(x))
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x, attn_weights
# ============================================================
# 4. VISION TRANSFORMER COMPLETO
# ============================================================
class VisionTransformer(nn.Module):
"""
Vision Transformer (ViT) completo.
Paper: "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
num_classes: int = 1000,
d_model: int = 768,
depth: int = 12,
n_heads: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
attn_dropout: float = 0.0,
representation_size: int = None # Pre-logit layer (opzionale)
):
super().__init__()
self.num_classes = num_classes
self.d_model = d_model
# Patch + Position Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, d_model)
n_patches = self.patch_embed.n_patches
# Token CLS e positional embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
self.pos_drop = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
ViTBlock(d_model, n_heads, mlp_ratio, dropout, attn_dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(d_model)
# Classification head
if representation_size is not None:
self.pre_logits = nn.Sequential(
nn.Linear(d_model, representation_size),
nn.Tanh()
)
else:
self.pre_logits = nn.Identity()
self.head = nn.Linear(
representation_size if representation_size else d_model,
num_classes
)
# Inizializzazione pesi
self._init_weights()
def _init_weights(self):
"""Inizializzazione standard per ViT."""
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor, return_attn: bool = False):
B = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x) # [B, n_patches, d_model]
# 2. Prepend CLS token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
x = torch.cat([cls_tokens, x], dim=1) # [B, n_patches+1, d_model]
# 3. Add positional embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# 4. Transformer blocks
attn_weights_list = []
for block in self.blocks:
x, attn_weights = block(x)
attn_weights_list.append(attn_weights)
# 5. Layer norm finale
x = self.norm(x)
# 6. Usa solo il CLS token per classificazione
cls_output = x[:, 0]
cls_output = self.pre_logits(cls_output)
logits = self.head(cls_output)
if return_attn:
return logits, attn_weights_list
return logits
# ============================================================
# 5. CREAZIONE VARIANTI STANDARD
# ============================================================
def vit_base_16(num_classes=1000, **kwargs):
"""ViT-B/16: 86M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=768,
depth=12, n_heads=12, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_large_16(num_classes=1000, **kwargs):
"""ViT-L/16: 307M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=1024,
depth=24, n_heads=16, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_tiny_16(num_classes=1000, **kwargs):
"""ViT-Ti/16: ~6M parametri, per edge/mobile."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=192,
depth=12, n_heads=3, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
# Test
model = vit_base_16(num_classes=100)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}") # [2, 100]
print(f"Parametri: {sum(p.numel() for p in model.parameters()):,}")
# Parametri: 85,880,164
アーキテクチャのバリエーション: DeiT、Swin、BEiT
オリジナルの ViT では、大量のデータ (JFT-300M、3 億枚の画像) が必要でした。 CNNを超える。この制限により、よりデータ効率の高いバリアントの開発が促進されました。
| モデル | Anno | 主要なイノベーション | イメージネットトップ1 | パラメータ |
|---|---|---|---|---|
| ViT-B/16 | 2020年 | 初めての ViT、JFT-300M が必要 | 81.8% | 86M |
| DeiT-B | 2021年 | CNN 教師からの抽出、ImageNet のみ | 83.1% | 87M |
| スイングB | 2021年 | シフト ウィンドウ アテンション、階層型 | 85.2% | 88M |
| BEiT-L | 2022年 | マスクされた画像モデリング (視覚用 BERT) | 87.4% | 307M |
| ディートIII-H | 2022年 | 上級トレーニングレシピ | 87.7% | 632M |
| ViT-G(エヴァ) | 2023年 | 1B パラメータにスケール、CLIP 事前トレーニング | 89.6% | 1.0B |
DeiT (データ効率の高い画像変換器) Facebook AI とおそらくその亜種 実践にとって最も重要: を紹介します。 蒸留トークン それはあなたが学ぶことを可能にします CNN 教師 (RegNet や ConvNext など) から提供され、ImageNet-1K のみで優れたパフォーマンスが得られます。
スイングトランス 注意の二次複雑性問題を解決します を紹介する シフトウィンドウ: アテンションはローカル ウィンドウ内で計算されます 画像全体ではなく、画像に関して線形の計算コストがかかります。スウィン 階層表現 (CNN など) を生成し、検出に推奨されるバックボーンです そしてセグメンテーション。
事前トレーニング済みの ViT の微調整
実稼働環境で ViT を使用する最も実用的な方法は、事前トレーニングされたモデルから開始することです。 ImageNet-21K を使用して、データセットを微調整します。ハグフェイストランスフォーマーは、すべての機能を提供します。 統一 API を備えたコア ViT モデル。
# pip install transformers timm torch torchvision datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
ViTForImageClassification, ViTImageProcessor,
AutoImageProcessor
)
from torchvision import datasets, transforms
import os
# ============================================================
# FINE-TUNING ViT-B/16 su Dataset Custom
# ============================================================
class ViTFineTuner(nn.Module):
"""
ViT pre-addestrato con classification head custom.
Supporta fine-tuning parziale o completo.
"""
def __init__(self, num_classes: int, model_name: str = "google/vit-base-patch16-224",
freeze_backbone: bool = False):
super().__init__()
# Carica ViT pre-addestrato da HuggingFace
self.vit = ViTForImageClassification.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True # Permette cambio num_classes
)
if freeze_backbone:
# Congela tutto tranne il classification head
for param in self.vit.vit.parameters():
param.requires_grad = False
# Solo il classifier rimane trainable
print(f"Parametri trainabili: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
def forward(self, x):
outputs = self.vit(pixel_values=x)
return outputs.logits
# ============================================================
# DATA AUGMENTATION per ViT
# ============================================================
def get_vit_transforms(img_size: int = 224, mode: str = "train"):
"""
Augmentation pipeline ottimizzata per ViT.
ViT beneficia molto da augmentation aggressiva.
"""
if mode == "train":
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9), # RandAugment
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=0.25) # CutOut/Erasing
])
else:
# Resize + center crop per validation/test
return transforms.Compose([
transforms.Resize(int(img_size * 1.143)), # 256 per 224
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ============================================================
# TRAINING LOOP CON WARMUP + COSINE DECAY
# ============================================================
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
"""LR schedule: linear warmup + cosine decay (standard per ViT)."""
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
def train_vit(
model, train_loader, val_loader,
num_epochs=30, base_lr=3e-5, weight_decay=0.05,
device="cuda", label_smoothing=0.1
):
model = model.to(device)
# AdamW con weight decay (standard per ViT)
# Escludi bias e LayerNorm dal weight decay
no_decay_params = []
decay_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'bias' in name or 'norm' in name or 'cls_token' in name or 'pos_embed' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = torch.optim.AdamW([
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=base_lr)
total_steps = len(train_loader) * num_epochs
warmup_steps = len(train_loader) * 5 # 5 epoch di warmup
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
# Label smoothing loss
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
best_acc = 0.0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, labels)
loss.backward()
# Gradient clipping (importante per ViT)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_loss += loss.item()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_loss = train_loss / len(train_loader)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {avg_loss:.4f} | "
f"Val Acc: {val_acc:.4f} | "
f"LR: {current_lr:.2e}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_vit.pth")
print(f" -> Nuovo best: {best_acc:.4f}")
return best_acc
MixUp と CutMix: ViT の高度な拡張機能
ViT は特にテクニックの恩恵を受けます ミックスアップオーグメンテーション。ミックスアップ 画像のペアとそのラベルを線形的に結合します。 CutMix は一部を置き換えます ある画像の長方形部分と別の画像の対応する部分。どちらのテクニックも モデルの一般化と調整を改善します。
import numpy as np
class MixUpCutMix:
"""
Combinazione di MixUp e CutMix come in DeiT e timm.
Applica randomicamente uno dei due metodi per ogni batch.
"""
def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0,
prob=0.5, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.prob = prob
self.num_classes = num_classes
def one_hot(self, labels: torch.Tensor) -> torch.Tensor:
return F.one_hot(labels, self.num_classes).float()
def mixup(self, imgs, labels_oh):
"""MixUp: interpolazione lineare."""
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
B = imgs.size(0)
idx = torch.randperm(B)
mixed_imgs = lam * imgs + (1 - lam) * imgs[idx]
mixed_labels = lam * labels_oh + (1 - lam) * labels_oh[idx]
return mixed_imgs, mixed_labels
def cutmix(self, imgs, labels_oh):
"""CutMix: ritaglia e incolla patch."""
lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
B, C, H, W = imgs.shape
idx = torch.randperm(B)
# Calcola dimensioni bounding box
cut_ratio = math.sqrt(1.0 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
# Centro casuale
cy = np.random.randint(H)
cx = np.random.randint(W)
y1 = max(0, cy - cut_h // 2)
y2 = min(H, cy + cut_h // 2)
x1 = max(0, cx - cut_w // 2)
x2 = min(W, cx + cut_w // 2)
# Applica CutMix
mixed_imgs = imgs.clone()
mixed_imgs[:, :, y1:y2, x1:x2] = imgs[idx, :, y1:y2, x1:x2]
# Ricalcola lambda effettivo
lam_actual = 1.0 - (y2 - y1) * (x2 - x1) / (H * W)
mixed_labels = lam_actual * labels_oh + (1 - lam_actual) * labels_oh[idx]
return mixed_imgs, mixed_labels
def __call__(self, imgs, labels):
labels_oh = self.one_hot(labels).to(imgs.device)
if np.random.random() < self.prob:
if np.random.random() < 0.5:
return self.mixup(imgs, labels_oh)
else:
return self.cutmix(imgs, labels_oh)
return imgs, labels_oh
# Uso nel training loop
mixup_cutmix = MixUpCutMix(num_classes=100)
# Nel training loop:
# imgs, soft_labels = mixup_cutmix(imgs, labels)
# loss = F.cross_entropy(logits, soft_labels) # Soft labels
アテンション ロールアウト: ViT が見ているものを視覚化する
ViT の最も興味深い機能の 1 つは、 アテンションマップ モデルが画像のどの領域を考慮しているかを理解するため 関連性のある。テクニック アテンションのロールアウト すべての層に注意を伝播します 世界的な関連性のある地図を取得します。
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def compute_attention_rollout(attn_weights_list: list,
discard_ratio: float = 0.9) -> np.ndarray:
"""
Attention Rollout (Abnar & Zuidema, 2020).
Propaga le attention maps attraverso tutti i layer.
attn_weights_list: lista di tensori [B, heads, N, N]
discard_ratio: percentuale di attention da azzerare (focus sui top)
"""
n_layers = len(attn_weights_list)
# Media su tutte le teste
rollout = None
for attn in attn_weights_list:
# attn: [B, heads, N, N] -> media teste -> [B, N, N]
attn_mean = attn.mean(dim=1) # [B, N, N]
# Aggiunge identità (residual connection)
eye = torch.eye(attn_mean.size(-1), device=attn_mean.device)
attn_mean = attn_mean + eye
attn_mean = attn_mean / attn_mean.sum(dim=-1, keepdim=True)
if rollout is None:
rollout = attn_mean
else:
rollout = torch.bmm(attn_mean, rollout)
return rollout
def visualize_vit_attention(model, image_tensor: torch.Tensor,
patch_size: int = 16):
"""
Visualizza l'attention del CLS token sull'immagine.
"""
model.eval()
with torch.no_grad():
_, attn_list = model(image_tensor.unsqueeze(0), return_attn=True)
# Calcola rollout
rollout = compute_attention_rollout(attn_list) # [1, N+1, N+1]
# Attenzione del CLS verso tutti i patch
cls_attn = rollout[0, 0, 1:] # Escludi il CLS token stesso
# Ridimensiona alla griglia dei patch
H = W = int(math.sqrt(cls_attn.size(0)))
attn_map = cls_attn.reshape(H, W).cpu().numpy()
# Normalizza e upscale alla dimensione immagine
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
attn_map_upscaled = np.kron(attn_map, np.ones((patch_size, patch_size)))
return attn_map_upscaled
# Esempio di utilizzo e visualizzazione
# model = vit_base_16(num_classes=1000)
# img_tensor = get_vit_transforms(mode="val")(Image.open("dog.jpg"))
# attn_map = visualize_vit_attention(model, img_tensor)
#
# plt.figure(figsize=(12, 4))
# plt.subplot(1, 2, 1)
# plt.imshow(img_tensor.permute(1,2,0).numpy())
# plt.title("Immagine originale")
# plt.subplot(1, 2, 2)
# plt.imshow(attn_map, cmap='inferno')
# plt.title("Attention Rollout (CLS token)")
# plt.colorbar()
# plt.tight_layout()
# plt.savefig("vit_attention.png", dpi=150)
Swin Transformer: 階層ウィンドウへの注意
Il スイングトランス 標準 ViT の 2 つの基本的な制限に対処します。 注意の二次複雑さ(処理可能な解像度が制限される)と不在 階層表現(検出とセグメンテーションに必要)。
Swin は画像を重ならないウィンドウに分割し、その中でのみ注意力を計算します。 各ウィンドウの (線形複雑さ)。ある層と別の層の間に窓が現れる シフト 隣接するウィンドウ間の通信を可能にします。階層構造 空間解像度を段階的に下げて、次のような 4 スケールの特徴マップを生成します。 従来の CNN。
# Uso di Swin Transformer tramite timm (più semplice che implementare da zero)
# pip install timm
import timm
import torch
# Crea Swin-T (Tiny): 28M param, 81.3% ImageNet Top-1
swin_tiny = timm.create_model(
'swin_tiny_patch4_window7_224',
pretrained=True,
num_classes=0 # 0 = rimuovi classifier (backbone solo)
)
# Swin-B (Base): 88M param, 85.2% ImageNet Top-1
swin_base = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
num_classes=100 # Custom classifier
)
# Swin-V2-L per alta risoluzione (resolution scaling)
swin_v2 = timm.create_model(
'swinv2_large_window12to16_192to256_22kft1k',
pretrained=True,
num_classes=10
)
# Verifica feature maps gerarchiche (per detection/segmentation)
swin_backbone = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
features_only=True, # Restituisce feature a 4 scale
out_indices=(0, 1, 2, 3)
)
x = torch.randn(2, 3, 224, 224)
features = swin_backbone(x)
for i, feat in enumerate(features):
print(f"Stage {i}: {feat.shape}")
# Stage 0: torch.Size([2, 192, 56, 56])
# Stage 1: torch.Size([2, 384, 28, 28])
# Stage 2: torch.Size([2, 768, 14, 14])
# Stage 3: torch.Size([2, 1536, 7, 7])
# Fine-tuning completo con timm
from timm.loss import SoftTargetCrossEntropy
from timm.data.mixup import Mixup
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
# Parametri ottimali per fine-tuning Swin
model = timm.create_model('swin_base_patch4_window7_224',
pretrained=True, num_classes=10)
# Optimizer con parametri specifici per Swin
optimizer = create_optimizer_v2(
model,
opt='adamw',
lr=5e-5,
weight_decay=0.05,
layer_decay=0.9 # Layer-wise LR decay: layer più profondi = LR più bassa
)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Swin output: {out.shape}") # [2, 10]
最適化された展開: ONNX と TorchScript
運用環境のデプロイメントでは、モデルを最適化された形式でエクスポートすることが不可欠です。 ONNX フレームワーク間の相互運用性とハードウェア固有の最適化が可能になります。 トーチスクリプト 推論のための Python のオーバーヘッドを排除します。
import torch
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np
import timm
# Modello ViT/Swin pre-addestrato e fine-tuned
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
model.load_state_dict(torch.load('best_vit.pth'))
model.eval()
# ============================================================
# EXPORT ONNX
# ============================================================
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"vit_model.onnx",
export_params=True,
opset_version=17, # ONNX opset 17 per operatori recenti
do_constant_folding=True, # Ottimizzazione grafo
input_names=['pixel_values'],
output_names=['logits'],
dynamic_axes={
'pixel_values': {0: 'batch_size'}, # Batch size dinamico
'logits': {0: 'batch_size'}
}
)
# Verifica modello ONNX
onnx_model = onnx.load("vit_model.onnx")
onnx.checker.check_model(onnx_model)
print("Modello ONNX valido!")
# Inferenza con ONNX Runtime (CPU o GPU)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
ort_session = ort.InferenceSession("vit_model.onnx", providers=providers)
# Test inferenza ONNX
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = ort_session.run(None, {'pixel_values': test_input})
print(f"ONNX output shape: {outputs[0].shape}")
# ============================================================
# TORCHSCRIPT (alternativa per deployment PyTorch)
# ============================================================
model_scripted = torch.jit.script(model)
model_scripted.save("vit_scripted.pt")
# Ricarica e usa
loaded = torch.jit.load("vit_scripted.pt")
with torch.no_grad():
out = loaded(dummy_input)
print(f"TorchScript output: {out.shape}")
# ============================================================
# BENCHMARK ONNX vs PyTorch
# ============================================================
import time
def benchmark(fn, n_runs=50, warmup=10):
for _ in range(warmup):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(n_runs):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = (time.perf_counter() - t0) / n_runs * 1000
return elapsed
# PyTorch
def pt_inference():
with torch.no_grad():
model(dummy_input)
# ONNX Runtime
def onnx_inference():
ort_session.run(None, {'pixel_values': test_input})
pt_ms = benchmark(pt_inference)
onnx_ms = benchmark(onnx_inference)
print(f"PyTorch: {pt_ms:.1f} ms/inference")
print(f"ONNX RT: {onnx_ms:.1f} ms/inference")
print(f"Speedup ONNX: {pt_ms/onnx_ms:.2f}x")
特殊なタスクのための ViT: 医療、衛星、マルチモーダル
ViT は、従来とはまったく異なるドメイン上で優れた転送能力を実証しました。 イメージネット。特に重要な 3 つの応用分野は次のとおりです。 コンピュータ 医療ビジョン (放射線学、デジタル病理学、皮膚科)、 リモートセンシング (衛星画像、ドローン画像) とモデル マルチモーダル (CLIP、SigLIP、LLaVA)。
import timm
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
# ============================================================
# ViT PER IMAGING MEDICO (classificazione CXR)
# ============================================================
# Chest X-Ray classification con DeiT fine-tuned
class MedicalViT(nn.Module):
"""
ViT per classificazione immagini mediche.
Usa un backbone pre-addestrato su ImageNet + fine-tuning su CXR.
Considera: le immagini mediche sono spesso grayscale (convertite a 3ch)
e richiedono risoluzione maggiore (384px).
"""
def __init__(self, n_classes: int, model_name: str = "deit3_base_patch16_384",
dropout: float = 0.2):
super().__init__()
# DeiT3 a 384px: più accurato per dettagli fini nelle immagini mediche
self.backbone = timm.create_model(
model_name,
pretrained=True,
num_classes=0, # Rimuovi head originale
img_size=384
)
d_model = self.backbone.embed_dim
# Head medica con dropout aggressivo (evita overfit su dataset piccoli)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Dropout(dropout),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout / 2),
nn.Linear(d_model // 2, n_classes)
)
# Congela i primi 6 layer (feature basiche = ImageNet features)
# Fine-tuna solo i layer superiori (feature ad alto livello)
total_blocks = len(self.backbone.blocks)
freeze_until = total_blocks // 2
for i, block in enumerate(self.backbone.blocks):
if i < freeze_until:
for p in block.parameters():
p.requires_grad = False
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Parametri trainabili: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x) # [B, d_model] - CLS token
return self.head(features)
# Uso per NIH Chest X-Ray Dataset (14 classi, multi-label)
medical_vit = MedicalViT(n_classes=14, dropout=0.3)
x = torch.randn(4, 3, 384, 384) # 384px per CXR
out = medical_vit(x)
print(f"CXR prediction shape: {out.shape}") # [4, 14]
# ============================================================
# CLIP: VISION-LANGUAGE PRETRAINING
# ============================================================
# CLIP usa un ViT come encoder visuale accoppiato a un Transformer testuale.
# L'addestramento contrasto allinea rappresentazioni visive e testuali.
def clip_zero_shot_classification(
images: torch.Tensor,
class_descriptions: list, # ["a photo of a cat", "a photo of a dog", ...]
model_name: str = "openai/clip-vit-base-patch32"
):
"""
Zero-shot image classification con CLIP.
Non richiede esempi di training: usa descrizioni testuali delle classi.
"""
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()
# Codifica testi e immagini nello stesso spazio embedding
with torch.no_grad():
# Text embeddings
text_inputs = processor(text=class_descriptions, return_tensors="pt",
padding=True, truncation=True)
text_features = model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Image embeddings (usa ViT internamente)
image_inputs = processor(images=images, return_tensors="pt")
image_features = model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Similarità coseno: matrice [n_images, n_classes]
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
return similarity
# Esempio: classificazione zero-shot senza training
class_names = [
"a chest X-ray showing pneumonia",
"a normal chest X-ray",
"a chest X-ray showing cardiomegaly",
"a chest X-ray with pleural effusion"
]
# similarity = clip_zero_shot_classification(images, class_names)
# print(f"Predicted class: {class_names[similarity.argmax()]}")
print("ViT multimodale (CLIP) pronto per zero-shot classification")
エッジデバイス向けの ViT 最適化
ViT をエッジ ハードウェアに導入するには、特定の戦略が必要です。標準 ViT (8,600 万以上のパラメータ) Raspberry Pi やマイクロコントローラーには重すぎます。などの軽量バージョン ViT-Ti/16 (6Mパラメータ) モバイルヴィット (5Mパラメータ)は このユースケース向けに設計されており、注意の表現力と 畳み込みの効率。
import timm
import torch
import torch.onnx
import time
import numpy as np
# ============================================================
# VARIANTI ViT LEGGERE PER EDGE
# ============================================================
edge_models = {
"vit_tiny_patch16_224": "ViT-Ti (6M, ~4ms GPU)",
"deit_tiny_patch16_224": "DeiT-Ti (5.7M, ~3ms GPU)",
"mobilevit_s": "MobileViT-S (5.6M, 4ms, ottimo CPU)",
"efficientvit_m0": "EfficientViT-M0 (2.4M, ultra-light)",
"fastvit_t8": "FastViT-T8 (4M, 3x più veloce di DeiT)",
}
def benchmark_edge_models(input_size=(1, 3, 224, 224), device="cpu", n_runs=50):
"""
Benchmark dei modelli ViT leggeri su CPU (simula edge device).
CPU benchmark e più rappresentativo di deployment su RPi/Jetson Nano.
"""
results = []
x = torch.randn(*input_size).to(device)
for model_name, description in edge_models.items():
try:
model = timm.create_model(model_name, pretrained=False, num_classes=10)
model = model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
model_size_mb = n_params * 4 / (1024**2)
# Warmup
with torch.no_grad():
for _ in range(5):
model(x)
# Benchmark
t0 = time.perf_counter()
with torch.no_grad():
for _ in range(n_runs):
model(x)
latency_ms = (time.perf_counter() - t0) / n_runs * 1000
results.append({
"model": model_name,
"desc": description,
"params_M": n_params / 1e6,
"size_mb": model_size_mb,
"latency_ms": latency_ms
})
print(f"{model_name:<35} {n_params/1e6:>5.1f}M "
f"{model_size_mb:>6.1f}MB {latency_ms:>8.1f}ms")
except Exception as e:
print(f"{model_name}: Errore - {e}")
return results
# ============================================================
# EXPORT OTTIMIZZATO PER EDGE
# ============================================================
def export_vit_for_edge(model_name: str = "vit_tiny_patch16_224",
n_classes: int = 10):
"""
Pipeline completa: carica ViT-Ti, quantizza e esporta per edge.
"""
model = timm.create_model(model_name, pretrained=False, num_classes=n_classes)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# 1. Export ONNX con opset 17
torch.onnx.export(
model, dummy_input, f"{model_name}_edge.onnx",
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}}
)
# 2. Quantizzazione dinamica INT8 (per CPU edge)
import torch.quantization
model_quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.MultiheadAttention}, dtype=torch.qint8
)
# Confronto dimensioni
original_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
print(f"Modello originale FP32: {original_size:.1f} MB")
# Salva versione quantizzata
torch.save(model_quantized.state_dict(), f"{model_name}_int8.pt")
# Test latenza quantizzata su CPU
with torch.no_grad():
for _ in range(5): model_quantized(dummy_input) # warmup
t0 = time.perf_counter()
for _ in range(50): model_quantized(dummy_input)
lat_quant = (time.perf_counter() - t0) / 50 * 1000
print(f"Latenza INT8 CPU: {lat_quant:.1f}ms")
return model_quantized
print("ViT edge export pipeline pronto")
一般的なタスクに関する ViT と CNN のベンチマーク (2025 年)
| モデル | イメージネットトップ1 | レイテンシー (ミリ秒) | スループット (img/s) | パラメータ |
|---|---|---|---|---|
| レスネット-50 | 76.1% | 4.1ミリ秒 | 1,240 | 25M |
| ConvNext-T | 82.1% | 5.5ミリ秒 | 960 | 29M |
| DeiT-B | 83.1% | 9.2ミリ秒 | 570 | 87M |
| Swin-T | 81.3% | 6.8ミリ秒 | 740 | 28M |
| ViT-B/16 (ティム) | 85.5% | 11.4ミリ秒 | 460 | 86M |
| EfficientNet-B4 | 83.0% | 7.3ミリ秒 | 690 | 19M |
RTX 4090、バッチサイズ 32、FP16 で測定。レイテンシー = 単一イメージ、バッチ = 1。
警告: ViT が常に最良の選択であるとは限りません
- 小規模なデータセット (10K 未満の画像): CNN または EfficientNet は、大規模な事前トレーニングを行わなくてもパフォーマンスが向上します。 ViT が正しく収束するには大量のデータが必要です。
- エッジでのリアルタイムタスク: ViT-Ti/16 の遅延は GPU では約 4ms ですが、CPU では >100ms です。 CPU の導入には MobileNet または EfficientNet-Lite が推奨されます。
- CPU での物体検出: スイングと素晴らしいバックボーンですが、重いです。軽量バックボーンを備えた YOLO は、CPU 上の Swin よりも遅延が優れています。
- 極端なドメインシフト データセットによる微調整: ターゲット データセットが ImageNet と大きく異なる場合、事前トレーニングされた CNN はより適切に一般化できます。
本番環境における ViT のベスト プラクティス
導入のチェックリスト ViT
- 適切なバリエーションを選択してください: 限られたリソースには ViT-Ti/S、標準品質には ViT-B、 検出/セグメンテーションには Swin-T/S、ImageNet スケールでのゼロからのトレーニングには DeiT-B。
- ImageNet-21K での事前トレーニング: ImageNet-1K ではなく、常に ImageNet-21K の重みから始まります。 特に小規模なデータセットの場合、精度が大幅に向上します。
- 微調整のための低い学習率: ViT-B にはベース LR 3e-5 を使用し、 少なくとも 5 エポックのウォームアップを伴う。 LR が高すぎると、事前トレーニングされた表現が破壊されます。
- 入力解像度: 224px で事前トレーニングされた ViT が最適に機能します 224px入力の場合。 384px に微調整すると精度が向上しますが、メモリが 2.3 倍かかります。
- バッチサイズとグラジエントの累積: ViT は大きなバッチサイズからメリットを得る (256-2048)。 VRAM が十分でない場合は、勾配累積を使用します。
-
混合精度 (BF16/FP16): 常に有効にする
torch.autocast。 ViT は精度を損なうことなく 2 倍のスピードアップを実現します。 -
フラッシュ注意: アメリカ合衆国
torch.nn.functional.scaled_dot_product_attention(PyTorch 2.0+) またはflash-attn注意記憶を 40% 削減します。
結論
ビジョン トランスフォーマーは、コンピューター ビジョンの世界を再定義しました。 2026 年、二項対立 ViT と CNN、そしてほとんど時代遅れ: ハイブリッド アーキテクチャ (ConvNeXt、CoAtNet、FastViT) の組み合わせ EVA や SigLIP のような純粋な ViT が大規模なベンチマークを独占する一方で、両方の長所を兼ね備えています。
実践用の最適かつ明確なワークフロー: 大規模なデータセットで事前トレーニングされたバックボーンを選択します (ImageNet-21K、LAION)、積極的な拡張による微調整 (MixUp、CutMix、RandAugment) LR ウォームアップを行った後、ONNX にエクスポートして展開を最適化します。 CNNとの違い それは定量的なものだけではありません。ViT はグローバルな注目機能により、 画像内の長距離関係は、複雑な視覚的理解タスクに不可欠です。
シリーズの次のステップは、 ニューラル アーキテクチャ検索 (NAS):どうやって 特定のタスクと計算予算に最適なアーキテクチャの選択を自動化します。 ViT、CNN、ハイブリッド バリアント間の手動選択を超えています。
次のステップ
- 次の記事: ニューラル アーキテクチャの検索: 深層学習のための AutoML
- 関連している: LoRA と QLoRA による微調整
- コンピュータービジョンシリーズ: Swin Transformer による物体検出
- MLOps シリーズ: 本番環境でのビジョンモデルの提供







