Vision Transformer (ViT): Architecture and Practical Applications
In 2020, a Google Research paper fundamentally changed computer vision: "An Image is Worth 16x16 Words." The intuition was simple but revolutionary — apply the Transformer architecture, dominant in NLP, directly to images by treating visual patches as tokens. The result was the Vision Transformer (ViT), which within a few years surpassed state-of-the-art CNNs on ImageNet and dozens of other benchmarks, opening the door to a new generation of visual models.
The promise of ViTs is not just accuracy: it's versatility. The same Transformer backbone used for text can be shared with images, enabling multimodal models like CLIP, DALL-E, and GPT-4V. ViTs scale better than CNNs with additional data and compute, and variants like Swin Transformer and DeiT have made these models efficient even on medium-sized datasets without pre-training on hundreds of millions of images.
In this guide we build a ViT from scratch in PyTorch, explore the most important architectural variants, and show how to fine-tune for specific production tasks.
What You'll Learn
- ViT architecture: patch embedding, positional encoding, visual self-attention
- Complete from-scratch implementation with PyTorch
- Differences between ViT-B/16, ViT-L/32, DeiT, Swin Transformer
- Fine-tuning a pre-trained ViT on a custom dataset
- Data augmentation techniques for ViT (MixUp, CutMix, RandAugment)
- Attention rollout and interpretability of attention maps
- Optimized deployment: ONNX, TorchScript, edge devices
- ViT vs CNN benchmarks on real datasets
The ViT Architecture: How It Works
The Vision Transformer takes an image as input and divides it into non-overlapping patches
of fixed size (typically 16x16 or 32x32 pixels). Each patch is flattened and
linearly projected into a vector of dimension d_model (the embedding). These
embeddings, called patch embeddings, become the tokens of the Transformer.
A special [CLS] token (class token) is prepended to the sequence, similarly
to BERT in NLP. After encoding, the CLS token representation is passed to a classification
head to produce the final prediction. Positional encoding — in learned rather than sinusoidal
form — is added to the embeddings to preserve spatial information that would otherwise be lost.
# ViT architecture diagram
#
# Input Image (224x224x3)
# |
# v
# Patch Extraction: split into 196 patches of 16x16
# (224/16 = 14 patches per side -> 14*14 = 196 patches)
# |
# v
# Patch Embedding: each patch [768] via Linear projection
# + [CLS] token -> sequence of 197 tokens
# |
# v
# + Positional Embedding (learnable, 197x768)
# |
# v
# Transformer Encoder (L layers):
# - LayerNorm
# - Multi-Head Self-Attention (h heads)
# - Residual connection
# - LayerNorm
# - MLP (d_model -> 4*d_model -> d_model)
# - Residual connection
# |
# v
# [CLS] token representation
# |
# v
# MLP Head -> num_classes output
# Standard variants:
# ViT-B/16: d_model=768, L=12, h=12 | ~86M params
# ViT-L/16: d_model=1024, L=24, h=16 | ~307M params
# ViT-H/14: d_model=1280, L=32, h=16 | ~632M params
ViT Implementation from Scratch
We build a complete ViT in PyTorch. We start from the fundamental component: the Patch Embedding.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
# ============================================================
# 1. PATCH EMBEDDING
# ============================================================
class PatchEmbedding(nn.Module):
"""
Converts an image into a sequence of patch embeddings.
Method 1: Convolution (efficient, equivalent to patch+linear)
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Equivalent to: flatten each patch + linear projection
# But implemented as Conv2d for efficiency
self.projection = nn.Sequential(
# Split into patches and project
nn.Conv2d(in_channels, d_model,
kernel_size=patch_size, stride=patch_size),
# [B, d_model, H/P, W/P] -> [B, n_patches, d_model]
Rearrange('b d h w -> b (h w) d')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W]
return self.projection(x) # [B, n_patches, d_model]
# ============================================================
# 2. MULTI-HEAD SELF ATTENTION for ViT
# ============================================================
class ViTAttention(nn.Module):
"""Multi-head self-attention with dropout."""
def __init__(self, d_model=768, n_heads=12, attn_dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim ** -0.5
# QKV projection
self.qkv = nn.Linear(d_model, d_model * 3, bias=True)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
B, N, C = x.shape
# Compute Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Each: [B, heads, N, head_dim]
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
attn_weights = attn # Save for attention rollout
attn = self.attn_drop(attn)
# Weighted sum + output projection
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, attn_weights
# ============================================================
# 3. TRANSFORMER ENCODER BLOCK
# ============================================================
class ViTBlock(nn.Module):
"""Single Transformer block for ViT."""
def __init__(self, d_model=768, n_heads=12, mlp_ratio=4.0,
dropout=0.0, attn_dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = ViTAttention(d_model, n_heads, attn_dropout)
self.norm2 = nn.LayerNorm(d_model)
mlp_dim = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor):
# Pre-norm + residual connection
attn_out, attn_weights = self.attn(self.norm1(x))
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x, attn_weights
# ============================================================
# 4. FULL VISION TRANSFORMER
# ============================================================
class VisionTransformer(nn.Module):
"""
Full Vision Transformer (ViT).
Paper: "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
num_classes: int = 1000,
d_model: int = 768,
depth: int = 12,
n_heads: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
attn_dropout: float = 0.0,
representation_size: int = None
):
super().__init__()
self.num_classes = num_classes
self.d_model = d_model
# Patch + Position Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, d_model)
n_patches = self.patch_embed.n_patches
# CLS token and positional embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
self.pos_drop = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
ViTBlock(d_model, n_heads, mlp_ratio, dropout, attn_dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(d_model)
# Classification head
if representation_size is not None:
self.pre_logits = nn.Sequential(
nn.Linear(d_model, representation_size),
nn.Tanh()
)
else:
self.pre_logits = nn.Identity()
self.head = nn.Linear(
representation_size if representation_size else d_model,
num_classes
)
self._init_weights()
def _init_weights(self):
"""Standard ViT weight initialization."""
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor, return_attn: bool = False):
B = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x) # [B, n_patches, d_model]
# 2. Prepend CLS token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
x = torch.cat([cls_tokens, x], dim=1) # [B, n_patches+1, d_model]
# 3. Add positional embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# 4. Transformer blocks
attn_weights_list = []
for block in self.blocks:
x, attn_weights = block(x)
attn_weights_list.append(attn_weights)
# 5. Final layer norm
x = self.norm(x)
# 6. Use only CLS token for classification
cls_output = x[:, 0]
cls_output = self.pre_logits(cls_output)
logits = self.head(cls_output)
if return_attn:
return logits, attn_weights_list
return logits
# ============================================================
# 5. STANDARD VARIANTS
# ============================================================
def vit_base_16(num_classes=1000, **kwargs):
"""ViT-B/16: 86M parameters, 224x224 input."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=768,
depth=12, n_heads=12, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_large_16(num_classes=1000, **kwargs):
"""ViT-L/16: 307M parameters, 224x224 input."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=1024,
depth=24, n_heads=16, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_tiny_16(num_classes=1000, **kwargs):
"""ViT-Ti/16: ~6M parameters, for edge/mobile."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=192,
depth=12, n_heads=3, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
# Test
model = vit_base_16(num_classes=100)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}") # [2, 100]
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Parameters: 85,880,164
Architectural Variants: DeiT, Swin, and BEiT
The original ViT required enormous amounts of data (JFT-300M, 300 million images) to outperform CNNs. This limitation drove the development of more data-efficient variants:
| Model | Year | Key Innovation | ImageNet Top-1 | Parameters |
|---|---|---|---|---|
| ViT-B/16 | 2020 | First ViT, requires JFT-300M | 81.8% | 86M |
| DeiT-B | 2021 | CNN teacher distillation, ImageNet only | 83.1% | 87M |
| Swin-B | 2021 | Shifted Window Attention, hierarchical | 85.2% | 88M |
| BEiT-L | 2022 | Masked Image Modeling (BERT for vision) | 87.4% | 307M |
| DeiT III-H | 2022 | Advanced training recipe | 87.7% | 632M |
| ViT-G (EVA) | 2023 | Scales to 1B params, CLIP pre-training | 89.6% | 1.0B |
DeiT (Data-efficient Image Transformers) from Facebook AI is probably the most practically important variant: it introduces a distillation token that allows learning from a CNN teacher (such as RegNet or ConvNext), achieving excellent performance with only ImageNet-1K.
Swin Transformer solves the quadratic attention complexity problem by introducing Shifted Windows: attention is computed within local windows rather than over the entire image, with linear computational cost relative to image size. Swin produces hierarchical representations (like CNNs) and is the preferred backbone for detection and segmentation.
Fine-tuning a Pre-trained ViT
The most practical way to use ViTs in production is to start from a model pre-trained on ImageNet-21K and fine-tune on your own dataset. Hugging Face Transformers provides all major ViT models with a uniform API.
# pip install transformers timm torch torchvision datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
ViTForImageClassification, ViTImageProcessor,
AutoImageProcessor
)
from torchvision import datasets, transforms
import os
# ============================================================
# FINE-TUNING ViT-B/16 on a Custom Dataset
# ============================================================
class ViTFineTuner(nn.Module):
"""
Pre-trained ViT with custom classification head.
Supports partial or full fine-tuning.
"""
def __init__(self, num_classes: int, model_name: str = "google/vit-base-patch16-224",
freeze_backbone: bool = False):
super().__init__()
# Load pre-trained ViT from HuggingFace
self.vit = ViTForImageClassification.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True # Allows changing num_classes
)
if freeze_backbone:
# Freeze everything except the classification head
for param in self.vit.vit.parameters():
param.requires_grad = False
print(f"Trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
def forward(self, x):
outputs = self.vit(pixel_values=x)
return outputs.logits
# ============================================================
# DATA AUGMENTATION for ViT
# ============================================================
def get_vit_transforms(img_size: int = 224, mode: str = "train"):
"""
Augmentation pipeline optimized for ViT.
ViT benefits greatly from aggressive augmentation.
"""
if mode == "train":
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9), # RandAugment
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=0.25) # CutOut/Erasing
])
else:
# Resize + center crop for validation/test
return transforms.Compose([
transforms.Resize(int(img_size * 1.143)), # 256 for 224
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ============================================================
# TRAINING LOOP WITH WARMUP + COSINE DECAY
# ============================================================
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
"""LR schedule: linear warmup + cosine decay (standard for ViT)."""
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
def train_vit(
model, train_loader, val_loader,
num_epochs=30, base_lr=3e-5, weight_decay=0.05,
device="cuda", label_smoothing=0.1
):
model = model.to(device)
# AdamW with weight decay (standard for ViT)
# Exclude bias and LayerNorm from weight decay
no_decay_params = []
decay_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'bias' in name or 'norm' in name or 'cls_token' in name or 'pos_embed' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = torch.optim.AdamW([
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=base_lr)
total_steps = len(train_loader) * num_epochs
warmup_steps = len(train_loader) * 5 # 5 epochs of warmup
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
# Label smoothing loss
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
best_acc = 0.0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, labels)
loss.backward()
# Gradient clipping (important for ViT)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_loss += loss.item()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_loss = train_loss / len(train_loader)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {avg_loss:.4f} | "
f"Val Acc: {val_acc:.4f} | "
f"LR: {current_lr:.2e}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_vit.pth")
print(f" -> New best: {best_acc:.4f}")
return best_acc
MixUp and CutMix: Advanced Augmentation for ViT
ViTs benefit particularly from mixup augmentation techniques. MixUp linearly combines pairs of images and their labels; CutMix replaces a rectangular region of one image with the corresponding region from another. Both techniques improve generalization and model calibration.
import numpy as np
class MixUpCutMix:
"""
Combination of MixUp and CutMix as used in DeiT and timm.
Randomly applies one of the two methods for each batch.
"""
def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0,
prob=0.5, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.prob = prob
self.num_classes = num_classes
def one_hot(self, labels: torch.Tensor) -> torch.Tensor:
return F.one_hot(labels, self.num_classes).float()
def mixup(self, imgs, labels_oh):
"""MixUp: linear interpolation."""
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
B = imgs.size(0)
idx = torch.randperm(B)
mixed_imgs = lam * imgs + (1 - lam) * imgs[idx]
mixed_labels = lam * labels_oh + (1 - lam) * labels_oh[idx]
return mixed_imgs, mixed_labels
def cutmix(self, imgs, labels_oh):
"""CutMix: cut and paste patches."""
lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
B, C, H, W = imgs.shape
idx = torch.randperm(B)
# Compute bounding box dimensions
cut_ratio = math.sqrt(1.0 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
# Random center
cy = np.random.randint(H)
cx = np.random.randint(W)
y1 = max(0, cy - cut_h // 2)
y2 = min(H, cy + cut_h // 2)
x1 = max(0, cx - cut_w // 2)
x2 = min(W, cx + cut_w // 2)
# Apply CutMix
mixed_imgs = imgs.clone()
mixed_imgs[:, :, y1:y2, x1:x2] = imgs[idx, :, y1:y2, x1:x2]
# Recompute actual lambda
lam_actual = 1.0 - (y2 - y1) * (x2 - x1) / (H * W)
mixed_labels = lam_actual * labels_oh + (1 - lam_actual) * labels_oh[idx]
return mixed_imgs, mixed_labels
def __call__(self, imgs, labels):
labels_oh = self.one_hot(labels).to(imgs.device)
if np.random.random() < self.prob:
if np.random.random() < 0.5:
return self.mixup(imgs, labels_oh)
else:
return self.cutmix(imgs, labels_oh)
return imgs, labels_oh
# Usage in training loop
mixup_cutmix = MixUpCutMix(num_classes=100)
# In the training loop:
# imgs, soft_labels = mixup_cutmix(imgs, labels)
# loss = F.cross_entropy(logits, soft_labels) # Soft labels
Attention Rollout: Visualizing What the ViT Sees
One of the most interesting features of ViTs is the ability to visualize attention maps to understand which regions of the image the model considers relevant. The Attention Rollout technique propagates attention through all layers to obtain a global relevance map.
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def compute_attention_rollout(attn_weights_list: list,
discard_ratio: float = 0.9) -> np.ndarray:
"""
Attention Rollout (Abnar & Zuidema, 2020).
Propagates attention maps through all layers.
attn_weights_list: list of tensors [B, heads, N, N]
discard_ratio: fraction of attention to zero out (focus on top)
"""
rollout = None
for attn in attn_weights_list:
# attn: [B, heads, N, N] -> mean over heads -> [B, N, N]
attn_mean = attn.mean(dim=1) # [B, N, N]
# Add identity (residual connection)
eye = torch.eye(attn_mean.size(-1), device=attn_mean.device)
attn_mean = attn_mean + eye
attn_mean = attn_mean / attn_mean.sum(dim=-1, keepdim=True)
if rollout is None:
rollout = attn_mean
else:
rollout = torch.bmm(attn_mean, rollout)
return rollout
def visualize_vit_attention(model, image_tensor: torch.Tensor,
patch_size: int = 16):
"""
Visualizes the CLS token attention over the image.
"""
model.eval()
with torch.no_grad():
_, attn_list = model(image_tensor.unsqueeze(0), return_attn=True)
# Compute rollout
rollout = compute_attention_rollout(attn_list) # [1, N+1, N+1]
# CLS attention toward all patches
cls_attn = rollout[0, 0, 1:] # Exclude the CLS token itself
# Reshape to patch grid
H = W = int(math.sqrt(cls_attn.size(0)))
attn_map = cls_attn.reshape(H, W).cpu().numpy()
# Normalize and upscale to image size
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
attn_map_upscaled = np.kron(attn_map, np.ones((patch_size, patch_size)))
return attn_map_upscaled
# Example usage and visualization
# model = vit_base_16(num_classes=1000)
# img_tensor = get_vit_transforms(mode="val")(Image.open("dog.jpg"))
# attn_map = visualize_vit_attention(model, img_tensor)
#
# plt.figure(figsize=(12, 4))
# plt.subplot(1, 2, 1)
# plt.imshow(img_tensor.permute(1,2,0).numpy())
# plt.title("Original image")
# plt.subplot(1, 2, 2)
# plt.imshow(attn_map, cmap='inferno')
# plt.title("Attention Rollout (CLS token)")
# plt.colorbar()
# plt.tight_layout()
# plt.savefig("vit_attention.png", dpi=150)
Swin Transformer: Hierarchical Window Attention
The Swin Transformer solves two fundamental limitations of standard ViT: the quadratic attention complexity (which limits the processable resolution) and the lack of hierarchical representations (needed for detection and segmentation).
Swin divides the image into non-overlapping windows and computes attention only within each window (linear complexity). Between layers, windows are shifted to allow communication between adjacent windows. The hierarchical structure progressively reduces spatial resolution, producing feature maps at 4 scales like traditional CNNs.
# Using Swin Transformer via timm (simpler than implementing from scratch)
# pip install timm
import timm
import torch
# Create Swin-T (Tiny): 28M params, 81.3% ImageNet Top-1
swin_tiny = timm.create_model(
'swin_tiny_patch4_window7_224',
pretrained=True,
num_classes=0 # 0 = remove classifier (backbone only)
)
# Swin-B (Base): 88M params, 85.2% ImageNet Top-1
swin_base = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
num_classes=100 # Custom classifier
)
# Swin-V2-L for high resolution (resolution scaling)
swin_v2 = timm.create_model(
'swinv2_large_window12to16_192to256_22kft1k',
pretrained=True,
num_classes=10
)
# Check hierarchical feature maps (for detection/segmentation)
swin_backbone = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
features_only=True, # Returns features at 4 scales
out_indices=(0, 1, 2, 3)
)
x = torch.randn(2, 3, 224, 224)
features = swin_backbone(x)
for i, feat in enumerate(features):
print(f"Stage {i}: {feat.shape}")
# Stage 0: torch.Size([2, 192, 56, 56])
# Stage 1: torch.Size([2, 384, 28, 28])
# Stage 2: torch.Size([2, 768, 14, 14])
# Stage 3: torch.Size([2, 1536, 7, 7])
# Full fine-tuning with timm
from timm.loss import SoftTargetCrossEntropy
from timm.data.mixup import Mixup
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
# Optimal parameters for Swin fine-tuning
model = timm.create_model('swin_base_patch4_window7_224',
pretrained=True, num_classes=10)
# Optimizer with Swin-specific parameters
optimizer = create_optimizer_v2(
model,
opt='adamw',
lr=5e-5,
weight_decay=0.05,
layer_decay=0.9 # Layer-wise LR decay: deeper layers = lower LR
)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Swin output: {out.shape}") # [2, 10]
Optimized Deployment: ONNX and TorchScript
For production deployment, it is essential to export the model in an optimized format. ONNX enables interoperability between frameworks and hardware-specific optimizations; TorchScript eliminates Python overhead for inference.
import torch
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np
import timm
# Pre-trained and fine-tuned ViT/Swin model
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
model.load_state_dict(torch.load('best_vit.pth'))
model.eval()
# ============================================================
# ONNX EXPORT
# ============================================================
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"vit_model.onnx",
export_params=True,
opset_version=17, # ONNX opset 17 for recent operators
do_constant_folding=True, # Graph optimization
input_names=['pixel_values'],
output_names=['logits'],
dynamic_axes={
'pixel_values': {0: 'batch_size'}, # Dynamic batch size
'logits': {0: 'batch_size'}
}
)
# Verify ONNX model
onnx_model = onnx.load("vit_model.onnx")
onnx.checker.check_model(onnx_model)
print("Valid ONNX model!")
# Inference with ONNX Runtime (CPU or GPU)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
ort_session = ort.InferenceSession("vit_model.onnx", providers=providers)
# Test ONNX inference
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = ort_session.run(None, {'pixel_values': test_input})
print(f"ONNX output shape: {outputs[0].shape}")
# ============================================================
# TORCHSCRIPT (alternative for PyTorch deployment)
# ============================================================
model_scripted = torch.jit.script(model)
model_scripted.save("vit_scripted.pt")
# Reload and use
loaded = torch.jit.load("vit_scripted.pt")
with torch.no_grad():
out = loaded(dummy_input)
print(f"TorchScript output: {out.shape}")
# ============================================================
# ONNX vs PyTorch BENCHMARK
# ============================================================
import time
def benchmark(fn, n_runs=50, warmup=10):
for _ in range(warmup):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(n_runs):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = (time.perf_counter() - t0) / n_runs * 1000
return elapsed
# PyTorch
def pt_inference():
with torch.no_grad():
model(dummy_input)
# ONNX Runtime
def onnx_inference():
ort_session.run(None, {'pixel_values': test_input})
pt_ms = benchmark(pt_inference)
onnx_ms = benchmark(onnx_inference)
print(f"PyTorch: {pt_ms:.1f} ms/inference")
print(f"ONNX RT: {onnx_ms:.1f} ms/inference")
print(f"ONNX Speedup: {pt_ms/onnx_ms:.2f}x")
ViT for Specialized Tasks: Medical, Satellite, and Multimodal
ViTs have demonstrated exceptional transfer capability to domains very different from ImageNet. Three particularly important application areas are medical computer vision (radiology, digital pathology, dermatology), remote sensing (satellite imagery, drone imagery), and multimodal models (CLIP, SigLIP, LLaVA). The global attention mechanism makes ViTs especially powerful for detecting subtle anomalies across entire images — a task where CNNs with limited receptive fields struggle.
import timm
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
# ============================================================
# ViT FOR MEDICAL IMAGING (Chest X-Ray Classification)
# ============================================================
class MedicalViT(nn.Module):
"""
ViT for medical image classification.
Uses a pre-trained backbone on ImageNet + fine-tuning on CXR.
Notes:
- Medical images are often grayscale (converted to 3ch)
- Higher resolution needed (384px) for fine anatomical details
- Aggressive dropout to prevent overfitting on small datasets
"""
def __init__(self, n_classes: int, model_name: str = "deit3_base_patch16_384",
dropout: float = 0.2):
super().__init__()
# DeiT3 at 384px: more accurate for fine-grained medical details
self.backbone = timm.create_model(
model_name,
pretrained=True,
num_classes=0, # Remove original head
img_size=384
)
d_model = self.backbone.embed_dim
# Medical head with aggressive dropout (prevents overfit on small datasets)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Dropout(dropout),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout / 2),
nn.Linear(d_model // 2, n_classes)
)
# Freeze the first half of blocks (basic features = ImageNet features)
# Fine-tune only upper layers (high-level features)
total_blocks = len(self.backbone.blocks)
freeze_until = total_blocks // 2
for i, block in enumerate(self.backbone.blocks):
if i < freeze_until:
for p in block.parameters():
p.requires_grad = False
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x) # [B, d_model] - CLS token
return self.head(features)
# NIH Chest X-Ray Dataset (14 classes, multi-label)
medical_vit = MedicalViT(n_classes=14, dropout=0.3)
x = torch.randn(4, 3, 384, 384) # 384px for CXR
out = medical_vit(x)
print(f"CXR prediction shape: {out.shape}") # [4, 14]
# ============================================================
# CLIP: VISION-LANGUAGE PRETRAINING
# ============================================================
# CLIP uses a ViT as the visual encoder paired with a text Transformer.
# Contrastive training aligns visual and textual representations in the
# same embedding space — enabling zero-shot transfer with text prompts.
def clip_zero_shot_classification(
images,
class_descriptions: list, # ["a photo of a cat", "a photo of a dog", ...]
model_name: str = "openai/clip-vit-base-patch32"
):
"""
Zero-shot image classification with CLIP.
No training examples needed: uses text descriptions of the classes.
Works for medical, satellite, or any domain with descriptive labels.
"""
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()
# Encode texts and images in the same embedding space
with torch.no_grad():
# Text embeddings
text_inputs = processor(text=class_descriptions, return_tensors="pt",
padding=True, truncation=True)
text_features = model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Image embeddings (uses ViT internally)
image_inputs = processor(images=images, return_tensors="pt")
image_features = model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Cosine similarity matrix: [n_images, n_classes]
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
return similarity
# Example: zero-shot classification without any training
class_names = [
"a chest X-ray showing pneumonia",
"a normal chest X-ray",
"a chest X-ray showing cardiomegaly",
"a chest X-ray with pleural effusion"
]
# similarity = clip_zero_shot_classification(images, class_names)
# print(f"Predicted class: {class_names[similarity.argmax()]}")
print("Multimodal ViT (CLIP) ready for zero-shot classification")
Optimizing ViT for Edge Devices
Deploying ViTs on edge hardware requires specific strategies. Standard ViTs (86M+ parameters) are too heavy for Raspberry Pi or microcontrollers. Lighter variants like ViT-Ti/16 (6M params) and MobileViT (5M params) are designed for this use case, combining the expressive power of attention with the efficiency of convolutions. INT8 dynamic quantization can further reduce latency by 30-50% on ARM CPUs with no retraining required.
import timm
import torch
import torch.nn as nn
import torch.onnx
import time
# ============================================================
# LIGHTWEIGHT ViT VARIANTS FOR EDGE
# ============================================================
edge_models = {
"vit_tiny_patch16_224": "ViT-Ti (6M, ~4ms GPU)",
"deit_tiny_patch16_224": "DeiT-Ti (5.7M, ~3ms GPU)",
"mobilevit_s": "MobileViT-S (5.6M, great CPU)",
"efficientvit_m0": "EfficientViT-M0 (2.4M, ultra-light)",
"fastvit_t8": "FastViT-T8 (4M, 3x faster than DeiT)",
}
def benchmark_edge_models(input_size=(1, 3, 224, 224), device="cpu", n_runs=50):
"""
Benchmark lightweight ViT models on CPU (simulates edge device).
CPU benchmark is more representative of deployment on RPi/Jetson Nano
than GPU benchmark.
"""
results = []
x = torch.randn(*input_size).to(device)
for model_name, description in edge_models.items():
try:
model = timm.create_model(model_name, pretrained=False, num_classes=10)
model = model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
model_size_mb = n_params * 4 / (1024**2)
# Warmup (critical for accurate CPU benchmarks)
with torch.no_grad():
for _ in range(5):
model(x)
# Benchmark
t0 = time.perf_counter()
with torch.no_grad():
for _ in range(n_runs):
model(x)
latency_ms = (time.perf_counter() - t0) / n_runs * 1000
results.append({
"model": model_name,
"desc": description,
"params_M": n_params / 1e6,
"size_mb": model_size_mb,
"latency_ms": latency_ms
})
print(f"{model_name:<35} {n_params/1e6:>5.1f}M "
f"{model_size_mb:>6.1f}MB {latency_ms:>8.1f}ms")
except Exception as e:
print(f"{model_name}: Error - {e}")
return results
# ============================================================
# OPTIMIZED EDGE EXPORT PIPELINE
# ============================================================
def export_vit_for_edge(model_name: str = "vit_tiny_patch16_224",
n_classes: int = 10):
"""
Complete pipeline: load ViT-Ti, quantize INT8, export for edge.
Combines ONNX export (for runtime flexibility) and PyTorch INT8
dynamic quantization (for ARM CPU acceleration).
"""
model = timm.create_model(model_name, pretrained=False, num_classes=n_classes)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# Step 1: Export to ONNX with opset 17
torch.onnx.export(
model, dummy_input, f"{model_name}_edge.onnx",
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}}
)
# Step 2: Dynamic INT8 quantization for CPU edge deployment
# Quantizes Linear and MultiheadAttention layers to INT8
# No calibration dataset required (dynamic = runtime quantization)
model_quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.MultiheadAttention}, dtype=torch.qint8
)
# Size comparison
original_size = sum(
p.numel() * p.element_size()
for p in model.parameters()
) / (1024**2)
print(f"Original FP32 model: {original_size:.1f} MB")
# Save quantized version
torch.save(model_quantized.state_dict(), f"{model_name}_int8.pt")
# Measure INT8 latency on CPU
with torch.no_grad():
for _ in range(5): model_quantized(dummy_input) # warmup
t0 = time.perf_counter()
for _ in range(50): model_quantized(dummy_input)
lat_quant = (time.perf_counter() - t0) / 50 * 1000
print(f"INT8 CPU latency: {lat_quant:.1f}ms")
# Expected results for ViT-Ti on x86 CPU:
# FP32: ~85ms, INT8: ~48ms -> ~1.8x speedup
# On Raspberry Pi 5 (ARM): FP32 ~320ms, INT8 ~175ms
return model_quantized
print("ViT edge export pipeline ready")
ViT vs CNN Benchmarks on Common Tasks (2025)
| Model | ImageNet Top-1 | Latency (ms) | Throughput (img/s) | Params |
|---|---|---|---|---|
| ResNet-50 | 76.1% | 4.1 ms | 1,240 | 25M |
| ConvNeXt-T | 82.1% | 5.5 ms | 960 | 29M |
| DeiT-B | 83.1% | 9.2 ms | 570 | 87M |
| Swin-T | 81.3% | 6.8 ms | 740 | 28M |
| ViT-B/16 (timm) | 85.5% | 11.4 ms | 460 | 86M |
| EfficientNet-B4 | 83.0% | 7.3 ms | 690 | 19M |
Measured on RTX 4090, batch size 32, FP16. Latency = single image, batch=1.
Warning: ViT is Not Always the Best Choice
- Small datasets (<10K images): CNNs or EfficientNet perform better without large-scale pre-training. ViT requires a lot of data to converge correctly.
- Real-time tasks on edge: ViT-Ti/16 has ~4ms latency on GPU but >100ms on CPU. MobileNet or EfficientNet-Lite are preferable for CPU deployment.
- Object detection on CPU: Swin is an excellent backbone but heavy. YOLO with a lightweight backbone outperforms Swin on CPU in latency.
- Fine-tuning with extreme domain-shift: Pre-trained CNNs can generalize better if the target dataset is very different from ImageNet.
Best Practices for ViT in Production
ViT Deployment Checklist
- Choose the right variant: ViT-Ti/S for limited resources, ViT-B for standard quality, Swin-T/S for detection/segmentation, DeiT-B for training from scratch on ImageNet-scale.
- Pre-training on ImageNet-21K: always start from ImageNet-21K weights, not ImageNet-1K. Offers a significant accuracy jump, especially with small datasets.
- Low learning rate for fine-tuning: use base LR 3e-5 for ViT-B, with at least 5 epochs of warmup. Too high an LR destroys pre-trained representations.
- Input resolution: ViTs pre-trained at 224px work best with 224px input. Fine-tuning at 384px improves accuracy but costs 2.3x more memory.
- Batch size and gradient accumulation: ViT benefits from large batch sizes (256-2048). Use gradient accumulation if VRAM is insufficient.
-
Mixed precision (BF16/FP16): always enable
torch.autocast. ViT gains 2x speedup without accuracy loss. -
Flash Attention: use
torch.nn.functional.scaled_dot_product_attention(PyTorch 2.0+) orflash-attnto reduce attention memory by 40%.
Conclusions
Vision Transformers have redefined the computer vision landscape. In 2026, the ViT vs CNN dichotomy is largely overcome: hybrid architectures (ConvNeXt, CoAtNet, FastViT) combine the best of both worlds, while pure ViTs like EVA and SigLIP dominate large-scale benchmarks.
For practice, the optimal workflow is clear: choose a backbone pre-trained on large datasets (ImageNet-21K, LAION), fine-tune with aggressive augmentation (MixUp, CutMix, RandAugment) and LR warmup, then export to ONNX for optimized deployment. The difference from CNNs is not just quantitative — the global attention capability lets ViT capture long-range relationships in images, crucial for complex visual understanding tasks.
The next step in the series is Neural Architecture Search (NAS): how to automate the choice of the optimal architecture for a given task and computational budget, going beyond manual selection among ViT, CNN, and hybrid variants.
Next Steps
- Next article: Neural Architecture Search: AutoML for Deep Learning
- Related: Fine-tuning with LoRA and QLoRA
- Computer Vision series: Object Detection with Swin Transformer
- MLOps series: Serving Vision Models in Production







