Skip to content

An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale

Paper Information

Published: ICLR 2021 Authors: Alexey Dosovitskiy et al. Date: 2026-02-24

The Big Idea

Until this work, transformers — the models powering breakthroughs in natural language processing (NLP) — were rarely used by themselves for image tasks. Instead, researchers mixed them with convolutional neural networks (CNNs), the long-standing dominant models in computer vision.

The authors asked a bold question:

Quote

"Can we treat an image like a sequence of words and process it with a pure transformer, without convolutions?"

The answer was yes — and it worked extremely well.


How the Vision Transformer (ViT) Works

Instead of processing images pixel by pixel or through convolutional filters, the Vision Transformer (ViT) does this:

  1. Split the image into patches

    An image (e.g., 224×224 pixels) is divided into smaller 16×16 patches. Each patch is flattened (converted to a vector), just like a word in NLP would be processed.

  2. Embed patches as tokens

    These patch vectors are linearly projected into a vector space and treated as tokens — similar to words — that go into a transformer.

  3. Add positional information

    Transformers don't know the order of tokens unless you tell them. So "position embeddings" are added so the model knows where each patch came from in the image.

  4. Pass through a standard Transformer

    The sequence of patch tokens is processed by transformer layers — multi-head self-attention and feedforward networks — exactly like in NLP.

  5. Classification via special token

    Like BERT in NLP, a learnable [CLS] token is added to the sequence and its output representation is used for classification.


Why This Is Important

Before this, CNNs had built-in biases like translation equivariance (objects look the same if moved slightly) and localized receptive fields — helpful intuitions for vision tasks.

ViT showed something surprising:

Key Insight

If you have enough data and compute, a pure transformer can learn visual patterns without those handcrafted biases.

This shifted the field toward foundation models for vision, just like language models for NLP.


Python Implementation

Below is a simple Vision Transformer (ViT) example using PyTorch.

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# ============================================
# 1. PATCH EMBEDDING
# ============================================
class PatchEmbedding(nn.Module):
    """
    Split image into patches and embed them.

    Args:
        img_size: Size of input image (assumed square)
        patch_size: Size of each patch (assumed square)
        in_channels: Number of input channels (3 for RGB)
        embed_dim: Dimension of embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Linear projection of flattened patches
        self.proj = nn.Sequential(
            # Rearrange image into patches: (B, C, H, W) -> (B, N, P*P*C)
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                     p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, embed_dim)
        )

    def forward(self, x):
        """
        x: (batch_size, channels, height, width)
        returns: (batch_size, n_patches, embed_dim)
        """
        x = self.proj(x)
        return x

# ============================================
# 2. MULTI-HEAD ATTENTION
# ============================================
class MultiHeadAttention(nn.Module):
    """
    Multi-head self-attention mechanism.

    Args:
        embed_dim: Dimension of embeddings
        n_heads: Number of attention heads
        dropout: Dropout probability
    """
    def __init__(self, embed_dim=768, n_heads=12, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.scale = self.head_dim ** -0.5

        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"

        # Linear layers for Q, K, V
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: (batch_size, n_patches + 1, embed_dim)
        returns: (batch_size, n_patches + 1, embed_dim)
        """
        B, N, C = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Attention: (Q @ K^T) / sqrt(d_k)
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, n_heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # Weighted sum of values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (B, N, embed_dim)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

# ============================================
# 3. MLP (FEEDFORWARD NETWORK)
# ============================================
class MLP(nn.Module):
    """
    Multi-layer perceptron (feedforward network).

    Args:
        in_features: Input dimension
        hidden_features: Hidden layer dimension
        out_features: Output dimension
        dropout: Dropout probability
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# ============================================
# 4. TRANSFORMER ENCODER BLOCK
# ============================================
class TransformerBlock(nn.Module):
    """
    Single Transformer encoder block.

    Args:
        embed_dim: Dimension of embeddings
        n_heads: Number of attention heads
        mlp_ratio: Ratio of mlp hidden dim to embedding dim
        dropout: Dropout probability
    """
    def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=dropout)

    def forward(self, x):
        # Multi-head attention with residual connection
        x = x + self.attn(self.norm1(x))
        # MLP with residual connection
        x = x + self.mlp(self.norm2(x))
        return x

# ============================================
# 5. VISION TRANSFORMER
# ============================================
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) implementation.

    Args:
        img_size: Input image size
        patch_size: Size of image patches
        in_channels: Number of input channels
        n_classes: Number of output classes
        embed_dim: Embedding dimension
        depth: Number of transformer blocks
        n_heads: Number of attention heads
        mlp_ratio: MLP hidden dimension ratio
        dropout: Dropout probability
        emb_dropout: Embedding dropout probability
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        n_classes=1000,
        embed_dim=768,
        depth=12,
        n_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
        emb_dropout=0.1
    ):
        super().__init__()

        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        # Class token (learnable)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Positional embeddings (learnable)
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(emb_dropout)

        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, n_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        # Initialize positional embeddings
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        # Initialize linear layers and layernorms
        self.apply(self._init_weights_module)

    def _init_weights_module(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        """
        x: (batch_size, channels, height, width)
        returns: (batch_size, n_classes)
        """
        B = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)

        # Add class token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, n_patches + 1, embed_dim)

        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)

        # Apply layer norm
        x = self.norm(x)

        # Extract class token and classify
        cls_token_final = x[:, 0]  # (B, embed_dim)
        x = self.head(cls_token_final)  # (B, n_classes)

        return x

# ============================================
# 6. MODEL VARIANTS
# ============================================
def vit_tiny(num_classes=1000):
    """ViT-Tiny: 5M parameters"""
    return VisionTransformer(
        img_size=224,
        patch_size=16,
        n_classes=num_classes,
        embed_dim=192,
        depth=12,
        n_heads=3,
        mlp_ratio=4.0
    )

def vit_small(num_classes=1000):
    """ViT-Small: 22M parameters"""
    return VisionTransformer(
        img_size=224,
        patch_size=16,
        n_classes=num_classes,
        embed_dim=384,
        depth=12,
        n_heads=6,
        mlp_ratio=4.0
    )

def vit_base(num_classes=1000):
    """ViT-Base: 86M parameters"""
    return VisionTransformer(
        img_size=224,
        patch_size=16,
        n_classes=num_classes,
        embed_dim=768,
        depth=12,
        n_heads=12,
        mlp_ratio=4.0
    )

def vit_large(num_classes=1000):
    """ViT-Large: 307M parameters"""
    return VisionTransformer(
        img_size=224,
        patch_size=16,
        n_classes=num_classes,
        embed_dim=1024,
        depth=24,
        n_heads=16,
        mlp_ratio=4.0
    )

# ============================================
# 7. USAGE EXAMPLE
# ============================================
if __name__ == "__main__":
    # Create model
    model = vit_base(num_classes=10)  # CIFAR-10 has 10 classes

    # Dummy input
    x = torch.randn(4, 3, 224, 224)  # Batch of 4 RGB images

    # Forward pass
    output = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")  # (4, 10)

    # Count parameters
    n_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {n_params:,}")

Training Script

Here's a complete training loop for CIFAR-10:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# ============================================
# TRAINING CONFIGURATION
# ============================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 3e-4
NUM_CLASSES = 10
IMG_SIZE = 224

# ============================================
# DATA PREPARATION
# ============================================
transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(IMG_SIZE, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True,
                                 download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False,
                                download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                         shuffle=False, num_workers=2)

# ============================================
# MODEL, LOSS, OPTIMIZER
# ============================================
model = vit_small(num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# ============================================
# TRAINING LOOP
# ============================================
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(loader), 100. * correct / total

# ============================================
# EVALUATION LOOP
# ============================================
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(loader, desc='Evaluating')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })

    return running_loss / len(loader), 100. * correct / total

# ============================================
# MAIN TRAINING
# ============================================
best_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 50)

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)

    # Update learning rate
    scheduler.step()

    # Print results
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

    # Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'vit_best.pth')
        print(f"✓ Saved best model with accuracy: {best_acc:.2f}%")

print(f"\\n{'='*50}")
print(f"Training completed! Best accuracy: {best_acc:.2f}%")

Installation Requirements

pip install torch torchvision einops tqdm

Model Architecture Summary

Component Description
PatchEmbedding Splits image into patches and linearly embeds them
MultiHeadAttention Self-attention with multiple heads
MLP Two-layer feedforward network with GELU
TransformerBlock Attention + MLP with residual connections
VisionTransformer Complete ViT with classification head

Further Reading