Implementing Swin Transformer: Hierarchical Vision Transformer with Shifted Windows¶
Introduction¶
The Swin Transformer (Shifted Window Transformer) represents a breakthrough in computer vision, introducing a hierarchical vision transformer architecture that efficiently computes self-attention using shifted windows. Published in 2021, this paper has become foundational for modern vision models.
Paper Reference: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
In this article, we'll explore the key concepts behind Swin Transformer and provide a clean, educational PyTorch implementation from scratch.
Key Innovations¶
The Swin Transformer addresses major challenges in applying transformers to vision tasks:
1. Hierarchical Architecture¶
Unlike standard Vision Transformers (ViT) that maintain fixed-scale features, Swin Transformer uses a hierarchical structure similar to CNNs. This allows it to:
- Handle images at multiple scales
- Serve as a general-purpose backbone for dense prediction tasks
- Efficiently process high-resolution images
2. Shifted Window Mechanism¶
Instead of computing global self-attention (which is computationally expensive), Swin Transformer:
- Computes attention within local windows
- Uses shifted windows across layers to enable cross-window connections
- Achieves linear computational complexity relative to image size
3. Efficiency¶
The windowed attention mechanism reduces complexity from O(n²) to O(n), making it practical for high-resolution images.
Architecture Overview¶
The Swin Transformer consists of several key components:
Input Image
↓
Patch Embedding (Split into patches)
↓
Stage 1: Swin Blocks (Window Attention)
↓
Patch Merging (Downsample)
↓
Stage 2: Swin Blocks (Shifted Window Attention)
↓
Patch Merging
↓
Stage 3 & 4: More Swin Blocks
↓
Output Features
Core Components:¶
- Patch Embedding: Splits the image into non-overlapping patches
- Window Partitioning: Divides feature maps into fixed-size windows
- Window Attention: Computes self-attention within each window
- Shifted Windows: Alternates between regular and shifted window partitions
- Patch Merging: Downsamples features between stages (like pooling in CNNs)
Implementation¶
Let's build the Swin Transformer from scratch in PyTorch. This is an educational implementation focusing on clarity and understanding.
Step 1: Window Utilities¶
First, we need helper functions to partition feature maps into windows and reverse the operation:
import torch
def window_partition(x, window_size):
"""
Partition feature map into non-overlapping windows.
Args:
x: (B, H, W, C) tensor
window_size: Window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
return windows.view(-1, window_size, window_size, C)
def window_reverse(windows, window_size, H, W):
"""
Reverse window partition back to feature map.
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size: Window size
H: Height of feature map
W: Width of feature map
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
return x.view(B, H, W, -1)
Step 2: Window-Based Multi-Head Self-Attention¶
The core innovation - computing attention within local windows:
import torch
import torch.nn as nn
class WindowAttention(nn.Module):
"""
Window-based multi-head self-attention module.
"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# QKV projection
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(self, x, mask=None):
"""
Args:
x: Input features (num_windows*B, N, C)
mask: Attention mask for shifted windows (optional)
"""
B_, N, C = x.shape
# Generate Q, K, V
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B_, num_heads, N, head_dim)
q, k, v = qkv.unbind(0)
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn + mask
attn = attn.softmax(dim=-1)
# Apply attention to values
out = (attn @ v).transpose(1, 2).reshape(B_, N, C)
return self.proj(out)
Step 3: Swin Transformer Block¶
Each block performs windowed attention followed by an MLP:
import torch
import torch.nn as nn
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer Block with window-based or shifted window attention.
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution # (H, W)
self.window_size = window_size
self.shift_size = shift_size
self.num_heads = num_heads
# Layer normalization and attention
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size, num_heads)
# MLP (Feed-forward network)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
"""
Args:
x: Input features (B, H*W, C)
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "Input feature size doesn't match"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# Cyclic shift for shifted window attention
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# Partition into windows
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# Window attention
attn_windows = self.attn(x_windows)
# Merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
# Reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# Residual connection
x = shortcut + x
# MLP with residual
x = x + self.mlp(self.norm2(x))
return x
Step 4: Patch Embedding¶
Convert the input image into patch embeddings:
import torch.nn as nn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding using convolution.
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = (img_size // patch_size, img_size // patch_size)
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
"""
Args:
x: Input image (B, C, H, W)
Returns:
Patch embeddings (B, num_patches, embed_dim)
"""
B, C, H, W = x.shape
x = self.proj(x) # (B, embed_dim, H/patch_size, W/patch_size)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim)
x = self.norm(x)
return x
Step 5: Swin Stage¶
A stage consists of multiple Swin Transformer blocks with alternating window shifts:
import torch.nn as nn
class SwinStage(nn.Module):
"""
A Swin Transformer stage with multiple blocks.
"""
def __init__(self, dim, depth, num_heads, input_resolution, window_size=7):
super().__init__()
self.dim = dim
self.depth = depth
# Build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2
)
for i in range(depth)
])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
Step 6: Patch Merging (Downsampling)¶
Downsample the feature maps between stages:
import torch
import torch.nn as nn
class PatchMerging(nn.Module):
"""
Patch Merging Layer (downsampling by 2x).
"""
def __init__(self, input_resolution, dim):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = nn.LayerNorm(4 * dim)
def forward(self, x):
"""
Args:
x: Input features (B, H*W, C)
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "Input size doesn't match"
x = x.view(B, H, W, C)
# Downsample by taking every other pixel
x0 = x[:, 0::2, 0::2, :] # Top-left
x1 = x[:, 1::2, 0::2, :] # Bottom-left
x2 = x[:, 0::2, 1::2, :] # Top-right
x3 = x[:, 1::2, 1::2, :] # Bottom-right
x = torch.cat([x0, x1, x2, x3], dim=-1) # (B, H/2, W/2, 4*C)
x = x.view(B, -1, 4 * C)
x = self.norm(x)
x = self.reduction(x) # (B, H/2*W/2, 2*C)
return x
Step 7: Complete Swin Transformer¶
Put it all together:
import torch
import torch.nn as nn
class SwinTransformer(nn.Module):
"""
Swin Transformer for image classification.
"""
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim
)
patches_resolution = self.patch_embed.patches_resolution
# Build stages
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = SwinStage(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)
),
window_size=window_size
)
self.layers.append(layer)
# Add patch merging between stages (except last)
if i_layer < self.num_layers - 1:
self.layers.append(
PatchMerging(
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)
),
dim=int(embed_dim * 2 ** i_layer)
)
)
# Classification head
self.norm = nn.LayerNorm(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes)
def forward(self, x):
"""
Args:
x: Input images (B, 3, H, W)
Returns:
Class logits (B, num_classes)
"""
x = self.patch_embed(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = self.avgpool(x.transpose(1, 2)) # (B, C, 1)
x = torch.flatten(x, 1)
x = self.head(x)
return x
Usage Example¶
Here's how to use the Swin Transformer:
import torch
from swin_transformer import SwinTransformer
# Create model
model = SwinTransformer(
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2], # Swin-Tiny configuration
num_heads=[3, 6, 12, 24],
window_size=7
)
# Random input
x = torch.randn(1, 3, 224, 224)
# Forward pass
output = model(x)
print(f"Output shape: {output.shape}") # (1, 1000)
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params:,}")
Model Variants¶
The paper introduces several model sizes:
| Model | Embed Dim | Depths | Heads | Params |
|---|---|---|---|---|
| Swin-T | 96 | [2,2,6,2] | [3,6,12,24] | 28M |
| Swin-S | 96 | [2,2,18,2] | [3,6,12,24] | 50M |
| Swin-B | 128 | [2,2,18,2] | [4,8,16,32] | 88M |
| Swin-L | 192 | [2,2,18,2] | [6,12,24,48] | 197M |
Key Takeaways¶
- Local + Global: Swin Transformer efficiently combines local window attention with cross-window connections via shifting.
- Hierarchical Design: Multi-scale features make it suitable for dense prediction tasks (detection, segmentation).
- Efficiency: Linear complexity w.r.t. image size enables processing of high-resolution images.
- Versatility: Can serve as a general-purpose backbone for various vision tasks.
Resources¶
- Paper: arXiv:2103.14030
- Official Implementation: microsoft/Swin-Transformer
- PyPI Package: swin-transformer-pytorch
Further Reading¶
To extend this implementation, consider:
- Adding relative position bias for better spatial awareness
- Implementing attention masking for shifted windows properly
- Adding stochastic depth for regularization
- Creating training scripts with ImageNet dataset
- Implementing downstream tasks (object detection with Swin Transformer backbone)
Conclusion¶
The Swin Transformer represents a significant advancement in vision transformers, bridging the gap between CNNs and transformers through its hierarchical, efficient design. This implementation provides a foundation for understanding and experimenting with the architecture.
Whether you're building a classification model, an object detection system, or exploring new vision architectures, the Swin Transformer's design principles offer valuable insights into efficient attention mechanisms for computer vision.