An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale¶
Paper Information
Published: ICLR 2021 Authors: Alexey Dosovitskiy et al. Date: 2026-02-24
The Big Idea¶
Until this work, transformers — the models powering breakthroughs in natural language processing (NLP) — were rarely used by themselves for image tasks. Instead, researchers mixed them with convolutional neural networks (CNNs), the long-standing dominant models in computer vision.
The authors asked a bold question:
Quote
"Can we treat an image like a sequence of words and process it with a pure transformer, without convolutions?"
The answer was yes — and it worked extremely well.
How the Vision Transformer (ViT) Works¶
Instead of processing images pixel by pixel or through convolutional filters, the Vision Transformer (ViT) does this:
-
Split the image into patches
An image (e.g., 224×224 pixels) is divided into smaller 16×16 patches. Each patch is flattened (converted to a vector), just like a word in NLP would be processed.
-
Embed patches as tokens
These patch vectors are linearly projected into a vector space and treated as tokens — similar to words — that go into a transformer.
-
Add positional information
Transformers don't know the order of tokens unless you tell them. So "position embeddings" are added so the model knows where each patch came from in the image.
-
Pass through a standard Transformer
The sequence of patch tokens is processed by transformer layers — multi-head self-attention and feedforward networks — exactly like in NLP.
-
Classification via special token
Like BERT in NLP, a learnable
[CLS]token is added to the sequence and its output representation is used for classification.
Why This Is Important¶
Before this, CNNs had built-in biases like translation equivariance (objects look the same if moved slightly) and localized receptive fields — helpful intuitions for vision tasks.
ViT showed something surprising:
Key Insight
If you have enough data and compute, a pure transformer can learn visual patterns without those handcrafted biases.
This shifted the field toward foundation models for vision, just like language models for NLP.
Python Implementation¶
Below is a simple Vision Transformer (ViT) example using PyTorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# ============================================
# 1. PATCH EMBEDDING
# ============================================
class PatchEmbedding(nn.Module):
"""
Split image into patches and embed them.
Args:
img_size: Size of input image (assumed square)
patch_size: Size of each patch (assumed square)
in_channels: Number of input channels (3 for RGB)
embed_dim: Dimension of embedding
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Linear projection of flattened patches
self.proj = nn.Sequential(
# Rearrange image into patches: (B, C, H, W) -> (B, N, P*P*C)
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=patch_size, p2=patch_size),
nn.Linear(patch_size * patch_size * in_channels, embed_dim)
)
def forward(self, x):
"""
x: (batch_size, channels, height, width)
returns: (batch_size, n_patches, embed_dim)
"""
x = self.proj(x)
return x
# ============================================
# 2. MULTI-HEAD ATTENTION
# ============================================
class MultiHeadAttention(nn.Module):
"""
Multi-head self-attention mechanism.
Args:
embed_dim: Dimension of embeddings
n_heads: Number of attention heads
dropout: Dropout probability
"""
def __init__(self, embed_dim=768, n_heads=12, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
self.scale = self.head_dim ** -0.5
assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
# Linear layers for Q, K, V
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x):
"""
x: (batch_size, n_patches + 1, embed_dim)
returns: (batch_size, n_patches + 1, embed_dim)
"""
B, N, C = x.shape
# Generate Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Attention: (Q @ K^T) / sqrt(d_k)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, n_heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# Weighted sum of values
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, embed_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
# ============================================
# 3. MLP (FEEDFORWARD NETWORK)
# ============================================
class MLP(nn.Module):
"""
Multi-layer perceptron (feedforward network).
Args:
in_features: Input dimension
hidden_features: Hidden layer dimension
out_features: Output dimension
dropout: Dropout probability
"""
def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
# ============================================
# 4. TRANSFORMER ENCODER BLOCK
# ============================================
class TransformerBlock(nn.Module):
"""
Single Transformer encoder block.
Args:
embed_dim: Dimension of embeddings
n_heads: Number of attention heads
mlp_ratio: Ratio of mlp hidden dim to embedding dim
dropout: Dropout probability
"""
def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=dropout)
def forward(self, x):
# Multi-head attention with residual connection
x = x + self.attn(self.norm1(x))
# MLP with residual connection
x = x + self.mlp(self.norm2(x))
return x
# ============================================
# 5. VISION TRANSFORMER
# ============================================
class VisionTransformer(nn.Module):
"""
Vision Transformer (ViT) implementation.
Args:
img_size: Input image size
patch_size: Size of image patches
in_channels: Number of input channels
n_classes: Number of output classes
embed_dim: Embedding dimension
depth: Number of transformer blocks
n_heads: Number of attention heads
mlp_ratio: MLP hidden dimension ratio
dropout: Dropout probability
emb_dropout: Embedding dropout probability
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
n_classes=1000,
embed_dim=768,
depth=12,
n_heads=12,
mlp_ratio=4.0,
dropout=0.1,
emb_dropout=0.1
):
super().__init__()
# Patch embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = self.patch_embed.n_patches
# Class token (learnable)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Positional embeddings (learnable)
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(emb_dropout)
# Transformer encoder blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, n_heads, mlp_ratio, dropout)
for _ in range(depth)
])
# Classification head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, n_classes)
# Initialize weights
self._init_weights()
def _init_weights(self):
# Initialize positional embeddings
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
# Initialize linear layers and layernorms
self.apply(self._init_weights_module)
def _init_weights_module(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
x: (batch_size, channels, height, width)
returns: (batch_size, n_classes)
"""
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x) # (B, n_patches, embed_dim)
# Add class token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
x = torch.cat([cls_tokens, x], dim=1) # (B, n_patches + 1, embed_dim)
# Add positional embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# Apply transformer blocks
for block in self.blocks:
x = block(x)
# Apply layer norm
x = self.norm(x)
# Extract class token and classify
cls_token_final = x[:, 0] # (B, embed_dim)
x = self.head(cls_token_final) # (B, n_classes)
return x
# ============================================
# 6. MODEL VARIANTS
# ============================================
def vit_tiny(num_classes=1000):
"""ViT-Tiny: 5M parameters"""
return VisionTransformer(
img_size=224,
patch_size=16,
n_classes=num_classes,
embed_dim=192,
depth=12,
n_heads=3,
mlp_ratio=4.0
)
def vit_small(num_classes=1000):
"""ViT-Small: 22M parameters"""
return VisionTransformer(
img_size=224,
patch_size=16,
n_classes=num_classes,
embed_dim=384,
depth=12,
n_heads=6,
mlp_ratio=4.0
)
def vit_base(num_classes=1000):
"""ViT-Base: 86M parameters"""
return VisionTransformer(
img_size=224,
patch_size=16,
n_classes=num_classes,
embed_dim=768,
depth=12,
n_heads=12,
mlp_ratio=4.0
)
def vit_large(num_classes=1000):
"""ViT-Large: 307M parameters"""
return VisionTransformer(
img_size=224,
patch_size=16,
n_classes=num_classes,
embed_dim=1024,
depth=24,
n_heads=16,
mlp_ratio=4.0
)
# ============================================
# 7. USAGE EXAMPLE
# ============================================
if __name__ == "__main__":
# Create model
model = vit_base(num_classes=10) # CIFAR-10 has 10 classes
# Dummy input
x = torch.randn(4, 3, 224, 224) # Batch of 4 RGB images
# Forward pass
output = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}") # (4, 10)
# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {n_params:,}")
Training Script¶
Here's a complete training loop for CIFAR-10:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
# ============================================
# TRAINING CONFIGURATION
# ============================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 3e-4
NUM_CLASSES = 10
IMG_SIZE = 224
# ============================================
# DATA PREPARATION
# ============================================
transform_train = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(IMG_SIZE, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=2)
# ============================================
# MODEL, LOSS, OPTIMIZER
# ============================================
model = vit_small(num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# ============================================
# TRAINING LOOP
# ============================================
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
pbar = tqdm(loader, desc='Training')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
# Forward pass
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# Update progress bar
pbar.set_postfix({
'loss': running_loss / (pbar.n + 1),
'acc': 100. * correct / total
})
return running_loss / len(loader), 100. * correct / total
# ============================================
# EVALUATION LOOP
# ============================================
def evaluate(model, loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
pbar = tqdm(loader, desc='Evaluating')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({
'loss': running_loss / (pbar.n + 1),
'acc': 100. * correct / total
})
return running_loss / len(loader), 100. * correct / total
# ============================================
# MAIN TRAINING
# ============================================
best_acc = 0.0
for epoch in range(EPOCHS):
print(f"\\nEpoch {epoch+1}/{EPOCHS}")
print("-" * 50)
# Train
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
# Evaluate
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
# Update learning rate
scheduler.step()
# Print results
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
# Save best model
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), 'vit_best.pth')
print(f"✓ Saved best model with accuracy: {best_acc:.2f}%")
print(f"\\n{'='*50}")
print(f"Training completed! Best accuracy: {best_acc:.2f}%")
Installation Requirements¶
Model Architecture Summary¶
| Component | Description |
|---|---|
| PatchEmbedding | Splits image into patches and linearly embeds them |
| MultiHeadAttention | Self-attention with multiple heads |
| MLP | Two-layer feedforward network with GELU |
| TransformerBlock | Attention + MLP with residual connections |
| VisionTransformer | Complete ViT with classification head |
Further Reading¶
- ViT - Schweizer KI Akademie
- Original Paper: arXiv:2010.11929